Skip to content

Instantly share code, notes, and snippets.

@hartikainen
Created February 1, 2020 14:21
Show Gist options
  • Save hartikainen/d8901fb2c360cfba54eaa7b53aa7af6c to your computer and use it in GitHub Desktop.
Save hartikainen/d8901fb2c360cfba54eaa7b53aa7af6c to your computer and use it in GitHub Desktop.
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, model):
self.model = model
x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
model = tf.keras.Model(x, y)
module_1 = TestModule(model)
tf.saved_model.save(module_1, "./foo")
imported = tf.saved_model.load("./foo")
original_weights = module_1.model.weights
imported_weights = imported.model.variables.weights
for weight_idx, _ in enumerate(original_weights):
assert (
original_weights[weight_idx].numpy() == imported_weights[weight_idx].numpy()
).all()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment