めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

S - Digit Sum の解説(その1)

何の話かと言うと

atcoder.jp

この問題をネタにして、「桁DP」の考え方を説明します。

なお、問題文では「1 以上 K 以下の整数」となっていますが、この手の問題では、0 を含めて計算する方が簡単なので、「0 以上 K 以下の整数」として問題を解いておき、得られた答えから 1 を引きます。(0 は必ず D の倍数になっているので。)

「しらみつぶし」で考える

「しらみつぶし」の場合、考え方は簡単です。

c = 0
for n in range(0, K+1):
  if digit_sum(n) % D == 0:
    c += 1
print(c-1)

0 から K のそれぞれについて、関数 digit_sum() で各桁の合計を計算してチェックしています。ただし、関数 digit_sum() は、各桁を個別に取り出していくので、桁数に比例する計算量になります。数値 K の桁数は \log K ですので、全体の計算量は O(K\log K) になります。

それほど悪くない気もしますが、実は、この手の問題では、嫌がらせの様に大きな K が登場します。問題文をよく読むと、K は最大で 10^{10000} にもなります。もっともっと効率的なアルゴリズムを発見しないと実行時間制限を守ることはできません。

桁ごとの処理(キリのよいKの場合)

桁ごとの数字を見る問題ですので、たとえば、桁ごとに処理をすればどうでしょうか? ここでは、下の桁から上の桁に向かって処理する場合で考えます。

N = # K の桁数
for n in range(1, N+1):
  #(下から)n 桁目についての計算

仮に「(下から)n 桁目についての計算」が O(1) でできれば、全体として、O(\log K) という圧倒的な高速化が実現できることになります。

特に K=9999 といった( 10^n - 1 で表される)キリの良い数字であれば、

N = # K の桁数
for n in range(1, N+1):
  for i in range(10):
    # n 桁目が i の場合の計算

という感じで、0 以上 K 以下のすべての場合を網羅するループができそうです。

では、このループの前提で、DP の考え方を適用してみましょう。

    # n 桁目が i の場合の計算

という部分では、dp[n-1] に n-1 桁の場合(K=10^{n-1}-1 の場合)の計算結果が入っているとして、ここから、頭に新しい数字 i を付け加えた場合の結果を求める必要があります。

うーん。抽象的でわかりづらいですよね。。。。。

具体的にいきましょう。今、1桁目の計算が終わって、2桁目のループを回しているとします。i を 0 から 9 に変化させながら計算を進めるので、

・00 〜 09 の範囲での D の倍数の個数 ---- (0)
・10 〜 19 の範囲での D の倍数の個数 ---- (1)
 \vdots
・90 〜 99 の範囲での D の倍数の個数

という感じに、2桁目の数字による場合分けになります。まず、(0) は 1 桁目でもとめた答えそのままになります。うん。簡単。

(1) はどうでしょうか・・・? これは、1 桁の数字の中に、「1 を足した上でDの倍数になる数」、つまり、「Dの倍数-1」、つまーり、「D で割った余りが D-1」の数が何個あるかを知る必要があります。

ということは・・・・

そうです! DP を活用する時は、問題で与えらた条件そのままではなくて、条件を色々と変えた問題をまとめて解いていく、という発想が必要だったのです。

問題では、「D の倍数」、すなわち、「D で割った余りが 0」の個数を数えさせているのですが、ここでは、

・「D で割った余りが 0」の個数
・「D で割った余りが 1」の個数
 \vdots
・「D で割った余りが D-1」の個数

という D 種類の問題をまとめて解けばよいのです。一般に dp[n][r] を「n 桁以下の整数の中で、D で割った余りが r の個数」として、これを順番に埋めていくことにしましょう。n=1 の場合は直接計算(しらみつぶし)ですぐに埋まります。

N = # K の桁数。つまり K = 10**N - 1
dp = [[0] * D for _ in range(N+1)]

# n = 1 の場合
for i in range(10):
  dp[1][i%D] += 1

これを踏まえて、あらためて、n=2 の場合を考えると、次の様になります。

for i in range(10):
  for r in range(D):
    dp[2][(r+i)%D] += dp[1][r] # 配るDP

ここでは、「1桁目で余りが r だったものは、2桁目の i を足すことで、余りが (r+i)%D に変化する」という風に、n=1 の結果を n=2 に配っていく発想のループになっています。(こういうのを「配るDP」と呼ぶ人もいるようです。)

ここから先は同じことのくり返しですね。

mod = 10**9 + 7
N = # K の桁数。つまり K = 10**N - 1
dp = [[0] * D for _ in range(N+1)]

# n = 1 の場合
for i in range(10):
  dp[1][i%D] += 1

# n = 2,..., N の場合
for n in range(2, N+1):
  for i in range(10):
    for r in range(D):
      dp[2][(r+i)%D] += dp[1][r] # 配るDP
      dp[2] %= mod

print((dp[N][0] - 1) % mod)

最後に、dp[N][0] を見れば答えが得られるというわけです。(「答えが大きくなりすぎるため mod = 10**9 + 7 で割った余りを答えなさい」という指示があるので、計算途中でも(値が大きくなりすぎない様に)適宜 %= mod を演算しています。)

意外と簡単でしたね!

・・・というのはまだ早くて・・・・・

ここまでは、K=9999 といったキリの良い数字に限定した話です。一般の K=4727364 などの場合は、このままではうまく行きません。一般の場合への拡張は、次回の記事に譲ります。