Last active
November 9, 2018 17:21
-
-
Save crypt3lx2k/2e5da6a94a180455290284a9dca4e143 to your computer and use it in GitHub Desktop.
Training and exporting a keras model to the TFLite format
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
import tensorflow as tf | |
mnist = tf.keras.datasets.mnist | |
(x_train, y_train),(x_test, y_test) = mnist.load_data() | |
x_train, x_test = x_train / 255.0, x_test / 255.0 | |
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Flatten(input_shape=(28, 28)), | |
tf.keras.layers.Dense(512, activation=tf.nn.relu), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(10, activation=tf.nn.softmax) | |
]) | |
model.compile(optimizer='adam', | |
loss='sparse_categorical_crossentropy', | |
metrics=['accuracy']) | |
model.fit(x_train, y_train, epochs=5) | |
model.save('keras_mnist.h5') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
import tensorflow.contrib.lite as lite | |
converter = lite.TFLiteConverter.from_keras_model_file('keras_mnist.h5') | |
converter.post_training_quantize = True | |
converter.inference_input_type = lite.constants.QUANTIZED_UINT8 | |
converter.quantized_input_stats = {"flatten_input" : (0.0, 255.0)} | |
flatbuffer = converter.convert() | |
with open('keras_mnist.tflite', 'wb') as outfile: | |
outfile.write(flatbuffer) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow.contrib.lite as lite | |
mnist = tf.keras.datasets.mnist | |
batch_size = 32 | |
_,(x_test, y_test) = mnist.load_data() | |
interpreter = lite.Interpreter('keras_mnist.tflite') | |
input_info = interpreter.get_input_details()[0] | |
output_info = interpreter.get_output_details()[0] | |
interpreter.resize_tensor_input(input_info['index'], (batch_size, 28, 28)) | |
interpreter.allocate_tensors() | |
interpreter.set_tensor(input_info['index'], x_test[0:batch_size]) | |
interpreter.invoke() | |
probs = interpreter.get_tensor(output_info['index']) | |
print('predicted={}, label={}'.format(np.argmax(probs, axis=-1), y_test[0:batch_size])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment