ナップサック問題との類似性で考える
「花を取り除く」という立て付けですが、「残しておく花を選ぶ」と考えると、「価値が最大になるように荷物を選ぶ」というナップサック問題に類似していることに気がつきます。
まずは、下記の記事でナップサック問題の(ディクショナリーを使った)解法を確認してください。
enakai00.hatenablog.com
enakai00.hatenablog.com
これを踏まえると、
dp[h] = 最大の高さ h で実現できる最大価値
というディクショナリーを用意して、花を順番に加えていくという解法を思いつきます。ナップサック問題とほぼ同じロジックなので、詳細は省いてコードをお見せするとこちらになります。
import sys, copy from collections import defaultdict, deque def main(f): N = int(f.readline()) H = [None] + list(map(int, f.readline().split())) A = [None] + list(map(int, f.readline().split())) # n = 1 dp = defaultdict(int) # dp[h] : 最大の高さ h で実現できる最大価値 dp[0] = 0 dp[H[1]] = A[1] # n = 2 for n in range(2, N+1): # 刈り込み # h が増えれば価値も増えるべき hs = list(dp.keys()) hs.sort() pre = -1 for h in hs: if dp[h] <= pre: del dp[h] else: pre = dp[h] dp_n = copy.copy(dp) # n 個目を加えない場合 for h in dp.keys(): if H[n] > h: dp_n[H[n]] = max(dp_n[H[n]], dp[h] + A[n]) dp = dp_n #print(dp) print(max(dp.values())) main(sys.stdin)
なのですが、このコードを提出すると、半数強のケースで TLE します。
方向性としては悪くなさそうですが、さらなる効率化が必要です。
ループ処理の無駄を見つける
上記のコードのどこに無駄があるのでしょうか・・・。無駄なループ(本当はループしなくても実現できる処理をわざわざループで処理している部分)を探すのがテクニックの一つです。
たとえば、このループを考えます。
for h in dp.keys(): if H[n] > h: dp_n[H[n]] = max(dp_n[H[n]], dp[h] + A[n])
新しく追加する花の高さ H[n] に対して、H[n] > h となる h をループで探していますが、最終的に残るのは、dp[h] + A[n](すなわち dp[h])が最大になるものだけです。この「本当に必要な h」をループを回さずに一発で発見できないものでしょうか。
で・・・ここが絶妙なところなのですが・・・
上記のコードでは、新しい花を追加するごとに、dp に刈り込み処理を入れていますが、これにより、
・dp.keys() を昇順にソートすれば、対応する dp の値も昇順ソートされる
という著しい特徴があります。
したがって、hs = 「dp.keys() をソートしたリスト」として、(ソート状態を保ったまま)H[n] を挿入できる位置、すなわち、hs[index-1] < H[n] < hs[index] を満たす index を発見すれば、h = hs[index-1] が求める h になります。このような挿入位置は、バイナリーサーチ(Python であれば、bisect.bisect 関数)で高速に検索することができます。
バイナリーサーチを活用した実装
これで、愚直なループをより高速なバイナリーサーチに置き換えられることがわかりましたが、このためには、dp.keys() をソートしたリスト hs をメンテナンスする必要があります。
そこで、ディクショナリー dp のキーとバリューを別々のリストに分解してしまいます。
# 最大の高さを hs[i] とした場合の最大価値が dp[i] hs = [] # 昇順ソートを保つ dp = [] # hs がソートされていると(刈り込みによって)こちらもソートされるはず
たとえば、n = 1 のケースのデータは次の様に代入されることになります。
# n = 1 hs.append(0) hs.append(H[1]) dp.append(0) dp.append(A[1])
次に、n = 2,...,N では、H[n] > h を満たす最大の h を探して、最大の高さを H[n] とする場合の最大価値として、dp[h] + A[n] を記録することができます。具体的には、次のようなコードになります。
# n = 2,...,N for n in range(2, N+1): index = bisect.bisect(hs, H[n]) # 高さによる挿入位置 val_new = dp[index-1] + A[n] hs.insert(index, H[n]) # 高さのオーダーで挿入 ---- (1) dp.insert(index, val_new) # 高さのオーダーで挿入 ---- (2) ...
bisect.bisect(hs, H[n]) は、hs[index-1] < H[n] < hs[index] を満たす index を返します。
この時点では、hs はソート状態が保たれていますが、ds はソート状態が保たれなくなります。ただし、この後、刈り込みを入れる事で ds もソート状態にもどります。高さが上がれば、最大価値も上がるべきなので、今、ds に挿入した val_new に対して、これよりも後ろに val_new 未満の値があれば、それらは削除して構いません。これにより、ds もソートされた状態になります。
では、削除するべき要素はどうやって発見するかというと・・・
これもうまくできているのですが、(1)(2) の挿入を行う前に、dp に対して、バイナリサーチをかけると、
・dp[index2-1] < val_new < dp[index2]
を満たす index2 が発見できます。したがって、リスト dp の中で、「実際に val_new を挿入した位置から index2-1 の範囲」にある値が削除対象になります。
というわけで、刈り込みについても、ループを回さずに実施できてしまいます。
# n = 2,...,N for n in range(2, N+1): index = bisect.bisect(hs, H[n]) # 高さによる挿入位置 val_new = dp[index-1] + A[n] index2 = bisect.bisect(dp, val_new) # 価値による挿入位置 hs.insert(index, H[n]) # 高さのオーダーで挿入 dp.insert(index, val_new) # 高さのオーダーで挿入 if index+1 < index2+1: # index+1 〜 index2 のデータは不要(刈り込み) del hs[index+1:index2+1] del dp[index+1:index2+1] # これで dp も sorted になる
index2 を検索した時点では、index2-1 までのデータが不要でしたが、その後、val_new を挿入しているので、実際に削除するのは、val_new の挿入位置 index に対して、index+1 〜 index2 の範囲になります。
できてしまえば、実にあっさりですが、これですべてのケースがパスできます。(これに気づくまで、実際には数時間かかってますが。。。。)