What is this?
enakai00.hatenablog.com
As I mentioned in the article above, you need to make some modifications on your TensorFlow codes when you train the model with the distributed mode. I will explain some key aspects you need know to write TensorFlow codes for the distributed training.
There are some strategies to train a model with multiple nodes, and I will focus on the simplest one, "asynchronous data parallel" where all nodes share the same neural network and calculate a gradient vector independently from some part of the training set (as in the same manner with the mini-batch process). Then variables are updated with the gradient vectors from those nodes. Roughly speaking, if you iterate 10,000 batches with 10 nodes, each node works on 1,000 batches in parallel. You can find more details in the Google research paper.
Note: The code snippets in this note are based on TensorFlow r0.12.
Basic architecture
The distributed training is done by the three players:
- Parameter server:Update variables with gradient vectors from workers (and a master).
- Worker:Calculate a gradient vector from the training set.
- Master:Coordinates the operations of workers. (A master can be one of workers, and can do some additional house keeping tasks if necessary.)
In typical deployments, there are a few parameter servers, a single master and a bunch of workers. When you submit a job into Cloud ML, these nodes are created in containers and your code starts running on them.
Getting a cluster configuration and a role of the node
Since the same code runs in all nodes, it needs to branch the operation based on the node's role. So first of all, you need to get a cluster configuration and a role of the node with the following code.
# Get cluster and node info env = json.loads(os.environ.get('TF_CONFIG', '{}')) cluster_info = env.get('cluster', None) cluster_spec = tf.train.ClusterSpec(cluster_info) task_info = env.get('task', None) job_name, task_index = task_info['type'], task_info['index']
It retrieves the necessary information from the environmental variable TF_CONFIG and stores the following information.
- cluster_spec: the cluster configuration
- job_name: the node's role (ps, master, worker)
- task_info: a serial number starting from 0 to distinguish nodes with the same role.
By passing these data, you can create a server object which handles the communication with other nodes.
server = tf.train.Server(cluster_spec, job_name=job_name, task_index=task_index)
Code for parameter servers
If the job_name is 'ps', you simply start the server object and it works as a parameter server. That's it. The code never returns until terminated by an external signal.
if job_name == "ps": # Parameter server server.join()
Code for workers
On workers and a master, you iterate the variable update loops with a coordination mechanism provided by a Supervisor object. Tasks which require a coordination such as creating a session and saving a checkpoint are done through the supervisor. Internally, a master node works as a chief coordinator.
The following code creates a Supervisor object.
if job_name == "master" or job_name == "worker": # Worker node is_chief = (job_name == "master") ... global_step = tf.Variable(0, trainable=False) init_op = tf.global_variables_initializer() saver = tf.train.Saver() # Create a supervisor sv = tf.train.Supervisor(is_chief=is_chief, logdir=LOG_DIR, init_op=init_op, saver=saver, summary_op=None, global_step=global_step, save_model_secs=0)
Options:
- is_chief:True if the node is a master.
- logdir:a directory to store checkpoint files.
- init_op:a variable initialization operation at the session creation.
- saver:a Saver object to save checkpoint files.
- global_step:a global counter of the training loop.
- summary_op:a summary operation (used by TensorBoard). None if you save summary logs by hand without a supervisor.
- save_model_secs:time interval to save checkpoints. None if you save checkpoints by hand without a supervisor.
When you define a model, it must be done inside a "with tf.device" clause which specifies the node's role through job_name and task_index.
device_fn = tf.train.replica_device_setter( cluster=cluster_spec, worker_device="/job:%s/task:%d" % (job_name, task_index) ) ... if job_name == "master" or job_name == "worker": # Worker node is_chief = (job_name == "master") with tf.Graph().as_default() as graph: with tf.device(device_fn):
Now you can create a session and start the training loop. The following is a simple example to clarify some key points.
# Create a session and run training loops with sv.managed_session(server.target) as sess: reports, step = 0, 0 start_time = time.time() while not sv.should_stop() and step < MAX_STEPS: images, labels = mnist_data.train.next_batch(BATCH_SIZE) feed_dict = {x:images, t:labels, keep_prob:0.5} _, loss_val, step = sess.run([train_step, loss, global_step], feed_dict=feed_dict) if step > CHECKPOINT * reports: reports += 1 logging.info("Step: %d, Train loss: %f" % (step, loss_val)) if is_chief: # Save checkpoint sv.saver.save(sess, sv.save_path, global_step=step) ... # Save summary feed_dict = {test_loss:loss_val, test_accuracy:acc_val} sv.summary_computed(sess, sess.run(summary, feed_dict=feed_dict), step) sv.summary_writer.flush()
The first point is you evaluate the variable "global step" together with the optimization algorithm "train_step". It increases the global_step by one and you can see the total iteration count from the resulting value (stored in "step" in this example.) In this example, by checking the value of step, each node outputs the training step and training loss periodically with the interval of CHECKPOINT. In addition, it does some additional tasks on the master (by checking is_cheif variable defined before) such as saving a checkpoint file and saving a summary log. These tasks are done through the Supervisor object 'sv'. (sv.save_path corresponds to the directory specified with the logdir option.)
Exporting the trained model
One of the tricky points in the distributed training is to export the trained model. The model defined in the previous code has some additional cluster and node information, but these are unnecessary when you restore the model for predictions, especially, on a single node. So you need to build another model which can be restored on a single node, and then, export it.
This should be done at the exit point of the training loop. In the following example, it saves a checkpoint file with the latest variables and calls a model export function.
if is_chief: # Export the final model sv.saver.save(sess, sv.save_path, global_step=sess.run(global_step)) export_model(tf.train.latest_checkpoint(LOG_DIR))
The following is an example of the model exporter which builds a new model for predictions and restores the latest checkpoint. The extra variables in the checkpoint (which are not used in the new model) are simply ignored. Then the model is exported with the saver object. The API inputs/outputs objects stored in a collection are required by the Cloud ML's prediction service.
def export_model(last_checkpoint): # create a session with a new graph. with tf.Session(graph=tf.Graph()) as sess: x = tf.placeholder(tf.float32, [None, 784]) p = mnist.get_model(x, None, training=False) # Define key elements input_key = tf.placeholder(tf.int64, [None,]) output_key = tf.identity(input_key) # Define API inputs/outpus object inputs = {'key': input_key.name, 'image': x.name} outputs = {'key': output_key.name, 'scores': p.name} tf.add_to_collection('inputs', json.dumps(inputs)) tf.add_to_collection('outputs', json.dumps(outputs)) init_op = tf.global_variables_initializer() sess.run(init_op) # Restore the latest checkpoint and save the model saver = tf.train.Saver() saver.restore(sess, last_checkpoint) saver.export_meta_graph(filename=MODEL_DIR + '/export.meta') saver.save(sess, MODEL_DIR + '/export', write_meta_graph=False)
The full example code.
trainer ├── __init__.py # empty file. ├── mnist.py # model definition. └── task.py # training code.
Note that the model definition is just a quick example of CNN for MNIST dataset.
trainer/task.py
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np import time, json, os, logging import mnist flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_integer('batch_size', 100, 'Batch size. Must divide evenly into the dataset sizes.') flags.DEFINE_integer('max_steps', 10000, 'Number of steps to run trainer.') flags.DEFINE_integer('checkpoint', 100, 'Interval steps to save checkpoint.') flags.DEFINE_string('log_dir', '/tmp/logs', 'Directory to store checkpoints and summary logs') flags.DEFINE_string('model_dir', '/tmp/model', 'Directory to store trained model') # Global flags BATCH_SIZE = FLAGS.batch_size MODEL_DIR = FLAGS.model_dir LOG_DIR = FLAGS.log_dir MAX_STEPS = FLAGS.max_steps CHECKPOINT = FLAGS.checkpoint def export_model(last_checkpoint): # create a session with a new graph. with tf.Session(graph=tf.Graph()) as sess: x = tf.placeholder(tf.float32, [None, 784]) p = mnist.get_model(x, None, training=False) # Define key elements input_key = tf.placeholder(tf.int64, [None,]) output_key = tf.identity(input_key) # Define API inputs/outpus object inputs = {'key': input_key.name, 'image': x.name} outputs = {'key': output_key.name, 'scores': p.name} tf.add_to_collection('inputs', json.dumps(inputs)) tf.add_to_collection('outputs', json.dumps(outputs)) init_op = tf.global_variables_initializer() sess.run(init_op) # Restore the latest checkpoint and save the model saver = tf.train.Saver() saver.restore(sess, last_checkpoint) saver.export_meta_graph(filename=MODEL_DIR + '/export.meta') saver.save(sess, MODEL_DIR + '/export', write_meta_graph=False) def run_training(): # Get cluster and node info env = json.loads(os.environ.get('TF_CONFIG', '{}')) cluster_info = env.get('cluster', None) cluster_spec = tf.train.ClusterSpec(cluster_info) task_info = env.get('task', None) job_name, task_index = task_info['type'], task_info['index'] device_fn = tf.train.replica_device_setter( cluster=cluster_spec, worker_device="/job:%s/task:%d" % (job_name, task_index) ) logging.info('Start job:%s, index:%d' % (job_name, task_index)) # Create server server = tf.train.Server(cluster_spec, job_name=job_name, task_index=task_index) if job_name == "ps": # Parameter server server.join() if job_name == "master" or job_name == "worker": # Worker node is_chief = (job_name == "master") with tf.Graph().as_default() as graph: with tf.device(device_fn): # Prepare training data mnist_data = input_data.read_data_sets("/tmp/data/", one_hot=True) # Create placeholders x = tf.placeholder_with_default( tf.zeros([BATCH_SIZE, 784], tf.float32), shape=[None, 784]) t = tf.placeholder_with_default( tf.zeros([BATCH_SIZE, 10], tf.float32), shape=[None, 10]) keep_prob = tf.placeholder_with_default( tf.zeros([], tf.float32), shape=[]) global_step = tf.Variable(0, trainable=False) # Add test loss and test accuracy to summary test_loss = tf.placeholder_with_default( tf.zeros([], tf.float32), shape=[]) test_accuracy = tf.placeholder_with_default( tf.zeros([], tf.float32), shape=[]) tf.summary.scalar("Test_loss", test_loss) tf.summary.scalar("Test_accuracy", test_accuracy) # Define a model p = mnist.get_model(x, keep_prob, training=True) train_step, loss, accuracy = mnist.get_trainer(p, t, global_step) init_op = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() # Create a supervisor sv = tf.train.Supervisor(is_chief=is_chief, logdir=LOG_DIR, init_op=init_op, saver=saver, summary_op=None, global_step=global_step, save_model_secs=0) # Create a session and run training loops with sv.managed_session(server.target) as sess: reports, step = 0, 0 start_time = time.time() while not sv.should_stop() and step < MAX_STEPS: images, labels = mnist_data.train.next_batch(BATCH_SIZE) feed_dict = {x:images, t:labels, keep_prob:0.5} _, loss_val, step = sess.run([train_step, loss, global_step], feed_dict=feed_dict) if step > CHECKPOINT * reports: reports += 1 logging.info("Step: %d, Train loss: %f" % (step, loss_val)) if is_chief: # Save checkpoint sv.saver.save(sess, sv.save_path, global_step=step) # Evaluate the test loss and test accuracy loss_vals, acc_vals = [], [] for _ in range(len(mnist_data.test.labels) // BATCH_SIZE): images, labels = mnist_data.test.next_batch(BATCH_SIZE) feed_dict = {x:images, t:labels, keep_prob:1.0} loss_val, acc_val = sess.run([loss, accuracy], feed_dict=feed_dict) loss_vals.append(loss_val) acc_vals.append(acc_val) loss_val, acc_val = np.sum(loss_vals), np.mean(acc_vals) # Save summary feed_dict = {test_loss:loss_val, test_accuracy:acc_val} sv.summary_computed(sess, sess.run(summary, feed_dict=feed_dict), step) sv.summary_writer.flush() logging.info("Time elapsed: %d" % (time.time() - start_time)) logging.info("Step: %d, Test loss: %f, Test accuracy: %f" % (step, loss_val, acc_val)) # Finish training if is_chief: # Export the final model sv.saver.save(sess, sv.save_path, global_step=sess.run(global_step)) export_model(tf.train.latest_checkpoint(LOG_DIR)) sv.stop() def main(_): run_training() if __name__ == '__main__': logging.basicConfig(level=logging.INFO) tf.app.run()
trainer/mnist.py
import tensorflow as tf import json def get_model(x, keep_prob, training=True): num_filters1 = 32 num_filters2 = 64 with tf.name_scope('cnn'): with tf.name_scope('convolution1'): x_image = tf.reshape(x, [-1,28,28,1]) W_conv1 = tf.Variable(tf.truncated_normal([5,5,1,num_filters1], stddev=0.1)) h_conv1 = tf.nn.conv2d(x_image, W_conv1, strides=[1,1,1,1], padding='SAME') b_conv1 = tf.Variable(tf.constant(0.1, shape=[num_filters1])) h_conv1_cutoff = tf.nn.relu(h_conv1 + b_conv1) h_pool1 = tf.nn.max_pool(h_conv1_cutoff, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') with tf.name_scope('convolution2'): W_conv2 = tf.Variable( tf.truncated_normal([5,5,num_filters1,num_filters2], stddev=0.1)) h_conv2 = tf.nn.conv2d(h_pool1, W_conv2, strides=[1,1,1,1], padding='SAME') b_conv2 = tf.Variable(tf.constant(0.1, shape=[num_filters2])) h_conv2_cutoff = tf.nn.relu(h_conv2 + b_conv2) h_pool2 = tf.nn.max_pool(h_conv2_cutoff, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') with tf.name_scope('fully-connected'): h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*num_filters2]) num_units1 = 7*7*num_filters2 num_units2 = 1024 w2 = tf.Variable(tf.truncated_normal([num_units1, num_units2])) b2 = tf.Variable(tf.constant(0.1, shape=[num_units2])) hidden2 = tf.nn.relu(tf.matmul(h_pool2_flat, w2) + b2) with tf.name_scope('output'): if training: hidden2_drop = tf.nn.dropout(hidden2, keep_prob) else: hidden2_drop = hidden2 w0 = tf.Variable(tf.zeros([num_units2, 10])) b0 = tf.Variable(tf.zeros([10])) p = tf.nn.softmax(tf.matmul(hidden2_drop, w0) + b0) tf.summary.histogram("conv_filters1", W_conv1) tf.summary.histogram("conv_filters2", W_conv2) return p def get_trainer(p, t, global_step): with tf.name_scope('optimizer'): loss = -tf.reduce_sum(t * tf.log(p), name='loss') train_step = tf.train.AdamOptimizer(0.0001).minimize(loss, global_step=global_step) with tf.name_scope('evaluator'): correct_prediction = tf.equal(tf.argmax(p, 1), tf.argmax(t, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy') return train_step, loss, accuracy
You can train the model on Cloud ML by running the following commands from the Cloud Shell. You should run the commands in the parent directory of "trainer".
$ PROJECT_ID=project01 # your project ID $ TRAIN_BUCKET="gs://$PROJECT_ID-mldata" $ gsutil mkdir $TRAIN_BUCKET $ cat << EOF > config.yaml trainingInput: # Use a cluster with many workers and a few parameter servers. scaleTier: STANDARD_1 EOF $ JOB_NAME="job01" $ gsutil rm -rf $TRAIN_BUCKET/$JOB_NAME $ touch .dummy $ gsutil cp .dummy $TRAIN_BUCKET/$JOB_NAME/train/ $ gsutil cp .dummy $TRAIN_BUCKET/$JOB_NAME/model/ $ gcloud beta ml jobs submit training ${JOB_NAME} \ --package-path=trainer \ --module-name=trainer.task \ --staging-bucket="${TRAIN_BUCKET}" \ --region=us-central1 \ --config=config.yaml \ -- \ --log_dir=$TRAIN_BUCKET/$JOB_NAME/train \ --model_dir=$TRAIN_BUCKET/$JOB_NAME/model \ --max_steps=10000
You can also see the training process on TensorBoard with the following command.
$ tensorboard --port 8080 --logdir $TRAIN_BUCKET/$JOB_NAME/train
After the training, the resulting model is stored under $TRAIN_BUCKET/$JOB_NAME/model. The following is a quick example of using the exported model for a prediction.
#!/usr/bin/python import tensorflow as tf import numpy as np import json from tensorflow.examples.tutorials.mnist import input_data model_meta = 'gs://project01-mldata/job01/model/export.meta' model_param = 'gs://project01-mldata/job01/model/export' with tf.Graph().as_default() as graph: sess = tf.InteractiveSession() saver = tf.train.import_meta_graph(model_meta) saver.restore(sess, model_param) inputs = json.loads(tf.get_collection('inputs')[0]) outputs = json.loads(tf.get_collection('outputs')[0]) x = graph.get_tensor_by_name(inputs['image']) input_key = graph.get_tensor_by_name(inputs['key']) p = graph.get_tensor_by_name(outputs['scores']) output_key = graph.get_tensor_by_name(outputs['key']) mnist_data = input_data.read_data_sets("/tmp/data/", one_hot=True) images, labels = mnist_data.test.next_batch(10) index = range(10) keys, preds = sess.run([output_key, p], feed_dict={input_key:index, x:images}) for key, pred, label in zip(keys, preds, labels): print key, np.argmax(pred), np.argmax(label)
The more sophisticated example
The example used in this note is straightforward but may not practical. If you are interested in a more sophisticated example, see the following one.
Disclaimer: All code snippets are released under Apache 2.0 License. This is not an official Google product.