めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

E -Mod i の解説

何の話かというと

atcoder.jp

(とある事情があって)競技プログラミングの過去問をいろいろ調べていたところ、上記の問題に遭遇して、(公式解説もチラ見しながら)なんとか Accept されるコードにたどり着いたのですが、この問題、かなり色々な要素(テクニック)が含まれている気がして、懇切丁寧に解説してみたくなった次第です。

部分和の計算

まず、基本的なところですが、数列 \{A_1,A_2,\cdots,A_N\} の一部を切り出した \{A_i,A_{i+1},\cdots,A_j\} の和(部分和)が登場する問題では、必要になる度に部分和を計算するのは無駄が多いので、事前に

 S_0 = 0
 S_1 = A_1
 S_2 = A_1 + A_2
  \vdots
 S_N = A_1+\cdots+A_N

を計算しておいて、

 A_i+A_{i+1}+\cdots+A_j = S_{j} - S_{i-1}

S_k の引き算に置き換えてしまいます。ちなみに、S_0, S_1, \cdots, S_N を事前計算する際も、それぞれを個別に計算するのは無駄なので、

 S_k = S_{k-1} + A_k

という関係を使って、N 回の足し算に置き換えます。コードにするとこんな感じ。

# A = [A_0, A_1, A_2, ..., A_N]
S = []
S.append(0) # S_0 = 0
for k in range(1, N+1):
  S.append(S[k-1] + A[k])

本当は A_0 は存在しない値ですが、リストのインデックスと数列の添字を一致させるために(つまり A[k] が A_k に対応する様に)ダミー値として加えてあります。

計算量的にいうと、部分和の計算を毎回やると、毎回 O(N) の計算が必要になるところが、S_k を利用した場合は、事前に O(N) の計算を一度だけやっておけば、それ以降の部分和の計算はすべて O(1) に計算量を減らすことができます。

「どっちにしても O(N) じゃん!」と思う方もいるかも知れませんが、「事前に一度やっておけばOK」という点がポイントになります。仮に N 回のループの中で、毎回、部分和を計算した場合、全体の計算量は O(N^2) に膨れ上がります。一般に、似た様な計算を何度も繰り返す問題の場合、「事前にまとめて計算できる部分はないかしら?」と考えるのは良いことです。

しらみつぶしの場合を考える

競技プログラミングでは、2つの意味での「時間制限」があります。1つは、コードの「実行時間」の制限で、もう一つは、競技の開催時間、つまり、「回答時間」の制限です。

回答時間の制限を考えると、「実行時間」が間に合わないと分かっている非効率な解法をあえてコードにする余裕はありませんが、回答時間の制限がない場合(勉強目的、あるいは、実務でコードを書く場合)は、まずは、「しらみつぶし」の解法をコードにするのは無意味ではありません。「しらみつぶし」のコードで(サンプル入力に対する)正解が得られれば、少なくとも、自分の「問題の理解」が間違っていないことがチェックできることになります。

で、実際にやるとこんな感じですね。

import sys

N = 0
A, S = [], []
mod = 10**9+7

def count(i=1, l=1): # A[i] 以降を l の倍数、l+1 の倍数・・・のグループに分ける
  global N, A, S, mod
  if i == N+1:
    return 1

  result = 0
  for k in range(i, N+1): # (A[i] ... A[k]) が l の倍数かチェックする
    if (S[k] - S[i-1]) % l == 0:
      result += count(k+1, l+1) # A[k+1] 以降を l+1 の倍数・・・のグループに分ける
      result %= mod # 結果が大きくなりすぎない様に mod で割ったあまりに変換しておく
  return result

def main(f):
  global N, A, S, mod

  # 入力データ読み込み
  lines = f.readlines()
  N = int(lines.pop(0).strip())
  A = [0] + list(map(int, lines.pop(0).strip().split()))
  # A = [A_0, A_1, A_2, ..., A_N]

  S.append(0) # S_0 = 0
  for k in range(1, N+1):
    S.append(S[k-1] + A[k])

  print(count(1, 1))

#main(sys.stdin)
with open('input.txt', 'r') as file:
  main(file)

コードを読めば分かる様に、「1本目の区切りをどこにおくか」を 1\sim N でスキャン、「2本目の区切りをどこにおくか」を x\sim N でスキャン・・・ ということを再帰的に行なっているので、全体の計算量は O(N!) という非常に大きなものになっています。

ここから計算量をどうやって減らすかを考えるわけですが、実際に計算量が減らせるかどうかは、冗長な計算が含まれているかどうか、つまり、「コードの実行時にまったく同じ計算を何度も繰り返しているか」に依存します。もしも同じ計算を繰り返していれば、それを1回にまとめるようなアルゴリズムを考えれば、計算量を減らせることになります。

例えば、いまの場合、再帰的に呼び出される関数 count(i, l) は、「まったく同じ引数で何度も呼ばれる」可能性があることに気が付きます。例として、1本目と2本目の区切りが次のように決まり、どちらも前提条件を満たしたとします。

A_0 | A_1 A_2 | A_3 ...
A_0 A_1 | A_2 | A_3 ...

いずれの状態からも、「A_3 以降を3の倍数、4の倍数・・・のグループに分ける」という処理、すなわち、count(3, 3) がコールされます。確かに改善の余地があるようです。

ちなみに、このような状況でお手軽に高速化するテクニックに、「関数のメモ化」があります。次の様に関数の実行結果を保持するキャッシュを用意して、同じ引数で呼ばれた場合は、キャッシュから値を返すという手法です。

cache = {}

def count(i=1, l=1):
  global N, A, S, mod, cache

  if (i, l) in cache.keys():
    return cache[(i, l)]

...

  cache[(i, l)] = result
  return result

ただし、今回のケースでは、大元が O(N!) というツライ状況なので、この程度では実行時間制限は突破できません。

サイズに対する漸化的な解法

このような冗長な再起処理を改善する方法に、「サイズが k-1 までの問題(つまり A_1,\cdots,A_{k-1} を対象とした問題)の答えが求まっていると仮定して、サイズが k の問題(A_1,\cdots,A_k を対象とした問題)の答えを導く」という手法があります。(小難しくいうと、「DP(動的計画法)」なんですが、ここでは、まずは「考え方」にフォーカスしましょう。)

今回の場合は、「A_1,\cdots,A_Nを(条件を満たす様に)l 分割するパターン数 c(N, l) を求める」という問題を l=1,\cdots,N について解いた後に、c(N, 1) + \cdots + C(N,N) を計算するという流れになるので、

c(1, 1)
c(2, 1), c(2, 2)
c(3, 1), c(3, 2), c(3, 3)
\vdots
c(k-1, 1), c(k-1, 2), \cdots, c(k-1,k-1)

が分かっているとして、

c(k, 1), c(k-1, 2), \cdots, c(k,k)

を計算する、ということを考えます。この計算を段階的にすすめて行き、最終的に、

c(N, 1), c(N, 2), \cdots, c(N,N)

にたどり着けば、これを合計して問題の答えが得られることになります。

全部で N^2 / 2 (程度)の c(k, l) を計算するので、仮に、「(前段の計算結果を用いて)次の c(k, l) を計算する」という処理(これを「遷移計算」と呼びます)が O(N) で実行できれば、全体の計算量は O(N^3) になります。あるいは、遷移計算が O(1) であれば、全体の計算量は O(N^2) まで減ります。

このように、DP(動的計画法)では、遷移計算をいかに効率化するか(あるいは、遷移計算が簡単になるような、計算の「段取り」をうまく見つけられるか)がポイントになります。

遷移計算

で・・・仮に O(1) の遷移計算が(簡単に)実現できれば、普通の DP の問題ということになるのですが、実は、この問題の場合、ここが一筋縄ではいきません。まずは、愚直に考えてみます。

k-1 までの問題(つまり、A_1,\cdots,A_{k-1} を対象とする問題)が解けているという前提で、新たに A_k を追加して、c(k, l) に該当する分割を考えます。この時、新たに追加された A_k は、必ず、最後のグループ、つまり、l の倍数グループに入ります。そこで、最後のグループを決める区切り線の場所を A_k の直前から1つずつ順番に動かしていって、

 「A_{x+1},\cdots, A_kl の倍数グループになる」

(つまり、(A_{x+1}+\cdots+A_k)\%l = (S_k - S_x)\%l = 0 を満たす)位置 x を探していきます。見つかった x に対して、その区切り線の前にある A_1,\cdots,A_x は、全部で l-1 個のグループに分かれていないといけないので、そのパターン数は、c(x, l-1) として決まります。これらを全部足しあげれば、c(k, l) が得られます。コードで書くならこんな感じ。(x==0 の場合だけちょっと別枠なのに注意。)

c[k][l] = 0
for x in range(k, -1, -1):
  if (S[k] - S[x]) % l == 0:
    if x == 0:
      c[k][l] += 1
    else:
      c[k][l] += c[x][l-1]

しかしながら、これは k 回の繰り返しを含む処理なので、(k は最大で N を取るので)大雑把にいうと O(N) の計算量になります。前述の議論を思い出すと、アルゴリズム全体の計算量は O(N^3) になります。当初の O(N!) に比べると格段に進歩しましたが、残念ながら、競技プログラミングとしての冒頭の問題の回答にはなりません。このコードを提出すると、実行時間オーバーになってしまいます。

いったいどうすればよいのでしょうか・・・

で、ここで、公式解説を「ちらっ」とだけ覗いてみます。

これによると、最後のグループを決める区切り線の場所を A_k の直前から1つずつ動かすという処理は、最後まで全部実行する必要はなくて、

 「A_{x+1},\cdots, A_kl の倍数グループになる」-------- (1)

という条件を満たす最初の x (つまり 0\le x\le k-1 の範囲で最大の x)を発見すれば十分なのです。解説の図を見ると理解できる様に、この x を用いて、

 c(k, l) = c(x, l) + c(x, l-1)

と遷移計算が実行できます。あら素敵。

(ただし、例によって x=0 の場合は別枠で、x=0, l=1 の場合は c(k, 1)=1x=0, l>=2 の場合は c(k, l)=0 となります。)

なのですが・・・これだけでは問題は解決したことになりません。最初の x を発見するために愚直にスキャンした場合、最悪ケースで O(N) のループになることに変わりありません。この x をもっと高速に O(1) で発見する必要があるのです。

とはいえ、そろそろ道に迷いつつある方もいるやも知れませんので、いったん、ここまでの解法をコードにまとめておきます。

import sys

N = 0
A, S = [], []
mod = 10**9+7

def main(f):
  global N, A, S, mod

  # 入力データ読み込み
  N = int(f.readline())
  A = [0] + list(map(int, f.readline().split()))
  # A = [A_0, A_1, A_2, ..., A_N]

  S.append(0) # S_0 = 0
  for k in range(1, N+1):
    S.append(S[k-1] + A[k])

  C = []
  for _ in range(N+1):
    C.append([0] * (N+1))

  for k in range(1, N+1):
    C[k][1] = 1
    for l in range(2, k+1):
      for x in range(k-1, 0, -1):
        if (S[k] - S[x]) % l == 0:
          C[k][l] = C[x][l] + C[x][l-1]
          C[k][l] %= mod
          break
      # 条件を満たす x がない場合、C[k][l] は初期値 0 のまま
      # x == 0 の別枠は C[k][1] = 1 および、その他の場合の初期値 0 で満たされている  

  print(sum(C[N]) % mod)

#main(sys.stdin)
with open('input.txt', 'r') as file:
  main(file)

DPで遷移計算をする際は、最初の c(1, 1) は自力で計算する必要がありますが、今の場合、任意の k について c(k, 1) = 1 と分かるので、上記のコードではこれを利用しています。

x の事前計算

ここで利用するのが、「部分和の計算」で説明した「事前計算」のテクニックです。前述のDPでは、遷移計算の回数が O(N^2) になるので、それぞれの遷移計算において、毎回、O(N) のループを回すと、全体で O(N^3) になるわけですが、遷移計算のループを回す前に事前に、必要なすべての x を計算しておきます。これは1回限りの計算のなので、最悪、O(N^2) で済ますことができればOKです。遷移計算のループの中では、計算済みの x を参照するだけなので、1回の遷移を O(1) で済ますことができます。

では、実際に求める必要がある x はどのような集合でしょうか? さきほどの条件 (1) を見返すと、k\ (1\le k\le N)l\ (1\le l\le N) を決めると対応する x がひとつ決まるので、x[k][l] という N\times N リストにまとめることができます。全部で N^2 個の要素があるので、1つの要素を O(1) で決定できれば、リスト全体を O(N^2) で用意できることになります。

ただしここで、この要素の決め方に工夫が必要です。各 k に対して、(1) を満たす直近の x \le k-1) を発見する必要があるわけですが、これを愚直に実装すると、次の様になるでしょう。

  X = []
  for _ in range(N+1):
    X.append([-1] * (N+1))

  for k in range(1, N+1):
    for l in range(1, N+1):
      for x in range(k-1, -1, -1):
        if (S[k] - S[x]) % l == 0:
          X[k][l] = x
          break

これは、本質的には、以前のコードと同じで、O(X^3) の計算量になってしまいます。さあどうすればよいのでしょうか・・・。ここで注目したいのは、次の条件式です。

 if (S[k] - S[x]) % l == 0:

内容的には、先程の (1) の条件を表すものですが、これだけをみていると、

if S[k]%l == S[x]%l:

という、数列 S_1,S_2,\cdots,S_N に対する条件式に見えてきます。つまり、k を固定した場合に、

S_1,\cdots,S_k

という数列を考えて、l で割った余りが S_k に一致する直近の S_x を発見する問題になっているのです。ここで、話を理解しやすくするために、次の簡単化した例題を考えてみます。

[例題]
次のリストの各要素について、自分より前にある直近の同じ値のインデックスを求めて、これらを並べたリストを作成しなさい。

A = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 4, 8, 9, 7, 9, 3]

たとえば、A[3] = 1 については、A[1] = 1 が直近の同じ値になります。あるいは、A[15] = 3 については、A[9] = 3 が該当します。自分より前に同じ値がない場合は、インデックスとして -1 を与えるものとします。


実は、この問題は前からスキャンすることにより、O(N) で解くことができます。

たとえば、最初の A[0] = 3 を考えると、この要素のインデックス 0 は、この後に来る 3 に対する直近の値の「候補」となります。そこで、3 に対する答えの候補として、

candidate[A[0]] = 0

と記録しておきます。スキャンを進めると、A[9] = 3 が登場したところで、これに対する答えは、

candidate[A[9]]

として読み出すことができます。ただし、この時点で、インデックス 0 は、これより後に来る 3 に対する答えにはなりません。いま登場した A[9] = 3 のインデックス 9 を新たな候補として再登録する必要があります。

candidate[A[9]] = 9

このようにして、過去に記録した「候補」の読み出しと、候補のアップデートを繰り返すことで、すべての要素に対する答えを得ることができます。具体的には、次の様な実装が可能です。

candidate = [-1] * 10
A = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 4, 8, 9, 7, 9, 3]
idx = [-1] * len(A)

for i in range(len(A)):
  idx[i] = candidate[A[i]]
  candidate[A[i]] = i

print(idx)
###############

[-1, -1, -1, 1, -1, -1, -1, -1, 4, 0, 2, -1, 5, -1, 12, 9]

これと同じテクニックが今の場合にも利用できます。

たとえば、x=1 は、(k=2,3,\cdots)に対する S_k\%l == S_1\%l を満たす x の候補となります。

candidate[S[1]%l] = 1 # ----- (2)

次に、k=2 の場合を考えると、S_k\%l == S_x\%l を満たす x は何でしょうか? 

仮に、S_2\%l == S_1\%l であれば、直前の候補 candidate[S[1]%l] が答えになります。

X[2][l] = candidate[S[2]%l] # ----- (3)

もしも S_2\%l == S_1\%l でなければ、条件を満たす x は存在しないことになりますが、その場合、リスト candidate の要素は、初期値 -1 で初期化されていることにすれば、(3) の代入によって、X[2][l] = -1 となります。-1 を「対応する x は存在しない」という意味のシンボルと解釈すれば、いずれにしても (3) の代入は正しい計算になります。

ここでさらに、x=2S_k\%l == S_2\%l に対するあらたな候補として記録します。

candidate[S[2]%l] = 2

そして、次に、k=3 の場合を考えると・・・。もうお分かりのように、

X[3][l] = candidate[S[3]%l]

とすることで、(条件を満たす x が存在しない場合を含めた)適切な答えが得られます。つまり、直前までの候補から答えを決定して、次の候補を記録する、という k についての O(N) のループを回すことで、(特定の l に対する)X[k][l] が決められるのです。

これを l についてもループすれば、全体として、O(N^2) で X[k][l] を事前計算することができるのです。コードとしては、こんな感じになります。

  X = []
  for _ in range(N+1):
    X.append([-1] * (N+1))

  for l in range(1, N+1):
    candidate = [-1] * (N+1) # -1 で初期化
    candidate[S[1]%l] = 1 # k=2 以降に対する S[k]%l=S[1]%l となる x の候補 
    for k in range(2, N+1):
      X[k][l] = candidate[S[k]%l] # これまでの候補から決まる
      candidate[S[k]%l] = k # k'=k+1 以降に対する S[k']%l=S[k]%l となる x の候補

厳密には、X[1][l] については、別枠で計算して、0 もしくは -1 をセットしないといけないのですが、ここでは初期値 0 のままにしてあります。実際にリスト X を使う部分は次の様になりますので、x > 0 の場合だけを計算しておけば十分です。

  for k in range(1, N+1):
    C[k][1] = 1
    for l in range(2, k+1):
      x = X[k][l]
      if x > 0:
        C[k][l] = C[x][l] + C[x][l-1]
        C[k][l] %= mod
      # x == 0 の別枠は C[k][1] = 1 および、その他の場合の初期値 0 で満たされている

完成品

というわけで、なかなか大変でしたが、無事に完成したコードがこちらになります。(それでも、Python だと実行時間が苦しいです。PyPy3 で通りました。)

atcoder.jp

最後の x の事前計算は、簡単化した例題を知らないとなかなか思いつけない気もしますが・・・。こういう複雑な問題の一部として利用されることが多い、「簡単化した例題」を集めた問題集とかあるといいですね。

おまけ(関数のメモ化)について

「しらみつぶしの場合を考える」で触れた関数のメモ化ですが、「しらみつぶし」ケースを単純にメモ化してもうまくいきませんでした。これは、メモ化された関数を呼び出す順序に理由があります。もともと O(N!) のアルゴリズムですので、関数の実行時間そのものはメモ化で短縮されたとしても、関数を呼び出す回数そのものが減っているかどうかは別問題です。仮に呼び出し回数が O(N!) のままだとすると、何もしないループを N! 回まわすのと同じことで、それなりの実行時間がかかります。

一方、メモ化によって、再帰的な関数の再呼び出しが削減されて、関数の呼び出し回数自体が O(N^2) などに減少したとすれば、これはメモ化でうまくいくケースになります。

じゃあ、どうやってそのケースを見分けるの・・・? という事になりますが、うまくいくケースは、ほとんどの場合、「DPと同等のアルゴリズムをメモ化によって実現している」というオチになります。なので、はじめからメモ化にたよるのではなく、DPによる解法を見つけた後に、それを実装する方法として、(配列を用意して順番に埋めていくのではなく)メモ化を利用するという流れになります。

具体的には、DPの遷移処理を再帰的に実装して、その関数にメモ化を適用すれば、実質的に、配列を用いたDPと同等の処理になります。

ただし、関数呼び出しのオーバーヘッドと、ネストした呼び出しの上限があるため、サイズの大きい問題では、メモ化を使うのではなく、愚直に配列を用いた方が安全です。たとえば、先ほどの回答をメモ化で実装すると、こちらになります。が・・・まず、

sys.setrecursionlimit(5000)

でネストの上限を上げないと、ランタイムエラーが発生します。また、残念ながらこの例では、いくつかのケースで実行時間制限をオーバーしていることがわかります。