Skip to content

Instantly share code, notes, and snippets.

@benoitdescamps
Last active May 10, 2019 01:14
Show Gist options
  • Save benoitdescamps/543bc3d68187dd2a2b15832c47a6f25a to your computer and use it in GitHub Desktop.
Save benoitdescamps/543bc3d68187dd2a2b15832c47a6f25a to your computer and use it in GitHub Desktop.
def map_fun(args, ctx):
worker_num = ctx.worker_num
job_name = ctx.job_name
task_index = ctx.task_index
cluster, server = ctx.start_cluster_server(1)
if job_name == "ps":
server.join()
elif job_name == "worker":
is_chiefing = (task_index == 0)
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % task_index,
cluster=cluster)):
def build_model():
pass
hooks=[...]
with tf.train.MonitoredTrainingSession(master=server.target,\
is_chief=is_chiefing,
checkpoint_dir=arsg['save_dir'],\
hooks=hooks,\
save_checkpoint_secs=600.) as mon_sess:
tf_feed = ctx.get_data_feed(train_mode=True)
while not mon_sess.should_stop() and not tf_feed.should_stop():
batch_data = tf_feed.next_batch(args['batch_size']))
#apply what you need to be done here
_ = mon_sess.run(...)
if mon_sess.should_stop():
tf_feed.terminate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment