めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

D - Cooking の解説

何の話かと言うと

atcoder.jp

上記の問題をネタに、ナップサック問題の計算量に関するちょっとした考察をしてみます。

「しらみつぶし」がナップサック問題の基本

問題をパッとみて、ナップサック問題に似ていますよね。まず、下記の記事で、ディクショナリーを使ったナップサック問題の解法を確認してください。

enakai00.hatenablog.com

この記事では、

今回のケースでは、それぞれの荷物について、「入れる」「入れない」の二択の判断が入ることになります。実際にナップサックに入り切るかどうかは別にして、

・1個目の荷物を入れない場合
・1個目の荷物を入れる場合

の結果を記録して、その結果に基づいて、

・2個目の荷物を入れない場合
・2個目の荷物を入れる場合

の結果を記録して、その結果に基づいて、

・3個目の荷物を入れない場合
・3個目の荷物を入れる場合

・・・ということを N 個目の荷物まで繰り返せば、あらゆる組み合わせパターンを記録することができます。

とあるように、基本的にはすべてのパターンを記録していくという「しらみつぶし」をベースにした解法を紹介しています。

今回の問題についても、同様に考えると、

・1個目の料理をレンジ1に入れる場合(対称性から、1個目についてはレンジ2に入れる場合は考えなくてよい)

の結果を記録して、その結果に基づいて、

・2個目の料理をレンジ1に入れる場合
・2個目の料理をレンジ2に入れる場合

の結果を記録して・・・

という繰り返しになります。それぞれの場合で変化するのは、2個のレンジそれぞれの使用時間合計ですので、これらの組 (t1, t2) が Key になります。で、ナップサック問題であれば、それで実現できる合計価値が Value になりますが、今は、max(t1, t2) が最小になる組を知りたいわけですので、max(t1, t2) を Value にしておきます。そうすると、次の様な非常にシンプルなループが書けます。

  dp = {}

  dp[(T[1], 0)] = T[1] # dp[(t1, t2)] = max(t1, t2)
  for n in range(2, N+1):
    dp_n = {}
    for t1, t2 in dp.keys():
      dp_n[(t1 + T[n], t2)] = max(t1 + T[n], t2)
      dp_n[(t1, t2 + T[n])] = max(t1, t2 + T[n])
    dp = dp_n

最後に dp[ ] に記録された Value の最小値を選べばOKです。

  print(min(list(dp.values())))

で、実は、このシンプルな解法ですべてのケースがパスできます。えっ?!

atcoder.jp

計算量の考察

上記の解法は、基本、「しらみつぶし」なので、状態 (t1, t2) の数が膨大になって計算量が大変なことになる気がします。単純計算で、2^N 通りの組み合わせですよね。

しかしながら、実際には、t1 のとり得る値は 0 \le \sum T_i \le 10^5 に限定されます。t2 は t1 から一意に決まることに注意すると、状態数は高々 10^5 で抑えられるのです。したがって、一般に、全体の計算量は O(N \sum T_i) で収まります。なるほどー。

ちなみに、この問題でも、ナップサック問題と同様に刈り込みを入れることができます。ナップサック問題での刈り込みについては、下記を参照。

enakai00.hatenablog.com

具体的には、「max(t1, t2) が大きくなると、min(t1, t2) は小さくなるべき」という条件です。で、dp[ ] への情報の持たせ方をちょっと工夫して、刈り込みをいれてみたのですが、なんと、逆に TLE が発生してしまいます。

atcoder.jp

刈り込みの際は、Key のソートとスキャンという処理が走るので、こちらのオーバーヘッドが刈り込みによる計算量の削減を上回るということなのでしょう。なるほどー。

Union-find をグラフ探索で代替できる例

何の話かと言うと

要素をグループ分けする問題で、よく Union-Find がテクニックとして取り上げられますが、もうちょっと単純に、同じグループの要素に双方向リンクを貼っていって、最後に、グラフ探索で同じグループの要素をまとめるという手法で間に合う場合もあります。

例1 : D - Equals

atcoder.jp

互いに入れ替え可能な位置を同一グループと見なします。その後、「整数 P_i の位置 i が位置 P_i と同じグループに入っている」という条件を満たす P_i の個数を数えればOKです。双方向リンクとグラフ探索による実装はこちら。

atcoder.jp

例2 : D - KAIBUNsyo

atcoder.jp

数列に現れる数字をノードとして、ペアになる数字を同じグループに属するものとします。それぞれのグループについて「ノード数 - 1」を計算して、これらを合計したものが答えになります。双方向リンクとグラフ探索による実装はこちら。

atcoder.jp

Union-Find でないと間に合わない例

とは言え、同一グループの探索が何度も走る問題では、Union-Find でないと間に合いません。例えばこちら。

atcoder.jp

グラフ探索による結果

atcoder.jp

Union-Find による結果

atcoder.jp

Q - Flowers の解説

何の話かと言うと

atcoder.jp

この問題をネタにして、「ソートされたリストへの挿入位置をバイナリーサーチで高速に検索する」というテクニックを紹介します。

ぶっちゃけ、これ難問かも?

ナップサック問題との類似性で考える

「花を取り除く」という立て付けですが、「残しておく花を選ぶ」と考えると、「価値が最大になるように荷物を選ぶ」というナップサック問題に類似していることに気がつきます。

まずは、下記の記事でナップサック問題の(ディクショナリーを使った)解法を確認してください。

enakai00.hatenablog.com
enakai00.hatenablog.com

これを踏まえると、

dp[h] = 最大の高さ h で実現できる最大価値

というディクショナリーを用意して、花を順番に加えていくという解法を思いつきます。ナップサック問題とほぼ同じロジックなので、詳細は省いてコードをお見せするとこちらになります。

import sys, copy
from collections import defaultdict, deque

def main(f):
  N = int(f.readline())
  H = [None] + list(map(int, f.readline().split()))
  A = [None] + list(map(int, f.readline().split()))

  # n = 1
  dp = defaultdict(int) # dp[h] : 最大の高さ h で実現できる最大価値
  dp[0] = 0
  dp[H[1]] = A[1]

  # n = 2
  for n in range(2, N+1):
    # 刈り込み
    # h が増えれば価値も増えるべき
    hs = list(dp.keys())
    hs.sort()
    pre = -1
    for h in hs:
      if dp[h] <= pre:
        del dp[h]
      else:
        pre = dp[h]
      
    dp_n = copy.copy(dp) # n 個目を加えない場合
    for h in dp.keys():
      if H[n] > h:
        dp_n[H[n]] = max(dp_n[H[n]], dp[h] + A[n])
    dp = dp_n
    #print(dp)

  print(max(dp.values()))

main(sys.stdin)

なのですが、このコードを提出すると、半数強のケースで TLE します。

atcoder.jp

方向性としては悪くなさそうですが、さらなる効率化が必要です。

ループ処理の無駄を見つける

上記のコードのどこに無駄があるのでしょうか・・・。無駄なループ(本当はループしなくても実現できる処理をわざわざループで処理している部分)を探すのがテクニックの一つです。

たとえば、このループを考えます。

    for h in dp.keys():
      if H[n] > h:
        dp_n[H[n]] = max(dp_n[H[n]], dp[h] + A[n])

新しく追加する花の高さ H[n] に対して、H[n] > h となる h をループで探していますが、最終的に残るのは、dp[h] + A[n](すなわち dp[h])が最大になるものだけです。この「本当に必要な h」をループを回さずに一発で発見できないものでしょうか。

で・・・ここが絶妙なところなのですが・・・

上記のコードでは、新しい花を追加するごとに、dp に刈り込み処理を入れていますが、これにより、

・dp.keys() を昇順にソートすれば、対応する dp の値も昇順ソートされる

という著しい特徴があります。

したがって、hs = 「dp.keys() をソートしたリスト」として、(ソート状態を保ったまま)H[n] を挿入できる位置、すなわち、hs[index-1] < H[n] < hs[index] を満たす index を発見すれば、h = hs[index-1] が求める h になります。このような挿入位置は、バイナリーサーチ(Python であれば、bisect.bisect 関数)で高速に検索することができます。

バイナリーサーチを活用した実装

これで、愚直なループをより高速なバイナリーサーチに置き換えられることがわかりましたが、このためには、dp.keys() をソートしたリスト hs をメンテナンスする必要があります。

そこで、ディクショナリー dp のキーとバリューを別々のリストに分解してしまいます。

  #  最大の高さを hs[i] とした場合の最大価値が dp[i]
  hs = [] # 昇順ソートを保つ
  dp = [] # hs がソートされていると(刈り込みによって)こちらもソートされるはず

たとえば、n = 1 のケースのデータは次の様に代入されることになります。

  # n = 1
  hs.append(0)
  hs.append(H[1])
  dp.append(0)
  dp.append(A[1])

次に、n = 2,...,N では、H[n] > h を満たす最大の h を探して、最大の高さを H[n] とする場合の最大価値として、dp[h] + A[n] を記録することができます。具体的には、次のようなコードになります。

  # n = 2,...,N
  for n in range(2, N+1):
    index = bisect.bisect(hs, H[n]) # 高さによる挿入位置
    val_new = dp[index-1] + A[n]

    hs.insert(index, H[n])    # 高さのオーダーで挿入 ---- (1)
    dp.insert(index, val_new) # 高さのオーダーで挿入 ---- (2)
...

bisect.bisect(hs, H[n]) は、hs[index-1] < H[n] < hs[index] を満たす index を返します。

この時点では、hs はソート状態が保たれていますが、ds はソート状態が保たれなくなります。ただし、この後、刈り込みを入れる事で ds もソート状態にもどります。高さが上がれば、最大価値も上がるべきなので、今、ds に挿入した val_new に対して、これよりも後ろに val_new 未満の値があれば、それらは削除して構いません。これにより、ds もソートされた状態になります。

では、削除するべき要素はどうやって発見するかというと・・・

これもうまくできているのですが、(1)(2) の挿入を行う前に、dp に対して、バイナリサーチをかけると、

・dp[index2-1] < val_new < dp[index2]

を満たす index2 が発見できます。したがって、リスト dp の中で、「実際に val_new を挿入した位置から index2-1 の範囲」にある値が削除対象になります。

というわけで、刈り込みについても、ループを回さずに実施できてしまいます。

  # n = 2,...,N
  for n in range(2, N+1):
    index = bisect.bisect(hs, H[n]) # 高さによる挿入位置
    val_new = dp[index-1] + A[n]
    index2 = bisect.bisect(dp, val_new) # 価値による挿入位置

    hs.insert(index, H[n])    # 高さのオーダーで挿入
    dp.insert(index, val_new) # 高さのオーダーで挿入

    if index+1 < index2+1: # index+1 〜 index2 のデータは不要(刈り込み)
      del hs[index+1:index2+1]
      del dp[index+1:index2+1] # これで dp も sorted になる

index2 を検索した時点では、index2-1 までのデータが不要でしたが、その後、val_new を挿入しているので、実際に削除するのは、val_new の挿入位置 index に対して、index+1 〜 index2 の範囲になります。

できてしまえば、実にあっさりですが、これですべてのケースがパスできます。(これに気づくまで、実際には数時間かかってますが。。。。)

atcoder.jp

おまけ

もう少しシンプルなバイナリーサーチの練習問題は、こちらなどがよいでしょう。

atcoder.jp

解答例はこちら

atcoder.jp