Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save visionscaper/06a75e9066a368fc2ed01cf0c3f606da to your computer and use it in GitHub Desktop.
Save visionscaper/06a75e9066a368fc2ed01cf0c3f606da to your computer and use it in GitHub Desktop.
[Python 3.5.2] Script showing that A) predicting and training a Keras model with non-stateful sub-models works, B) predicting and training a Keras model with stateful processing embedded (no sub-model) works, C) training of a Keras model with a stateful sub-model works, BUT predicting does NOT work!. Implementation C) is desired.
from keras.layers import Input
from keras.layers.recurrent import GRU, LSTM, SimpleRNN
from keras.layers.wrappers import TimeDistributed
from keras.layers.core import Dense, Activation, RepeatVector
from keras.layers.merge import Concatenate
from keras.layers import Dropout
from keras.optimizers import Adam
from keras.models import Model
import numpy as np
###################### SETUP #######################
batch_size = 128
seq_len = 30
num_symbols = 50
state_size = 256
RNNLayer = GRU
#####################################################
def create_encoder():
"""
EXAMPLE STATELESS ENCODER MODEL
"""
encoder_input = Input(
name="encoder-input",
batch_shape=(batch_size, seq_len, num_symbols))
output = RNNLayer(state_size,
name="encoder-layer",
stateful=False,
go_backwards=True,
return_sequences=False)(encoder_input)
output = Activation('tanh')(output)
output = Dropout(0.2)(output)
encoder = Model(inputs=[encoder_input],
outputs=[output],
name='encoder-model')
return encoder
def _process_context(encoded_user_chat_context_input,
encoded_previous_response_context_input,
stateful=True):
"""
ALL PROCESSING FOR CONTEXT MODEL
"""
# Convert in to one-step sequences
encoded_previous_response_seq = RepeatVector(1)(encoded_previous_response_context_input)
encoded_user_chat_seq = RepeatVector(1)(encoded_user_chat_context_input)
# Create a sequence of two time steps
sequence = Concatenate(axis=1)([encoded_previous_response_seq, encoded_user_chat_seq])
output = RNNLayer(state_size,
name="context-layer",
stateful=stateful,
return_sequences=False)(sequence)
output = Activation('tanh')(output)
output = Dropout(0.2)(output)
return output
def create_context_processor(stateful=True):
"""
EXAMPLE STATEFUL CONTEXT MODEL
"""
encoded_user_chat_context_input = Input(
name="encoded-user-chat-context-input",
batch_shape=(batch_size, state_size))
encoded_previous_response_context_input = Input(
name="encoded-previous-response-context-input",
batch_shape=(batch_size, state_size))
output = _process_context(encoded_user_chat_context_input,
encoded_previous_response_context_input,
stateful)
context_processor = Model(inputs=[encoded_user_chat_context_input, encoded_previous_response_context_input],
outputs=[output],
name='context-model')
return context_processor
def create_decoder():
"""
EXAMPLE STATELESS DECODER MODEL
"""
encoded_user_chat_decoder_input = Input(
name="encoded-user-chat-decoder-input",
batch_shape=(batch_size, state_size))
encoded_previous_response_decoder_input = Input(
name="encoded-previous-response-decoder-input",
batch_shape=(batch_size, state_size))
current_context_decoder_input = Input(
name="current-context-decoder-input",
batch_shape=(batch_size, state_size))
# concatenate encoded user chat, previous response and conversation context features
all_features = Concatenate(axis=1)([encoded_user_chat_decoder_input,
encoded_previous_response_decoder_input,
current_context_decoder_input])
all_features_repeated = RepeatVector(seq_len,
name='all-features-repeater')(all_features)
output = RNNLayer(state_size,
name="decoder-layer",
stateful=False,
return_sequences=True)(all_features_repeated)
output = Activation('tanh')(output)
output = Dropout(0.2)(output)
output = TimeDistributed(Dense(num_symbols,
name="decoder-output",
activation='softmax'))(output)
decoder = Model(inputs=[encoded_user_chat_decoder_input,
encoded_previous_response_decoder_input,
current_context_decoder_input],
outputs=[output],
name='decoder-model')
return decoder
def create_train_model(context_processor):
"""
CREATE TRAINING MODEL BASED ON SUB-MODELS
"""
user_chat_input = Input(
name="user-chat-input",
batch_shape=(batch_size, seq_len, num_symbols))
previous_response_input = Input(
name="previous-response-input",
batch_shape=(batch_size, seq_len, num_symbols))
encoded_user_chat = encoder(user_chat_input)
encoded_previous_response = encoder(previous_response_input)
current_context = context_processor([encoded_user_chat, encoded_previous_response])
response = decoder([encoded_user_chat, encoded_previous_response, current_context])
train_model = Model(inputs=[user_chat_input, previous_response_input],
outputs=[response],
name='train-model')
return train_model
def create_train_model_embed_context_processing_without_submodel():
"""
CREATE TRAINING MODEL BASED ON SUB-MODELS, BUT STATEFUL CONTEXT PROCESSING
EMBEDDED DIRECTLY WITHOUT SUB-MODEL
"""
user_chat_input = Input(
name="user-chat-input",
batch_shape=(batch_size, seq_len, num_symbols))
previous_response_input = Input(
name="previous-response-input",
batch_shape=(batch_size, seq_len, num_symbols))
encoded_user_chat = encoder(user_chat_input)
encoded_previous_response = encoder(previous_response_input)
# NO STATEFUL CONTEXT MODEL, BUT EMBED PROCESSING DIRECTLY
current_context = _process_context(encoded_user_chat,
encoded_previous_response,
stateful=True)
response = decoder([encoded_user_chat, encoded_previous_response, current_context])
train_model = Model(inputs=[user_chat_input, previous_response_input],
outputs=[response],
name='train-model')
return train_model
################## CREATE MODELS ###################
encoder = create_encoder()
context_processor_stateful = create_context_processor()
context_processor_not_stateful = create_context_processor(False)
decoder = create_decoder()
train_model_stateful = create_train_model(context_processor_stateful)
train_model_not_stateful = create_train_model(context_processor_not_stateful)
train_model_embedded_context_processing = \
create_train_model_embed_context_processing_without_submodel()
print('Encoder : ')
encoder.summary()
print('Context processor [STATEFUL] : ')
context_processor_stateful.summary()
print('Context processor [NOT STATEFUL] : ')
context_processor_not_stateful.summary()
print('Decoder : ')
decoder.summary()
print('Train model [STATEFUL, DESIRED IMPLEMENTATION] : ')
train_model_stateful.summary()
print('Train model [NOT STATEFUL] : ')
train_model_not_stateful.summary()
print('Train model [STATEFUL, EMBEDDED CONTEXT PROCESSING, WITHOUT SUB-MODEL] : ')
train_model_embedded_context_processing.summary()
####################################################
################# TRAIN & PREDICT ###################
def create_fake_chats(bs, sl, ns):
chats = np.random.rand(bs, sl, ns)
chats = np.argmax(chats, axis=2)
# one-hot encoding of characters
batch = np.zeros((bs, sl, ns), dtype=int)
seq = np.arange(sl)
for i in range(bs):
batch[i, seq, chats[i, seq]] = 1
return batch
train_model_stateful.compile(loss="categorical_crossentropy",
sample_weight_mode="temporal",
optimizer=Adam(),
weighted_metrics=["accuracy"])
train_model_not_stateful.compile(loss="categorical_crossentropy",
sample_weight_mode="temporal",
optimizer=Adam(),
weighted_metrics=["accuracy"])
train_model_embedded_context_processing.compile(loss="categorical_crossentropy",
sample_weight_mode="temporal",
optimizer=Adam(),
weighted_metrics=["accuracy"])
# INPUTS
user_chat_batch = create_fake_chats(batch_size, seq_len, num_symbols)
previous_response_batch = create_fake_chats(batch_size, seq_len, num_symbols)
# EXPECTED OUTPUTS
expected_response_batch = create_fake_chats(batch_size, seq_len, num_symbols)
# TIME STEP WEIGHTS
sample_weights_batch = np.ones((batch_size, seq_len), dtype=float)
print("=== EXPERIMENT 1 : TRAIN AND PREDICT MODEL : NO STATEFULNESS ===")
print("TRY TRAINING ON A BATCH...")
metrics = train_model_not_stateful.train_on_batch([user_chat_batch, previous_response_batch],
expected_response_batch,
sample_weight=sample_weights_batch)
print("WORKS!")
print("TRY PREDICTING ON BATCH...")
predicted_response_batch = train_model_not_stateful.predict_on_batch([user_chat_batch, previous_response_batch])
print("WORKS!")
print("=== EXPERIMENT 2 : TRAIN AND PREDICT MODEL : STATEFUL, CONTEXT PROCESSING EMBEDDED DIRECTLY, WITHOUT SUB-MODEL ===")
print("TRY TRAINING ON BATCH...")
metrics = train_model_embedded_context_processing.train_on_batch(
[user_chat_batch, previous_response_batch],
expected_response_batch,
sample_weight=sample_weights_batch)
print("WORKS!")
print("TRY PREDICTING ON BATCH...")
predicted_response_batch = train_model_embedded_context_processing.predict_on_batch([user_chat_batch,
previous_response_batch])
print("WORKS!")
print("=== EXPERIMENT 3 : TRAIN AND PREDICT MODEL : STATEFUL, WITH SUB-MODEL (DESIRED IMPLEMENTATION) ===")
print("TRY TRAINING ON A BATCH...")
metrics = train_model_stateful.train_on_batch(
[user_chat_batch, previous_response_batch],
expected_response_batch,
sample_weight=sample_weights_batch)
print("WORKS!")
print("TRY PREDICTING ON BATCH...")
try:
# TRY PREDICTING ON BATCH : DOES NOT WORK :( :( :(
predicted_response_batch = train_model_stateful.predict_on_batch([user_chat_batch, previous_response_batch])
print("WORKS!")
except Exception as e:
print("!!! DOES NOT WORK :( !!!")
print("Exception : ")
print(e)
#####################################################
@visionscaper
Copy link
Author

visionscaper commented Jan 15, 2018

Implementation C) with stateful sub-model results in exception:

You must feed a value for placeholder tensor 'encoded-previous-response-context-input' with dtype float and shape [128,256]
     [[Node: encoded-previous-response-context-input = Placeholder[dtype=DT_FLOAT, shape=[128,256], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment