めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

PRML 第1章の「ベイズ推定によるパラメータフィッティング」の解説(その2)

ちなみに

本記事のタイトルは「PRML第1章」とついていますが、実質的には、「3.3 Bayesian Linear Regression」の説明になっています。

ベイズ推定によるカーブフィッティング

下記の記事では、平均 \mu が未知の正規分布について、ベイズ推定でフィッティングする例を説明しました。

ここでは、下記の記事の続きとして、「正弦波+正規分布のノイズ」を多項式でフィッティングする例について、ベイズ推定によるフィッティングを適用してみます。

しつこく繰り返しますが、ベイズ推定は、『「観測データ」を元に「ある事柄の確率」を洗練していく』という手法です。今の場合、「観測データ」に相当するのは「N 個の観測ポイント \{x_n\}_{n=1}^N から得られた観測値 \{t_n\}_{n=1}^N」です。一方、「ある事柄」に相当するのは、M次多項式の係数群 \{w_m\}_{m=0}^M です。それぞれ、次のようにベクトル表記しておきます。

 \mathbf{x} := \begin{pmatrix}x_1 \\ \vdots \\ x_N\end{pmatrix}\,\,\,\,\mathbf{t} := \begin{pmatrix}t_1 \\ \vdots \\ t_N\end{pmatrix}\,\,\,\,\mathbf{w} := \begin{pmatrix}w_0 \\ \vdots \\ w_M\end{pmatrix}

したがって、係数群 \mathbf{w} について、これらがとり得る値について、それぞれの確率 P({\mathbf w}) が決まっていると考えることにします。

まず、前提条件が無い段階での確率(事前分布)は、例によって(?)適当に決めます。たとえば、\alpha を任意の定数として、平均0、分散 \alpha^{-1} の正規分布を仮定します。

 p({\mathbf w}) = {\mathscr N}({\mathbf w} \mid {\mathbf 0},\alpha^{-1}\mathbf{I})=\left(\frac{\alpha}{2\pi}\right)^{(M+1)/2}\exp\left\{-\frac{\alpha}{2}{\mathbf w}^{\rm T}{\mathbf w}\right\} ―― (1)

つづいて、トレーニングセットが与えられた場合の確率(事後分布)は、ベイズの定理から次のようになります。

 p({\mathbf w} \mid {\mathbf t}) = \frac{p({\mathbf t} \mid {\mathbf w})}{\int p(\mathbf t \mid \mathbf w)p({\mathbf w})\,d{\mathbf w}} p({\mathbf w}) ―― (2)

ここで、最尤推定の場合と同様に、今回得られたトレーニングセットを与える情報源は、「何らかの関数 y(x) に従って値が決まっているが、それぞれのデータには一定の正規分布の誤差が入り込む」という性質を持っていることが予めわかっているものとします。この誤差の分散を \beta^{-1} として、次の関係が成り立ちます。

 p({\mathbf t} | {\mathbf w}) = \prod_{n=1}^N{\mathscr N}\left(t_n\,\middle|\,y(x_n,{\mathbf w}),\beta^{-1}\right)

 \,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,=\left(\frac{\beta}{2\pi}\right)^{\frac{N}{2}}\exp\left\{-\frac{\beta}{2}\sum_{n=1}^N\left(y(x_n,{\mathbf w})-t_n\right)^2\right\} ―― (3)

 y(x, {\mathbf w}) := \sum_{m=0}^M w_m x^m

(1)(3)を(2)に代入して、{\mathbf w} に依存する比例項を取り出すと次が得られます。

 p({\mathbf w} \mid {\mathbf t}) \propto \exp\left\{-\frac{\beta}{2}\sum_{n=1}^N\left(y(x_n,{\mathbf w})-t_n\right)^2 -\frac{\alpha}{2}{\mathbf w}^{\rm T}{\mathbf w}\right\} ―― (4)

最後の式の \exp の中身は、PRML 第1章の多項式フィッティングの例を再現で紹介した下記の誤差関数と同じ形をしていることが分かります。

 \tilde E := \frac{1}{2}\sum_{n=1}^N\left(t_n-\sum_{m=0}^{M} w_m x^m\right)^2+\frac{\lambda}{2}\sum_{m=0}^Mw_m^2

この時は、オーバーフィッティングを抑えるために {\mathbf w} が大きくなりすぎないように手で誤差関数を修正しました。一方、今の場合は、事前分布の影響として第2項が出現しています。つまり、ベイズ推定の場合は、事前分布を調整することでオーバーフィッティングを抑えることができるわけです。当然ながら、「どのようにオーバーフィッティングを抑えるのか」という方法はたくさん考えられらますので、どのような事前分布がベストかを判定するのは簡単ではありません。

次に得られるデータの推定

(4)を元にして、ある観測点 x_0 から次に得られるデータ t_0 を推測してみましょう。PRML 第1章の「ベイズ推定によるパラメータフィッティング」の解説(その1)で説明したように、{\mathbf w} のあらゆる値について確率が与えられているので、それらについて期待値をとる必要があります。

 p(t_0 \mid {\mathbf t}) = \int p({\mathbf w} \mid {\mathbf t})p(t_0 \mid {\mathbf w})\,d{\mathbf w}

ここで、積分に含まれる2つの確率は、どちらも正規分布になっている事に注意します。

 p(t_0 \mid {\mathbf w}) = {\mathscr N}(t_0 \mid {\mathbf \phi}(x_0)^{\rm T}{\mathbf w},\,\beta^{-1})

 p({\mathbf w} \mid {\mathbf t}) = {\mathscr N}(\mathbf{w} \mid \beta{\mathbf S}\sum_{n=1}^Nt_n{\mathbf \phi}(x_n),\,{\mathbf S})

ここに、

 {\mathbf \phi}(x) := \begin{pmatrix}x^0 \\ x^1 \\ \vdots \\ x^M\end{pmatrix},\,\,\,\,{\mathbf S}^{-1} := \alpha{\mathbf I} + \beta\sum_{n=1}^N{\mathbf \phi}(x_n){\mathbf \phi}(x_n)^{\rm T} ―― (5)

計算式はこちら

また、一般に、正規分布について次の公式が成立します。(証明はPRMLの「2.3.3 Bayes' theorem for Gaussian variables」を参照。)

 \int {\mathscr N}({\mathbf w} \mid {\mathbf \mu},\,{\mathbf \Sigma}){\mathscr N}(t_0 \mid {\mathbf a}^{\rm T}{\mathbf w},\,\beta^{-1})\,d{\mathbf w} = {\mathscr N}(t_0 \mid {\mathbf a}^{\rm T}{\mathbf \mu},\,\beta^{-1}+{\mathbf a}^{\rm T}{\mathbf \Sigma a})

以上を利用すると、最終的に次の結果が得られます。

 p(t_0 \mid {\mathbf t}) = {\mathscr N}\left(t_0 \,\middle|\, \beta\mathbf{\phi}(x_0)^{\rm T}{\mathbf S}\sum_{n=1}^Nt_n{\mathbf \phi}(x_n),\,\beta^{-1}+{\mathbf \phi}(x_0)^{\rm T}{\mathbf S \phi}(x_0)\right) ―― (6)

少し複雑な式になりましたが、(5)からは、観測点 N が増えると |S| は小さくなることが分かります。つまり、(6)の分散は小さくなります。観測点が少ない場合は、平均値の確信度が低くなるので、その分、大きな分散で次の点を予測するという、PRML 第1章の「ベイズ推定によるパラメータフィッティング」の解説(その1)と同じ挙動になっていることが分かります。

数値計算で確認

それでは、(6)の分布をグラフ表示してみましょう。(6)は x_0 の関数と見なせますが、値が確率になっていますので、

  • 平均: m(x_0) := \beta\mathbf{\phi}(x_0)^{\rm T}{\mathbf S}\sum_{n=1}^Nt_n{\mathbf \phi}(x_n)
  • 分散: s(x_0) := \beta^{-1}+{\mathbf \phi}(x_0)^{\rm T}{\mathbf S \phi}(x_0)

と置いて、

y = m(x)、および、y = m(x) \pm \sqrt{s(x)} のグラフを描きます。多項式の次数は M=9 としています。

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pandas import Series, DataFrame

from numpy.random import normal

beta = 1.0/(0.3)**2
alpha = 1.0/100**2
order = 9

def create_dataset(num):
    dataset = DataFrame(columns=['x','y'])
    for i in range(num):
        x = float(i)/float(num-1)
        y = np.sin(2.0*np.pi*x) + normal(scale=0.3)
        dataset = dataset.append(Series([x,y], index=['x','y']),
                                 ignore_index=True)
    return dataset

def resolve(dataset, m):
    t = dataset.y
    phis = DataFrame()
    for i in range(0,m+1):
        p = dataset.x**i
        p.name="x**%d" % i
        phis = pd.concat([phis,p], axis=1)

    for index, line in phis.iterrows():
        phi = DataFrame(line)
        if index == 0:
            phiphi = np.dot(phi,phi.T)
        else:
            phiphi += np.dot(phi,phi.T)
    s_inv = alpha * DataFrame(np.identity(m+1)) + beta * phiphi
    s = np.linalg.inv(s_inv)

    def mean_fun(x0):
        phi_x0 = DataFrame([x0 ** i for i in range(0,m+1)])
        for index, line in phis.iterrows():
            if index == 0:
                tmp = t[index] * line
            else:
                tmp += t[index] * line
        return (beta * np.dot(np.dot(phi_x0.T, s), DataFrame(tmp))).flatten()

    def deviation_fun(x0):
        phi_x0 = DataFrame([x0 ** i for i in range(0,m+1)])
        deviation = np.sqrt(1.0/beta + np.dot(np.dot(phi_x0.T, s), phi_x0))
        return deviation.diagonal()

    return mean_fun, deviation_fun

if __name__ == '__main__':
    df_ws = DataFrame()

    # Show fitting curves
    fig = plt.figure()
    ax = {}
    for c, num in enumerate([4,5,10,100]): # Num of datapoints
        train_set = create_dataset(num)
        mean_fun, deviation_fun = resolve(train_set, order)
        ax[c] = fig.add_subplot(2,2,c+1)
        ax[c].set_xlim(-0.05,1.05)
        ax[c].set_ylim(-2,2)
        ax[c].set_title("N=%d" % num)

        # dataset
        ax[c].scatter(train_set.x, train_set.y, marker='o', color='blue')

        # correct curve
        linex = np.arange(0,1.01,0.01)
        liney = np.sin(2*np.pi*linex)
        ax[c].plot(linex, liney, color='green')

        # polynomial fit
        m = np.array(mean_fun(linex))
        d = np.array(deviation_fun(linex))
        ax[c].plot(linex, m, color='red', label="mean")
        ax[c].legend(loc=1)
        ax[c].plot(linex, m-d, color='black', linestyle='--')
        ax[c].plot(linex, m+d, color='black', linestyle='--')
    fig.show()

実行結果は次のようになります。

これを見ると次のようなことが分かります。

  • 観測点が少ない場合、平均値のカーブは真のカーブより大きくずれる部分もあるが、その分だけ分散も大きくなっており、真のカーブは標準偏差の範囲内には収まっている。
  • 観測点が多くなると分散は小さくなっており、十分なデータがあれば本来の分散である 0.3 付近に収まっている。
  • 事前分布の影響でオーバーフィッティングが抑えられており、N=10 においてすべての点を通るような形にはなっていない。

ところで、上記の N=4 のグラフをみると観測点から離れた部分は、分散が非常に大きくなっています。これは、次のように理解することができます。

まず、上記のグラフでは、観測点 x_0 を固定して、その点における平均/分散を考えましたが、実際に推定しているのは、下記で与えられる係数 {\mathbf w} の値です。

 p({\mathbf w} \mid {\mathbf t}) = {\mathscr N}(\mathbf{w} \mid \beta{\mathbf S}\sum_{n=1}^Nt_n{\mathbf \phi}(x_n),\,{\mathbf S})

したがって、上記の正規分布にしたがって、{\mathbf w} の値が1つランダムに決まると、それに対応した下記の推定曲線が決まります。

 y(x, {\mathbf w}) = \sum_{m=0}^M w_m x^m

そこで、ランダムに選んだいくつかの {\mathbf w} に対応する推定曲線を上に重ねて描いてみます。

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pandas import Series, DataFrame

from numpy.random import normal, multivariate_normal

beta = 1.0/(0.3)**2
alpha = 1.0/100**2
order = 9

def create_dataset(num):
    dataset = DataFrame(columns=['x','y'])
    for i in range(num):
        x = float(i)/float(num-1)
        y = np.sin(2.0*np.pi*x) + normal(scale=0.3)
        dataset = dataset.append(Series([x,y], index=['x','y']),
                                 ignore_index=True)
    return dataset

def resolve(dataset, m):
    t = dataset.y
    phis = DataFrame()
    for i in range(0,m+1):
        p = dataset.x**i
        p.name="x**%d" % i
        phis = pd.concat([phis,p], axis=1)

    for index, line in phis.iterrows():
        phi = DataFrame(line)
        if index == 0:
            phiphi = np.dot(phi,phi.T)
        else:
            phiphi += np.dot(phi,phi.T)
    s_inv = alpha * DataFrame(np.identity(m+1)) + beta * phiphi
    s = np.linalg.inv(s_inv)

    def mean_fun(x0):
        phi_x0 = DataFrame([x0 ** i for i in range(0,m+1)])
        for index, line in phis.iterrows():
            if index == 0:
                tmp = t[index] * line
            else:
                tmp += t[index] * line
        return (beta * np.dot(np.dot(phi_x0.T, s), DataFrame(tmp))).flatten()

    def deviation_fun(x0):
        phi_x0 = DataFrame([x0 ** i for i in range(0,m+1)])
        deviation = np.sqrt(1.0/beta + np.dot(np.dot(phi_x0.T, s), phi_x0))
        return deviation.diagonal()

    for index, line in phis.iterrows():
        if index == 0:
            tmp = t[index] * line
        else:
            tmp += t[index] * line
    mean = beta * np.dot(s, DataFrame(tmp)).flatten()

    return mean_fun, deviation_fun, mean, s

if __name__ == '__main__':
    df_ws = DataFrame()

    # Show fitting curves
    fig = plt.figure()
    ax = {}
    for c, num in enumerate([4,5,10,100]): # Num of datapoints
        train_set = create_dataset(num)
        mean_fun, deviation_fun, mean, sigma = resolve(train_set, order)
        ax[c] = fig.add_subplot(2,2,c+1)
        ax[c].set_xlim(-0.05,1.05)
        ax[c].set_ylim(-2,2)
        ax[c].set_title("N=%d" % num)
        ws_samples = DataFrame(multivariate_normal(mean,sigma,4))

        # dataset
        ax[c].scatter(train_set.x, train_set.y, marker='o', color='blue')

        # correct curve
        linex = np.arange(0,1.01,0.01)
        liney = np.sin(2*np.pi*linex)
        ax[c].plot(linex, liney, color='green')

        # polynomial fit
        m = np.array(mean_fun(linex))
        d = np.array(deviation_fun(linex))
        liney = m
        ax[c].plot(linex, liney, color='red', label="mean")
        ax[c].legend(loc=1)
        liney = m-d
        ax[c].plot(linex, liney, color='black', linestyle='--')
        liney = m+d
        ax[c].plot(linex, liney, color='black', linestyle='--')

        def f(x, ws):
            y = 0
            for i, w in enumerate(ws):
                y += w * (x ** i)
            return y

        for index, ws in ws_samples.iterrows():
            liney = f(linex, ws)
            ax[c].plot(linex, liney, color='red', linestyle='--')

    fig.show()

これを実行すると、次の結果が得られます。

赤い破線がランダムに選んだ {\mathbf w} で決まる曲線です。これらは、観測点の近くを通るようにフィッティングされているので、結果的に、観測点のまわりは分散が小さく、観測点から離れると分散が大きくなることが分かります。