めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

E - Digit Products の解説

何の話かと言うと

enakai00.hatenablog.com

上記の記事では、「桁DP」について、下から攻める方法と、上から攻める方法を紹介しました。

今回は、「上から攻めないとうまくいかない」問題例として、こちらを取り上げます。

atcoder.jp

何が難しいの?

問題のパターンは、前回の「Digit Sum」とほとんど同じに見えます。桁ごとの足し算が桁ごとの掛け算に変わっただけです。

なのですが!!!!

桁ごとの掛け算の場合、0 の取り扱いが微妙になります。たとえば、K = 250 の場合、

・101 → 各桁の積は 0

という計算はあっていますが、

・098 → 各桁の積は 0

とすると、もちろん間違いです。上位にならんだ 0 は無視する必要があります。しかしながら、下から攻めていった場合、最上位から 0 が並ぶ場合とそうでない場合は、事前に予見することは困難です。

では一方、上から攻めた場合、最上位から 0 が連続して並ぶ状況は捉えられるのでしょうか?

なんとかなります。

冒頭の記事で、上から攻める桁DPを説明した際に、

・各桁の値が K にぴったり一致するものだけを選んで行った場合:dp2[n]
・それ以外の場合:dp1[n]

という風に、各桁が K にしたがって変化する特別な場合を分けてトレースすることに成功しました。今回はさらに、各桁に 0 ばかり並ぶ場合も別途、トレースすることにします。

・各桁の値が 0 に一致するものだけを選んで行った場合:dp0
・各桁の値が K に一致するものだけを選んで行った場合:dp2
・それ以外の場合:dp1

なお、冒頭の問題では、「Nの各桁の積がK以下のものの個数」となっていますが、これまでの説明に記号を合わせて、

・「Kの各桁の積がD以下のものの個数」

を求めることにします。ここでは、dpX をディクショナリーとして、

・dpX[d]:各桁の積が d の個数(d <= D)
・dpX['inf']:各桁の積が D+1 以上の個数

という情報を埋めていきます。

おっと、桁数を表す添字(dpX[n][d] の [n])が消えましたが、まちがいではありません。この問題では、n の計算をする際は、直前の n-1 の情報だけあれば十分なので、(メモリー容量を節約するために)それ以前の古い情報は保持しないことにします。dpX[ ] を使って、dpX_n[ ] を計算して、計算が終われば、dpX[ ] = dpX_n[ ] と置き換える作戦でいきます。

まずは、n = 1 の場合です。

  dp0 = defaultdict(int) # dp0[k] : 積が k の個数(上位桁がすべて 0 の場合)
  dp1 = defaultdict(int) # dp1[k] : 積が k の個数(最大値未満スキャン)
  dp2 = defaultdict(int) # dp2[k] : 積が k の個数(最大値トレース)

  #  n = 1 の時
  limit = int(K[0])
 
  dp0[1] = 1  # 上位桁 0 は x 1 倍で扱う
  for i in range(1, limit):
    if i <= D:
      dp1[i] += 1
    else:
      dp1['inf'] += 1
  i = limit
  if i <= D:
    dp2[i] += 1
  else:
    dp2['inf'] += 1

dp0 に保存する値に注意してください。上位に連続して並ぶ 0 は、桁ごとの掛け算をする際は、0 ではなくて、1 として扱う必要があるのです。

そして、この後の遷移は次の様になります。ここでは、dp2 の更新、dp0 の更新、dp1 の更新という流れで整理しました。

  # n = 2, 3, ..., N の時
  for n in range(2, N+1):
    dp0_n = defaultdict(int) 
    dp1_n = defaultdict(int) 
    dp2_n = defaultdict(int) 
    limit = int(K[n-1])

    #  dp_0 の更新
    dp0_n[1] = 1  # 上位の 0 の積は 1 と見なす

    # dp_2 の更新
    i = limit
    for d in dp2.keys():
      dp2_n[prod(i, d, D)] += dp2[d]

    # dp_1 の更新
    for i in range(1, 10):
      for d in dp0.keys():  # 上位桁がすべて 0 の場合
        dp1_n[prod(i, d, D)] += dp0[d]

    for i in range(0, limit): # 上位桁が K[:n] に一致する場合
      for d in dp2.keys():
        dp1_n[prod(i, d, D)] += dp2[d]

    for i in range(0, 10): # その他の場合
      for d in dp1.keys():  
        dp1_n[prod(i, d, D)] += dp1[d]

    dp0 = dp0_n
    dp1 = dp1_n
    dp2 = dp2_n

おっと、ここで、prod() というのは、'inf' の場合を考慮した掛け算です。

def prod(i, d, D):
  if d == 'inf':
    if i == 0:
      prod = 0
    else:
      prod = 'inf'
    return prod

  prod = i*d
  if prod > D:
    prod = 'inf'
  return prod

これでOKです。一番最後は、dp1、dp2 の中から、キーが 'inf' 以外の要素の合計を計算すれば答えになります。

  print(sum(dp1.values()) - dp1['inf'] + sum(dp2.values()) - dp2['inf'])

0 に対応する答えは、今回の場合、dp0 に(0 × 0 × 0 × 0 .... = 1 という間違った計算結果で)収められていますが、「1 以上の整数」という条件があるので、これは捨ててしまえばOKです。

無事にパスした結果がこちらになります。

atcoder.jp