Skip to content

Instantly share code, notes, and snippets.

@rpicatoste
Last active October 28, 2022 09:40
Show Gist options
  • Save rpicatoste/02cecac1ed52524301e3ab423dac888b to your computer and use it in GitHub Desktop.
Save rpicatoste/02cecac1ed52524301e3ab423dac888b to your computer and use it in GitHub Desktop.
Function to convert a Keras LSTM model trained as stateless to a stateful model expecting a single sample and time step as input to use in inference.
import json
from keras.models import model_from_json
def convert_to_inference_model(original_model):
original_model_json = original_model.to_json()
inference_model_dict = json.loads(original_model_json)
layers = inference_model_dict['config']
for layer in layers:
if 'stateful' in layer['config']:
layer['config']['stateful'] = True
if 'batch_input_shape' in layer['config']:
layer['config']['batch_input_shape'][0] = 1
layer['config']['batch_input_shape'][1] = None
inference_model = model_from_json(json.dumps(inference_model_dict))
inference_model.set_weights(original_model.get_weights())
return inference_model
@jbonyun
Copy link

jbonyun commented Feb 9, 2021

Works in keras 2.3.0-tf if you change:

layers = inference_model_dict['config']

to

layers = inference_model_dict['config']['layers']

I also added a change to a Reshape layer, which was reshaping to the number of timesteps I used in training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment