Union-find
Union-find の基本的な実装はこちらになります。
group_parent = defaultdict(lambda:None) def create_group(x): global group_parent group_parent[x] = x return x def get_group(x): global group_parent p = group_parent[x] if p == None: return None while group_parent[p] != p: p = group_parent[p] # 最適化 p = group_parent[x] group_parent[x] = real_p while group_parent[p] != p: p = group_parent[p] group_parent[p] = real_p return p def merge_group(x, y): global group_parent group_parent[get_group(y)] = get_group(x)
グループ分けしたい(Hashable な)オブジェクトに対して、グループ ID を割り当てて、必要に応じてグループをマージしていくことができます。
・create_group(x) : グループ ID を持たないオブジェクト x に新しいグループ ID を割り当てて、グループ ID を返す。
・get_group(x) : オブジェクト x のグループ ID を返す。(グループ ID を持たない場合は None が返る。)
・merge_group(x, y) : オブジェクト x の属するグループとオブジェクト y の属するグループをマージする。(オブジェクト y の属するグループのグループ ID をオブジェクト x の属するグループのグループ IDに変更する。)
実装の中身に注目すると、各オブジェクト x に対して、グループ ID を示す p = group_parent[x] を割り当てるのですが、グループ ID p 自身もさらに group_parent[p] を持っています。仮に、p == group_parent[p] (自己参照)の場合は、p が正しいグループ ID になりますが、p != group_parent[p] の場合は、group_parent のチェインをたどっていって、px == group_parent[px] となった時点で、px が正しいグループ ID になります。
このチェイニングは、グループをマージする際に発生するもので、x と y のグループ ID を px, py とする時、元々は、
・group_parent[px] = px
・group_parent[py] = py
となっているわけですが、これを
・group_parent[py] = px
と書き換えることで、グループをマージすることができます。
ただし、グループのマージが連続すると、group_parent のチェインが長くなり、get_group(x) の処理時間が伸びます。そこで、get_group(x) を実行した際は、最終的に発見されたグループ ID real_p を用いて、チェインに含まれる group_parent が指す先を real_p に置き換えてショートカットします。(先ほどの実装の「# 最適化」とコメントした部分。)
問題の解説
数列の隣り合う2項の関係が与えられていくので、Union-find を用いて、与えられた関係でつながった項をグループ化していきます。その後、x と y の関係についての Query が与えられた際に、x と y が同じグループであれば、Query に答えることができます。
そして、公式解説 では、すべての Query を読み込んだ後に、Query の答えを計算する順序を工夫することで、計算量を減らすという手法(バッチ方式)が説明されていますが、ここでは、別解として、Query を読み込みながら答えていく方法(オンライン方式)を考えてみます。
まず、この問題では「隣り合う2項の和」が与えられますが、ちょっと工夫すると、「隣り合う2項の差」が与えられる問題に置き換えることができます。具体的には次の置き換えをします。
・ と の和 を に置き換える
・ の値 を に置き換える
「隣り合う2項の差」になると何が嬉しいかというと、部分和のテクニックが使えることです。たとえば、すべての隣り合う項の差がわかっている場合であれば、これらを加えることで、 を事前に計算することができます。これにより、任意の x, y について、 と の関係を と で計算することができます。
今回の問題では、互いにつながった項のグループが分かれるので、グループごとに同様の計算をすることが考えられます。グループの先頭の項を として、グループごとに を計算するのです。先ほどの Union-find の実装を思い出すと、(「# 最適化」の処理を削除すれば)要素 x に対して、get_group(x) で得られるグループ ID が、ちょうど、グループの先頭の要素になることに気がつきます。
グループをマージする際は、マージする各グループの先頭の要素間の差を記録しておきます。これにより、複数のグループがマージされた後でも、group_parent[x] のチェーンをたどりながら差分を加えることで、 を( とまではいきませんが)効率的に計算することができます。
実装結果はこちらになります。
※ この解法、最悪計算時間は になるので、最悪計算時間になるように意図的に組まれたデータに対しては TLE になります。