Skip to content

Instantly share code, notes, and snippets.

@hartikainen
Created February 1, 2020 14:21
Show Gist options
  • Save hartikainen/b688fe29ebedcb3bca11c4368250f3e2 to your computer and use it in GitHub Desktop.
Save hartikainen/b688fe29ebedcb3bca11c4368250f3e2 to your computer and use it in GitHub Desktop.
import pickle
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, model):
self.model = model
def main():
x = tf.keras.layers.Input((3, ))
y = tf.keras.layers.Dense(5)(x)
# Note, model *is not* picklable.
model = tf.keras.Model(x, y)
_ = model(tf.random.uniform((1, 3)))
module_1 = TestModule(model)
module_2 = pickle.loads(pickle.dumps(module_1))
assert module_1.model.units == module_2.model.units
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment