めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

089 - Partitions and Inversions(★7)の解説

何の話かと言うと

atcoder.jp

上記の問題について、AVL 木(ソート済みのリストに対して、ソートを保った挿入・削除を O(\log M) で実行できるデータ構造)を用いた別解を紹介します。

再帰処理による解法

まずは直感的にわかりやすい再帰的な解法を考えます。

与えられた数列の最初の分割位置について場合分けをします。たとえば、最初の区間として可能なものが次の 3 つだとします。

(A_1)
(A_1,\,A_2)
(A_1,\,A_2,\,A_3)

それぞれの場合について、

(A_1) の場合: (A_2,\cdots,A_N) に対する解 dp[2]
(A_1,\,A_2) の場合: (A_3,\cdots,A_N) に対する解 dp[3]
(A_1,\,A_2,\,A_3) の場合: (A_4,\cdots,A_N) に対する解 dp[4]

があらかじめ求められているとすれば、

・dp[1] = dp[2] + dp[3] + dp[4]

が求める答えになります。

そこで一般に、dp[n] (n=1,...,N) を求める関数を solve(n) として実装すれば、再帰的に計算することができます。

最初の区間として可能な範囲は、先頭から要素を順にソート済みリストに(ソートを保ちながら)追加していき、追加する際に「後ろから飛び越える要素の数」を加えていき、これが K を超えたところで打ち切れば求まります。

bisect を利用してソート済みリストに要素を挿入していけば、挿入位置から「飛び越える要素数」を求めることができます。

この方針で実装すると下記になります。

import sys, bisect
from functools import lru_cache
sys.setrecursionlimit(10**6)
mod = 10**9 + 7

@lru_cache(maxsize=None)
def solve(n): # A[n:] に対する解を求める
  global N, K, A
  if n == N + 1:
    return 1

  result = 0
  L = []
  i = n
  skips = 0
  while i <= N:
    skips += len(L) - bisect.bisect_right(L, A[i])
    if skips > K:
      break
    bisect.insort_right(L, A[i])
    result += solve(i+1)
    result %= mod
    i += 1

  return result

def main(f):
  global N, K, A
  N, K = map(int, f.readline().split())
  A = [None] + list(map(int, f.readline().split()))
  print(solve(1))

main(sys.stdin)

計算手続きとしては問題ありませんが、このままでは TLE します。

再帰処理をやめて DP で実装する

・dp[1] = dp[2] + dp[3] + dp[4]

という関係を見ると、dp[N], dp[N-1], ... と後ろから順に求めれば、再帰処理をやめて DP で計算できることがわかります。個々の dp[n] の計算方法は先ほどと同じで、次のように実装できます。

import sys, copy, heapq, math, bisect
from collections import defaultdict, deque
mod = 10**9 + 7

def main(f):
  global N, K, A
  N, K = map(int, f.readline().split())
  A = [None] + list(map(int, f.readline().split()))

  dp = [None] * (N+2)
  dp[N+1] = 1
  for n in range(N, 0, -1):
    dp[n] = 0
    L = []
    i = n
    skips = 0
    while i <= N:
      skips += len(L) - bisect.bisect_right(L, A[i])
      if skips > K:
        break
      bisect.insort_right(L, A[i])
      dp[n] += dp[i+1]
      dp[n] %= mod
      i += 1

  print(dp[1])

main(sys.stdin)

これもまた正しい実装ですが、やはりまだ TLE します。

ここで計算量を確認してみましょう。まず、dp[N], dp[N-1], ..., dp[1] と計算していく O(N) のループが一番外側にあります。

次に、個々の dp[n] を求める際に、最初の区間として可能なものを調べていきますが、これは最悪ケースで O(n) のループとなるので、この段階で、全体として O(N^2) になります。これは確かに TLE します。

「最初の区間として可能な範囲」の計算を分離する

そこで、dp[n] の計算と、「最初の区間として可能な範囲」の計算を分離します。

仮に、n = 1,2,..., N について、

A_n から始めた場合に、最初の区間として可能な範囲は、(A_n,\cdots,A_m) である

という情報を十分高速に決定できたとします。上記の m は、リスト skip_to に保存されているものとします。(m = skip_to[n])

この場合、DP による dp[n] の計算は次のように簡単化されます。

  dp = [None] * (N+2)
  dp[N+1] = 1
  for n in range(N, 0, -1):
    dp[n] = 0
    for i in range(n+1, skip_to[n]+2):
      dp[n] += dp[i]

  print(dp[1])

ただし、これは、skip_to[n] の計算を外だししただけで、O(N^2) のループであることには変わりありません。

しかしながら・・・、上記は「部分和」の計算になっているので、総和の差分として求めれば、計算量を O(N) に減らすことができます。

  dp = [None] * (N+2)
  dp_sum = [None] * (N+2)
  dp[N+1] = 1
  dp_sum[N+1] = 1
  for n in range(N, 0, -1):
    if skip_to[n] + 2 > N+1:
      dp[n] = dp_sum[n+1]
    else:
      dp[n] = dp_sum[n+1] - dp_sum[skip_to[n]+2]
    dp[n] %= mod
    dp_sum[n] = dp_sum[n+1] + dp[n]
    dp_sum[n] %= mod

  print(dp[1])

というわけで、あとは、「最初の区間として可能な範囲」の計算をうまく実装できればOKです。

・先頭から要素を順にソート済みリストに(ソートを保ちながら)追加していき、追加する際に「後ろから飛び越える要素の数」を加えていき、これが K を超えたところで打ち切る

という前述の作戦を n = 1, 2, ..., N について実行するのはどうでしょうか?

ダメですね。

n = 1 に対して、お尻の部分を i = 2, 3, ... と伸ばしていって打ち切る。次に n = 2 に対して、再び、お尻の部分を i = 3, 4, ... と伸ばしていって・・・とすると、O(N^2) になってしまいます。

実は、このやり方の場合、n = 1 についてお尻を伸ばし切って打ち切った際に、再び、お尻を i = 2 にリセットする部分に無駄があります。A_1 から始まる場合のお尻と、A_2 から始まる場合のお尻を比べれば、A_2 から始める場合の方が、お尻の位置はより後ろに伸びるはずなので、毎回、お尻をリセットするのではなく、そのままお尻を伸ばし続ければよいのです。

ただし、A_1 を取り除くことによって、「飛び越えた数」をその分だけ減らす必要があります。具体的には、これまでに追加した要素で、A_1 より小さいものの個数を減らせばOKです。この個数も bisect で求めることにすれば、次のような実装が可能です。

  skip_to = [None] * (N+1)
  skips = 0
  L = [A[1]]
  i = 2
  for n in range(1, N+1):
    while True:
      if i == N+1:
        skip_to[n] = N
        break
 
      if len(L) == 0:
        L.append(A[i])
        i += 1
        continue
 
      delta = len(L) - bisect.bisect_right(L, A[i])
      if skips + delta > K:
        skip_to[n] = i - 1
        count = bisect.bisect_left(L, A[n])
        skips -= count
        del L[count]
        break
      skips += delta
      bisect.insort_right(L, A[i])
      i += 1

ここまでの実装を提出すると・・・

atcoder.jp

惜しいです・・・。1 つだけ TLE が残ります。。。。

bisect の効率化

skip_to を求める上記の実装は、表面的には、N 回のループで O(N) に見えますが、それぞれのループの中で、bisect を用いた処理を行なっています。bisect の処理は O(N\log N) なので、全体として O(N^2\log N) なんですね。残念。

しかしながら、bisect と同等の処理を O(\log N) で実施する AVL 木がありました。

enakai00.hatenablog.com

bisect を AVL 木に置き換えれば、無事に AC です。

atcoder.jp

AVL木の実装は、個人の方がブログで公開しているこちらの実装を使っています。

PythonでAVL木を競プロ用に実装した

この実装は同一の値を保存できないので、(上記の AC するコードでは)A_1, A_2, \cdots に少しずつ小さな値を加える工夫をしてあります。

081 - Friendly Group(★5)の解説

何の話かというと

atcoder.jp

上記の問題の別解です。

公式解説 では、身長と体重を 2 次元にプロットして、身長と体重を対称に扱う解法が示されています。あえてこれらを非対称に扱うことも可能です。

身長だけの場合

仮に、身長についてだけ考えればよいとすれば、割と簡単な問題です。事前にソートした上で、いわゆる「尺取り法」で実装できます。

  A.sort()
  head = tail = 0
  a_max = a_min = A[head]

  max_len = 0
  while True:
    while a_max - a_min > K:
      tail += 1
      a_min = A[tail]
    while a_max - a_min <= K:
      max_len = max(max_len, head - tail + 1)
      head += 1
      if head == N:
        print(max_len)
        return
      a_max = A[head]

体重を含めた場合

上記の尺取り法の中では、「身長の観点で許容されるグループ」が自然に網羅されています。許容されるグループの中で、人数が最大になるものを検索する形になります。

そこで、「身長の観点で許容されるグループ」のそれぞれに対して、そのメンバーから「体重の観点でも許容されるサブグループ」を構成するという方法が考えられます。

極端な方法としては、「身長の観点で許容されるグループ」のそれぞれに対して、体重についての尺取り法を実行するというやり方が考えられます。

atcoder.jp

上記の実装では、tail を固定した状態で、head を伸ばしていき、最大限に伸ばし切ったタイミングで、体重についての尺取り法を呼び出すという実装になっています。いくつか TLE していますが、方向性としては悪くなさそうです。

TLE の 1 つの要因として、体重についての尺取り法を呼び出す際に、対象となるサブグループのコピーが発生する点がちょっと重い気がします。

そこで、体重に関しては、配列 num_B[i] を用意して、今考えている「身長の観点で許容されるグループ」の中で、

・num_B[i] = 体重が i 〜 i + K の範囲の人数

をトラッキングする様にします。こうすれば、max(num_B) として、身長・体重の両方について許容される最大人数を得ることができます。max(B) を計算するタイミングを最小限になるようにうまく調整すると無事に AC します。

atcoder.jp

074 - ABC String 2(★6)の解説

何の話かというと

atcoder.jp

上記の問題は、公式解説にもあるように、具体例を試しながら隠された規則性を発見する必要があります。どのような試行錯誤で、正解に至れるのか、一例を紹介します。

まずはしらみつぶし

まずは、しらみつぶしで、文字列が変化する様子を観察します。キューを用いた全件探索を実装します。

import sys, copy
from collections import defaultdict, deque

def main(f):
  N = int(f.readline())
  S = f.readline().strip()
  q = deque()
  q.append((0, S)) # count, string
  c_max = 0
  while q:
    c, s = q.pop()
    print(c,':', s)
    c_max = max(c, c_max)
    for i in range(len(s)):
      tmp = s.translate(str.maketrans('abc', 'bca'))
      if s[i] == 'b':
        s = tmp[:i] + 'a' + s[i+1:]
        q.append((c+1, s))
      if s[i] == 'c':
        s = tmp[:i] + 'b' + s[i+1:]
        q.append((c+1, s))

  print(c_max)

with open('input.txt', 'r') as f:
  main(f)
# input
3
aba

# output
0 : aba
1 : baa
2 : aaa
2

これを見ると文字列の右から順に、a が増えていって、最後はすべて a になって終わります。

ここで変換のルールを思い出すと、ある位置の文字を c -> b -> a と変更すると共に、それより前の文字をローテーションします。逆にいうと、c -> b -> a と変更する対象より後ろは、何も変化しません。したがって、ある位置から後ろがすべて a になれば、その部分はもはや変化することができないのです。

ということは、できるだけ長く変更を続けるには、前の方の文字を優先的に変更して、後ろの方の文字はできるだけ変更しない方がよさそう、と気づきます。

イメージでいうと、n 桁の数字を減らしていく作業に似ています。減らす操作をできるだけ長く続けるには、1 ずつ減らすのがベストです。

(1) 下の桁の値を 1 ずつ減らしていき、それ以上減らせなくなった(つまり 0 になった)ら、1つ上の桁を 1 減らす。
(2) 上の桁を 1 減らすと、その下の桁の数字はすべて 9 にもどる。

これって、変更した文字の手前をすべてローテーションする、という作業を (2) に対応させると、非常によく似ていることがわかります。

前を優先的に変更する

というわけで、しらみつぶしではなく、前の方の文字を優先的に変更するというロジックで組み直してみます。

import sys

def main(f):
  N = int(f.readline())
  s = list(f.readline().strip())
  while s[-1] == 'a' and len(s) > 1:
    s.pop()

  c = 0
  while True:
    print(c, ':', s)
    update = False
    for i in range(len(s)):
      if s[i] == 'c':
        s[i] = 'b'
        c += 1
        update = True
        break
      if s[i] == 'b':
        s[i] = 'a'
        if s[-1] == 'a':
          s.pop()
        c += 1
        update = True
        break
      s[i] = 'b'
    if not update:
      break

  print(c)

with open('input.txt', 'r') as f:
  main(f)

ここでは、文字列をリストに変換して、(文字列をコピーするのではなく)リストを直接変更していくように実装を変えています。また、結果を見やすくするために、末尾の a は削除していきます。

# input
5
baaca

# output
0 : ['b', 'a', 'a', 'c']
1 : ['a', 'a', 'a', 'c']
2 : ['b', 'b', 'b', 'b']
3 : ['a', 'b', 'b', 'b']
4 : ['b', 'a', 'b', 'b']
5 : ['a', 'a', 'b', 'b']
6 : ['b', 'b', 'a', 'b']
7 : ['a', 'b', 'a', 'b']
8 : ['b', 'a', 'a', 'b']
9 : ['a', 'a', 'a', 'b']
10 : ['b', 'b', 'b']
11 : ['a', 'b', 'b']
12 : ['b', 'a', 'b']
13 : ['a', 'a', 'b']
14 : ['b', 'b']
15 : ['a', 'b']
16 : ['b']
17 : []
17

問題で与えられたサンプルに対して、確かに正しい答えが得られています。

さきほど、数字を減らしていく作業との類似を指摘しましたが、出力を下から上に読むと、数字を 1 ずつ増やしていく操作にも見えてきます。実際、17 -> 16 -> ... -> 2 までの流れは、a = 0, b = 1 とした 2 進数で、0 〜 1111 まで数える作業とぴったり一致しています。c がなければ、単純な 2 進数への置き換えで答えが得られることになります。

c を含む場合を探ってみる

では、c は、この計算ルールにどのように影響しているのでしょうか? c を複数入れた例で様子を見ます。

# input
5
acbcb

# output
0 : ['a', 'c', 'b', 'c', 'a', 'b']
1 : ['b', 'b', 'b', 'c', 'a', 'b']
2 : ['a', 'b', 'b', 'c', 'a', 'b']
3 : ['b', 'a', 'b', 'c', 'a', 'b']
4 : ['a', 'a', 'b', 'c', 'a', 'b']
5 : ['b', 'b', 'a', 'c', 'a', 'b']
6 : ['a', 'b', 'a', 'c', 'a', 'b']
7 : ['b', 'a', 'a', 'c', 'a', 'b']
8 : ['a', 'a', 'a', 'c', 'a', 'b']
9 : ['b', 'b', 'b', 'b', 'a', 'b']
10 : ['a', 'b', 'b', 'b', 'a', 'b']
11 : ['b', 'a', 'b', 'b', 'a', 'b']
12 : ['a', 'a', 'b', 'b', 'a', 'b']
13 : ['b', 'b', 'a', 'b', 'a', 'b']
14 : ['a', 'b', 'a', 'b', 'a', 'b']
15 : ['b', 'a', 'a', 'b', 'a', 'b']
16 : ['a', 'a', 'a', 'b', 'a', 'b']
17 : ['b', 'b', 'b', 'a', 'a', 'b']
18 : ['a', 'b', 'b', 'a', 'a', 'b']
19 : ['b', 'a', 'b', 'a', 'a', 'b']
20 : ['a', 'a', 'b', 'a', 'a', 'b']
21 : ['b', 'b', 'a', 'a', 'a', 'b']
22 : ['a', 'b', 'a', 'a', 'a', 'b']
23 : ['b', 'a', 'a', 'a', 'a', 'b']
24 : ['a', 'a', 'a', 'a', 'a', 'b']
25 : ['b', 'b', 'b', 'b', 'b']
26 : ['a', 'b', 'b', 'b', 'b']
27 : ['b', 'a', 'b', 'b', 'b']
28 : ['a', 'a', 'b', 'b', 'b']
29 : ['b', 'b', 'a', 'b', 'b']
30 : ['a', 'b', 'a', 'b', 'b']
31 : ['b', 'a', 'a', 'b', 'b']
32 : ['a', 'a', 'a', 'b', 'b']
33 : ['b', 'b', 'b', 'a', 'b']
34 : ['a', 'b', 'b', 'a', 'b']
35 : ['b', 'a', 'b', 'a', 'b']
36 : ['a', 'a', 'b', 'a', 'b']
37 : ['b', 'b', 'a', 'a', 'b']
38 : ['a', 'b', 'a', 'a', 'b']
39 : ['b', 'a', 'a', 'a', 'b']
40 : ['a', 'a', 'a', 'a', 'b']
41 : ['b', 'b', 'b', 'b']
42 : ['a', 'b', 'b', 'b']
43 : ['b', 'a', 'b', 'b']
44 : ['a', 'a', 'b', 'b']
45 : ['b', 'b', 'a', 'b']
46 : ['a', 'b', 'a', 'b']
47 : ['b', 'a', 'a', 'b']
48 : ['a', 'a', 'a', 'b']
49 : ['b', 'b', 'b']
50 : ['a', 'b', 'b']
51 : ['b', 'a', 'b']
52 : ['a', 'a', 'b']
53 : ['b', 'b']
54 : ['a', 'b']
55 : ['b']
56 : []
56

出力結果を下から上に見ると、右から順に c が現れていることがわかります。一番右の c が現れる 1 つ前(出力で言うと、1行下)までを見ると、さきほどの2進数計算で 1 ずつ増えていって、

・c より前(c の位置を含む)は、すべて b
・c より後(c の位置は含まない)は、もとの文字列のまま

という値にたどり着いています。したがって、「上記で決まる 2 進数 + 1」回の(逆向きの)操作で、1 つ目の c が出現する所までいけるのです。

では、その上の行はどうなっているでしょうか? 一番右の c 以降は変化しませんので、c より前だけを見ればよいことに気が付きます。そして、一番右の c 以降を切り捨てて考えれば、次の c が現れるまで、同じルールの計算になっています。

つまり、右から順に c が出現する位置を探しながら、c が出現するごとに、先の「上記で決まる 2 進数 + 1」の値を加えていけばOKとわかります。さっくり実装すると、こんな感じですね。

import sys

def main(f):
  N = int(f.readline())
  s = list(f.readline().strip())

  c = 0
  bits = 0
  for i in range(len(s)-1, -1, -1):
    if s[i] == 'c':
      for j in range(i, -1, -1):
        bits = bits * 2 + 1
      c += bits + 1
      bits = 0
      continue
    if s[i] == 'a':
      bits = bits * 2 + 0
    if s[i] == 'b':
      bits = bits * 2 + 1
  c += bits
  print(c)

with open('input.txt', 'r') as f:
  main(f)

atcoder.jp