めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

PRML 第1章の多項式フィッティングの例を再現

何の話かというと

Pattern Recognition and Machine Learning (Information Science and Statistics)

Pattern Recognition and Machine Learning (Information Science and Statistics)

PRML(↑)の第1章(1.1 Example: Ploynomial Curve Fitting)では、正弦波にノイズの乗ったトレーニングセットを多項式でフィッティングする例があります。これと同じデータを生成して、この例を再現してみます。使用する言語は、Python + Pandas です。

PRMLの例の説明

トレーニングセット \{(x_n, t_n)\}_{n=1}^{N} は、次のように生成します。

説明変数 x_n に対して、下記の関数値(正弦波)に正規分布(平均 0、標準偏差 0.3)のノイズを加えたものを目的変数 t_n として生成します。x の値は 0 \le x \le 1 の範囲を10個のデータポイントに等分します。

 y=\sin(2\pi x)\,\,\,\, (0 \le x \le 1)

このようにして得られたトレーニングセットをM次多項式でフィッティングします。

 y = \sum_{m=0}^{M} w_m x^m

フィッティングの方法としては、下記の E_{\rm {RMS}} (二乗平均平方根誤差、RMS:Root Mean Square)を最小化するように、係数 w_m を決定します。

 E_{\rm {RMS}} := \sqrt{\frac{2E_{\rm D}}{N}},\,\,\,\,E_{\rm D}: = \frac{1}{2}\sum_{n=1}^N\left(\sum_{m=0}^{M} w_m x_n^m-t_n\right)^2

E_{\rm D}w_m による偏微分係数が0になるという条件から計算すると、係数は下記のように決まります。

 \mathbf{w} = \left(\mathbf{\Phi}^{\rm T}\mathbf{\Phi}\right)^{-1}\mathbf{\Phi}^{\rm T}\mathbf{t}

ここに、

 \mathbf{w} := \begin{pmatrix}w_0 \\ \vdots \\ w_M\end{pmatrix}\,\,\,\,\mathbf{t} := \begin{pmatrix}t_1 \\ \vdots \\ t_N\end{pmatrix}\,\,\,\,\mathbf{\Phi} := \begin{pmatrix}x_1^0&x_1^1&\dots&x_1^M \\ x_2^0&x_2^1&\dots&x_2^M \\ \vdots&\vdots&\ddots&\vdots \\ x_N^0&x_N^1&\dots&x_N^M\end{pmatrix}

証明はこちら:

PRMLの例を再現するコード

前述のトレーニングセットを生成して、上記の解法で多項式を決定するコードが下記になります。

# -*- coding: utf-8 -*-
#
# 誤差関数(最小二乗法)による回帰分析
#
# 2015/04/22 ver1.0
#

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pandas import Series, DataFrame

from numpy.random import normal

#------------#
# Parameters #
#------------#
N=10            # サンプルを取得する位置 x の個数

# データセット {x_n,y_n} (n=1...N) を用意
def create_dataset(num):
    dataset = DataFrame(columns=['x','y'])
    for i in range(num):
        x = float(i)/float(num-1)
        y = np.sin(2*np.pi*x) + normal(scale=0.3)
        dataset = dataset.append(Series([x,y], index=['x','y']),
                                 ignore_index=True)
    return dataset

# 平方根平均二乗誤差(Root mean square error)を計算
def rms_error(dataset, f):
    err = 0
    for index, line in dataset.iterrows():
        x, y = line.x, line.y
        err += 0.5 * (y - f(x))**2
    return np.sqrt(2 * err / len(dataset))

# 最小二乗法で解を求める
def resolve(dataset, m):
    t = dataset.y
    phi = DataFrame()
    for i in range(0,m+1):
        p = dataset.x**i
        p.name="x**%d" % i
        phi = pd.concat([phi,p], axis=1)
    tmp = np.linalg.inv(np.dot(phi.T, phi))
    ws = np.dot(np.dot(tmp, phi.T), t)

    def f(x):
        y = 0
        for i, w in enumerate(ws):
            y += w * (x ** i)
        return y

    return (f, ws)

# Main
if __name__ == '__main__':
    train_set = create_dataset(N)
    test_set = create_dataset(N)
    df_ws = DataFrame()

    # 多項式近似の曲線を求めて表示
    fig = plt.figure()
    for c, m in enumerate([0,1,3,9]): # 多項式の次数
        f, ws = resolve(train_set, m)
        df_ws = df_ws.append(Series(ws,name="M=%d" % m))

        subplot = fig.add_subplot(2,2,c+1)
        subplot.set_xlim(-0.05,1.05)
        subplot.set_ylim(-1.5,1.5)
        subplot.set_title("M=%d" % m)

        # トレーニングセットを表示
        subplot.scatter(train_set.x, train_set.y, marker='o', color='blue')

        # 真の曲線を表示
        linex = np.arange(0,1.01,0.01)
        liney = np.sin(2*np.pi*linex)
        subplot.plot(linex, liney, color='green')

        # 多項式近似の曲線を表示
        linex = np.arange(0,1.01,0.01)
        liney = f(linex)
        label = "E(RMS)=%.2f" % rms_error(train_set, f)
        subplot.plot(linex, liney, color='red', label=label)
        subplot.legend(loc=1)

    # 係数の値を表示
    print "Table of the coefficients"
    print df_ws.transpose()
    fig.show()

    # トレーニングセットとテストセットでの誤差の変化を表示
    df = DataFrame(columns=['Training','Test'])
    for m in range(0,10):   # 多項式の次数
        f, ws = resolve(train_set, m)
        train_error = rms_error(train_set, f)
        test_error = rms_error(test_set, f)
        df = df.append(
                Series([train_error, test_error], index=['Training','Test']),
                  ignore_index=True)
    df.plot(title='RMS Error')
    plt.show()

これを実行すると次の結果が得られます。

まず下記は、M=0,1,3,9 の次数でフィッティングした結果です。緑の正弦波が正しいデータソースで、青い点はノイズが乗ったトレーニングセット、赤いラインがフィッティングで得られた関数です。

M=9 の場合は、10個のデータポイントに対して、10個のパラメータ w_m があるので、すべてのデータポイントを正確に再現することが可能です。しかしながら、グラフの形から分かるようにデータポイント以外の点に関しては値が大きく変動しており、未知のデータを推定する汎用性はむしろ失われています。いわゆる、オーバーフィッティングの状態になります。

そこで、トレーニングセットとは独立に用意したテストセットを用いて、得られた関数の汎用性を確認します。次数を増やしていくとトレーニングセットに対する誤差(E_{\rm RMS})は減少していきますが、テストセットに対する誤差は減少しなくなります。

M=0,1,3,9 における係数 w_m の具体的な値は下記のようになります。

Table of the coefficients
       M=0       M=1        M=3            M=9
0 -0.02844  0.498661  -0.575134      -0.528572
1      NaN -1.054202  12.210765     151.946893
2      NaN       NaN -29.944028   -3569.939743
3      NaN       NaN  17.917824   34234.907567
4      NaN       NaN        NaN -169228.812728
5      NaN       NaN        NaN  478363.615824
6      NaN       NaN        NaN -804309.985246
7      NaN       NaN        NaN  795239.975974
8      NaN       NaN        NaN -426702.757987
9      NaN       NaN        NaN   95821.189286

次数が上がると係数の絶対値が極端に大きくなることが分かります。これから、係数の絶対値に上限を設けることでオーバーフィッティングを防止するという手法が考えられることになります。具体的には、適当な定数 \lambda を決めて、下記の誤差関数を最小にするように w_m を決定します。

\tilde E := \frac{1}{2}\sum_{n=1}^N\left(t_n-\sum_{m=0}^{M} w_m x^m\right)^2+\frac{\lambda}{2}\sum_{m=0}^Mw_m^2

最適な \lambda の値は、テストセットに対するフィッティング状況を見ながらトライ&エラーで決定する必要があります。

さらにまた、オーバーフィッティングを避ける上で最適な次数Mをどのように選択するかも、別途、検討が必要です。この例では、データ数が10個なので、M=9(パラメーター数が10個)は明らかにやりすぎです。一方、データ数が100個の場合で同じことをやると次の結果が得られます。


Table of the coefficients
       M=0       M=1        M=3           M=9
0  0.00132  0.929047  -0.202370      0.043305
1      NaN -1.855454  11.959013      7.577173
2      NaN       NaN -34.543562    -61.640223
3      NaN       NaN  22.955558    678.551318
4      NaN       NaN        NaN  -4138.781098
5      NaN       NaN        NaN  13535.233669
6      NaN       NaN        NaN -25621.807711
7      NaN       NaN        NaN  28107.392828
8      NaN       NaN        NaN -16522.578322
9      NaN       NaN        NaN   4016.362911

この場合、M=9 でもオーバーフィッティングになっているとは言えません。しかしながら、M=3 を超えると誤差は 0.3 付近より減っておらず、できるだけ計算量の少ないシンプルなモデルで最適なものは、M=3 と想像されます。

ちなみに、今回の例では、元々のデータが標準偏差 0.3 のノイズを含んでいるので、完全なモデル(オリジナルの正弦波を与えるモデル)でも二乗平均平方根誤差は 0.3 になります。この観点で上記の結果を見ると、M=3 で完全なモデルに近い結果を出していることが分かります。ただし、現実の問題では、元データがどのようなノイズを含んでいるかは分かりませんので、誤差の絶対的な値をモデルの評価基準にすることはできません。この例のように、パラメータ数を変化させた時の誤差の相対的な変動を見る必要があります。

モデルの最適なパラメータ数を判断する、より一般的な手法としては、「AIC(赤池情報量規準)」などもあります。