Skip to content

Instantly share code, notes, and snippets.

@yongjun823
Created November 19, 2018 10:35
Show Gist options
  • Save yongjun823/9468a318599efc02da271efa8519ed36 to your computer and use it in GitHub Desktop.
Save yongjun823/9468a318599efc02da271efa8519ed36 to your computer and use it in GitHub Desktop.
tensorflow + tensorlayer tpu session mnist
import os
import pprint
import tensorflow as tf
import tensorlayer as tl
if 'COLAB_TPU_ADDR' not in os.environ:
print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
else:
tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
print ('TPU address is', tpu_address)
with tf.Session(tpu_address) as session:
devices = session.list_devices()
print('TPU devices:')
pprint.pprint(devices)
def model_fn(input_x, input_y):
network = tl.layers.InputLayer(input_x, name='input_layer')
network = tl.layers.FlattenLayer(network)
network = tl.layers.DenseLayer(network, n_units=800, act = tf.nn.relu, name='relu12')
network = tl.layers.DenseLayer(network, n_units=10,
act = tf.identity,
name='output_layer2')
# define cost function and metric.
y = network.outputs
cost = tl.cost.cross_entropy(y, input_y, 'cost')
y_op = tf.argmax(tf.nn.softmax(y), 1)
correct_prediction = tf.equal(tf.argmax(y, 1), input_y)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
train_params = network.all_params
optimizer = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999,
epsilon=1e-08, use_locking=False)
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
train_op = optimizer.minimize(cost, var_list=train_params)
return acc, train_op
x = tf.placeholder(tf.float32, [None, 28, 28])
y_ = tf.placeholder(tf.int64, [None, ])
tpu_ops = tf.contrib.tpu.rewrite(model_fn, [x, y_])
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
with tf.Session(tpu_address) as sess:
sess.run(tf.contrib.tpu.initialize_system())
tl.layers.initialize_global_variables(sess)
network.print_params()
network.print_layers()
for j in range(30):
for i in range(0, len(x_train), 1024):
result = sess.run(tpu_ops, feed_dict={
x: x_train[i:i+10],
y_: y_train[i:i+10]
})
print(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment