Created
January 9, 2023 00:49
-
-
Save petered/6d085852f5393c69f48893fa0c2f5220 to your computer and use it in GitHub Desktop.
How do I to save a stateful TFLite model where shape of state depends on input?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import shutil | |
import tempfile | |
from dataclasses import dataclass | |
from typing import Optional, Callable, Any, Mapping | |
import os | |
import numpy as np | |
import tensorflow as tf | |
def save_signatures_to_tflite_model( | |
concrete_function_dict: Mapping[str, Callable], | |
path: str, | |
parent_object: Any, | |
allow_custom_ops=False, | |
): | |
tempdir = tempfile.mkdtemp() | |
try: | |
saved_model_dir = os.path.expanduser(tempdir) | |
tf.saved_model.save(obj=parent_object, export_dir=saved_model_dir, signatures=concrete_function_dict) | |
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) | |
converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS, tf.lite.OpsSet.TFLITE_BUILTINS] # enable TensorFlow Lite ops.] | |
converter.experimental_enable_resource_variables = True | |
converter.allow_custom_ops = allow_custom_ops | |
serialized_model = converter.convert() | |
with open(path, 'wb') as f: | |
f.write(serialized_model) | |
finally: | |
if os.path.isdir(tempdir): | |
shutil.rmtree(tempdir) | |
def load_tflite_model_func(path: str) -> Callable: | |
interpreter = tf.lite.Interpreter(model_path=os.path.expanduser(path)) | |
inputs = interpreter.get_input_details() | |
interpreter.allocate_tensors() | |
def model_func(*args): | |
assert len(inputs) == len(args) | |
for inp, a in zip(inputs, args): | |
interpreter.set_tensor(inp['index'], a) | |
interpreter.invoke() | |
output_details = interpreter.get_output_details() | |
if len(output_details) == 1: # Yes yes this is bad but we lose information about whether output is 1-tuple or scaler | |
return interpreter.get_tensor(output_details[0]['index']) | |
else: | |
return [interpreter.get_tensor(o['index']) for o in output_details] | |
return model_func | |
@dataclass | |
class TimeDelta(tf.Module): | |
_last_val: Optional[tf.Tensor] = None | |
def compute_delta(self, arr: tf.Tensor): | |
if self._last_val is None: | |
self._last_val = tf.Variable(tf.zeros(tf.shape(arr))) | |
delta = arr-self._last_val | |
self._last_val.assign(arr) | |
return delta | |
def test_save_delta(): | |
compile_time_shape = 30, 40 | |
tflite_model_file_path = tempfile.mktemp() | |
delta = TimeDelta() | |
save_signatures_to_tflite_model( | |
{'delta': tf.function(delta.compute_delta, input_signature=[tf.TensorSpec(shape=compile_time_shape)])}, | |
path=tflite_model_file_path, | |
parent_object=delta | |
) | |
# Load the model and run test inputs | |
# HOW CAN I RESIZE THE STATE VARIABLE TO MATCH THE RUNTIME SHAPE? | |
func = load_tflite_model_func(tflite_model_file_path) | |
# runtime_shape = compile_time_shape # If I do this, it works fine | |
runtime_shape = 60, 80 | |
rng = np.random.RandomState(1234) | |
ims = [rng.randn(*runtime_shape).astype(np.float32) for _ in range(3)] | |
assert np.allclose(func(ims[0]), ims[0]) | |
assert np.allclose(func(ims[1]), ims[1]-ims[0]) | |
assert np.allclose(func(ims[2]), ims[2]-ims[1]) | |
if __name__ == "__main__": | |
test_save_delta() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment