Skip to content

Instantly share code, notes, and snippets.

@sherjilozair
Created February 10, 2020 18:07
Show Gist options
  • Save sherjilozair/602230a0bc3ebe8ee8203b8d77ca0711 to your computer and use it in GitHub Desktop.
Save sherjilozair/602230a0bc3ebe8ee8203b8d77ca0711 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import sonnet as snt
# Helper libraries
import numpy as np
import os
import sys
from absl import flags
from absl import app
flags.DEFINE_string("tpu", default=None, help="TPU name.")
flags.DEFINE_integer("batch_size", default=None, help="Batch size.")
FLAGS = flags.FLAGS
class Architecture(snt.AbstractModule):
def _build(self, inputs):
losses = snt.Linear(1)(tf.ones([10, 10]))
loss = tf.reduce_mean(losses)
return {'loss': loss}
def main(argv):
resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu)
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)
with strategy.scope():
architecture = Architecture()
# optimizer = tf.train.GradientDescentOptimizer(1e-4)
def get_dataset(context):
batch_size = context.get_per_replica_batch_size(FLAGS.batch_size)
d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
return d.shard(
context.num_input_pipelines, context.input_pipeline_id)
iterator = strategy.make_input_fn_iterator(get_dataset)
iterator_init = iterator.initialize()
def step_fn(inputs):
outputs = architecture(inputs)
# optimizer_op = optimizer.minimize(
# outputs['loss'], architecture.trainable_variables)
with tf.control_dependencies([]):
return tf.identity(outputs['loss'])
run_values = strategy.experimental_run(step_fn, iterator).values
config = tf.ConfigProto()
config.allow_soft_placement = True
cluster_spec = resolver.cluster_spec()
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
print('Starting training...')
# Do all the computations inside a Session (as opposed to doing eager mode)
with tf.Session(target=resolver.master(), config=config) as session:
session.run(iterator_init)
session.run(tf.global_variables_initializer())
while True:
try:
print(session.run(run_values))
except tf.errors.OutOfRangeError:
break
if __name__ == "__main__":
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment