Skip to content

Instantly share code, notes, and snippets.

View thierryherrmann's full-sized avatar

Thierry Herrmann thierryherrmann

  • Montreal, Canada
View GitHub Profile
inspect_checkpoint(model_dir + '/variables/variables', print_values=True,
variables=['model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE'])
tensor : model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE (30,)
[ 3.51306298e-05 3.61366037e-05 -3.67252505e-06 9.21028666e-04
7.78463436e-04 2.24373052e-05 6.05550595e-04 7.36912712e-04
-4.31884764e-05 1.44443940e-04 1.24389135e-05 8.46692594e-04
1.70874955e-05 3.72679904e-04 5.41794288e-05 6.08396949e-04
1.95211032e-06 8.75406899e-04 9.23899701e-04 2.17679326e-06
8.70055985e-04 6.87883934e-04 5.30559737e-06 5.81342028e-04
2.78645912e-05 4.61369600e-05 7.27826264e-04 1.64074972e-05
-6.21771906e-05 1.15486218e-05]
signature_def['my_serve']:
The given SavedModel SignatureDef contains the following input(s):
inputs['X'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 8)
name: my_serve_X:0
The given SavedModel SignatureDef contains the following output(s):
outputs['output_0'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
# instantiate a new module and save it untrained
module = CustomModule()
save_module(module, model_dir)
del module
print('\n\n========== Reload module ===========')
# the following works also if we reload in another python process
model_dir = 'saved_model'
print('type of reloaded module:', type(new_module))
print('type of instantiated module:', type(CustomModule()))
print('my_train function:', new_module.my_train)
print('__call__ function:', new_module.__call__)
# demo a call to the module. (calls the __call__() method)
print('sample prediction: ', new_module(X_train[0:1]).numpy())
type of reloaded module: <class 'tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject'>
type of instantiated module: <class '__main__.CustomModule'>
my_train function: <tensorflow.python.saved_model.function_deserialization.RestoredFunction object at 0x7fc1c5872390>
__call__ function: <tensorflow.python.saved_model.function_deserialization.RestoredFunction object at 0x7fc1c58cbfd0>
sample prediction: [[0.54084957]]
loss_hist = train_module(new_module, train_dataset, valid_dataset)
plot_loss(loss_hist)
save_module(new_module, model_dir)
inspect_checkpoint(model_dir + '/variables/variables', print_values=True,
variables=['model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE'])
tensor : model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE (30,)
[ 1.3162548e-04 -1.0862495e-03 8.3323405e-04 8.4080239e-06
-1.6426330e-04 -9.0881845e-04 4.7971989e-04 -6.0352772e-06
-9.3550794e-04 -3.1544755e-03 5.4244534e-04 1.0909925e-03
1.3340317e-03 -1.0700974e-03 3.7469756e-04 -1.5879219e-03
-2.1641832e-03 -1.7716389e-03 2.8458738e-04 -6.3899945e-04
-2.9655998e-03 -1.7114554e-03 -3.9885961e-03 2.6567639e-05
-3.6036890e-05 6.1224034e-04 -1.0181948e-03 1.6523007e-04
-4.8340447e-03 1.5539475e-03]
del new_module
new_module_2 = tf.keras.models.load_model(model_dir)
loss_hist = train_module(new_module_2, train_dataset, valid_dataset)
plot_loss(loss_hist)
save_module(new_module_2, model_dir)