何の話かというと
(とある事情があって)競技プログラミングの過去問をいろいろ調べていたところ、上記の問題に遭遇して、(公式解説もチラ見しながら)なんとか Accept されるコードにたどり着いたのですが、この問題、かなり色々な要素(テクニック)が含まれている気がして、懇切丁寧に解説してみたくなった次第です。
部分和の計算
まず、基本的なところですが、数列 の一部を切り出した の和(部分和)が登場する問題では、必要になる度に部分和を計算するのは無駄が多いので、事前に
を計算しておいて、
と の引き算に置き換えてしまいます。ちなみに、 を事前計算する際も、それぞれを個別に計算するのは無駄なので、
という関係を使って、 回の足し算に置き換えます。コードにするとこんな感じ。
# A = [A_0, A_1, A_2, ..., A_N] S = [] S.append(0) # S_0 = 0 for k in range(1, N+1): S.append(S[k-1] + A[k])
本当は は存在しない値ですが、リストのインデックスと数列の添字を一致させるために(つまり A[k] が に対応する様に)ダミー値として加えてあります。
計算量的にいうと、部分和の計算を毎回やると、毎回 の計算が必要になるところが、 を利用した場合は、事前に の計算を一度だけやっておけば、それ以降の部分和の計算はすべて に計算量を減らすことができます。
「どっちにしても じゃん!」と思う方もいるかも知れませんが、「事前に一度やっておけばOK」という点がポイントになります。仮に 回のループの中で、毎回、部分和を計算した場合、全体の計算量は に膨れ上がります。一般に、似た様な計算を何度も繰り返す問題の場合、「事前にまとめて計算できる部分はないかしら?」と考えるのは良いことです。
しらみつぶしの場合を考える
競技プログラミングでは、2つの意味での「時間制限」があります。1つは、コードの「実行時間」の制限で、もう一つは、競技の開催時間、つまり、「回答時間」の制限です。
回答時間の制限を考えると、「実行時間」が間に合わないと分かっている非効率な解法をあえてコードにする余裕はありませんが、回答時間の制限がない場合(勉強目的、あるいは、実務でコードを書く場合)は、まずは、「しらみつぶし」の解法をコードにするのは無意味ではありません。「しらみつぶし」のコードで(サンプル入力に対する)正解が得られれば、少なくとも、自分の「問題の理解」が間違っていないことがチェックできることになります。
で、実際にやるとこんな感じですね。
import sys N = 0 A, S = [], [] mod = 10**9+7 def count(i=1, l=1): # A[i] 以降を l の倍数、l+1 の倍数・・・のグループに分ける global N, A, S, mod if i == N+1: return 1 result = 0 for k in range(i, N+1): # (A[i] ... A[k]) が l の倍数かチェックする if (S[k] - S[i-1]) % l == 0: result += count(k+1, l+1) # A[k+1] 以降を l+1 の倍数・・・のグループに分ける result %= mod # 結果が大きくなりすぎない様に mod で割ったあまりに変換しておく return result def main(f): global N, A, S, mod # 入力データ読み込み lines = f.readlines() N = int(lines.pop(0).strip()) A = [0] + list(map(int, lines.pop(0).strip().split())) # A = [A_0, A_1, A_2, ..., A_N] S.append(0) # S_0 = 0 for k in range(1, N+1): S.append(S[k-1] + A[k]) print(count(1, 1)) #main(sys.stdin) with open('input.txt', 'r') as file: main(file)
コードを読めば分かる様に、「1本目の区切りをどこにおくか」を でスキャン、「2本目の区切りをどこにおくか」を でスキャン・・・ ということを再帰的に行なっているので、全体の計算量は という非常に大きなものになっています。
ここから計算量をどうやって減らすかを考えるわけですが、実際に計算量が減らせるかどうかは、冗長な計算が含まれているかどうか、つまり、「コードの実行時にまったく同じ計算を何度も繰り返しているか」に依存します。もしも同じ計算を繰り返していれば、それを1回にまとめるようなアルゴリズムを考えれば、計算量を減らせることになります。
例えば、いまの場合、再帰的に呼び出される関数 count(i, l) は、「まったく同じ引数で何度も呼ばれる」可能性があることに気が付きます。例として、1本目と2本目の区切りが次のように決まり、どちらも前提条件を満たしたとします。
A_0 | A_1 A_2 | A_3 ...
A_0 A_1 | A_2 | A_3 ...
いずれの状態からも、「 以降を3の倍数、4の倍数・・・のグループに分ける」という処理、すなわち、count(3, 3) がコールされます。確かに改善の余地があるようです。
ちなみに、このような状況でお手軽に高速化するテクニックに、「関数のメモ化」があります。次の様に関数の実行結果を保持するキャッシュを用意して、同じ引数で呼ばれた場合は、キャッシュから値を返すという手法です。
cache = {} def count(i=1, l=1): global N, A, S, mod, cache if (i, l) in cache.keys(): return cache[(i, l)] ... cache[(i, l)] = result return result
ただし、今回のケースでは、大元が というツライ状況なので、この程度では実行時間制限は突破できません。
サイズに対する漸化的な解法
このような冗長な再起処理を改善する方法に、「サイズが までの問題(つまり を対象とした問題)の答えが求まっていると仮定して、サイズが の問題( を対象とした問題)の答えを導く」という手法があります。(小難しくいうと、「DP(動的計画法)」なんですが、ここでは、まずは「考え方」にフォーカスしましょう。)
今回の場合は、「を(条件を満たす様に) 分割するパターン数 を求める」という問題を について解いた後に、 を計算するという流れになるので、
が分かっているとして、
を計算する、ということを考えます。この計算を段階的にすすめて行き、最終的に、
にたどり着けば、これを合計して問題の答えが得られることになります。
全部で (程度)の を計算するので、仮に、「(前段の計算結果を用いて)次の を計算する」という処理(これを「遷移計算」と呼びます)が で実行できれば、全体の計算量は になります。あるいは、遷移計算が であれば、全体の計算量は まで減ります。
このように、DP(動的計画法)では、遷移計算をいかに効率化するか(あるいは、遷移計算が簡単になるような、計算の「段取り」をうまく見つけられるか)がポイントになります。
遷移計算
で・・・仮に の遷移計算が(簡単に)実現できれば、普通の DP の問題ということになるのですが、実は、この問題の場合、ここが一筋縄ではいきません。まずは、愚直に考えてみます。
までの問題(つまり、 を対象とする問題)が解けているという前提で、新たに を追加して、 に該当する分割を考えます。この時、新たに追加された は、必ず、最後のグループ、つまり、 の倍数グループに入ります。そこで、最後のグループを決める区切り線の場所を の直前から1つずつ順番に動かしていって、
「 が の倍数グループになる」
(つまり、 を満たす)位置 を探していきます。見つかった に対して、その区切り線の前にある は、全部で 個のグループに分かれていないといけないので、そのパターン数は、 として決まります。これらを全部足しあげれば、 が得られます。コードで書くならこんな感じ。(x==0 の場合だけちょっと別枠なのに注意。)
c[k][l] = 0 for x in range(k, -1, -1): if (S[k] - S[x]) % l == 0: if x == 0: c[k][l] += 1 else: c[k][l] += c[x][l-1]
しかしながら、これは 回の繰り返しを含む処理なので、( は最大で を取るので)大雑把にいうと の計算量になります。前述の議論を思い出すと、アルゴリズム全体の計算量は になります。当初の に比べると格段に進歩しましたが、残念ながら、競技プログラミングとしての冒頭の問題の回答にはなりません。このコードを提出すると、実行時間オーバーになってしまいます。
いったいどうすればよいのでしょうか・・・
で、ここで、公式解説を「ちらっ」とだけ覗いてみます。
これによると、最後のグループを決める区切り線の場所を の直前から1つずつ動かすという処理は、最後まで全部実行する必要はなくて、
「 が の倍数グループになる」-------- (1)
という条件を満たす最初の (つまり の範囲で最大の )を発見すれば十分なのです。解説の図を見ると理解できる様に、この を用いて、
と遷移計算が実行できます。あら素敵。
(ただし、例によって の場合は別枠で、 の場合は 、 の場合は となります。)
なのですが・・・これだけでは問題は解決したことになりません。最初の を発見するために愚直にスキャンした場合、最悪ケースで のループになることに変わりありません。この をもっと高速に で発見する必要があるのです。
とはいえ、そろそろ道に迷いつつある方もいるやも知れませんので、いったん、ここまでの解法をコードにまとめておきます。
import sys N = 0 A, S = [], [] mod = 10**9+7 def main(f): global N, A, S, mod # 入力データ読み込み N = int(f.readline()) A = [0] + list(map(int, f.readline().split())) # A = [A_0, A_1, A_2, ..., A_N] S.append(0) # S_0 = 0 for k in range(1, N+1): S.append(S[k-1] + A[k]) C = [] for _ in range(N+1): C.append([0] * (N+1)) for k in range(1, N+1): C[k][1] = 1 for l in range(2, k+1): for x in range(k-1, 0, -1): if (S[k] - S[x]) % l == 0: C[k][l] = C[x][l] + C[x][l-1] C[k][l] %= mod break # 条件を満たす x がない場合、C[k][l] は初期値 0 のまま # x == 0 の別枠は C[k][1] = 1 および、その他の場合の初期値 0 で満たされている print(sum(C[N]) % mod) #main(sys.stdin) with open('input.txt', 'r') as file: main(file)
DPで遷移計算をする際は、最初の は自力で計算する必要がありますが、今の場合、任意の について と分かるので、上記のコードではこれを利用しています。
の事前計算
ここで利用するのが、「部分和の計算」で説明した「事前計算」のテクニックです。前述のDPでは、遷移計算の回数が になるので、それぞれの遷移計算において、毎回、 のループを回すと、全体で になるわけですが、遷移計算のループを回す前に事前に、必要なすべての を計算しておきます。これは1回限りの計算のなので、最悪、 で済ますことができればOKです。遷移計算のループの中では、計算済みの を参照するだけなので、1回の遷移を で済ますことができます。
では、実際に求める必要がある はどのような集合でしょうか? さきほどの条件 (1) を見返すと、 と を決めると対応する がひとつ決まるので、x[k][l] という リストにまとめることができます。全部で 個の要素があるので、1つの要素を で決定できれば、リスト全体を で用意できることになります。
ただしここで、この要素の決め方に工夫が必要です。各 に対して、(1) を満たす直近の を発見する必要があるわけですが、これを愚直に実装すると、次の様になるでしょう。
X = [] for _ in range(N+1): X.append([-1] * (N+1)) for k in range(1, N+1): for l in range(1, N+1): for x in range(k-1, -1, -1): if (S[k] - S[x]) % l == 0: X[k][l] = x break
これは、本質的には、以前のコードと同じで、 の計算量になってしまいます。さあどうすればよいのでしょうか・・・。ここで注目したいのは、次の条件式です。
if (S[k] - S[x]) % l == 0:
内容的には、先程の (1) の条件を表すものですが、これだけをみていると、
if S[k]%l == S[x]%l:
という、数列 に対する条件式に見えてきます。つまり、 を固定した場合に、
という数列を考えて、 で割った余りが に一致する直近の を発見する問題になっているのです。ここで、話を理解しやすくするために、次の簡単化した例題を考えてみます。
[例題]
次のリストの各要素について、自分より前にある直近の同じ値のインデックスを求めて、これらを並べたリストを作成しなさい。
A = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 4, 8, 9, 7, 9, 3]
たとえば、A[3] = 1 については、A[1] = 1 が直近の同じ値になります。あるいは、A[15] = 3 については、A[9] = 3 が該当します。自分より前に同じ値がない場合は、インデックスとして -1 を与えるものとします。
実は、この問題は前からスキャンすることにより、 で解くことができます。
たとえば、最初の A[0] = 3 を考えると、この要素のインデックス 0 は、この後に来る 3 に対する直近の値の「候補」となります。そこで、3 に対する答えの候補として、
candidate[A[0]] = 0
と記録しておきます。スキャンを進めると、A[9] = 3 が登場したところで、これに対する答えは、
candidate[A[9]]
として読み出すことができます。ただし、この時点で、インデックス 0 は、これより後に来る 3 に対する答えにはなりません。いま登場した A[9] = 3 のインデックス 9 を新たな候補として再登録する必要があります。
candidate[A[9]] = 9
このようにして、過去に記録した「候補」の読み出しと、候補のアップデートを繰り返すことで、すべての要素に対する答えを得ることができます。具体的には、次の様な実装が可能です。
candidate = [-1] * 10 A = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 4, 8, 9, 7, 9, 3] idx = [-1] * len(A) for i in range(len(A)): idx[i] = candidate[A[i]] candidate[A[i]] = i print(idx) ############### [-1, -1, -1, 1, -1, -1, -1, -1, 4, 0, 2, -1, 5, -1, 12, 9]
これと同じテクニックが今の場合にも利用できます。
たとえば、 は、()に対する を満たす の候補となります。
candidate[S[1]%l] = 1 # ----- (2)
次に、 の場合を考えると、 を満たす は何でしょうか?
仮に、 であれば、直前の候補 candidate[S[1]%l] が答えになります。
X[2][l] = candidate[S[2]%l] # ----- (3)
もしも でなければ、条件を満たす は存在しないことになりますが、その場合、リスト candidate の要素は、初期値 -1 で初期化されていることにすれば、(3) の代入によって、X[2][l] = -1 となります。-1 を「対応する は存在しない」という意味のシンボルと解釈すれば、いずれにしても (3) の代入は正しい計算になります。
ここでさらに、 を に対するあらたな候補として記録します。
candidate[S[2]%l] = 2
そして、次に、 の場合を考えると・・・。もうお分かりのように、
X[3][l] = candidate[S[3]%l]
とすることで、(条件を満たす が存在しない場合を含めた)適切な答えが得られます。つまり、直前までの候補から答えを決定して、次の候補を記録する、という についての のループを回すことで、(特定の に対する)X[k][l] が決められるのです。
これを についてもループすれば、全体として、 で X[k][l] を事前計算することができるのです。コードとしては、こんな感じになります。
X = [] for _ in range(N+1): X.append([-1] * (N+1)) for l in range(1, N+1): candidate = [-1] * (N+1) # -1 で初期化 candidate[S[1]%l] = 1 # k=2 以降に対する S[k]%l=S[1]%l となる x の候補 for k in range(2, N+1): X[k][l] = candidate[S[k]%l] # これまでの候補から決まる candidate[S[k]%l] = k # k'=k+1 以降に対する S[k']%l=S[k]%l となる x の候補
厳密には、X[1][l] については、別枠で計算して、0 もしくは -1 をセットしないといけないのですが、ここでは初期値 0 のままにしてあります。実際にリスト X を使う部分は次の様になりますので、 の場合だけを計算しておけば十分です。
for k in range(1, N+1): C[k][1] = 1 for l in range(2, k+1): x = X[k][l] if x > 0: C[k][l] = C[x][l] + C[x][l-1] C[k][l] %= mod # x == 0 の別枠は C[k][1] = 1 および、その他の場合の初期値 0 で満たされている
完成品
というわけで、なかなか大変でしたが、無事に完成したコードがこちらになります。(それでも、Python だと実行時間が苦しいです。PyPy3 で通りました。)
最後の の事前計算は、簡単化した例題を知らないとなかなか思いつけない気もしますが・・・。こういう複雑な問題の一部として利用されることが多い、「簡単化した例題」を集めた問題集とかあるといいですね。
おまけ(関数のメモ化)について
「しらみつぶしの場合を考える」で触れた関数のメモ化ですが、「しらみつぶし」ケースを単純にメモ化してもうまくいきませんでした。これは、メモ化された関数を呼び出す順序に理由があります。もともと のアルゴリズムですので、関数の実行時間そのものはメモ化で短縮されたとしても、関数を呼び出す回数そのものが減っているかどうかは別問題です。仮に呼び出し回数が のままだとすると、何もしないループを 回まわすのと同じことで、それなりの実行時間がかかります。
一方、メモ化によって、再帰的な関数の再呼び出しが削減されて、関数の呼び出し回数自体が などに減少したとすれば、これはメモ化でうまくいくケースになります。
じゃあ、どうやってそのケースを見分けるの・・・? という事になりますが、うまくいくケースは、ほとんどの場合、「DPと同等のアルゴリズムをメモ化によって実現している」というオチになります。なので、はじめからメモ化にたよるのではなく、DPによる解法を見つけた後に、それを実装する方法として、(配列を用意して順番に埋めていくのではなく)メモ化を利用するという流れになります。
具体的には、DPの遷移処理を再帰的に実装して、その関数にメモ化を適用すれば、実質的に、配列を用いたDPと同等の処理になります。
ただし、関数呼び出しのオーバーヘッドと、ネストした呼び出しの上限があるため、サイズの大きい問題では、メモ化を使うのではなく、愚直に配列を用いた方が安全です。たとえば、先ほどの回答をメモ化で実装すると、こちらになります。が・・・まず、
sys.setrecursionlimit(5000)
でネストの上限を上げないと、ランタイムエラーが発生します。また、残念ながらこの例では、いくつかのケースで実行時間制限をオーバーしていることがわかります。