Created
June 11, 2019 12:41
-
-
Save chunseoklee/7050f658c1b93648500d1e9b96f93cd1 to your computer and use it in GitHub Desktop.
last_model on tf experimental 1.14
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Note this needs to happen before import tensorflow. | |
import os | |
import sys | |
import tensorflow as tf | |
import argparse | |
os.environ['TF_ENABLE_CONTROL_FLOW_V2'] = '1' | |
class MnistLstmModel(object): | |
"""Build a simple LSTM based MNIST model. | |
Attributes: | |
time_steps: The maximum length of the time_steps, but since we're just using | |
the 'width' dimension as time_steps, it's actually a fixed number. | |
input_size: The LSTM layer input size. | |
num_lstm_layer: Number of LSTM layers for the stacked LSTM cell case. | |
num_lstm_units: Number of units in the LSTM cell. | |
units: The units for the last layer. | |
num_class: Number of classes to predict. | |
""" | |
def __init__(self, time_steps, input_size, num_lstm_layer, num_lstm_units, | |
units, num_class): | |
self.time_steps = time_steps | |
self.input_size = input_size | |
self.num_lstm_layer = num_lstm_layer | |
self.num_lstm_units = num_lstm_units | |
self.units = units | |
self.num_class = num_class | |
def build_model(self): | |
"""Build the model using the given configs. | |
Returns: | |
x: The input placehoder tensor. | |
logits: The logits of the output. | |
output_class: The prediction. | |
""" | |
x = tf.placeholder( | |
'float32', [None, self.time_steps, self.input_size], name='INPUT') | |
lstm_layers = [] | |
for _ in range(self.num_lstm_layer): | |
lstm_layers.append( | |
# Important: | |
# | |
# Note here, we use `tf.lite.experimental.nn.TFLiteLSTMCell` | |
# (OpHinted LSTMCell). | |
tf.lite.experimental.nn.TFLiteLSTMCell( | |
self.num_lstm_units, forget_bias=0)) | |
# Weights and biases for output softmax layer. | |
out_weights = tf.Variable(tf.random_normal([self.units, self.num_class])) | |
out_bias = tf.Variable(tf.zeros([self.num_class])) | |
# Transpose input x to make it time major. | |
lstm_inputs = tf.transpose(x, perm=[1, 0, 2]) | |
lstm_cells = tf.keras.layers.StackedRNNCells(lstm_layers) | |
# Important: | |
# | |
# Note here, we use `tf.lite.experimental.nn.dynamic_rnn` and `time_major` | |
# is set to True. | |
outputs, _ = tf.lite.experimental.nn.dynamic_rnn( | |
lstm_cells, lstm_inputs, dtype='float32', time_major=True) | |
# Transpose the outputs back to [batch, time, output] | |
outputs = tf.transpose(outputs, perm=[1, 0, 2]) | |
outputs = tf.unstack(outputs, axis=1) | |
logits = tf.matmul(outputs[-1], out_weights) + out_bias | |
output_class = tf.nn.softmax(logits, name='OUTPUT_CLASS') | |
return x, logits, output_class | |
def train(model, | |
model_dir, | |
batch_size=20, | |
learning_rate=0.001, | |
train_steps=2000, | |
eval_steps=500, | |
save_every_n_steps=1000): | |
"""Train & save the MNIST recognition model.""" | |
# Train & test dataset. | |
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() | |
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) | |
train_iterator = train_dataset.shuffle( | |
buffer_size=1000).batch(batch_size).repeat().make_one_shot_iterator() | |
x, logits, output_class = model.build_model() | |
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) | |
test_iterator = test_dataset.batch( | |
batch_size).repeat().make_one_shot_iterator() | |
# input label placeholder | |
y = tf.placeholder(tf.int32, [ | |
None, | |
]) | |
one_hot_labels = tf.one_hot(y, depth=model.num_class) | |
# Loss function | |
loss = tf.reduce_mean( | |
tf.nn.softmax_cross_entropy_with_logits( | |
logits=logits, labels=one_hot_labels)) | |
correct = tf.nn.in_top_k(output_class, y, 1) | |
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) | |
# Optimization | |
opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) | |
# Initialize variables | |
init = tf.global_variables_initializer() | |
saver = tf.train.Saver() | |
batch_x, batch_y = train_iterator.get_next() | |
batch_test_x, batch_test_y = test_iterator.get_next() | |
with tf.Session() as sess: | |
sess.run([init]) | |
for i in range(train_steps): | |
batch_x_value, batch_y_value = sess.run([batch_x, batch_y]) | |
_, loss_value = sess.run([opt, loss], | |
feed_dict={ | |
x: batch_x_value, | |
y: batch_y_value | |
}) | |
if i % 100 == 0: | |
tf.logging.info('Training step %d, loss is %f' % (i, loss_value)) | |
if i > 0 and i % save_every_n_steps == 0: | |
accuracy_sum = 0.0 | |
for _ in range(eval_steps): | |
test_x_value, test_y_value = sess.run([batch_test_x, batch_test_y]) | |
accuracy_value = sess.run( | |
accuracy, feed_dict={ | |
x: test_x_value, | |
y: test_y_value | |
}) | |
accuracy_sum += accuracy_value | |
tf.logging.info('Training step %d, accuracy is %f' % | |
(i, accuracy_sum / (eval_steps * 1.0))) | |
saver.save(sess, model_dir) | |
def export(model, model_dir, tflite_model_file, | |
use_post_training_quantize=True): | |
"""Export trained model to tflite model.""" | |
tf.reset_default_graph() | |
x, _, output_class = model.build_model() | |
saver = tf.train.Saver() | |
sess = tf.Session() | |
saver.restore(sess, model_dir) | |
# Convert to Tflite model. | |
converter = tf.lite.TFLiteConverter.from_session(sess, [x], [output_class]) | |
converter.post_training_quantize = use_post_training_quantize | |
tflite = converter.convert() | |
with open(tflite_model_file, 'wb') as f: | |
f.write(tflite) | |
def train_and_export(parsed_flags): | |
"""Train the MNIST LSTM model and export to TfLite.""" | |
model = MnistLstmModel( | |
time_steps=28, | |
input_size=28, | |
num_lstm_layer=2, | |
num_lstm_units=64, | |
units=64, | |
num_class=10) | |
tf.logging.info('Starts training...') | |
train(model, parsed_flags.model_dir) | |
tf.logging.info('Finished training, starts exporting to tflite to %s ...' % | |
parsed_flags.tflite_model_file) | |
export(model, parsed_flags.model_dir, parsed_flags.tflite_model_file, | |
parsed_flags.use_post_training_quantize) | |
tf.logging.info( | |
'Finished exporting, model is %s' % parsed_flags.tflite_model_file) | |
def run_main(_): | |
"""Main in the TfLite LSTM tutorial.""" | |
parser = argparse.ArgumentParser( | |
description=('Train a MNIST recognition model then export to TfLite.')) | |
parser.add_argument( | |
'--model_dir', | |
type=str, | |
help='Directory where the models will store.', | |
required=True) | |
parser.add_argument( | |
'--tflite_model_file', | |
type=str, | |
help='Full filepath to the exported tflite model file.', | |
required=True) | |
parser.add_argument( | |
'--use_post_training_quantize', | |
action='store_true', | |
default=True, | |
help='Whether or not to use post_training_quatize.') | |
parsed_flags, _ = parser.parse_known_args() | |
train_and_export(parsed_flags) | |
#def main(): | |
# app.run(main=run_main, argv=sys.argv[:1]) | |
if __name__ == '__main__': | |
run_main(sys.argv[:1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment