Last active
September 24, 2020 15:02
-
-
Save Namburger/8369600feae227004deb9bd5a6c1f50f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tempfile | |
import os | |
import tensorflow as tf | |
import numpy as np | |
from tensorflow import keras | |
# Load MNIST dataset | |
mnist = keras.datasets.mnist | |
(train_images, train_labels), (test_images, test_labels) = mnist.load_data() | |
# Normalize the input image so that each pixel value is between 0 to 1. | |
train_images = train_images / 255.0 | |
test_images = test_images / 255.0 | |
# Define the model architecture. | |
model = keras.Sequential([ | |
keras.layers.InputLayer(input_shape=(28, 28)), | |
keras.layers.Reshape(target_shape=(28, 28, 1)), | |
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), | |
keras.layers.MaxPooling2D(pool_size=(2, 2)), | |
keras.layers.Flatten(), | |
keras.layers.Dense(10) | |
]) | |
# Train the digit classification model | |
model.compile(optimizer='adam', | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=['accuracy']) | |
model.fit( | |
train_images, | |
train_labels, | |
epochs=4, | |
validation_split=0.1, | |
) | |
_, baseline_model_accuracy = model.evaluate( | |
test_images, test_labels, verbose=0) | |
print('Baseline test accuracy:', baseline_model_accuracy) | |
_, keras_file = tempfile.mkstemp('.h5') | |
tf.keras.models.save_model(model, keras_file, include_optimizer=False) | |
print('Saved baseline model to:', keras_file) | |
import tensorflow_model_optimization as tfmot | |
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude | |
# Compute end step to finish pruning after 2 epochs. | |
batch_size = 128 | |
epochs = 2 | |
validation_split = 0.1 # 10% of training set will be used for validation set. | |
num_images = train_images.shape[0] * (1 - validation_split) | |
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs | |
# Define model for pruning. | |
pruning_params = { | |
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50, | |
final_sparsity=0.80, | |
begin_step=0, | |
end_step=end_step) | |
} | |
model_for_pruning = prune_low_magnitude(model, **pruning_params) | |
# `prune_low_magnitude` requires a recompile. | |
model_for_pruning.compile(optimizer='adam', | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=['accuracy']) | |
model_for_pruning.summary() | |
logdir = tempfile.mkdtemp() | |
callbacks = [ | |
tfmot.sparsity.keras.UpdatePruningStep(), | |
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir), | |
] | |
model_for_pruning.fit(train_images, train_labels, | |
batch_size=batch_size, epochs=epochs, validation_split=validation_split, | |
callbacks=callbacks) | |
_, model_for_pruning_accuracy = model_for_pruning.evaluate( | |
test_images, test_labels, verbose=0) | |
print('Baseline test accuracy:', baseline_model_accuracy) | |
print('Pruned test accuracy:', model_for_pruning_accuracy) | |
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning) | |
_, pruned_keras_file = tempfile.mkstemp('.h5') | |
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False) | |
print('Saved pruned Keras model to:', pruned_keras_file) | |
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export) | |
def representative_data_gen(): | |
mnist_train, _ = tf.keras.datasets.mnist.load_data() | |
images = tf.cast(mnist_train[0], tf.float32) / 255.0 | |
mnist_ds = tf.data.Dataset.from_tensor_slices((images)).batch(1) | |
for input_value in mnist_ds.take(100): | |
# Model has only one input so each data point has one element. | |
yield [input_value] | |
# This enables quantization | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
converter.target_spec.supported_types = [tf.int8] | |
# This ensures that if any ops can't be quantized, the converter throws an error | |
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | |
# These set the input and output tensors to uint8 (added in r2.3) | |
converter.inference_input_type = tf.uint8 | |
converter.inference_output_type = tf.uint8 | |
# And this sets the representative dataset so we can quantize the activations | |
converter.representative_dataset = representative_data_gen | |
converter.experimental_new_converter = False | |
pruned_tflite_model = converter.convert() | |
pruned_tflite_file = 'mnist_pruned_ptq.tflite' | |
with open(pruned_tflite_file, 'wb') as f: | |
f.write(pruned_tflite_model) | |
print('Saved pruned TFLite model to:', pruned_tflite_file) | |
def get_gzipped_model_size(file): | |
# Returns size of gzipped model, in bytes. | |
import os | |
import zipfile | |
_, zipped_file = tempfile.mkstemp('.zip') | |
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f: | |
f.write(file) | |
return os.path.getsize(zipped_file) | |
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file))) | |
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file))) | |
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file))) | |
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export) | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
quantized_and_pruned_tflite_model = converter.convert() | |
quantized_and_pruned_tflite_file = 'mnist_pruned_quantized.tflite' | |
with open(quantized_and_pruned_tflite_file, 'wb') as f: | |
f.write(quantized_and_pruned_tflite_model) | |
print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file) | |
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file))) | |
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file))) | |
import numpy as np | |
def evaluate_model(interpreter): | |
input_index = interpreter.get_input_details()[0]["index"] | |
output_index = interpreter.get_output_details()[0]["index"] | |
# Run predictions on ever y image in the "test" dataset. | |
prediction_digits = [] | |
for i, test_image in enumerate(test_images): | |
if i % 1000 == 0: | |
print('Evaluated on {n} results so far.'.format(n=i)) | |
# Pre-processing: add batch dimension and convert to float32 to match with | |
# the model's input data format. | |
test_image = np.expand_dims(test_image, axis=0).astype(np.float32) | |
interpreter.set_tensor(input_index, test_image) | |
# Run inference. | |
interpreter.invoke() | |
# Post-processing: remove batch dimension and find the digit with highest | |
# probability. | |
output = interpreter.tensor(output_index) | |
digit = np.argmax(output()[0]) | |
prediction_digits.append(digit) | |
print('\n') | |
# Compare prediction results with ground truth labels to calculate accuracy. | |
prediction_digits = np.array(prediction_digits) | |
accuracy = (prediction_digits == test_labels).mean() | |
return accuracy | |
import numpy as np | |
def evaluate_model(interpreter): | |
input_index = interpreter.get_input_details()[0]["index"] | |
output_index = interpreter.get_output_details()[0]["index"] | |
# Run predictions on ever y image in the "test" dataset. | |
prediction_digits = [] | |
for i, test_image in enumerate(test_images): | |
if i % 1000 == 0: | |
print('Evaluated on {n} results so far.'.format(n=i)) | |
# Pre-processing: add batch dimension and convert to float32 to match with | |
# the model's input data format. | |
test_image = np.expand_dims(test_image, axis=0).astype(np.float32) | |
interpreter.set_tensor(input_index, test_image) | |
# Run inference. | |
interpreter.invoke() | |
# Post-processing: remove batch dimension and find the digit with highest | |
# probability. | |
output = interpreter.tensor(output_index) | |
digit = np.argmax(output()[0]) | |
prediction_digits.append(digit) | |
print('\n') | |
# Compare prediction results with ground truth labels to calculate accuracy. | |
prediction_digits = np.array(prediction_digits) | |
accuracy = (prediction_digits == test_labels).mean() | |
return accuracy | |
# Test accurace of pruned tflite model | |
interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model) | |
interpreter.allocate_tensors() | |
test_accuracy = evaluate_model(interpreter) | |
print('Pruned and quantized TFLite test_accuracy:', test_accuracy) | |
print('Pruned TF test accuracy:', model_for_pruning_accuracy) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment