Skip to content

Instantly share code, notes, and snippets.

@chinalwb
Last active January 15, 2024 06:45
Show Gist options
  • Save chinalwb/48abcbc6db40d2bc64d024aa3ad03afa to your computer and use it in GitHub Desktop.
Save chinalwb/48abcbc6db40d2bc64d024aa3ad03afa to your computer and use it in GitHub Desktop.
TensorFlow add metadata to tflite model
# This script snippet comes from https://www.tensorflow.org/lite/models/convert/metadata_writer_tutorial#audio_classifiers
# Before running check out this prerequisites:
# https://www.tensorflow.org/lite/models/convert/metadata_writer_tutorial#prerequisites
# pip install tflite-support-nightly
from tflite_support.metadata_writers import audio_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils
import sys
AudioClassifierWriter = audio_classifier.MetadataWriter
args = sys.argv[1:]
_MODEL_PATH = args[0]
print("model file path == " + _MODEL_PATH)
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = args[1]
print("label file path == " + _LABEL_FILE)
# Expected sampling rate of the input audio buffer.
_SAMPLE_RATE = 16000
# Expected number of channels of the input audio buffer. Note, Task library only
# support single channel so far.
_CHANNELS = 1
_SAVE_TO_PATH = args[2]
print("save file path == " + _SAVE_TO_PATH)
# Create the metadata writer.
writer = AudioClassifierWriter.create_for_inference(
writer_utils.load_file(_MODEL_PATH), _SAMPLE_RATE, _CHANNELS, [_LABEL_FILE])
# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())
# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment