めもめも

このブログに記載の内容は個人の見解であり、必ずしも所属組織の立場、戦略、意見を代表するものではありません。

分散学習用TensorFlowコードの書き方

何の話かというと

Google Cloud MLを利用して、TensorFlowの分散学習を行う方法です。取り急ぎ、自分用のメモとして公開しておきます。

分散学習にはいくつかのパターンがありますが、最もシンプルな「データ分散」の場合を説明します。各ノードは同じモデルに対して、個別に学習データを適用して、Variableを修正する勾配ベクトルを計算します。それぞれで計算した勾配ベクトルを用いて、共通のVariableを修正していきます。

前提知識

TensorFlowの分散学習処理を行う際は、3種類のノードを使用します。

・Parameter Server:Workerが計算した勾配ベクトルを用いて、Variableのアップデートを行います。
・Worker:教師データから勾配ベクトルを計算します。
・Master:Workerと同様の処理に加えて、学習済みモデルの保存やテストセットに対する評価などの追加処理を行います。

一般にParameter Serverは1〜2ノード、Workerは必要に応じた沢山のノード、Masterは1ノードだけで動かします。Cloud MLにジョブを投げると、これらのノード群がコンテナで生成されて、各ノードでコードの実行が行われます。

クラスター構成情報とノードの役割の取得

TensorFlowのコードは全ノードで共通ですが、コード内でノードの種類に応じて処理を分岐させます。コード内では環境変数 TF_CONFIG を通じて、クラスターの構成情報と自分自身の役割を取得します。次のコードは、クラスター構成情報オブジェクト cluster_spec を作成して、job_name と task_index にノードの役割を格納します。

  # Get cluster and node info
  env = json.loads(os.environ.get('TF_CONFIG', '{}'))
  cluster_info = env.get('cluster', None)
  cluster_spec = tf.train.ClusterSpec(cluster_info)
  task_info = env.get('task', None)
  job_name, task_index = task_info['type'], task_info['index']

job_name には、「ps」「master」「worker」のいずれかの文字列が入ります。task_index は同じ種類のノードが複数ある際に 0 からの通し番号が入ります。これらの情報を用いて、次のように Server オブジェクトを作成すると、このオブジェクトがクラスター内の他のノードとの通信処理を担います。

  server = tf.train.Server(cluster_spec,
                           job_name=job_name, task_index=task_index)

Parameter Serverの処理

Parameter Serverの場合は、次のコマンドで Server オブジェクトを起動すれば、それで必要な処理は終わりです。あとは Server オブジェクトが Parameter Server としての機能を提供してくれます。このコマンドは外部からプロセスを停止するまで、戻ってくることはありません。

  if job_name == "ps": # Parameter server
    server.join()

Workerの処理

Worker(および、Master)では、学習処理のループを回す必要がありますが、この際、他のノードと協調動作するために、Supervisorオブジェクトを生成した後に、このオブジェクトを経由して、チェックポイントファイルの保存やセッションの作成といった処理を実施します。

次は、Supervisorオブジェクトを生成するコードの例です。

  if job_name == "master" or job_name == "worker": # Worker node
    is_chief = (job_name == "master")
...
        global_step = tf.Variable(0, trainable=False)
        init_op = tf.global_variables_initializer()
        saver = tf.train.Saver()
        # Create a supervisor
        sv = tf.train.Supervisor(is_chief=is_chief, logdir=LOG_DIR,
                                 init_op=init_op, saver=saver, summary_op=None,
                                 global_step=global_step, save_model_secs=0)

・is_chief:Masterノードの場合に True を渡します。
・logdir:チェックポイントファイルの保存ディレクトリー
・init_op:セッション作成時のVariable初期化処理
・saver:チェックポイントファイルを保存するためのSaverオブジェクト
・global_step:最適化処理の実施回数をカウントするVariable
・summary_op:TensorBoard用のサマリーオブジェクト(Supervisorオブジェクトを介さずにサマリーを保存する際はNoneを指定)
・save_model_secs:チェックポイントの定期保存間隔(自動保存ではなく、明示的に保存する際はNoneを指定)

また、モデルを定義する際は、次の様に、job_name と task_index を用いて、自分の役割を tf.device で設定した with 構文の中で定義していきます。

  device_fn = tf.train.replica_device_setter(
    cluster=cluster_spec,
    worker_device="/job:%s/task:%d" % (job_name, task_index)
  )
...
  if job_name == "master" or job_name == "worker": # Worker node
    is_chief = (job_name == "master")

    with tf.Graph().as_default() as graph:
      with tf.device(device_fn):

そして次は、セッションを生成して、学習処理のループを回す部分です。

        # Create a session and run training loops
        with sv.managed_session(server.target) as sess:
          reports, step = 0, 0
          start_time = time.time()
          while not sv.should_stop() and step < MAX_STEPS:
             images, labels = mnist_data.train.next_batch(BATCH_SIZE)
             feed_dict = {x:images, t:labels, keep_prob:0.5}
             _, loss_val, step = sess.run([train_step, loss, global_step],
                                          feed_dict=feed_dict)
             if step > CHECKPOINT * reports:
               reports += 1
               logging.info("Step: %d, Train loss: %f" % (step, loss_val))
               if is_chief:
                 # Save checkpoint
                 sv.saver.save(sess, sv.save_path, global_step=step)
...
                 # Save summary
                 feed_dict = {test_loss:loss_val, test_accuracy:acc_val}
                 sv.summary_computed(sess,
                   sess.run(summary, feed_dict=feed_dict), step)
                 sv.summary_writer.flush()

ここでのポイントは、最適化アルゴリズム train_step を評価して Variable をアップデートする際に、global_step を一緒に評価する点です。これにより、global_step の値が 1 増加して、他のノードを含めた学習処理のトータルの回数が取得できます。この例では取得した値を変数 step に格納して、全体として CHECKPOINT 回評価するごとに進捗をログ出力するということを行っています。また、先に定義しておいた is_chief (Masterの場合に True)を用いて、Masterだけで追加の処理をすることもできます。この例では、サマリーの出力とチェックポイントの保存を行っています。sv.save_path には、Supervisorを作成した時に logdir で指定したディレクトリーが入ります。

学習済みモデルの保存

分散学習でトリッキーな点の1つが学習済みモデルの保存方法です。先に構築したモデルは、分散学習用にノードの情報がひも付いていますが、学習済みモデルで分類を行う際は、これらの情報は不要です。そこで、単体ノードでもリストア可能な分類処理専用のモデルを再構築して、それを保存するという処理を行います。

まず、学習処理のループを抜けた部分で次を実行します。これは、最終状態の Variable を一旦チェックポイントファイルに保存して、それからモデルの再構築・保存処理(export_model)を呼び出しています。

          if is_chief: # Export the final model
            sv.saver.save(sess, sv.save_path, global_step=sess.run(global_step))
            export_model(tf.train.latest_checkpoint(LOG_DIR))

そして、再構築・保存処理の例は次のようになります。

def export_model(last_checkpoint):
  # create a session with a new graph.
  with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.float32, [None, 784])
    p = mnist.get_model(x, None, training=False)

    # Define key elements
    input_key = tf.placeholder(tf.int64, [None,])
    output_key = tf.identity(input_key)

    # Define API inputs/outpus object
    inputs = {'key': input_key.name, 'image': x.name}
    outputs = {'key': output_key.name, 'scores': p.name}
    tf.add_to_collection('inputs', json.dumps(inputs))
    tf.add_to_collection('outputs', json.dumps(outputs))

    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # Restore the latest checkpoint and save the model
    saver = tf.train.Saver()
    saver.restore(sess, last_checkpoint)
    saver.export_meta_graph(filename=MODEL_DIR + '/export.meta')
    saver.save(sess, MODEL_DIR + '/export',
               write_meta_graph=False)

ここでは、新たなグラフとセッションを用意して、入力 x に対して、予測結果 p を計算する最低限の関係を定義した後に、先ほど保存したチェックポイントの内容をリストアしています。このセッションに含まれていない Variable の値は単純に無視されます。また、分類用コードに入出力変数を渡すために、入出力に関連した変数名を JSON にまとめたものを collection に入れてあります。output_key は、Placeholder の input_key にいれた値がそのまま出てくる変数で、複数データをバッチ処理する際にどの出力データがどの入力データに対応するかを紐付けるために使用します。

コードの全体像

モデル定義の中身は適当に用意した MNIST 用 CNN です。

trainer
├── __init__.py # 空ファイル
├── mnist.py # モデル定義
└── task.py # 学習処理用コード


trainer/task.py

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import time, json, os, logging

import mnist

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', 100,
                     'Batch size. Must divide evenly into the dataset sizes.')
flags.DEFINE_integer('max_steps', 10000, 'Number of steps to run trainer.')
flags.DEFINE_integer('checkpoint', 100, 'Interval steps to save checkpoint.')
flags.DEFINE_string('log_dir', '/tmp/logs',
                    'Directory to store checkpoints and summary logs')
flags.DEFINE_string('model_dir', '/tmp/model',
                    'Directory to store trained model')


# Global flags
BATCH_SIZE = FLAGS.batch_size
MODEL_DIR = FLAGS.model_dir
LOG_DIR = FLAGS.log_dir
MAX_STEPS = FLAGS.max_steps
CHECKPOINT = FLAGS.checkpoint


def export_model(last_checkpoint):
  # create a session with a new graph.
  with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.float32, [None, 784])
    p = mnist.get_model(x, None, training=False)

    # Define key elements
    input_key = tf.placeholder(tf.int64, [None,])
    output_key = tf.identity(input_key)

    # Define API inputs/outpus object
    inputs = {'key': input_key.name, 'image': x.name}
    outputs = {'key': output_key.name, 'scores': p.name}
    tf.add_to_collection('inputs', json.dumps(inputs))
    tf.add_to_collection('outputs', json.dumps(outputs))

    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    # Restore the latest checkpoint and save the model
    saver = tf.train.Saver()
    saver.restore(sess, last_checkpoint)
    saver.export_meta_graph(filename=MODEL_DIR + '/export.meta')
    saver.save(sess, MODEL_DIR + '/export',
               write_meta_graph=False)
  

def run_training():
  # Get cluster and node info
  env = json.loads(os.environ.get('TF_CONFIG', '{}'))
  cluster_info = env.get('cluster', None)
  cluster_spec = tf.train.ClusterSpec(cluster_info)
  task_info = env.get('task', None)
  job_name, task_index = task_info['type'], task_info['index']

  device_fn = tf.train.replica_device_setter(
    cluster=cluster_spec,
    worker_device="/job:%s/task:%d" % (job_name, task_index)
  )

  logging.info('Start job:%s, index:%d' % (job_name, task_index))

  # Create server
  server = tf.train.Server(cluster_spec,
                           job_name=job_name, task_index=task_index)

  if job_name == "ps": # Parameter server
    server.join()

  if job_name == "master" or job_name == "worker": # Worker node
    is_chief = (job_name == "master")

    with tf.Graph().as_default() as graph:
      with tf.device(device_fn):

        # Prepare training data
        mnist_data = input_data.read_data_sets("/tmp/data/", one_hot=True)

        # Create placeholders
        x = tf.placeholder_with_default(
          tf.zeros([BATCH_SIZE, 784], tf.float32), shape=[None, 784])
        t = tf.placeholder_with_default(
          tf.zeros([BATCH_SIZE, 10], tf.float32), shape=[None, 10])
        keep_prob = tf.placeholder_with_default(
          tf.zeros([], tf.float32), shape=[])
        global_step = tf.Variable(0, trainable=False)

        # Add test loss and test accuracy to summary
        test_loss = tf.placeholder_with_default(
          tf.zeros([], tf.float32), shape=[])
        test_accuracy = tf.placeholder_with_default(
          tf.zeros([], tf.float32), shape=[])
        tf.summary.scalar("Test_loss", test_loss) 
        tf.summary.scalar("Test_accuracy", test_accuracy) 

        # Define a model
        p = mnist.get_model(x, keep_prob, training=True)
        train_step, loss, accuracy = mnist.get_trainer(p, t, global_step)

        init_op = tf.global_variables_initializer()
        saver = tf.train.Saver()
        summary = tf.summary.merge_all()

        # Create a supervisor
        sv = tf.train.Supervisor(is_chief=is_chief, logdir=LOG_DIR,
                                 init_op=init_op, saver=saver, summary_op=None,
                                 global_step=global_step, save_model_secs=0)
    
        # Create a session and run training loops
        with sv.managed_session(server.target) as sess:
          reports, step = 0, 0
          start_time = time.time()
          while not sv.should_stop() and step < MAX_STEPS:
             images, labels = mnist_data.train.next_batch(BATCH_SIZE)
             feed_dict = {x:images, t:labels, keep_prob:0.5}
             _, loss_val, step = sess.run([train_step, loss, global_step],
                                          feed_dict=feed_dict)
             if step > CHECKPOINT * reports:
               reports += 1
               logging.info("Step: %d, Train loss: %f" % (step, loss_val))
               if is_chief:
                 # Save checkpoint
                 sv.saver.save(sess, sv.save_path, global_step=step)

                 # Evaluate the test loss and test accuracy
                 loss_vals, acc_vals = [], []
                 for _ in range(len(mnist_data.test.labels) // BATCH_SIZE):
                   images, labels = mnist_data.test.next_batch(BATCH_SIZE)
                   feed_dict = {x:images, t:labels, keep_prob:1.0}
                   loss_val, acc_val = sess.run([loss, accuracy],
                                                feed_dict=feed_dict)
                   loss_vals.append(loss_val)
                   acc_vals.append(acc_val)
                 loss_val, acc_val = np.sum(loss_vals), np.mean(acc_vals)

                 # Save summary
                 feed_dict = {test_loss:loss_val, test_accuracy:acc_val}
                 sv.summary_computed(sess,
                   sess.run(summary, feed_dict=feed_dict), step)
                 sv.summary_writer.flush()

                 logging.info("Time elapsed: %d" % (time.time() - start_time))
                 logging.info("Step: %d, Test loss: %f, Test accuracy: %f" %
                              (step, loss_val, acc_val))

          # Finish training
          if is_chief: # Export the final model
            sv.saver.save(sess, sv.save_path, global_step=sess.run(global_step))
            export_model(tf.train.latest_checkpoint(LOG_DIR))

        sv.stop()  


def main(_):
  run_training()


if __name__ == '__main__':
  logging.basicConfig(level=logging.INFO) 
  tf.app.run()

trainer/mnist.py

import tensorflow as tf
import json

def get_model(x, keep_prob, training=True):
  num_filters1 = 32
  num_filters2 = 64

  with tf.name_scope('cnn'):
    with tf.name_scope('convolution1'):
      x_image = tf.reshape(x, [-1,28,28,1])
      
      W_conv1 = tf.Variable(tf.truncated_normal([5,5,1,num_filters1],
                                                stddev=0.1))
      h_conv1 = tf.nn.conv2d(x_image, W_conv1,
                             strides=[1,1,1,1], padding='SAME')
      
      b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1]))
      h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1)
      
      h_pool1 = tf.nn.max_pool(h_conv1_cutoff, ksize=[1,2,2,1],
                               strides=[1,2,2,1], padding='SAME')

    with tf.name_scope('convolution2'):
      W_conv2 = tf.Variable(
                  tf.truncated_normal([5,5,num_filters1,num_filters2],
                                      stddev=0.1))
      h_conv2 = tf.nn.conv2d(h_pool1, W_conv2,
                             strides=[1,1,1,1], padding='SAME')
      
      b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2]))
      h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2)
      
      h_pool2 = tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1],
                               strides=[1,2,2,1], padding='SAME')

    with tf.name_scope('fully-connected'):
      h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2])
      num_units1 = 7*7*num_filters2
      num_units2 = 1024
      w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2]))
      b2 = tf.Variable(tf.constant(0.1, shape=[num_units2]))
      hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2)

    with tf.name_scope('output'):
      if training:
        hidden2_drop = tf.nn.dropout(hidden2, keep_prob)
      else:
        hidden2_drop = hidden2
      w0 = tf.Variable(tf.zeros([num_units2, 10]))
      b0 = tf.Variable(tf.zeros([10]))
      p = tf.nn.softmax(tf.matmul(hidden2_drop, w0) + b0)

  tf.summary.histogram("conv_filters1", W_conv1)
  tf.summary.histogram("conv_filters2", W_conv2)

  return p

  
def get_trainer(p, t, global_step):
  with tf.name_scope('optimizer'):
    loss = -tf.reduce_sum(t * tf.log(p), name='loss')
    train_step = tf.train.AdamOptimizer(0.0001).minimize(loss, global_step=global_step)
      
  with tf.name_scope('evaluator'):
    correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,
                                      tf.float32), name='accuracy')

  return train_step, loss, accuracy

このコードを使って、Cloud MLで学習する際は、Cloud Shellから次のコマンドを実行していきます。(trainerディレクトリーの親ディクレクトリーで実行します。)

$ PROJECT_ID=project01 # your project ID
$ TRAIN_BUCKET="gs://$PROJECT_ID-mldata"
$ gsutil mkdir $TRAIN_BUCKET

$ cat << EOF > config.yaml
trainingInput:
  # Use a cluster with many workers and a few parameter servers.
  scaleTier: STANDARD_1
EOF

$ JOB_NAME="job01"
$ gsutil rm -rf $TRAIN_BUCKET/$JOB_NAME
$ touch .dummy
$ gsutil cp .dummy $TRAIN_BUCKET/$JOB_NAME/train/
$ gsutil cp .dummy $TRAIN_BUCKET/$JOB_NAME/model/

$ gcloud beta ml jobs submit training ${JOB_NAME} \
  --package-path=trainer \
  --module-name=trainer.task \
  --staging-bucket="${TRAIN_BUCKET}" \
  --region=us-central1 \
  --config=config.yaml \
  -- \
  --log_dir=$TRAIN_BUCKET/$JOB_NAME/train \
  --model_dir=$TRAIN_BUCKET/$JOB_NAME/model \
  --max_steps=10000

TensorBoardで進捗を見る時は、次を実行します。

$  tensorboard --port 8080 --logdir $TRAIN_BUCKET/$JOB_NAME/train

学習が終わると、「$TRAIN_BUCKET/$JOB_NAME/model」以下に学習済みモデル(export.meta、および、exprot-data.xxxx)が出力されます。

学習済みモデルをリストアして、予測処理を行うコードの例は次になります。

#!/usr/bin/python
import tensorflow as tf
import numpy as np
import json
from tensorflow.examples.tutorials.mnist import input_data

model_meta = 'gs://project01-mldata/job01/model/export.meta'
model_param = 'gs://project01-mldata/job01/model/export'

with tf.Graph().as_default() as graph:
  sess = tf.InteractiveSession()
  saver = tf.train.import_meta_graph(model_meta)
  saver.restore(sess, model_param)

  inputs = json.loads(tf.get_collection('inputs')[0])
  outputs = json.loads(tf.get_collection('outputs')[0])
  x = graph.get_tensor_by_name(inputs['image'])
  input_key = graph.get_tensor_by_name(inputs['key'])
  p = graph.get_tensor_by_name(outputs['scores'])
  output_key = graph.get_tensor_by_name(outputs['key'])

  mnist_data = input_data.read_data_sets("/tmp/data/", one_hot=True)
  images, labels = mnist_data.test.next_batch(10)
  index = range(10)
  keys, preds = sess.run([output_key, p], feed_dict={input_key:index, x:images})
  for key, pred, label in zip(keys, preds, labels):
    print key, np.argmax(pred), np.argmax(label)


Disclaimer: All code snippets are released under Apache 2.0 License. This is not an official Google product.