めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

P - Independent Set の解説

何の話かと言うと

atcoder.jp

この問題をネタに、木構造データに関する問題の基本を説明します。(DPというよりは、木構造データの取り扱いがメインの問題ですね。)

木構造データの取り扱い

データとして与えられるのはノード間のリンク情報のみで、どちらが親ノードかは決められていません。このような場合は、どちらが親かを気にせずに、すべての子ノードとして記録しておきます。

  children = [[] for _ in range(N+1)]
  for _ in range(N-1):
    x, y = list(map(int, f.readline().split()))
    children[y].append(x)
    children[x].append(y)

この例では children[i] には、ノード i に結合したすべてのノードが収められています。

そして、あるノード i から子ノードを利用する関数を呼ぶ際は、ノード i 自身の親ノードを補足情報として渡します。たとえば・・・、

def tree_search(i, parent=-1):
  global children

  # 「上から順」の場合は、ここでノード i についての処理をする

  for j in children[i]:
    if j == parent: # 親ノードはスキップする
      continue
    tree_search(j, i) # i を親ノードとして、子ノード j を処理する

  # 「下から順」の場合は、ここでノード i についての処理をする

上記の関数を tree_search(1) として呼ぶと、ノード 1 をルートとした深さ優先探索が行われます。

再帰による実装

木構造データ問題は、まずは、(直感的にわかりやすい)再帰的な実装で、自分の考え方が間違っていないことを確認するのがよいでしょう。この問題の場合は、まず、自分が黒の場合と白の場合で、次の様な場合分けが発生します。

・自分が黒の場合:子ノードは、白の場合だけ
・自分が白の場合:子ノードは、黒の場合と白の場合がある

したがって、子ノード j について、

・j が白の場合の j 以下の部分木の場合の数 b[j] ---- (1)
・j が黒の場合の j 以下の部分木の場合の数 w[j] ---- (2)

が分かっていれば、親ノード i 以下の部分木の場合の数は、

・i が黒の場合:b[i] = \prod_j w[j]
・i が白の場合:w[i] = \prod_j (w[j]+b[j])

と決まります。この関係を再帰的に呼び出す関数を作ればOKです。ルートノードを 1 として、b[1] + w[1] が最終的な答えになります。

def solve(i, black=True, parent=-1):
  global children, mod
  if children[i] == [parent]: # 子ノードが無い場合
    return 1

  result = 1
  for j in children[i]:
    if j == parent:
      continue
    c = solve(j, False, i) # white
    if not black:
      c += solve(j, True, i) # black
    result *= c
    result %= mod
  return result

この関数は、i 以下の部分木の場合の数を計算します。オプション black で i が黒の場合と白の場合を指定できます。i に子ノードが無い場合は、場合の数は 1 になります。(i 自身の色は指定されているので)

これを用いて、solve(1, True) + solve(1, False) で答えが得られます。関数をメモ化したものを提出してみましょう。

atcoder.jp

なんと!最後の1ケースだけが TLE しています。おしいです。

メモ化せずに Tree order で遷移

メモ化で高速化されるとは言え、キャッシュから検索する時間のオーバーヘッドがやはり問題なのでしょうか・・・。今回計算するべきものは、前述の (1)(2) と分かっていますので、これらを(DP っぽく)dp_b[j]、dp_w[j] として、メモ化にたよらずに、下から順番にまじめに遷移計算していくことにしましょう。木構造データについて「下から準備に計算する」際は、先ほどの深さ優先探索の関数が利用できます。

def tree_search(i, parent=-1):
  global children, dp_b, dp_w
  for j in children[i]:
    if j == parent: # 親ノードはスキップする
      continue
    tree_search(j, i) # i を親ノードとして、子ノード j を処理する

  dp_b[i], dp_w[i] = calc_dp(i, parent) # dp_b[i], dp_w[i] を計算する関数

こうすれば、calc_dp(i, parent) は下から順に呼び出されるので、calc_dp(i, parent) の実行時は、ノード i の子ノード j について、dp_b[j]、dp_w[j] が確定していることが保証されます。

「dp_b[i], dp_w[i] を計算する関数」の中身はこんな感じです。

def calc_dp(i, parent=-1):
  global children, dp_b, dp_w, mod
  if children[i] == [parent]: # bottom case
    return 1, 1

  result_w = 1
  result_b = 1
  for j in children[i]:
    if j == parent:
      continue
    result_w *= (dp_w[j] + dp_b[j])
    result_b *= dp_w[j]
    result_w %= mod
    result_b %= mod

  return result_b, result_w

これで、すべてのケースにパスするコードができました。

atcoder.jp

再帰呼び出しを使わずに実装する

上記のコードでは、深さ優先探索で下から順に処理する際に、関数の再帰呼び出しを利用しています。再帰呼び出し(関数呼び出し)のオーバーヘッドが問題になる場合は、次の様にキューを用いた深さ優先探索を利用することができます。ただし、キューを用いた実装では、「上から順に処理する」ことしかできないので(番兵を仕込む方法もありますが、それは後ほど・・・)、この例では、上から順にノードを詰めた(また別の)キュー tree_order を構築したのちに、tree_order を右から順に処理する事で、「下から順」の処理を実装しています。

  # 下から順に処理するためのキュー tree_order を用意する
  tree_order = deque()
  q = deque()
  q.append((1, -1))
  while q:
    i, parent = q.pop()
    tree_order.append((i, parent))
    for j in children[i]:
      if j == parent:
        continue
      q.append((j, i))

  # tree_order を右から順に処理する
  while tree_order:
    i, parent = tree_order.pop()
    dp_b[i], dp_w[i] = calc_dp(i, parent) # dp_b[i], dp_w[i] を計算する関数

  print((dp_b[1] + dp_w[1]) % mod)

再帰呼び出しに比べると、すこーしだけ実行時間が短くなります。

atcoder.jp

一般には、次の様なテンプレートになるでしょう。(キューを用いた「上から順」の処理)

  q = deque()
  q.append((1, -1))
  while q:
    i, parent = q.pop()

    # ここでノード i に関する処理をする

    for j in children[i]:
      if j == parent:
        continue
      q.append((j, i))

これが利用できる、「木を上から順に処理する」パターンの問題には、下記があります。

atcoder.jp

おまけ:番兵を用いて下から順に処理する方法

深さ優先探索のループの中で、直接に「下から順」の処理をするには、次のような番兵を用いる方法もあります。

  while q:
    i, parent = q.pop()
    if i < 0:
      # ここでノード i に関する処理をする(i -> -i に置き換え)
      continue

    q.append((-i, parent)) # 下が終わった後に引っ掛けるための番兵

    for j in children[i]:
      if j == parent:
        continue
      q.append((j, i))

atcoder.jp