-
-
Save alexwal/decd4cf124023113b2633dac9ef34fc5 to your computer and use it in GitHub Desktop.
Example of using shared counters to implement Barrier primitive
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Alex Walczak, 2017 | |
Example of barrier implementation using TensorFlow shared variables | |
across a multi-machine cluster. | |
All workers synchronize on the barrier, copy global parameters to local versions, | |
and increment the global parameter variable asynchronously. | |
On each worker run: | |
$ killall python3 | |
If you have a cluster of 4 machines, then on the first machine run: | |
$ python3 cluster_barrier_for_tensorflow.py --job_name=ps --task_index=0 --ps_hosts=... --worker_hosts=... & | |
(with the trailing &) | |
$ python3 cluster_barrier_for_tensorflow.py --job_name=worker --task_index=0 --ps_hosts=... --worker_hosts=... | |
And on the other 3 machines, run one of: | |
$ python3 cluster_barrier_for_tensorflow.py --job_name=worker --task_index=1 --ps_hosts=... --worker_hosts=... | |
$ python3 cluster_barrier_for_tensorflow.py --job_name=worker --task_index=2 --ps_hosts=... --worker_hosts=... | |
$ python3 cluster_barrier_for_tensorflow.py --job_name=worker --task_index=3 --ps_hosts=... --worker_hosts=... | |
You should see something like this for each worker k=0,1,2,3: | |
Worker k: local_param 0, global_param 1 | |
Worker k: local_param 4, global_param 5 | |
Worker k: local_param 8, global_param 9 | |
Worker k: local_param 12, global_param 13 | |
Worker k: local_param 16, global_param 17 | |
(Tested with Tensorflow r1.5) | |
Thanks to Yaroslav Bulatov (yaroslavvb) for the original implementation, | |
which spawns multiple processes on a single machine. | |
https://gist.github.com/yaroslavvb/ef407a599f0f549f62d91c3a00dcfb6c | |
''' | |
import numpy as np | |
import os | |
import time | |
import tensorflow as tf | |
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # optionally supress TF warnings in log files | |
tf.app.flags.DEFINE_string('ps_hosts', '', 'comma separated list of ps_host_ip:port') | |
tf.app.flags.DEFINE_string('worker_hosts', '', 'comma separated list of worker_host_ip:port') | |
tf.app.flags.DEFINE_string('job_name', '', 'the job: ps | worker') | |
tf.app.flags.DEFINE_integer('task_index', 0, 'which task number for the ps or worker') | |
tf.app.flags.DEFINE_float('sleep_interval', 0.1, 'how long to sleep in wait loop') | |
tf.app.flags.DEFINE_integer('iters', 10, 'maximum number of steps to run per worker') | |
FLAGS = tf.app.flags.FLAGS | |
def default_config(): | |
optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0) | |
config = tf.ConfigProto( | |
graph_options=tf.GraphOptions(optimizer_options=optimizer_options)) | |
config.log_device_placement = False | |
config.allow_soft_placement = False | |
return config | |
def run_test(): | |
# TF DIST SETUP | |
ps_hosts = FLAGS.ps_hosts.split(',') | |
worker_hosts = FLAGS.worker_hosts.split(',') | |
# Create a cluster from the parameter server and worker hosts. | |
cluster = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts}) | |
# Create and start a server for the local task. | |
server = tf.train.Server(cluster, | |
job_name=FLAGS.job_name, | |
task_index=FLAGS.task_index, | |
config=default_config()) | |
if FLAGS.job_name == 'ps': | |
server.join() | |
elif FLAGS.job_name == 'worker': | |
dtype = tf.int32 | |
num_workers = len(worker_hosts) | |
# vars and ops | |
init_op = None | |
train_ops = [] # worker local train ops, read local params, update global | |
counter_vars = [] # counters for barrier | |
counter_adder_ops = [] | |
global_param_var = None | |
local_param_vars = [] | |
local_param_sync_ops = [] | |
# all ps and worker tasks | |
ps_device = '/job:ps/task:0/cpu:0' | |
ps_host = ps_hosts[0] # assume we are only launching a single ps host and many workers | |
worker_devices = ['/job:worker/task:{}'.format(i) for i in range(num_workers)] | |
# create global parameters | |
with tf.device(ps_device): | |
global_param_var = tf.get_variable('param', shape=(), dtype=dtype, | |
initializer=tf.zeros_initializer) | |
for i in range(2): | |
counter_var = tf.get_variable('counter-{}'.format(i), (), dtype, | |
initializer=tf.zeros_initializer) | |
counter_vars.append(counter_var) | |
counter_adder_ops.append(counter_var.assign_add(1, use_locking=True)) | |
# create local version of parameters | |
for (i, device) in enumerate(worker_devices): | |
with tf.device(device): | |
local_param_var = tf.get_variable('local_param-{}'.format(i), (), dtype, | |
initializer=tf.zeros_initializer) | |
local_param_vars.append(local_param_var) | |
local_param_sync_op = local_param_var.assign(global_param_var) | |
local_param_sync_ops.append(local_param_sync_op) | |
train_op = global_param_var.assign_add(1) | |
train_ops.append(train_op) | |
init_op = tf.global_variables_initializer() | |
with tf.Session('grpc://{}'.format(ps_host), config=default_config()) as sess: # Workers connect to the same Session | |
def barrier(): | |
# When this function returns, every worker will execute the following line next. | |
for i in range(2): | |
sess.run(counter_adder_ops[i]) # Increment global counter once on this worker | |
while sess.run(counter_vars[i]) % num_workers != 0: # Wait until every worker has incremented the global counter | |
time.sleep(FLAGS.sleep_interval) # Sleep ensures that every worker will increment the global counter | |
sess.run(init_op) | |
worker_id = FLAGS.task_index | |
local_param_var = local_param_vars[worker_id] | |
sync_op = local_param_sync_ops[worker_id] | |
train_op = train_ops[worker_id] | |
for i in range(FLAGS.iters): | |
# 1. Wait for all workers to finish incrementing global_param_var | |
barrier() | |
sess.run(sync_op) | |
# 2. Wait for all workers to finish assigning global_param_var to local_param_var | |
barrier() | |
local_val, global_val = sess.run([local_param_var, train_op]) # Increment global_param_var | |
print('Worker {}: local_param {}, global_param {}'.format(worker_id, local_val, global_val)) | |
barrier() | |
sess.run(sync_op) | |
local_val, global_val = sess.run([local_param_var, global_param_var]) | |
print('+++ Final value for worker {}: local_param {}, global_param {}'.format(worker_id, local_val, global_val)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Why use two counters in barrier?