Last active
January 15, 2024 06:45
-
-
Save chinalwb/48abcbc6db40d2bc64d024aa3ad03afa to your computer and use it in GitHub Desktop.
TensorFlow add metadata to tflite model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This script snippet comes from https://www.tensorflow.org/lite/models/convert/metadata_writer_tutorial#audio_classifiers | |
# Before running check out this prerequisites: | |
# https://www.tensorflow.org/lite/models/convert/metadata_writer_tutorial#prerequisites | |
# pip install tflite-support-nightly | |
from tflite_support.metadata_writers import audio_classifier | |
from tflite_support.metadata_writers import metadata_info | |
from tflite_support.metadata_writers import writer_utils | |
import sys | |
AudioClassifierWriter = audio_classifier.MetadataWriter | |
args = sys.argv[1:] | |
_MODEL_PATH = args[0] | |
print("model file path == " + _MODEL_PATH) | |
# Task Library expects label files that are in the same format as the one below. | |
_LABEL_FILE = args[1] | |
print("label file path == " + _LABEL_FILE) | |
# Expected sampling rate of the input audio buffer. | |
_SAMPLE_RATE = 16000 | |
# Expected number of channels of the input audio buffer. Note, Task library only | |
# support single channel so far. | |
_CHANNELS = 1 | |
_SAVE_TO_PATH = args[2] | |
print("save file path == " + _SAVE_TO_PATH) | |
# Create the metadata writer. | |
writer = AudioClassifierWriter.create_for_inference( | |
writer_utils.load_file(_MODEL_PATH), _SAMPLE_RATE, _CHANNELS, [_LABEL_FILE]) | |
# Verify the metadata generated by metadata writer. | |
print(writer.get_metadata_json()) | |
# Populate the metadata into the model. | |
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment