めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

S - Digit Sum の解説(その2)

何の話かと言うと

enakai00.hatenablog.com

こちらの記事の続きです。

一般の K に拡張する

前回は、K = 99999 と言った(各桁が 0 〜 9 を回る)キリのよい数字という前提で問題を解きました。

今回は、K = 38463 と言った一般の場合を考えます。

この場合は、各桁の数字を回すループの範囲がややこしくなります。

例によって、小さな数字の簡単な場合で考えましょう。

例えば、K=35 とすると、(下からみて)2 桁目の数字は当然ながら、0, 1, 2, 3 の4種類の値しかとりません。

一方、1桁目の数字は・・・。実は、場合分けが発生します。

・2桁目が3の場合:1 桁目は 0 〜 5 の範囲 --- (1)
・2桁目が0, 1, 2 の場合: 0 〜 9 のすべての範囲 --- (2)

しかしながら、DPの基本として、1桁目の計算をする時に、「今、2桁目がいくらの場合を考えているのか」という次のステップの情報を利用することはできません。

どうすればよいのでしょうか・・・・

実は、これは、(1)(2) の両方の場合を計算しておけばよいのです。リスト dp[ ][ ] を2つ用意して、

・dp1[n][r]:上の桁が実際の「桁値」(上記の例では 3)より小さい場合の答え
・dp2[n][r]:上の桁が実際の「桁値」(上記の例では 3)の場合の答え

とします。抽象的でわかりにくいという方のために、具体的なコードにしてしまいましょう。

K = '35'

dp1 = [[0] * D for _ in range(len(K)+1)]
dp2 = [[0] * D for _ in range(len(K)+1)]

#  n = 1 の時
limit = int(K[-1]) # 1桁目の「桁値」5
for i in range(0, limit+1):
  dp2[1][i%D] += 1
for i in range(0, 10):
  dp1[1][i%D] += 1

なお上記のコードでは、与えられた K を文字列にしています。前回の冒頭で触れた様に、いやがらせのように大きな K がやってくるので、整数値として扱うと桁溢れを起こす可能性があります。また、文字列にしておけば、「n 桁目の値」を取り出すのも簡単です。int(K[-n]) が下から n 桁目の数値になります。

このように準備しておけば、2桁目の計算を考える際に、上記の場合分けが簡単になります。

#  n = 2 の時
limit = int(K[-2]) # 2桁目の「桁値」3
for i in range(0, limit): # --- (2) の場合
  for r in range(0, D):
    dp2[2][(r+i)%D] += dp1[1][r] # 配るDP
i = limit # --- (1) の場合
  for r in range(0, D):
    dp2[2][(r+i)%D] += dp2[1][r] # 配るDP

2桁の数字の問題であれば、これで終了です。dp2[2][0] から答えが読み出せます。

ただし、3桁目もある問題の場合は、dp1[2][ ] の計算も必要です。たとえば、K = 435 とすると、この次の n = 3 に対しては、

・3 桁目が 4 の場合:2桁目以降は 00 〜 35 に制限されるので dp2[2][ ] から配る --- (1)'
・3 桁目が 0,1,2,3 の場合:2桁目以降は 00 〜 99 のすべての場合を含むので dp1[2][ ] から配る --- (2)'

という手順になるからです。dp1[2][ ] は、(2)' からわかる様に数字の範囲に制限がない場合ですので、素朴に(?)(もしくは、前回と同様に)

#  n = 2 の時
for i in range(0, 10):
  for r in range(0, D):
    dp1[2][(r+i)%D] += dp1[1][r] # 配るDP

でOKです。あとはこれと同じことを繰り返していきます。

  K = # 与えられた数値
  N = len(K)

  dp1 = [[0] * D for _ in range(len(K)+1)]
  dp2 = [[0] * D for _ in range(len(K)+1)]

  #  n = 1 の時
  limit = int(K[-1]) # 1 桁目の「桁値」
  for i in range(0, limit+1):
    dp2[1][i%D] += 1
  for i in range(0, 10):
    dp1[1][i%D] += 1

  # n = 2, 3, ..., N の時
  for n in range(2, N+1):
    limit = int(K[-n]) # n 桁目の「桁値」
    for i in range(0, limit):
      for r in range(0, D):
        dp2[n][(r+i)%D] += dp1[n-1][r] # 配るDP
        dp2[n][(r+i)%D] %= mod
    i = limit
    for r in range(0, D):
      dp2[n][(r+i)%D] += dp2[n-1][r] # 配るDP
      dp2[n][(r+i)%D] %= mod

    for i in range(0, 10):
      for r in range(0, D):
        dp1[n][(r+i)%D] += dp1[n-1][r] # 配るDP
        dp1[n][(r+i)%D] %= mod

  print((dp2[N][0]-1) % mod)

これで無事にパスするコードになりました。やったー。

atcoder.jp

というわけで、(上位桁が桁値ぎりぎりの場合のための)上限を設けた答え dp2[n][ ] と、上限を設けない答え dp1[n][ ] の 2 種類を計算するというのが、「桁DP」のポイントということになります。

上の桁から計算する場合(「最大値トレース」の考え方)

上記の考え方は、アルゴリズムとして正しいものですが、場合によっては、上の桁から順番に計算した方が都合のよい場合があります。(どちらかと言うと、こちらのケースが多い様な気がします。)そこで、上の桁から考えた場合に、どのような遷移が可能かを説明します。

結論から言うと、この場合は、

・各桁の値が K にぴったり一致するものだけを選んで行った場合:dp2[n]
・それ以外の場合:dp1[n]

という2種類の場合について、個別のリストに情報を格納していきます。この後のロジックを追うとわかる様に、この場合、dp1[ ] と dp2[ ] を合わせることで、0 〜 K の値をすべてカバーすることになるので、最後の答えは、dp1[ ] と dp2[ ] の両方を読み出して合計する必要があります。

まずは、具体例で考えてみましょう。

K = 345

とする場合、n=1、つまり、最上位の桁についての計算は、0, 1, 2 の場合(dp1[1] に格納)と 3 の場合(dp2[1] に格納)で場合分けされます。

  #  n = 1 の時
  limit = int(K[0])
  for i in range(0, limit): # K[0] 未満
    dp1[1][i%D] += 1
  i = limit
  dp2[1][i%D] += 1 # K[0] に一致

ここでは、

・dp1[1][r]:{0, 1, 2} の中で D で割った余りが r の個数
・dp2[1][r]:{3} の中で D で割った余りが r の個数

という情報が格納されますので、「0 〜 3 の中で D で割った余りが r の個数」が欲しければ、dp1[1][r] + dp2[1][r] を計算することになります。

次に、n=2、つまり、上から2桁目を考えますが、この際、

・上位の桁全体が K[:n] に一致する場合(今の場合は、1 桁目が 3 の場合) --- (1) :この計算には dp2[n-1] が利用できる。
・上位の桁全体が K[:n] 未満の場合(今の場合は、1 桁目が 0, 1, 2 の場合) --- (2):この計算には dp1[n-1] が利用できる。

を分けて考えます。

(1) の中でも特に、2 桁目が K[2](つまり、4)に一致する場合の情報を dp2[2] に格納します。2 桁目が K[2] 未満(つまり、0, 1, 2, 3)の場合の情報を dp1[2] に格納します。

    # (1) 上位桁が K[:n] に一致する場合
    
    # dp2[2] の更新
    i = limit 
    for r in range(0, D):
      dp2[2][(r+i)%D] += dp2[1][r]
      dp2[2][(r+i)%D] %= mod

    # dp1[2] の更新
    for i in range(0, limit):
      for r in range(0, D):
        dp1[2][(r+i)%D] += dp2[1][r]
        dp1[2][(r+i)%D] %= mod

(2) については、dp2[2] は関係ないので、dp1[2] のみを更新します。上位の桁が K[2] 未満の場合なので、2桁目は、0 〜 9 のすべてが取れます。

    # (2) 上位桁が K[:n] 未満の場合

    # dp1[2] の更新
    for i in range(0, 10):
      for r in range(0, D):
        dp1[2][(r+i)%D] += dp1[2-1][r]
        dp1[2][(r+i)%D] %= mod

同様の計算は、n = 3 以降も繰り返すことができて、次の様になります。

  N = len(K)

  dp1 = [[0] * D for _ in range(len(K)+1)]  # K 未満の場合すべて
  dp2 = [[0] * D for _ in range(len(K)+1)]  # 全桁が K に一致する場合

  #  n = 1 の時
  limit = int(K[0])
  for i in range(0, limit): # K[0] 未満
    dp1[1][i%D] += 1
  i = limit
  dp2[1][i%D] += 1 # K[0] に一致

  # n = 2, 3, ..., N の時
  for n in range(2, N+1):
    limit = int(K[n-1])

    # (1) 上位桁が K[:n] に一致する場合
    
    # dp2[n] の更新
    i = limit 
    for r in range(0, D):
      dp2[n][(r+i)%D] += dp2[n-1][r]
      dp2[n][(r+i)%D] %= mod

    # dp1[n] の更新
    for i in range(0, limit):
      for r in range(0, D):
        dp1[n][(r+i)%D] += dp2[n-1][r]
        dp1[n][(r+i)%D] %= mod

    # (2) 上位桁が K[:n] 未満の場合

    # dp1[n] の更新
    for i in range(0, 10):
      for r in range(0, D):
        dp1[n][(r+i)%D] += dp1[n-1][r]
        dp1[n][(r+i)%D] %= mod

  print((dp1[N][0]+dp2[N][0]-1) % mod)

これで無事にパスするコードになりました。

atcoder.jp

下から攻める場合と上から攻める場合、それぞれの考え方の違いをよーーーーく整理しておいてください。

上記のコードでは、

    # (1) 上位桁が K[:n] に一致する場合
    # (2) 上位桁が K[:n] 未満の場合

という場合分けで整理しましたが、慣れてくれば、dp2、dp1 のそれぞれを個別に更新すると考えて、次の様に整理してもよいでしょう。

  # n = 2, 3, ..., N の時
  for n in range(2, N+1):
    limit = int(K[n-1])

    # dp2[n] の更新
    i = limit 
    for r in range(0, D):
      dp2[n][(r+i)%D] += dp2[n-1][r]
      dp2[n][(r+i)%D] %= mod

    # dp1[n] の更新
    for i in range(0, limit):    # 上位桁が K[:n] に一致する場合
      for r in range(0, D):
        dp1[n][(r+i)%D] += dp2[n-1][r]
        dp1[n][(r+i)%D] %= mod

    for i in range(0, 10):    # 上位桁が K[:n] 未満の場合
      for r in range(0, D):
        dp1[n][(r+i)%D] += dp1[n-1][r]
        dp1[n][(r+i)%D] %= mod

  print((dp1[N][0]+dp2[N][0]-1) % mod)

次回は、上から攻めないとうまく解けない問題を紹介します。