何の話かというと
A Neural Representation of Sketch Drawings でスケッチの次のストロークを予測するモデルとして、混合ガウス分布が使われており、ガウス分布の混合係数、平均、分散を Latent Variable z を入力とする RNN で計算するという手法が用いられています。
上図のデコーダ部分の出力 y が混合係数、平均、分散にあたります。その後、この分布から次のストロークのサンプルを取得することで、非決定的に画像を生成します。
このモデルは、Bishop先生のMixture Density Networksが元ネタになっており、PRMLにも解説があります。そこで、勉強のためにPRMLで紹介されているサンプルをTensorFlowで実装してみました。
モデルの説明
座標 x に依存して、平均と分散が変化する正規分布 を3つ混合したモデルを考えます。
ここで、混合係数 も座標に依存します。さらに、平均、分散、混合係数の x に対する依存性は、ニューラルネットワークで計算されます。ここでは、一例として、5ノードの隠れ層を1層だけ持つモデルを使用します。
このモデルを用いて、下記のデータセットを学習すると、3つのパートを個別の正規分布でフィッティングできるものと期待されます。
誤差関数には、対数尤度の符号違いを用います。
TensorFlowを用いて実装した結果が下記になります。
実際のコードはこちらです。