Skip to content

Instantly share code, notes, and snippets.

@yongjun823
Created October 9, 2018 08:19
Show Gist options
  • Select an option

  • Save yongjun823/13549136ee1d316bf792edd3f43f8b67 to your computer and use it in GitHub Desktop.

Select an option

Save yongjun823/13549136ee1d316bf792edd3f43f8b67 to your computer and use it in GitHub Desktop.
mnist tpu keras (tensorflow)
import tensorflow as tf
import numpy as np
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
# add empty color dimension
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(64, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
model.add(tf.keras.layers.Dropout(0.25))
model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(128, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf.keras.layers.Dropout(0.25))
model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(256, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
model.add(tf.keras.layers.Dropout(0.25))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(256))
model.add(tf.keras.layers.Activation('elu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10))
model.add(tf.keras.layers.Activation('softmax'))
model.summary()
import os
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
tf.contrib.cluster_resolver.TPUClusterResolver(tpu=os.environ['TPU_NAME'])
)
)
tpu_model.compile(
optimizer=tf.train.AdamOptimizer(learning_rate=1e-3, ),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['sparse_categorical_accuracy']
)
def train_gen(batch_size):
while True:
offset = np.random.randint(0, x_train.shape[0] - batch_size)
yield x_train[offset:offset+batch_size], y_train[offset:offset + batch_size]
tpu_model.fit_generator(
train_gen(1024),
epochs=20,
steps_per_epoch=100,
validation_data=(x_test, y_test),
)
tpu_model.save('fashion_mnist.h5')
cpu_model = tpu_model.sync_to_cpu()
LABEL_NAMES = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
predictions = [LABEL_NAMES[x] for x in np.argmax(cpu_model.predict(x_test[:16]), axis=1)]
label_map = [LABEL_NAMES[x] for x in y_test[:16]]
print(predictions)
print(label_map)
## loading tpu model
keras_model = tf.keras.models.load_model('fashion_mnist.h5')
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
keras_model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
)
)
tpu_model.compile(
optimizer=tf.train.AdamOptimizer(learning_rate=1e-3, ),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['sparse_categorical_accuracy']
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment