めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

089 - Partitions and Inversions(★7)の解説

何の話かと言うと

atcoder.jp

上記の問題について、AVL 木(ソート済みのリストに対して、ソートを保った挿入・削除を O(\log M) で実行できるデータ構造)を用いた別解を紹介します。

再帰処理による解法

まずは直感的にわかりやすい再帰的な解法を考えます。

与えられた数列の最初の分割位置について場合分けをします。たとえば、最初の区間として可能なものが次の 3 つだとします。

(A_1)
(A_1,\,A_2)
(A_1,\,A_2,\,A_3)

それぞれの場合について、

(A_1) の場合: (A_2,\cdots,A_N) に対する解 dp[2]
(A_1,\,A_2) の場合: (A_3,\cdots,A_N) に対する解 dp[3]
(A_1,\,A_2,\,A_3) の場合: (A_4,\cdots,A_N) に対する解 dp[4]

があらかじめ求められているとすれば、

・dp[1] = dp[2] + dp[3] + dp[4]

が求める答えになります。

そこで一般に、dp[n] (n=1,...,N) を求める関数を solve(n) として実装すれば、再帰的に計算することができます。

最初の区間として可能な範囲は、先頭から要素を順にソート済みリストに(ソートを保ちながら)追加していき、追加する際に「後ろから飛び越える要素の数」を加えていき、これが K を超えたところで打ち切れば求まります。

bisect を利用してソート済みリストに要素を挿入していけば、挿入位置から「飛び越える要素数」を求めることができます。

この方針で実装すると下記になります。

import sys, bisect
from functools import lru_cache
sys.setrecursionlimit(10**6)
mod = 10**9 + 7

@lru_cache(maxsize=None)
def solve(n): # A[n:] に対する解を求める
  global N, K, A
  if n == N + 1:
    return 1

  result = 0
  L = []
  i = n
  skips = 0
  while i <= N:
    skips += len(L) - bisect.bisect_right(L, A[i])
    if skips > K:
      break
    bisect.insort_right(L, A[i])
    result += solve(i+1)
    result %= mod
    i += 1

  return result

def main(f):
  global N, K, A
  N, K = map(int, f.readline().split())
  A = [None] + list(map(int, f.readline().split()))
  print(solve(1))

main(sys.stdin)

計算手続きとしては問題ありませんが、このままでは TLE します。

再帰処理をやめて DP で実装する

・dp[1] = dp[2] + dp[3] + dp[4]

という関係を見ると、dp[N], dp[N-1], ... と後ろから順に求めれば、再帰処理をやめて DP で計算できることがわかります。個々の dp[n] の計算方法は先ほどと同じで、次のように実装できます。

import sys, copy, heapq, math, bisect
from collections import defaultdict, deque
mod = 10**9 + 7

def main(f):
  global N, K, A
  N, K = map(int, f.readline().split())
  A = [None] + list(map(int, f.readline().split()))

  dp = [None] * (N+2)
  dp[N+1] = 1
  for n in range(N, 0, -1):
    dp[n] = 0
    L = []
    i = n
    skips = 0
    while i <= N:
      skips += len(L) - bisect.bisect_right(L, A[i])
      if skips > K:
        break
      bisect.insort_right(L, A[i])
      dp[n] += dp[i+1]
      dp[n] %= mod
      i += 1

  print(dp[1])

main(sys.stdin)

これもまた正しい実装ですが、やはりまだ TLE します。

ここで計算量を確認してみましょう。まず、dp[N], dp[N-1], ..., dp[1] と計算していく O(N) のループが一番外側にあります。

次に、個々の dp[n] を求める際に、最初の区間として可能なものを調べていきますが、これは最悪ケースで O(n) のループとなるので、この段階で、全体として O(N^2) になります。これは確かに TLE します。

「最初の区間として可能な範囲」の計算を分離する

そこで、dp[n] の計算と、「最初の区間として可能な範囲」の計算を分離します。

仮に、n = 1,2,..., N について、

A_n から始めた場合に、最初の区間として可能な範囲は、(A_n,\cdots,A_m) である

という情報を十分高速に決定できたとします。上記の m は、リスト skip_to に保存されているものとします。(m = skip_to[n])

この場合、DP による dp[n] の計算は次のように簡単化されます。

  dp = [None] * (N+2)
  dp[N+1] = 1
  for n in range(N, 0, -1):
    dp[n] = 0
    for i in range(n+1, skip_to[n]+2):
      dp[n] += dp[i]

  print(dp[1])

ただし、これは、skip_to[n] の計算を外だししただけで、O(N^2) のループであることには変わりありません。

しかしながら・・・、上記は「部分和」の計算になっているので、総和の差分として求めれば、計算量を O(N) に減らすことができます。

  dp = [None] * (N+2)
  dp_sum = [None] * (N+2)
  dp[N+1] = 1
  dp_sum[N+1] = 1
  for n in range(N, 0, -1):
    if skip_to[n] + 2 > N+1:
      dp[n] = dp_sum[n+1]
    else:
      dp[n] = dp_sum[n+1] - dp_sum[skip_to[n]+2]
    dp[n] %= mod
    dp_sum[n] = dp_sum[n+1] + dp[n]
    dp_sum[n] %= mod

  print(dp[1])

というわけで、あとは、「最初の区間として可能な範囲」の計算をうまく実装できればOKです。

・先頭から要素を順にソート済みリストに(ソートを保ちながら)追加していき、追加する際に「後ろから飛び越える要素の数」を加えていき、これが K を超えたところで打ち切る

という前述の作戦を n = 1, 2, ..., N について実行するのはどうでしょうか?

ダメですね。

n = 1 に対して、お尻の部分を i = 2, 3, ... と伸ばしていって打ち切る。次に n = 2 に対して、再び、お尻の部分を i = 3, 4, ... と伸ばしていって・・・とすると、O(N^2) になってしまいます。

実は、このやり方の場合、n = 1 についてお尻を伸ばし切って打ち切った際に、再び、お尻を i = 2 にリセットする部分に無駄があります。A_1 から始まる場合のお尻と、A_2 から始まる場合のお尻を比べれば、A_2 から始める場合の方が、お尻の位置はより後ろに伸びるはずなので、毎回、お尻をリセットするのではなく、そのままお尻を伸ばし続ければよいのです。

ただし、A_1 を取り除くことによって、「飛び越えた数」をその分だけ減らす必要があります。具体的には、これまでに追加した要素で、A_1 より小さいものの個数を減らせばOKです。この個数も bisect で求めることにすれば、次のような実装が可能です。

  skip_to = [None] * (N+1)
  skips = 0
  L = [A[1]]
  i = 2
  for n in range(1, N+1):
    while True:
      if i == N+1:
        skip_to[n] = N
        break
 
      if len(L) == 0:
        L.append(A[i])
        i += 1
        continue
 
      delta = len(L) - bisect.bisect_right(L, A[i])
      if skips + delta > K:
        skip_to[n] = i - 1
        count = bisect.bisect_left(L, A[n])
        skips -= count
        del L[count]
        break
      skips += delta
      bisect.insort_right(L, A[i])
      i += 1

ここまでの実装を提出すると・・・

atcoder.jp

惜しいです・・・。1 つだけ TLE が残ります。。。。

bisect の効率化

skip_to を求める上記の実装は、表面的には、N 回のループで O(N) に見えますが、それぞれのループの中で、bisect を用いた処理を行なっています。bisect の処理は O(N\log N) なので、全体として O(N^2\log N) なんですね。残念。

しかしながら、bisect と同等の処理を O(\log N) で実施する AVL 木がありました。

enakai00.hatenablog.com

bisect を AVL 木に置き換えれば、無事に AC です。

atcoder.jp

AVL木の実装は、個人の方がブログで公開しているこちらの実装を使っています。

PythonでAVL木を競プロ用に実装した

この実装は同一の値を保存できないので、(上記の AC するコードでは)A_1, A_2, \cdots に少しずつ小さな値を加える工夫をしてあります。