めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

D - ナップサック問題の解説(その3)

何の話かと言うと

enakai00.hatenablog.com

上記のエントリーの続きです。はい。

不要な記録をさらに削ぎ落とす(カッコよく言うと刈り込みをする的な何か)

前回の記事の最後に

ちなにみ、サブタスク3と言うのは、1\le N\le 200 かつ 0\le v_1\le 1000 という条件があるもので、重さが巨大になる可能性があります。

一方、今回パスしたサブタスク2は、 1\le N\le 200 かつ 0\le w_1\le 1000 という条件があるもので、重さの範囲が 1,000 以下に制限されています。

さて・・・単純に記憶するべきデータ量 2^N はどちらも同じなのに、サブタスク2はオッケーで、サブタスク3が通らないのはなぜでしょう????

という謎のメッセージを残しました。

実際の所、データ量は同じなのに、なんでサブタスク2は時間がかからないんでしょう? どこかで妖精さんが勝手にデータを減らしてくれている???

実は前回の記事の中に、妖精さんの居場所を示したヒントがありました。

そう、繰り返しを進めるごとにディクショナリーのキーの数は、(残容量の偶然の一致を除いて)倍々に増えていくので、最終的には、2^N 個ものキーを持った巨大なディクショナリーができあがってしまいます。

ここです!

(残容量の偶然の一致を除いて)

ここ。

サブタスク2は重さの範囲が 1,000 以下に限定されているので、さまざまな重さの組み合わせをした結果、残容量が被る可能性が高くなります。データ数が最大 200 なので、総容量の値は、論理的に考えて高々 1,000 * 200 = 200,000 通りですので、残容量の値もこれと同じく 200,000 通りに限定されます。一般的な組み合わせの数 2^{200} よりは圧倒的に小さいことがわかります。

つまり、サブタスク2は、残容量の値が被りまくることで、結果的にディクショナリーに記録されるデータ量が削減されていたのです。

なるほど!

という感じですが、では、このような都合のよい条件がないサブタスク3に対応するにはどうすればよいのでしょうか・・・?

ここで、残容量が被った場合の処理を思い出してみます。

結論からいうと、値が大きい方を残すべきです。今、解こうとしている問題は、価値をどこまで上げられるか、という問題ですので、同じ残容量であれば、当然ながら、より高い価値を実現するケースの方が後々で必要になるはずです。

つまり、問題のゴールを考えれば、すべての情報をまじめに記録する必要はなく、「これよりも有利な状況があきらかに存在する」と分かっている組み合わせの情報は捨ててしまっても問題ないのです。残容量が被る場合は、ディクショナリーの構成上、どちらかを捨てざるを得ませんが、たとえ残容量が被ってなくても、「これよりも有利な状況があきらかに存在する」と判断できる場合があれば、その情報は積極的に捨てていってよいのです。

さて・・・それは、どのような状況でしょうか?

たとえば、n 個目までの荷物を使った組み合わせで、次の2つの状況が残ったとします。

・残容量 100 で合計価値 200 ---- (1)
・残容量 80 で合計価値 180 ---- (2)

これ、(1) の方があきらかに有利ですよね。(2) は残容量が少ない上に、合計価値も負けています。

この後、n+1 個目以降の荷物についてさまざまな組み合わせがやってきますが、どう考えても (2) の組み合わせが最終的な答え(合計価値が最大になる組み合わせ)に至ることはありません。n+1 個目以降はまったく同じ組み合わせで、これを (1) と組み合わせた方が最後の合計価値は確実に大きくなります。

つまり、残容量の降順に情報をソートした場合、合計価値は単調に増加していくべきで、もしも、合計価値が減るような部分があれば、その情報はここで捨ててしまっても後の計算には影響しないのです。dp[n] を計算するごとにこのような「刈り込み」を行なって、ディクショナリーに記録された情報を積極的に減らせば、実行時間を改善することができるかも知れません。(このような不要な情報が発生しない様に、意図的に用意されたデータの場合もあるやも知れませんが・・・まずは、試してみましょう。)

n についてのループの先頭で、こんな感じの刈り込み処理を追加します。

  for n in range(2, N+1):
    # 刈り込み
    dpk = list(dp.keys())
    dpk.sort(reverse=True)
    pre_val = -1
    for available_weight in dpk: # 残り容量が減ると価値は上がるべき
      if dp[available_weight] <= pre_val:
        del dp[available_weight]
      else:
        pre_val = dp[available_weight]

この修正を加えたコードを提出した結果は・・・

atcoder.jp

はい。全問正解です。おめでとー。

まとめ

というわけで、古典的なナップサック問題を例にして、「DP(動的計画法)の心」をひもといてみました。

現実にDPを使用する際は、そうは言っても、「後ろから前を振り返る」的な「とんち問題」の発想で絶妙な遷移処理を組み立てる(思いつく)必要があります。さまざまな典型問題にあたって、パターンを掴むことも重要ですが、単なるパターンとして覚えるのではなく、今回やったように、

・あえて「しらみつぶし」のアルゴリズムを考えてみる
・1個の場合を考えて、それを踏まえて2個の場合を考えて・・・とデータ数が数個程度の場合を具体的に考察することで、適切な遷移処理のヒントを得る

と言った取り組みに時間をかけるのもよいのではないでしょうか。

おまけ

今回の問題については、最終的に「刈り込み」でディクショナリーに保存するキーの数を削減することで、ディクショナリーのコピーやキーについてのループにかかる時間を削減することができました。が、刈り込みに気づくまでは、キーが多いままの状態で、ディクショナリーのコピーなど、(キーが多い場合に)時間のかかる処理をできるだけ減らすための地味なチューニングも試みていました。まあ、こういうのって、いわゆる「早すぎる最適化」(というか実際には最適化にもなっていない)というやつで、刈り込みがあれば、結局は不要だったわけですが、せっかくなので何を考えたか紹介しておきます。

ディクショナリーのコピーを避ける

今回の実装では、n-1 時点での dp と n 時点の dp_n の2つのディクショナリーを使っていますが、下記の様にディクショナリー全体のコピーが発生しています。

    dp_n = copy.copy(dp)  # n個目を入れない場合

いっその事、1つのディクショナリーだけを使って、次の様に実装すれば、コピーを避けることができるのではないでしょうか?

    dpk = list(dp.keys()) # イテレーター dp.keys()についてのループ中にキーが変わるとランタイムエラーが発生するので、事前にリスト化する
    for w0 in dpk:
      if w[n] <= w0:      # n個目を入れる場合
        if w0-w[n] in dp.keys():
          dp[w0-w[n]] = max(dp[w0-w[n]], dp[w0] + v[n]) # ---- (*)
        else:
          dp[w0-w[n]] = dp[w0] + v[n] # ---- (*)

ただし、このままではうまく行きません。(*) 部分で n-1 時点の情報 wp[w0] を参照した際に、それより前にその情報が上書きされていると、正しい情報が得られなくなるからです。これは、キーをリスト化した直後に、昇順にソートしてループをまわせばOKです。

    dpk = list(dp.keys())
    dpk.sort()  # ハッシュの使い回しをするので未チェックの n-1 の情報を上書きしないように更新順序を工夫する

なぜなら、(*) で上書きされるのはキー w0-w[n] に対する値なので、その時点でループに使っているキー w0 よりもかならず小さい値になります。したがって、w0 を昇順に処理すれば、未処理の値を誤って上書きすることはありません。

ただし・・・この場合、キーをソートする時間が加わるので、コピーする方がまだ早いんじゃね?と言われればその通りです。

最終結果をスキャンして最大値を検索することを避ける

DP のループが終わったら、最後に、ディクショナリーに残された合計価値の中から、その最大値を選択して答えが得られます。

  print(max(dp.values()))

ただ、ここに残される合計価値の値は、DP のループ中に書き込まれたものですので、ループ中に書き込む値の最大値をトラッキングしておけば、最後に検索する手間をはぶけるかも知れません。

    dpk = list(dp.keys())
    for w0 in dpk:
      if w[n] <= w0:      # n個目を入れる場合
        if w0-w[n] in dp.keys():
          dp[w0-w[n]] = max(dp[w0-w[n]], dp[w0] + v[n])
        else:
          dp[w0-w[n]] = dp[w0] + v[n]
      max_val = max(max_val, dp[w0-w[n]]) # その時点での最大値を保持

  print(max_val) # あらためて最大値を検索しなくてよい

まぁ、O(2^N) の処理を N 回ループした後に、最後の一回の O(2^N) の検索を削っても大勢に影響はないんですけどね。。。。刈り込みによって、根本の O(2^N) をなくすことがやはり本質だというわけです。