めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

039 - Tree Distance(★5)の解説

何の話かと言うと

atcoder.jp

この問題をネタに、「頂点についてのループと辺についてのループ」の話をします。・・・・というのは、ちょっと無理矢理感があって、ぶっちゃけは、別解を紹介したいだけです。

辺についてのループ

一般に、グラフの問題は、「N 個の頂点について順番に処理する」という発想で解きますが、この問題に関して言うと、(ちょっとした発想の転換で)「N-1 個の辺について順番に処理する」という方法で、驚くほど簡単に解くことができます。詳しくは、公式解説を参照ください。

頂点についてのループでは解けないの?

なのですが、「頂点ごとに考える」という普通の発想でも解けなくはないので、別解として、こちらの解法を紹介します。

この問題では、任意の2頂点の全ての組みについて、それらの最短経路を考える必要がありますが、頂点の数(N=10^5)を考えると、各頂点を1回ずつしか処理できそうにありません。それでは、たとえば、深さ優先探索で、木を下から順番にたどった時に、ある特定の頂点について、どのような経路をカバーすることができるでしょうか?

一般に、木構造グラフでは、任意の2点を結ぶ最短経路は、「(共通の親まで)登って、そこからまた下る」という形になりますが、少なくとも「親まで登る経路の距離」は、深さ優先探索のDPでまとめあげることができます。つまり、

・dp[i] = "頂点 i のすべての子孫 j についての「j から i に至る距離」の合計"

が計算できます。もうちょっと厳密に言うと、「i を頂点とするサブツリーのノード数」の情報が補助的に必要になるので、

・dp[i] = (c, d) : i を頂点とするサブツリーについての (ノード数, 「子孫 j から親 i の距離」の合計)

という形になります。

import sys
from collections import defaultdict, deque

def main(f):
  N = int(f.readline())
  link = [[] for _ in range(N+1)]
  for _ in range(N-1):
    a, b = map(int, f.readline().split())
    link[a].append(b)
    link[b].append(a)

# dp[i] = (c, d) : i を頂点とするサブツリーについての (ノード数, 「子孫 j から親 i の距離」の合計)
  dp = [(0, 0)] * (N+1)

  q = deque()
  q.append((1, 0)) # (node, parent)
  while q:
    i, p = q.pop()
    if i < 0:
      c, d = 0, 0
      for j in link[-i]:
        if j == p:
          continue
        c_j, d_j = dp[j]
        c += c_j
        d += d_j + 1*(c_j) # j->i の距離 1 を各子孫について加える
      dp[-i] = (c+1, d)
      continue

    q.append((-i, p)) # 帰りがけ順で処理するための番兵

    for j in link[i]:
      if j == p:
        continue
      q.append((j, i))

  print(dp[1:])
main(sys.stdin)
# input
4
1 2
2 3
1 4

# output
[(4, 4), (2, 1), (1, 0), (1, 0)]

得られた距離をすべて足しあげれば、「登って終わり」というパターンの経路についての合計が得られます。

「登ってから下る」パターンはどうするの?

もちろん、このままでは、まだ、「すべての頂点の組み合わせ」がカバーできていません。頂点 i に注目した際に、

・頂点 i まで登って終わり

という経路が網羅されている状況ですので、これに加えて、

・頂点 i の直下の子ノード j < k について、「サブツリー j 内のノードから i まで登って、サブツリー k 内のノードまで下る」

という経路を加える必要があります。これをすべての頂点 i について網羅すれば、あらゆる経路がカバーされますよね・・・?

 \displaystyle \{全経路\} = \bigcup_i \{i まで登る経路\} \cup \{i を折り返し点とする経路\}

(i を折り返し点とするものは、左右の対称性があるので、一方のみを採用するとしてください。)

で、ここが頭の使い所なのですが・・・

実は、すでに計算ずみの dp[i] を利用すると、「i を折り返し点とする全経路の距離の合計」が簡単に計算できます。

今、頂点 i の下に直接ぶら下がる 2 つの頂点 j, k について、

・j の子孫 → i → k の子孫

という経路について、前半と後半を分けて考えます。

・前半:j の子孫 → i
・後半:i → k の子孫

ここで、

・(c_j, d_j) = dp[j]

と置くと、k の子孫を 1 つ固定した場合、これに至る経路は c_j(j 以下のサブツリーのノード数)本ありますが、これらの前半の経路の総距離は、d_j + 1 * c_j になります。(j の子孫 -> j の各経路について、j->i の距離 1 を加えている点に注意してください。)

したがって、c_k 個ある(k 自身を含めた)k のすべての子孫を考慮すると、前半の経路の総計は、

・(d_j + 1 * c_j) * c_k

で計算されます。

同様に、

・(c_k, d_k) = dp[k]

と置いて、j の子孫を 1 つ固定した場合、ここからスタートする経路は c_k(k 以下のサブツリーのノード数)本あり、これらの後半の経路の総距離は、d_k + 1 * c_k になります。したがって、c_j 個ある(j 自身を含めた)j のすべての子孫を考慮すると、後半の経路の総計は、

・c_j * (d_k + 1 * c_k)

で計算されます。

前半と後半を合わせると、i で折り返す経路の総距離は、

・(d_j + c_j) * c_k + c_j * (d_k + c_k)

とまとまります。i の子ノードが 3 つ以上ある場合は、任意の 2 個の組み合わせについて、すべて合計すればOKです。

ここまでをまとめると、次の実装になります。

import sys
from collections import defaultdict, deque

def main(f):
  N = int(f.readline())
  link = [[] for _ in range(N+1)]
  for _ in range(N-1):
    a, b = map(int, f.readline().split())
    link[a].append(b)
    link[b].append(a)

# dp[n] = (c, d) : n を頂点とするサブツリーについての (ノード数, 「子孫 i から n の距離」の合計)
  dp = [(0, 0)] * (N+1)
  dists = 0

  q = deque()
  q.append((1, 0)) # (node, parent)
  while q:
    i, p = q.pop()
    if i < 0:
      c, d = 0, 0
      for j in link[-i]:
        if j == p:
          continue
        c_j, d_j = dp[j]
        c += c_j
        d += d_j + 1*(c_j) # j->i の距離 1 を各子孫について加える
      dp[-i] = (c+1, d)

      links = []
      for j in link[-i]:
        if j == p:
          continue
        links.append(j)

      for index1 in range(len(links)):
        j = links[index1]
        c_j, d_j = dp[j]
        for index2 in range(index1+1, len(links)):
          k = links[index2]
          c_k, d_k = dp[k]
          dists += (d_j+c_j)*c_k + c_j*(d_k+c_k) # k -> i -> j の距離
      continue

    q.append((-i, p)) # 帰りがけ順で処理するための番兵

    for j in link[i]:
      if j == p:
        continue
      q.append((j, i))

  for n in range(1, N+1):
    dists += dp[n][1] # k -> i の距離をすべて加える
  print(dists)

main(sys.stdin)

もう一歩だけ最適化

ただし、上記のコードを提出すると、一部のケースで TLE になります。

i の子ノードが大量にある場合、「任意の2つの組み合わせ」の場合の数が大きくなるためです。ただ、実際の計算内容をよく見ると、一方の子ノード j を固定すると、もう一方の子ノード k からの寄与は、c_k および d_k それぞれの合計としてまとめることができます。そこで、これらの部分和を事前に計算しておけば、i ごとに個別に k のループを回す必要がなくなります。これを実装したのが、最終的なこちらの解答になります。

atcoder.jp