何の話かというと
前回の記事では、畳み込みのフィルターを固定的に手で与えて、後段の処理(特徴変数の抽出とSoftmax関数による分類)のみを機械学習で最適化するという例を紹介しました。次のステップは、畳み込みのフィルターそのものを機械学習で最適化するという処理になります。
そこで、まずは試しに、前回のコードに対して、フィルター部分も学習するように処理を書き換えてみます。
コードの書き換え
なんと! TensorFlowを利用すると、これは、たった1行の修正でできてしまいます。フィルターのパラメーター W_conv1 を手で与えた定数(constant)から、学習対象の変数(Variable)に変えるだけです。W_conv1 の初期値は乱数で与えられます。
変更前
with tf.name_scope('convolution'): W_conv1 = weight_constant([5,5,1,2]) # <--- modify here h_conv1 = tf.nn.relu(tf.abs(conv2d(x_image, W_conv1))-0.2) _ = tf.histogram_summary('h_conv1', h_conv1)
変更後
with tf.name_scope('convolution'): W_conv1 = weight_variable([5,5,1,2]) h_conv1 = tf.nn.relu(tf.abs(conv2d(x_image, W_conv1))-0.2) _ = tf.histogram_summary('h_conv1', h_conv1)
ただし残念ながら、これだけでは、うまく学習できません。最適化処理のパラメーターを少し修正します。具体的には、tf.train.GradientDescentOptimizer() の引数を少し小さくします。
変更前
with tf.name_scope('optimizer'): y_ = tf.placeholder(tf.float32, [None, 3], name='y-input') cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) optimizer = tf.train.GradientDescentOptimizer(0.005) # <--- modify here train_step = optimizer.minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
変更後
with tf.name_scope('optimizer'): y_ = tf.placeholder(tf.float32, [None, 3], name='y-input') cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) optimizer = tf.train.GradientDescentOptimizer(0.002) train_step = optimizer.minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
このパラメーターは、1回の学習処理で変数(Variable)の値をどの程度大きく修正するかというものです。これが大きいと学習の速度は上がりますが、最適解の範囲が狭いとそこにたどり着くことができません。一方、小さすぎると、最適解にたどり着くのに時間がかかります。学習曲線(正解率の変化)を見ながら、ベストな値をさぐる必要があります。
コードの実行結果
上記の修正をしたコードを実行すると次の結果が得られます。正解率の変化は次のとおりです。
途中でちょっと苦労しているようですが、20回程度のIterationで正解率100%に達しています。
学習後のフィルターは次のようになっています。
期待通りに、縦棒と横棒を抽出するフィルターが得られていますが、縦棒のフィルターが少し右に傾いています。これは、学習用データの縦棒が右に傾いたものが多いことに起因しているのかも知れません。つまり、縦棒が右に傾いているという特徴を自然に学習したわけです。変数 の分布も期待通りです。
最後に、TensorBoardで見えるネットワークのグラフは次のようになります。
今回は、フィルター(Convolution)も学習対象になっているので、すべてのレイヤーからOptimizerへ入力が入っています。
次回予告
これで、畳み込みフィルターそのものを学習させるという方針がうまくいくことが分かりました。次回は、TensorFlow Tutorialで紹介されているCNNのコードが、本質的にこれと同じ構造になっていることを解説します。
次回の記事はこちらです。