Last active
March 12, 2019 12:11
-
-
Save gridcellcoder/db8e045bb920e066b2f6e2ac68ea0676 to your computer and use it in GitHub Desktop.
TPU Cross Shard Optimizer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
from tensorflow.contrib.tpu.python.tpu import tpu_function | |
import os | |
import pprint | |
import tensorflow as tf | |
if 'COLAB_TPU_ADDR' not in os.environ: | |
print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!') | |
else: | |
tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR'] | |
print ('TPU address is', tpu_address) | |
with tf.Session(tpu_address) as session: | |
devices = session.list_devices() | |
print('TPU devices:') | |
pprint.pprint(devices) | |
# Add this somewhere at the top | |
tpu_function.get_tpu_context().set_number_of_shards(8) | |
# x and y are placeholders for our training data | |
x = tf.placeholder("float") | |
y = tf.placeholder("float") | |
# w is the variable storing our values. It is initialised with starting "guesses" | |
# w[0] is the "a" in our equation, w[1] is the "b" | |
w = tf.Variable([1.0, 2.0,3.0, 4.0], name="w") | |
# Our model of y = a*x + b | |
y_model = tf.multiply(x, w[0]) + w[1] + w[2] +3 | |
# Our error is defined as the square of the differences | |
error = tf.square(y - y_model) | |
# The Gradient Descent Optimizer does the heavy lifting | |
train_op = tf.train.AdamOptimizer(0.01) | |
optimizer = tf.contrib.tpu.CrossShardOptimizer(train_op).minimize(error) # TPU change 1 | |
# Normal TensorFlow - initialize values, create a session and run the model | |
model = tf.global_variables_initializer() | |
with tf.Session(tpu_address) as session: | |
session.run(tf.contrib.tpu.initialize_system()) | |
print('init') | |
session.run(model) | |
for i in range(10000): | |
print(i) | |
x_value = np.random.rand() | |
y_value = x_value * 2 + 6 + 5 + 3 | |
session.run(optimizer, feed_dict={x: x_value, y: y_value}) | |
w_value = session.run(w) | |
print("Predicted model: {a:.3f}x + {b:.3f}+{c:.3f}x + {d:.3f}".format(a=w_value[0], b=w_value[1], c=w_value[2], d=w_value[3])) | |
session.run(tpu.shutdown_system()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment