めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

017 - Crossing Segments(★7)の解説

何の話かと言うと

atcoder.jp

上記の問題をネタに、計算量の観点から解法を考える、という話をしたいと思います。

問題の内容

問題では、円周上の点になっていますが、点 1 のすぐ左で切って円を開けば、1 〜 N の区間に複数の区間 (-----) がばら撒かれており、下記のように区間がクロスする部分をカウントする問題になります。(一方の区間がもう一方の区間に完全含まれている場合は、カウントされません。)

(------)
    (--------)

またこの問題では、点の数 N とばら撒かれた区間の数 M が与えられていますが、それぞれの上限を考えると O(N\times M) では間に合いそうにありません。(実行時間を気にしなければ)色々な解法が考えられますが、今回は、M と N のそれぞれについてループする方法(M と N の二重ループ、もしくは、M の二重ループなど)は使えない事になります。

で・・・どうしようか・・・なのですが、最終的には、N についてループしながら各ループ内を O(\log N) もしくは O(\log M) でがんばるか、もしくは、M についてループしながら、ループ内を O(\log N) もしくは O(\log M) でがんばるかのどちらかになるはずです。N でループするならどんな方法がありそうか、もしくは、M でループするならどんな方法がありそうか、それぞれ考えてみましょう。

N でループする場合

この場合、N 個の点を端から順番にチェックすることになります。それぞれの区間の左端 ( 、もしくは、右端 ) が順番に現れてきます。それぞれの区間について、「まだ現れていない」「左端が現れた」「右端が現れて閉じた」という 3 つの状態が変化していきます。各区間の状態をリストにまとめてトラッキングすることができそうです。

で、これをトラッキングすると何ができるかというと・・・

ある区間 m の右端 ) が登場した時に、対応する左端 ( との間に、(他の区間の)「左端が現れた」という事象が何個あるかをカウントすれば、区間 m とクロスする区間の数になりそうです。ただし、すでに右端が現れて閉じたというものは、除外する必要があります。

この時、他の区間の状態をひとつひとつチェックしながらクロスするものをカウントしてもよいのですが、これでは、O(\log M) では計算が終わりません。(他の区間は最大で M-1 個ありますからね。)重要なのは、「左端が現れたという事象の数」なので、もうちょっと単純に、この事象を1つのリストに詰め込んでいけばどうでしょうか?左端が現れた場所が重要ですので、左端が現れるたびに、リスト open_list にその点の位置 n を追加していきます。(これは昇順にソートされたリストになりますね。)で、右端が現れた場合に、自分のペアの左端の位置より大きな n の個数をカウントすればOKです。おっと、すでに閉じたものは除外する必要がありますが、これは、右端が現れたタイミングで、ペアとなる左端の位置を open_list から削除しておけばよいでしょう。

で・・・うーーーーんと考えて作ったのが下記のコードです。

import sys, bisect

def main(f):
  N, M = map(int, f.readline().split())

  ls = [[] for _ in range(N+1)]
  rs = [[] for _ in range(N+1)]
  pos_l = [None] * (M+1)
  for i in range(1, M+1):
    l, r = map(int, f.readline().split()) # l < r
    ls[l].append(i)
    rs[r].append(i)
    pos_l[i] = l

  count = 0
  open_list = []

  for n in range(1, N+1):
    for i in rs[n]:
      open_list.remove(pos_l[i]) # O(M)
    for i in rs[n]:
      index = bisect.bisect(open_list, pos_l[i]) # O(log M)
      count += len(open_list) - index
    for i in ls[n]:
      open_list.append(n) # sorted list
  print(count)

main(sys.stdin)

位置 n をスキャンしていくので、位置 n にある左端の区間番号を集めたリスト ls[n]、および、右端の番号を集めたリスト rs[n] を用意しています。また、右端が出てきた時に、対応する左端の位置を区間番号から取り出すためのリスト pos_l も用意しています。

特に重要なのは、「右端が現れた場合に、自分のペアの左端の位置より大きな n の個数をカウント」する部分について、バイナリーサーチ bisect を用いる事で O(\log M) で処理している点です。すばらしい。

なのですが・・・、実は、コード内のコメントにあるように、「右端が現れたタイミングで、ペアとなる左端の位置を open_list から削除」する処理は、(Python の仕様上)リストの長さに比例する時間がかかります。ここが O(M) になってしまうため、残念ながらこのコードは TLE になります。

atcoder.jp

ぐぬぬぬぬ。

M でループする場合

仕方がないので、M でループする方法を考えてみましょう。

この場合は、区間 (---) を一つずつ取り出しながら、すでに取り出した区間との位置関係を見て、クロスの個数を数えるという流れになるはずです。一般に、このような処理では、「取り出すもの」を事前にソートしておくと、場合分けが制限できて考えやすくなります。たとえばですが、区間の左端の位置について、昇順になるように取り出すとしてみましょう。

すると・・・・

(------)
    (--------)  ← 今取り出したもの。この ( より右には他の ( は存在しない。


という状況から、今取り出した区間内にある他の区間の右端 ) の個数がクロスの個数になります。「この ( より右には他の ( は存在しない。」というコメントに注意すると、今取り出した区間内に完全に含まれている他の区間は存在しない点に注意してください。

今回は、「右端の個数を数える」ということですので、先ほどと同様にバイナリーサーチを使えば、\log M でカウントできそうです。

で、うーーーーんと考えて実装したのがこちらです。

import sys, bisect

def main(f):
  N, M = map(int, f.readline().split())

  links = []
  for i in range(1, M+1):
    l, r = map(int, f.readline().split()) # l < r
    links.append((l, r))
  links.sort()

  rs = []
  count = 0
  pre_l = 0
  dup = 0
  for l, r in links:
    index1 = bisect.bisect(rs, l) # O(log M)
    index2 = bisect.bisect_left(rs, r) # O(log M)
    count += index2 - index1
    if pre_l == l:
      dup += 1
      count -= dup
    else:
      dup = 0
    pre_l = l
    rs.insert(index2, r) # O(M)
  print(count)

main(sys.stdin)

同じ位置に連続して左端が現れる場合をハンドリングする部分がちょっと面倒ですが、解法としてはOKです。

なのですが・・・・

ここにもトラップがありました。今回は、右端の位置をカウントするために、これらをソート済みリストにまとめていますが、右端の位置は昇順に現れるとは限らないため、昇順を保ってリストにインサートする部分が O(M) になってしまいます。

このコードも TLE になります。がーーーん。

atcoder.jp

解決策

じゃあどうすればいいの?

という感じですが、いずれの場合も、ソート済みリストに挿入・削除する部分が O(M) という共通の課題があります。

で、これは知っているかどうかだけの問題なのですが、実は、ソート済みのリストに対して、ソートを保った挿入・削除を O(\log M) で実行できるデータ構造があります。(二分探索木の気の利いたやつで、AVL木と言います)。ただ残念な事に、Python の標準ライブラリーに入っていないので、自分で実装して使う必要があります。

ここでは、個人の方がブログで公開しているこちらの実装をそのまま使わせていただくことにします。

PythonでAVL木を競プロ用に実装した

これを利用すると、先に紹介した2つの方法は、それぞれ、O(N\log M)、および、O(M\log M) に改善されて AC になります。意外なところで、別解が2つも見つかっちゃいましたね。ふふふふふ。

※ 今回用いた AVL 木は同じ値を重複して add できない制限があったので、微小な値を加えてずらすという方法を使っています。

N でループする場合

atcoder.jp

M でループする場合

atcoder.jp