何の話かと言うと
この問題をネタに、「頂点についてのループと辺についてのループ」の話をします。・・・・というのは、ちょっと無理矢理感があって、ぶっちゃけは、別解を紹介したいだけです。
辺についてのループ
一般に、グラフの問題は、「N 個の頂点について順番に処理する」という発想で解きますが、この問題に関して言うと、(ちょっとした発想の転換で)「N-1 個の辺について順番に処理する」という方法で、驚くほど簡単に解くことができます。詳しくは、公式解説を参照ください。
頂点についてのループでは解けないの?
なのですが、「頂点ごとに考える」という普通の発想でも解けなくはないので、別解として、こちらの解法を紹介します。
この問題では、任意の2頂点の全ての組みについて、それらの最短経路を考える必要がありますが、頂点の数()を考えると、各頂点を1回ずつしか処理できそうにありません。それでは、たとえば、深さ優先探索で、木を下から順番にたどった時に、ある特定の頂点について、どのような経路をカバーすることができるでしょうか?
一般に、木構造グラフでは、任意の2点を結ぶ最短経路は、「(共通の親まで)登って、そこからまた下る」という形になりますが、少なくとも「親まで登る経路の距離」は、深さ優先探索のDPでまとめあげることができます。つまり、
・dp[i] = "頂点 i のすべての子孫 j についての「j から i に至る距離」の合計"
が計算できます。もうちょっと厳密に言うと、「i を頂点とするサブツリーのノード数」の情報が補助的に必要になるので、
・dp[i] = (c, d) : i を頂点とするサブツリーについての (ノード数, 「子孫 j から親 i の距離」の合計)
という形になります。
import sys from collections import defaultdict, deque def main(f): N = int(f.readline()) link = [[] for _ in range(N+1)] for _ in range(N-1): a, b = map(int, f.readline().split()) link[a].append(b) link[b].append(a) # dp[i] = (c, d) : i を頂点とするサブツリーについての (ノード数, 「子孫 j から親 i の距離」の合計) dp = [(0, 0)] * (N+1) q = deque() q.append((1, 0)) # (node, parent) while q: i, p = q.pop() if i < 0: c, d = 0, 0 for j in link[-i]: if j == p: continue c_j, d_j = dp[j] c += c_j d += d_j + 1*(c_j) # j->i の距離 1 を各子孫について加える dp[-i] = (c+1, d) continue q.append((-i, p)) # 帰りがけ順で処理するための番兵 for j in link[i]: if j == p: continue q.append((j, i)) print(dp[1:]) main(sys.stdin)
# input 4 1 2 2 3 1 4 # output [(4, 4), (2, 1), (1, 0), (1, 0)]
得られた距離をすべて足しあげれば、「登って終わり」というパターンの経路についての合計が得られます。
「登ってから下る」パターンはどうするの?
もちろん、このままでは、まだ、「すべての頂点の組み合わせ」がカバーできていません。頂点 i に注目した際に、
・頂点 i まで登って終わり
という経路が網羅されている状況ですので、これに加えて、
・頂点 i の直下の子ノード j < k について、「サブツリー j 内のノードから i まで登って、サブツリー k 内のノードまで下る」
という経路を加える必要があります。これをすべての頂点 i について網羅すれば、あらゆる経路がカバーされますよね・・・?
(i を折り返し点とするものは、左右の対称性があるので、一方のみを採用するとしてください。)
で、ここが頭の使い所なのですが・・・
実は、すでに計算ずみの dp[i] を利用すると、「i を折り返し点とする全経路の距離の合計」が簡単に計算できます。
今、頂点 i の下に直接ぶら下がる 2 つの頂点 j, k について、
・j の子孫 → i → k の子孫
という経路について、前半と後半を分けて考えます。
・前半:j の子孫 → i
・後半:i → k の子孫
ここで、
・(c_j, d_j) = dp[j]
と置くと、k の子孫を 1 つ固定した場合、これに至る経路は c_j(j 以下のサブツリーのノード数)本ありますが、これらの前半の経路の総距離は、d_j + 1 * c_j になります。(j の子孫 -> j の各経路について、j->i の距離 1 を加えている点に注意してください。)
したがって、c_k 個ある(k 自身を含めた)k のすべての子孫を考慮すると、前半の経路の総計は、
・(d_j + 1 * c_j) * c_k
で計算されます。
同様に、
・(c_k, d_k) = dp[k]
と置いて、j の子孫を 1 つ固定した場合、ここからスタートする経路は c_k(k 以下のサブツリーのノード数)本あり、これらの後半の経路の総距離は、d_k + 1 * c_k になります。したがって、c_j 個ある(j 自身を含めた)j のすべての子孫を考慮すると、後半の経路の総計は、
・c_j * (d_k + 1 * c_k)
で計算されます。
前半と後半を合わせると、i で折り返す経路の総距離は、
・(d_j + c_j) * c_k + c_j * (d_k + c_k)
とまとまります。i の子ノードが 3 つ以上ある場合は、任意の 2 個の組み合わせについて、すべて合計すればOKです。
ここまでをまとめると、次の実装になります。
import sys from collections import defaultdict, deque def main(f): N = int(f.readline()) link = [[] for _ in range(N+1)] for _ in range(N-1): a, b = map(int, f.readline().split()) link[a].append(b) link[b].append(a) # dp[n] = (c, d) : n を頂点とするサブツリーについての (ノード数, 「子孫 i から n の距離」の合計) dp = [(0, 0)] * (N+1) dists = 0 q = deque() q.append((1, 0)) # (node, parent) while q: i, p = q.pop() if i < 0: c, d = 0, 0 for j in link[-i]: if j == p: continue c_j, d_j = dp[j] c += c_j d += d_j + 1*(c_j) # j->i の距離 1 を各子孫について加える dp[-i] = (c+1, d) links = [] for j in link[-i]: if j == p: continue links.append(j) for index1 in range(len(links)): j = links[index1] c_j, d_j = dp[j] for index2 in range(index1+1, len(links)): k = links[index2] c_k, d_k = dp[k] dists += (d_j+c_j)*c_k + c_j*(d_k+c_k) # k -> i -> j の距離 continue q.append((-i, p)) # 帰りがけ順で処理するための番兵 for j in link[i]: if j == p: continue q.append((j, i)) for n in range(1, N+1): dists += dp[n][1] # k -> i の距離をすべて加える print(dists) main(sys.stdin)
もう一歩だけ最適化
ただし、上記のコードを提出すると、一部のケースで TLE になります。
i の子ノードが大量にある場合、「任意の2つの組み合わせ」の場合の数が大きくなるためです。ただ、実際の計算内容をよく見ると、一方の子ノード j を固定すると、もう一方の子ノード k からの寄与は、c_k および d_k それぞれの合計としてまとめることができます。そこで、これらの部分和を事前に計算しておけば、i ごとに個別に k のループを回す必要がなくなります。これを実装したのが、最終的なこちらの解答になります。