めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

TensorFlow Tutorialの数学的背景 − Deep MNIST for Experts(その1)

何の話かというと

enakai00.hatenablog.com

上記の記事では、与えられたデータをそのまま分類するのではなく、分類に適した「特徴」を抽出した後、その特徴を表す変数(特徴変数)に対して分類処理をほどこすという考え方を紹介しました。今回は、とくに「畳み込み演算」によって、画像の特徴を抽出する方法を解説します。これは、Deep MNIST for Experts で紹介されているCNN(畳み込みニューラルネットワーク)による画像認識の基礎となります。

畳み込み演算とは?

(参考資料)コンボリューションを用いた画像の平滑化、鮮鋭化とエッジ検出

はじめに、畳み込み演算を簡単に説明しておきます。簡単な例として、画像処理ソフトウェアで、画像を「ぼかす」フィルターをかける場合を考えます。これは、画像の各ピクセルにおいて、その部分の色をその周りのピクセルの色とまぜて平均化した色に置き換えれば実現できます。白黒256階調などの単色画像であれば、3x3の範囲を考えて、各ピクセルの値(濃度)を下図のような重みで足し合わせたものを中央のピクセルの値(濃度)に置き換えます。

左の例は中央と周りのピクセルをほぼ同じ重みにしており、右の例は中央のピクセルの重みを大きくしています。つまり、左のほうが「ぼかし効果」は、より強くなります。(重みの合計が1になっている点に注意してください。)

このような演算処理を「畳み込み演算」と言います。

―― え? それだけ。

はい。それだけです。

ただ面白いことに、上記の重みのとり方を色々かえると、ぼかす以外の処理も可能になります。たとえば、次の例はどうでしょうか?

重みにマイナスの値が入っていますが、合計値がマイナスの場合は、絶対値をとるものとしてください。落ち着いて考えるとわかるように、これは、横に伸びた線を消去する効果があります。横に同じ色が続いている部分は、左右の±がキャンセルして0になるためです。つまり、これは縦のエッジを抽出する効果があります。同様に、上記フィルターを90度回転させると横のエッジが抽出されます。

あるいは、次のようにフィルターの大きさを5x5に広げると、ある程度の幅を持った縦線を残して、横線を消去することも可能です。このフィルターを適用する実例は、この後で登場します。

2   1   0   -1  -2
3   2   0   -2  -3
4   3   0   -3  -4
3   2   0   -2  -3
2   1   0   -1  -2

画像の特徴とは?

例によって、まずは、極端に簡単化した例題で解説を進めます。表題のチュートリアルでは、0〜9の手書き数字を分類していますが、ここでは、次のような「−」「|」「+」の3種類の図形を分類する問題を考えてみます。

容易に想像できるように、これらの図形は、「縦棒」と「横棒」の2つの変数で分類することが可能です。

図形 縦棒 z_0 横棒 z_1
-1 1
1 -1
1 1

したがって、前述の畳み込み演算を利用して縦棒と横棒を抽出すれば、それぞれの図形を上記の2変数 (z_0, z_1) に変換することができそうです。これができてしまえば、(z_0, z_1) から図形の種類を判別するのは容易です。具体的には、「TensorFlow Tutorialの数学的背景 − MNIST For ML Beginners(その2)」の冒頭で紹介した線形多項分類器(Softmax関数)を用いると、(z_0, z_1) 平面を3つの領域に分けて、(-1,1),(1,-1),(1,1) の3つの値を分類することができます。

TensorFlowによる畳み込み演算の実施

それでは、TensorFlowを利用して、次の3つの処理を実装します。

(1) 元の画像データから縦棒と横棒を抽出する

(2) 抽出したデータを変数 (z_0, z_1) に変換する

(3) 特徴変数 (z_0, z_1) から図形の種類を判定する

事前準備として、元の画像データを作成します。先ほどのデータは、MNISTの手書き数字データから比較的綺麗な「1」の画像を取り出して、それを90度回転したり重ねあわせたりして作っています。次のコードを実行するとデータファイル「mnist_simple.data」が作成されて、画像イメージが一覧表示されます。

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import cPickle as pickle
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
np.random.seed(1)

candidates = [
1,2,3,4,7,8,11,13,15,21,23,26,32,35,39,49,51,52,59,60,62,64,69,75,77,
81,82,85,86,89,90,92,94,95,97,100,107,108,111,113,114,115,120,122,123,
124,126,128,130,131,133,134,135,138,139,146,156,161,165,173,176,177,178,
182,183,184,185,187,188,189,191,193,195,197,198,201,202,203,206,209]

img_data, img_label = [], []
n = 0
c = 0
while(c < len(candidates)):
  imgs, labels = mnist.train.next_batch(1)
  img = imgs[0]
  label = labels[0]
  if label[1] != 1:
    continue
  n += 1
  if n in candidates:
    img0 = img
    img1 = np.transpose(img.reshape(28,28)).flatten()
    img2 = np.array([max(img0[x], img1[x]) for x in range(28*28)])
    img_data.append(img0)
    img_label.append([1,0,0])
    img_data.append(img1)
    img_label.append([0,1,0])
    img_data.append(img2)
    img_label.append([0,0,1])
    c += 1

images = np.c_[np.array(img_label), np.array(img_data)]
np.random.shuffle(images)
img_label, img_data = np.hsplit(images,[3])

with open('mnist_simple.data', 'wb') as file:
      pickle.dump((img_data, img_label), file)

fig = plt.figure()
for i, img in enumerate(img_data):
  i += 1
  subplot = fig.add_subplot(12,20,i)
  subplot.set_xticks([])
  subplot.set_yticks([])
  subplot.imshow(img.reshape(28,28), vmax=1, vmin=0,
                 cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

それでは、(1)の処理を行う畳み込み演算を定義します。コードの全体像は最後にまとめてお見せします。

縦棒と横棒を抽出するフィルターを定義します。

def weight_constant(shape):
  form0 = np.array(
            [[ 2, 1, 0,-1,-2],
             [ 3, 2, 0,-2,-3],
             [ 4, 3, 0,-3,-4],
             [ 3, 2, 0,-2,-3],
             [ 2, 1, 0,-1,-2]]) / 25.0
  form1 = np.array(
            [[ 2, 3, 4, 3, 2],
             [ 1, 2, 3, 2, 1],
             [ 0, 0, 0, 0, 0],
             [-1,-2,-3,-2,-1],
             [-2,-3,-4,-3,-2]]) / 25.0
  form = np.zeros(5*5*1*2).reshape(5,5,1,2)
  form[:,:,0,0] = form0
  form[:,:,0,1] = form1
  return tf.constant(form, dtype=tf.float32)

元データに対して、これらのフィルターを畳み込み演算する処理を書きます。

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
with tf.Graph().as_default():
  with tf.name_scope('input'):
    x = tf.placeholder(tf.float32, [None, 784], name='x-input')
    x_image = tf.reshape(x, [-1,28,28,1])

  with tf.name_scope('convolution'):
    W_conv1 = weight_constant([5,5,1,2])
    h_conv1 = tf.nn.relu(tf.abs(conv2d(x_image, W_conv1))-0.2)
    _ = tf.histogram_summary('h_conv1', h_conv1)

x_image は画像データが入るプレースホルダーです。tf.abs(conv2d(x_image, W_conv1)) で先ほどのフィルターとの畳み込みを行ないます。負の値は絶対値をとるように tf.abs() を被せています。その外側にある tf.nn.relu() は次のように負の値をカットする関数です。

ここでは、フィルターの効果を強調するために、値が0.2以下のピクセル値をカットしています。

この処理によって、1枚の画像に対して、2種類のフィルターがかかった2枚の画像が得られます。これは、ピクセル数で単純に考えるとデータ量が2倍に増えています。そこでフィルター後の画像の解像度を 28×28 から 14×14 に落としてしまいます。

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
  with tf.name_scope('pooling'):
    h_pool1 = max_pool_2x2(h_conv1)
    h_pool1_flat = tf.reshape(h_pool1, [-1, 14*14*2])

CNN(畳み込みニューラルネットワーク)では、このような処理を「プーリング」と言います。この例では、h_pool1 がプーリングした後の画像データになります。実は、ここで解像度を落とすことには、重要な意味があります。今大事なのはあくまで「縦棒」「横棒」という情報なので、高解像度の画像の詳細は、むしろ分類処理に対するノイズになってしまいます。解像度を落とすことで、畳み込みで得られた特徴をさらに強調することになるのです。実際に「畳み込み+プーリング」を行った画像は、次のようになります。解像度を落とすと同時に、画像に「ブレ」の効果を入れています。

左端の図形はフィルターで、その右にいくつかのサンプルが並んでいます。それぞれのフィルターによって縦横の棒が消えて、さらにプーリングによって残った棒が太く強調されていることが分かります。棒の端点が消えずに残っているは・・・、フィルターの特性上仕方がないですね。(「連続した」横棒、縦棒を消去するというフィルターなので、エッジは消えずに残ります。)

そして上記のコードでは、「畳み込み+プーリング」後の2種類の画像のピクセルデータを全部まとめて横一列にならべた 14*14*2 次元ベクトルとして、h_pool1_flat を定義しています。図形の種類よって、この横に長いベクトルのどの辺りに大きな値が集中するかが決まります。「|」は左半分、「−」は右半分、「+」は左右両側、ということです。

したがって、このような特徴に合わせて、横長ベクトルをバイナリ変数  (z_0, z_1) に変換することができそうな気がします。ここが、(2)の処理になります。

特徴変数の学習

とはいえ・・・。14*14*2 次元ベクトルを特徴変数  (z_0, z_1) に変換する関数を見つけるのは簡単ではありません。この部分は、機械学習で自動的に見つけることにします。まず、発見したい関数を次のように表現します。

  z_0 = \tanh({\mathbf h}{\mathbf W_0} + b'_0)

  z_1 = \tanh({\mathbf h}{\mathbf W_1} + b'_1)

{\mathbf h} がピクセル値をならべた 14*14*2 次元ベクトルで、{\mathbf W_0}, {\mathbf W_1}(14*14*2次元の縦ベクトル)と b'_0, b'_1(定数)は未知のパラメーターです。これらのパラメーターを調整することで、欲しい関数が得られるものと期待します。これは、次のように定義できます。

def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1, seed=1)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0, 1, shape=shape)
  return tf.Variable(initial)
  with tf.name_scope('features'):
    W_fc1 = weight_variable([14 * 14 * 2, 2])
    b_fc1 = bias_variable([2])
    h_fc1 = tf.nn.tanh(tf.matmul(h_pool1_flat, W_fc1) + b_fc1)

W_fc1, b_fc1 が未知のパラメーターに対応する変数です。h_fc1 が (z_0, z_1) に対応する配列です。

これで、(z_0, z_1) が決まったものとして、あとはこれらの値から最終結果を判別するsoftmax関数を用意します。先にあげた、(3)の処理になります。

  with tf.name_scope('readout'):
    W_fc2 = weight_variable([2, 3])
    b_fc2 = bias_variable([3])
    y_conv=tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)

y_conv は、[ P_0,P_1,P_2 ] という3成分の配列で、各成分の値が3種類の図形それぞれの確率を表します。最終的には、この確率が最も大きい図形だと判定します。

ここで、softmax関数を簡単に復習すると次のような仕組みでした。まず、次のように、(z_0, z_1) の3つの一次関数 f_0, f_1, f_2 を用意します。

 (f_0,f_1,f_2) = (z_0, z_1)
\left(
\begin{array}{rr}
w_{00} & w_{01} & w_{02} \\
w_{10} & w_{11} & w_{12} \\
\end{array}
\right)+ (b_0,b_1,b_2)

 これらは、3種類のそれぞれの図形に対する「確率」を表すと考えます。ただし、このままでは、値が 0〜1 に収まらないので本当の意味での確率にはなりません。次の変換を施すことで、本当の意味での確率になります。(P_0+P_1+P_2=1 が成立します。)

 P_0 = e^{f_0} / \sum_{k=0}^2 e^{f_k}

 P_1 = e^{f_1} / \sum_{k=0}^2 e^{f_k}

 P_2 = e^{f_2} / \sum_{k=0}^2 e^{f_k}

これが先ほどの y_conv の各成分になります。この関数も未知のパラメーターを含んでおり、これも機械学習で決定します。

最後に、y_conv による判定結果と元のデータに付属する正解ラベルを比較して、正解率が上がるように最適化を行ないます。最適化に使用する関数は次のとおりです。

  with tf.name_scope('optimizer'):
    y_ = tf.placeholder(tf.float32, [None, 3], name='y-input')
    cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
    optimizer = tf.train.GradientDescentOptimizer(0.005)
    train_step = optimizer.minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

cross_entropyが「対数尤度×(-1)」で、これが最小になるように学習します。accuracyは実際の正解率を示します。

ここまでの処理をまとめると、次のようなネットワークを構成したことになります。

いまの場合、左側の特徴抽出の部分は未知のパラメーターを含まない固定的な処理で、右側の判定処理に含まれるパラメーターを機械学習でチューニングする形になります。

コード実行例

以上の処理をまとめたコードが次になります。visualize() は計算結果を表示する関数ですので、ここは無視してその後ろだけみれば十分です。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import cPickle as pickle
import matplotlib.pyplot as plt
import numpy as np

np.random.seed(11)
tf.set_random_seed(11)

with open('mnist_simple.data', 'rb') as file:
  (train_data, train_label) = pickle.load(file)

def visualize(sess):
  fig1 = plt.figure()
  C = 12
  raw_data = train_data[:C]
  res = h_pool1_flat.eval(session=sess, feed_dict={x: train_data[:C]})
  res2 = h_fc1.eval(session=sess, feed_dict={x: train_data[:C]})
  res3 = y_conv.eval(session=sess, feed_dict={x: train_data[:C]})
  w = W_conv1.eval()

  vmax = np.max(res[:C].flatten())
  vmin = np.min(res[:C].flatten())

  # show filters
  for i in range(2):
    subplot = fig1.add_subplot(2+1, C+1, 1+(C+1)*(i+1))
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.imshow(w[:,:,0,i], cmap=plt.cm.gray_r, interpolation='nearest')
  for c in range(C):
    # show raw data
    chars = res[c].reshape(14,14,2)
    subplot = fig1.add_subplot(2+1, C+1, c+2)
    subplot.set_xticks([])
    subplot.set_yticks([])
    subplot.imshow(raw_data[c].reshape(28,28),
                   vmax=1, vmin=0,
                   cmap=plt.cm.gray_r, interpolation='nearest')

    # show filtered data
    for i in range(2):
      subplot = fig1.add_subplot(2+1, C+1, 1+(C+1)*(i+1)+c+1)
      subplot.set_xticks([])
      subplot.set_yticks([])
      subplot.imshow(chars[:,:,i],
                     vmax=vmax, vmin=vmin,
                     cmap=plt.cm.gray_r, interpolation='nearest')

  fig2 = plt.figure()
  res2 = h_fc1.eval(session=sess, feed_dict={x: train_data})
  data0_x, data0_y = [], []
  data1_x, data1_y = [], []
  data2_x, data2_y = [], []
  subplot = fig2.add_subplot(1,1,1)
  for c, label in enumerate(train_label):
    if label[0]:
      data0_x.append(res2[c][0])
      data0_y.append(res2[c][1])
    if label[1]:
      data1_x.append(res2[c][0])
      data1_y.append(res2[c][1])
    if label[2]:
      data2_x.append(res2[c][0])
      data2_y.append(res2[c][1])
  subplot.scatter(data0_x, data0_y, marker='o', color='red')
  subplot.scatter(data1_x, data1_y, marker='o', color='blue')
  subplot.scatter(data2_x, data2_y, marker='o', color='green')
  plt.show()

def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1, seed=1)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0, 1, shape=shape)
  return tf.Variable(initial)

def weight_constant(shape):
  form0 = np.array(
            [[ 2, 1, 0,-1,-2],
             [ 3, 2, 0,-2,-3],
             [ 4, 3, 0,-3,-4],
             [ 3, 2, 0,-2,-3],
             [ 2, 1, 0,-1,-2]]) / 25.0
  form1 = np.array(
            [[ 2, 3, 4, 3, 2],
             [ 1, 2, 3, 2, 1],
             [ 0, 0, 0, 0, 0],
             [-1,-2,-3,-2,-1],
             [-2,-3,-4,-3,-2]]) / 25.0
  form = np.zeros(5*5*1*2).reshape(5,5,1,2)
  form[:,:,0,0] = form0
  form[:,:,0,1] = form1
  return tf.constant(form, dtype=tf.float32)

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

with tf.Graph().as_default():
  with tf.name_scope('input'):
    x = tf.placeholder(tf.float32, [None, 784], name='x-input')
    x_image = tf.reshape(x, [-1,28,28,1])
  
  with tf.name_scope('convolution'):
    W_conv1 = weight_constant([5,5,1,2])
    h_conv1 = tf.nn.relu(tf.abs(conv2d(x_image, W_conv1))-0.2)
    _ = tf.histogram_summary('h_conv1', h_conv1)
  
  with tf.name_scope('pooling'):
    h_pool1 = max_pool_2x2(h_conv1)
    h_pool1_flat = tf.reshape(h_pool1, [-1, 14*14*2])

  with tf.name_scope('features'):
    W_fc1 = weight_variable([14 * 14 * 2, 2])
    b_fc1 = bias_variable([2])
    h_fc1 = tf.nn.tanh(tf.matmul(h_pool1_flat, W_fc1) + b_fc1)
  
  with tf.name_scope('readout'):
    W_fc2 = weight_variable([2, 3])
    b_fc2 = bias_variable([3])
    y_conv=tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
  
  with tf.name_scope('optimizer'):
    y_ = tf.placeholder(tf.float32, [None, 3], name='y-input')
    cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
    optimizer = tf.train.GradientDescentOptimizer(0.005)
    train_step = optimizer.minimize(cross_entropy)

    correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  
  # Logging data for TensorBoard
  _ = tf.scalar_summary('cross entropy', cross_entropy)
  _ = tf.scalar_summary('accuracy', accuracy)
  
  with tf.Session() as sess:
    writer = tf.train.SummaryWriter('/tmp/simple_cnn',
                                    graph_def=sess.graph_def)
    sess.run(tf.initialize_all_variables())
    for i in range(51):
      batch = [train_data, train_label] 
      train_step.run(feed_dict={x: batch[0], y_: batch[1]})
      summary_str, acc = sess.run([tf.merge_all_summaries(), accuracy],
                                  feed_dict={x: batch[0], y_: batch[1]})
      writer.add_summary(summary_str, i)
      print("step %d, training accuracy %g" % (i, acc))
    visualize(sess)

TensorBoardでは、次のグラフが確認できます。

畳み込みは固定演算でパラメーターを調整する必要がないので、Poolingの出力以降の値のみがOptimizerに入っています。扱うデータが単純なので、数回のイテレーションで正解率100%になります。(ここでは、過学習の問題は気にしないでおきます。)

また、コードを実行すると先に示したフィルターとプーリング後の画像サンプルと特徴変数 (z_0,z_1) の分布のグラフが表示されます。分布のグラフは次のとおりです。

3つの図形に対応して、(z_0,z_1) = (-1,1),(1,-1),(1,1) に値が集中していることが分かります。すばらしぃ。

次回予告

今回は、畳み込みのフィルターを固定的に手で与えました。これは分類対象の図形が単純で、最初から取り出すべき特徴がわかっているからできたことです。実際のチュートリアルにある「数字の分類」となるとそういうわけには行きません。そこで、次のステップは、フィルターそのものを機械学習で決定しようという事になります。つまり、最初はフィルターをランダムに用意しておき、分類精度があがるようにフィルターを修正していくというわけです。

次回の記事はこちらです。

enakai00.hatenablog.com