Skip to content

Instantly share code, notes, and snippets.

@PatWie
Last active April 27, 2017 15:58
Show Gist options
  • Save PatWie/89950d3f8491a9f0d84125dca0945afa to your computer and use it in GitHub Desktop.
Save PatWie/89950d3f8491a9f0d84125dca0945afa to your computer and use it in GitHub Desktop.
distributed TensorFlow
# -*- coding: utf-8 -*-
# File: base.py
# Author: Patrick Wieschollek <[email protected]>
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import argparse
"""
This example demonstrates how to use TensorFlow in a distributed setting.
Prepare
-----------
Make sure all machines can communicate with each other. UFW (uncomplicated firewall) can used by
ssh machineA
sudo ufw allow from machineB
ssh machineB
sudo ufw allow from machineA
Run
--------
- ssh machine A
- python dist.py --job_name ps --task_index 0
- python dist.py --job_name worker --task_index 0
- ssh machine B
- python dist.py --job_name worker --task_index 1
"""
def run_training(server, cluster_spec, gpu_index):
"""define graph layout
Args:
server (tf.train.server): informtion about current entity
cluster_spec (tf.train.ClusterSpec): information about entire cluster
gpu_index (int): id of gpu which should be used by current worker
"""
num_workers = len(cluster_spec.as_dict()['worker'])
task_index = server.server_def.task_index
is_chief = (task_index == 0)
with tf.Graph().as_default():
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % task_index,
cluster=cluster_spec)):
with tf.device('/cpu:0'):
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0), trainable=False)
with tf.device('/gpu:%d' % (gpu_index)):
# simple classification model
x = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28])
y = tf.placeholder(dtype=tf.int64, shape=[None])
W = tf.Variable(tf.zeros([784, 10])) # noqa
b = tf.Variable(tf.zeros([10]))
x_re = tf.reshape(x, [-1, 28 * 28])
logits = tf.matmul(x_re, W) + b
individual_costs = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
costs = tf.reduce_mean(individual_costs)
y_pred = tf.nn.softmax(logits)
# Define loss and optimizer
opt = tf.train.GradientDescentOptimizer(0.01)
opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=num_workers,
total_num_replicas=num_workers)
train_step = opt.minimize(costs, global_step=global_step)
# Test trained model
correct_prediction = tf.equal(y, tf.argmax(y_pred, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
init_token_op = opt.get_init_tokens_op()
chief_queue_runner = opt.get_chief_queue_runner()
init = tf.global_variables_initializer()
sv = tf.train.Supervisor(is_chief=is_chief,
init_op=init,
global_step=global_step)
# Create a session for running Ops on the Graph.
config = tf.ConfigProto(allow_soft_placement=True)
sess = sv.prepare_or_wait_for_session(server.target, config=config)
if is_chief:
sv.start_queue_runners(sess, [chief_queue_runner])
sess.run(init_token_op)
for i in range(100000):
x_data = np.random.randn(100, 28, 28)
y_data = np.random.randint(10, size=100)
_, cost, acc, step = sess.run([train_step, costs, accuracy, global_step],
feed_dict={x: x_data, y: y_data})
print(cost)
def main():
"""
In the distributed setting as far as I understan, there are two kind of entities:
- parameter-server "ps"
- worker "worker"
Each of those have an id: When having 2 ps and 3 worker, then the task_index is
- 0 for first ps
- 1 for second ps
- 0 for first worker
- 1 for second worker
- 2 for third worker
"""
parser = argparse.ArgumentParser()
parser.add_argument('--task_index', help='identity number', type=int)
parser.add_argument('--job_name', help='identity number', type=str)
args = parser.parse_args()
assert args.job_name in ['ps', 'worker']
# simple config similar to cluster_spec but allow to specify the used gpu's
distr_config = {
'ps': ['machineA:2222'],
'worker': [
{'host': 'machineA:2223', 'gpu': 0},
{'host': 'machineB:2224', 'gpu': 1}
]
}
# specify cluster layout
ps_hosts = distr_config['ps']
worker_hosts = [k['host'] for k in distr_config['worker']]
cluster_spec = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# specify current entity
server = tf.train.Server(cluster_spec, job_name=args.job_name, task_index=args.task_index)
if args.job_name == "ps":
server.join()
elif args.job_name == "worker":
# each worker is assigned to a different gpu
gpu_index = distr_config['worker'][args.task_index]['gpu']
run_training(server, cluster_spec, gpu_index)
if __name__ == '__main__':
#tf.app.run()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment