Skip to content

Instantly share code, notes, and snippets.

@tarun-ssharma
Forked from dansitu/convert.py
Last active April 22, 2022 02:44
Show Gist options
  • Save tarun-ssharma/6bbf1ffe1f1276a9603206a177c8fe0c to your computer and use it in GitHub Desktop.
Save tarun-ssharma/6bbf1ffe1f1276a9603206a177c8fe0c to your computer and use it in GitHub Desktop.
Convert to tflite with randomly picked representative dataset
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=1)
model.evaluate(x_test,  y_test, verbose=2)

image_shape = (28, 28)
def representative_dataset_gen():
    num_calibration_images = 10
    for i in range(num_calibration_images):
        image = tf.random.normal([1] + list(image_shape))
        yield [image]

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = representative_dataset_gen
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_quant_model = converter.convert()

open("coral8.tflite", "wb").write(tflite_quant_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment