Skip to content

Instantly share code, notes, and snippets.

@hartikainen
Created February 1, 2020 14:21
Show Gist options
  • Save hartikainen/c619523b89065a5526287e16dda076cf to your computer and use it in GitHub Desktop.
Save hartikainen/c619523b89065a5526287e16dda076cf to your computer and use it in GitHub Desktop.
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, value, model):
self.value = value
self.model = model
def main():
x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
# Note, model *is not* picklable.
model = tf.keras.Model(x, y)
_ = model(tf.random.uniform((1, 3)))
module_1 = TestModule(value=4, model=model)
tf.saved_model.save(module_1, "./foo")
module_2 = tf.saved_model.load("./foo")
for variable_1, variable_2 in zip(module_1.model.trainable_variables,
module_2.model.trainable_variables):
tf.debugging.assert_equal(variable_1, variable_2)
tf.debugging.assert_equal(module_1.value, module_2.value)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment