めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

PRML Figure 5.21 を再現するコード

何の話かというと

A Neural Representation of Sketch Drawings でスケッチの次のストロークを予測するモデルとして、混合ガウス分布が使われており、ガウス分布の混合係数、平均、分散を Latent Variable z を入力とする RNN で計算するという手法が用いられています。

上図のデコーダ部分の出力 y が混合係数、平均、分散にあたります。その後、この分布から次のストロークのサンプルを取得することで、非決定的に画像を生成します。

このモデルは、Bishop先生のMixture Density Networksが元ネタになっており、PRMLにも解説があります。そこで、勉強のためにPRMLで紹介されているサンプルをTensorFlowで実装してみました。

モデルの説明

座標 x に依存して、平均と分散が変化する正規分布 {\mathcal N}(t\mid \mu(x),\sigma^2(x)) を3つ混合したモデルを考えます。

\displaystyle p(t\mid x) = \sum_{k=1}^3 \pi_k(x){\mathcal N}(t\mid \mu_k(x),\sigma_k^2(x))
\displaystyle\sum_{k=1}^3 \pi_k(x)=1

ここで、混合係数 \pi_k(x) も座標に依存します。さらに、平均、分散、混合係数の x に対する依存性は、ニューラルネットワークで計算されます。ここでは、一例として、5ノードの隠れ層を1層だけ持つモデルを使用します。

このモデルを用いて、下記のデータセットを学習すると、3つのパートを個別の正規分布でフィッティングできるものと期待されます。

誤差関数には、対数尤度の符号違いを用います。

\displaystyle E(\theta) = -\sum_{n=1}^N\log\left\{\sum_{k=1}^3\pi_k(x_n,\theta){\mathcal N}(t_n\mid \mu_k(x_n,\theta),\sigma_k^2(x_n,\theta))\right\}

TensorFlowを用いて実装した結果が下記になります。

実際のコードはこちらです。


Mixture Density Network