めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

HaskellのContinuation Monadを理解する(1)

はじめに

HaskellのContinuation Monadを調べる機会があったのでメモしておきます。

参考文献はこのあたりです。

本当はcallCCの解体まで行いたかったのですが、長くなるので、今回はContinuation Monadの説明までです。

まずはCPSの説明

関数型言語のプログラミングスタイルで、CPS(Continuation Passing Style)と呼ばれるものがあります。

一般に計算結果の値を他の関数に渡す際は、値そのものを変数に格納して、その変数を他の関数に喰わせます。当たり前ですね。次は、nに格納した値を関数incに喰わせる例です。

$ ghci
Prelude> let inc x = x+1
Prelude> let n = 3
Prelude> inc n
4

一方、関数型言語では、「関数の関数を変数に代入する」という事が可能ですので、「喰ってもらう関数を待ち受ける値」を変数に代入することも可能です。分かりにくいですね。でも、次の例を見れば、意図は明白です。

Prelude> let n = \k -> k 3
Prelude> n inc
4

変数nには、3という値を代入する代わりに、「関数kを受け取ると、kに3を喰わせてあげる」という「関数待受状態」を代入しているわけです。nに関数incを代入すると、先と同じ4という結果が得られます。

ここで、少し記法を工夫しておきます。

Prelude> let f =: x = x f
Prelude> inc =: n
4

「関数待受状態」の変数nに、左から関数が喰い付いている感じですね :-)

ここまでは、喰わせる値としてリテラル値3を使っていましたが、もちろん、何らかの計算をしてその結果を「関数待受状態」として格納する関数を作ることも可能です。ここからは(結果を再利用するために)ファイルcps.hsに書いてghciからロードする形にします。

$ cat cps.hs
infixr 5 =:
(=:) :: a -> (a->b) -> b
x =: f = f x

add :: Int -> Int -> ((Int -> r) -> r)
add x y = \k -> k (x + y)

square :: Int -> ((Int -> r) -> r)
square x = \k -> k (x * x)

square_square :: Int -> ((Int -> r) -> r)
square_square n = (\x -> square x) =: square n

pythagoras :: Int -> Int -> ((Int -> r) -> r)
pythagoras n m = (\x -> ((\y -> add x y) =: square m)) =: square n

$ ghci cps.hs
*Main> print =: square 3
9
*Main> print =: add 2 3
5
*Main> print =: square_square 3
81
*Main> print =: pythagoras 2 3
13

addとsquareはわかりやすいと思います。ghciからは、addとsquareで作成した「関数待受状態」を関数printに喰わせることで、計算結果を表示しています。

square_squareは、squareを2回適用する例です。少し複雑ですが「f =: square n」は「f (square nの計算結果)」と同じことですので、ここでは最右項の「square n」に対して、fとして、再度、squareを適用しています。

同様に考えると、「pythagoras n m」は、平方和になります。次のように解釈すると理解できるでしょう。

(\x -> ((\y -> add x y) =: square m)) =: square n
-- 「\x -> (・・・)」という関数が最右項の「square n」にを食べて、引数xに「square nの計算結果」が入る。
==> (\y -> add (square nの計算結果) y) =: square m
-- 「\y -> add (square nの計算結果) y」という関数が「square m」を食べて、引数yに「square mの計算結果」が入る。
==> add (square nの計算結果) (square mの計算結果)

CPSの説明は以上になりますが、そろそろMonadの香りがしてきたと思います。要するに、nという生の値の代わりに、nという値を包んだ「関数待受状態」を構成するわけですが、この後で説明するContinuation Monadは、「関数待受状態」のコンテナに他なりません。

Continuation Monadを作る

Continuation Monadはモジュール「Control.Monad.Cont」で定義されていますが、ここでは理解のために、モジュールはインポートせずに直接にMonadを定義していきます。

まず、前述のように、「関数待受状態」のコンテナとして、Cont型を定義します。

cps.hsに追加

newtype Cont r a = Cont { runCont :: (a -> r) -> r }

先に作った関数squareの結果は、「関数待受状態」ですので、Cont型に格納できます。実際に関数を適用する際は、runContで格納した中身を取り出す必要があります。

$ ghci cps.hs
*Main> let s3 = Cont (square 3)
*Main> print =: runCont s3
9

さらに、Cont型をMonadとして構成して、次のようなbind(>>=)演算ができるようにすることを考えます。

$ ghci cps.hs
*Main> let { m :: Cont r Int; m = return 3 }
*Main> let { f :: Int -> Cont r Int; f x = Cont (square x) }
*Main> let result = m >>= f
*Main> print =: runCont result
9

「return 3」は、3という数字をそのまま包んだMonad「Cont (\k -> k 3)」を返します。つまり、

return x = Cont (\k -> k x)

という塩梅です。そして、これを「square x」を包んだMonad(への関数f)にbindでつなぐと、「Cont (square 3)」が得られるという寸法です。これを実現するbindの定義をじっくり考えていきます。

まず、先の関数「square_square」の定義を思い返すと、次の同値関係が想像できます。

         ??
square 3 == (\x -> square x) =: (\k' -> k' 3)
         == (\x -> runCont (f x)) =: runCont m

大人の事情で、束縛変数kをk'に書き換えています :-)

後半の「==」は、runContで(f x)とmの中身を取り出していることから自明に成立します。しかしながら、前半の「==」は不正確です。なぜなら、「square 3」は「関数待受状態」ですので、自分を食べてくれる関数を受け取ることができます。

square 3 = \k -> k (3 * 3)

先の(誤った)同値関係の右項には、関数を受け取る部分がありませんので、これをくっつけます。

         ??
square 3 == \k -> (\x -> square x) =: (\k' -> k' 3)
         == \k -> (\x -> runCont (f x)) =: runCont m

まだ不十分ですね。右項は、受け取ったkに何を喰わせるべきでしょうか? もちろん、最終計算結果「square x」、すなわち「runCont (f x)」の中身です。

square 3 == \k -> (\x -> k =: square x) =: (\k' -> k' 3)
         == \k -> (\x -> k =: runCont (f x) =: runCont m

ゴールが見えてきました。欲しかったのは、次の結果ですから、

Cont (square 3) == m >>= f

下記の定義が使えそうです。

m >>= f = Cont (\k -> (\x -> k =: runCont (f x) =: runCont m)

これは、「Control.Monad.Cont」における下記の定義と同じになります。「喰い付き演算子 =:」の役割をよく思い出してくださいね。。。

cps.hsに追加

instance Monad (Cont r) where
    return x = Cont (\k -> k x)
    m >>= f = Cont (\k -> runCont m (\x -> runCont (f x) k))

Monad演算を使うと例のごとく、組み合わせ演算を「手続きっぽく」表現することができます。先のpythagorasの例は次のようになります。

cps.hsに追加

squareCont x = Cont (\k -> k (x * x))
addCont x y = Cont (\k -> k (x + y))

pythagoras2 :: Int -> Int -> Cont r Int
pythagoras2 n m = do
    x <- squareCont n
    y <- squareCont m
    result <- addCont x y
    return result
$ ghci cps.hs 
*Main> print =: runCont (pythagoras2 2 3)
13

今回はここまでです。次回はcallCCを解体していく予定です。