Skip to content

Instantly share code, notes, and snippets.

@bnaul
Created June 15, 2017 20:30
Show Gist options
  • Save bnaul/3ba7ac7c21e52ed5fa50b57c34e3b28a to your computer and use it in GitHub Desktop.
Save bnaul/3ba7ac7c21e52ed5fa50b57c34e3b28a to your computer and use it in GitHub Desktop.
from concurrent.futures import ThreadPoolExecutor
import keras
def f():
# NOTE: removing this line causes the `get_weights` call later to fail
# The initialization performed in `get_session` is important somehow?
print('Info:', keras.backend.get_session(), keras.backend.get_session().graph)
model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(4,)))
model.compile(loss='mse', optimizer='sgd')
print('Before:', model.predict(np.ones((1, 4))))
return model
def recreate(model):
w = model.get_weights() # can't access after recreating the session
keras.backend.clear_session()
model = keras.models.model_from_config(model._updated_config())
model.set_weights(w)
return model
if __name__ == '__main__':
e = ThreadPoolExecutor(max_workers=1)
m = e.submit(f).result()
m = recreate(m)
print('After:', m.predict(np.ones((1, 4))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment