Skip to content

Instantly share code, notes, and snippets.

@sherjilozair
Created April 4, 2019 19:33
Show Gist options
  • Save sherjilozair/e7b6da63a308fe8adf2a8a4f8dcc5a13 to your computer and use it in GitHub Desktop.
Save sherjilozair/e7b6da63a308fe8adf2a8a4f8dcc5a13 to your computer and use it in GitHub Desktop.
import tensorflow.compat.v1 as tf
import os
tf.disable_eager_execution()
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=os.environ['TPU_ENDPOINT'])
tf.tpu.experimental.initialize_tpu_system(resolver)
tpu_strategy = tf.distribute.experimental.TPUStrategy(resolver)
with tpu_strategy.scope():
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
optimizer = tf.train.AdamOptimizer(2e-4)
with tpu_strategy.scope():
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(1000).batch(64)
input_iterator = tpu_strategy.make_dataset_iterator(dataset)
@tf.function
def train_step():
def step_fn(inputs):
features, labels = inputs
logits = model(features)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
loss = tf.reduce_sum(cross_entropy) * (1.0 / 64)
loss = tf.with_dependencies([optimizer.minimize(loss, model.weights)], loss)
return loss
per_replica_losses = tpu_strategy.experimental_run(
step_fn, input_iterator)
mean_loss = mirrored_strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses)
return mean_loss
with tf.Session() as sess:
with tpu_strategy.scope():
input_iterator.initialize()
for _ in range(10):
print(sess.run(train_step()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment