めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

M - Candies の解説

何の話かと言うと

atcoder.jp

上記の問題は、「前のループで事前に計算できる部分を見つける」というちょっとした工夫で高速化ができるよく知られた問題です。

「子供に飴を配る」という立て付けですが、端的には、

0\le x_i\le a_i\ (i=1,\cdots,N) という条件の下に、x_1+\cdots+x_N = K を満たす \{x_1,\cdots,x_N\} の組み合わせの数を見つける

という数学の問題です。

考え方

基本的には「しらみつぶし」でOKです。x_1 の値を1つ固定すると、残りの x_2,\cdots,x_N を用いて合計を K-x_1 にするという問題に帰着します。逆向きに考えると、

x_1 で合計を k にする場合の数 dp[1][k] (k=0,1,\cdots,K)を求める
x_1,x_2 で合計を k にする場合の数 dp[2][k] (k=0,1,\cdots,K)を求める
 \vdots
x_1,\cdots,x_N で合計を k にする場合の数 dp[N][k] (k=0,1,\cdots,K)を求める

という計算を順番に行って、最後に dp[N][K] を答えとします。

import sys

def main(f):
  mod = 10**9 + 7
  N, K = list(map(int, f.readline().split()))
  A = [None] + list(map(int, f.readline().split()))

  dp = [[0] * (K+1) for _ in range(N+1)] # dp[n][k] : x_1〜x_n の合計を k にする場合の数

  # n = 1 の場合  
  for k in range(0, K+1):
    if k <= A[1]:
      dp[1][k] = 1
  
  # n = 2,...,N の場合
  for n in range(2, N+1):
    for k in range(0, K+1):
      for i in range(0, A[n]+1):  # x_n = i の場合
        if k - i < 0:
          break
        dp[n][k] += dp[n-1][k-i]
        dp[n][k] %= mod

  print(dp[N][K])

main(sys.stdin)

計算の効率化

上記のコードは、ロジックとしては正しいのですが、3重のループがちょっとつらいですね。(提出すると TLE します。)

このループを減らす工夫が必要なのですが、今の場合、3重目の i のループは、n-1 のケースの答えを足し合わせていることに気がつきます。次の様に書き直してもよいでしょう。

  # n = 2,...,N の場合
  for n in range(2, N+1):
    for k in range(0, K+1):
      dp[n][k] = sum(dp[n-1][max(0, k-A[n]):k+1])
      dp[n][k] %= mod

この和 sum(dp[n-1][max(0, k-A[n]):k+1]) は、dp[n-1][0], dp[n-1][1], ..., dp[n-1][K] の部分和ですので、部分和計算のテクニックを思い出すと、

・S[n-1][k] = dp[n-1][0] + dp[n-1][1] + ... + dp[n-1][k] (k 項までの和)として、

・sum(dp[n-1][max(0, k-A[n]):k+1]) = S[n-1][k] (max(0, k-A[n]) == 0 の時)
・sum(dp[n-1][max(0, k-A[n]):k+1]) = S[n-1][k] - S[n-1][max(0, k-A[n])-1](max(0, k-A[n]) > 0 の時)

で計算できます。そして、上記の S[n-1][k] は・・・、そう、1 つ前のループで、dp[n-1] の計算をしながら同時に計算できてしまいますね。たとえば、n=1 のケースであれば、

  S = [[0] * (K+1) for _ in range(N+1)]
  # n = 1 の場合  
  for k in range(0, K+1):
    if k <= A[1]:
      dp[1][k] = 1
    if k == 0:
      S[1][k] = dp[1][k]
    else:
      S[1][k] = S[1][k-1] + dp[1][k]

でOKです。n=2 以降では、ここで計算した S[n-1] を利用すると同時に、次に使う S[n] も同様に計算します。

  # n = 2,...,N の場合
  for n in range(2, N+1):
    for k in range(0, K+1):
      if max(0, k-A[n]) == 0:
        dp[n][k] = S[n-1][k]
      else:
        dp[n][k] = S[n-1][k] - S[n-1][max(0, k-A[n])-1]
      dp[n][k] %= mod
      if k == 0:
        S[n][k] = dp[n][k]
      else:
        S[n][k] = S[n][k-1] + dp[n][k]
    S[n][k] %= mod

これで無事にパスできます。

atcoder.jp

まとめ

なかなか言われないと気づかないパターンかもしれませんが、dp[n-1] から dp[n] を計算する際に、dp[n-1] の部分和が必要な際は、このように事前計算ができる事を覚えておきましょう。

部分和の基本問題は、下記などがあります。

atcoder.jp

(解答例)
atcoder.jp