Last active
February 5, 2018 15:30
-
-
Save visionscaper/06a75e9066a368fc2ed01cf0c3f606da to your computer and use it in GitHub Desktop.
[Python 3.5.2] Script showing that A) predicting and training a Keras model with non-stateful sub-models works, B) predicting and training a Keras model with stateful processing embedded (no sub-model) works, C) training of a Keras model with a stateful sub-model works, BUT predicting does NOT work!. Implementation C) is desired.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Input | |
from keras.layers.recurrent import GRU, LSTM, SimpleRNN | |
from keras.layers.wrappers import TimeDistributed | |
from keras.layers.core import Dense, Activation, RepeatVector | |
from keras.layers.merge import Concatenate | |
from keras.layers import Dropout | |
from keras.optimizers import Adam | |
from keras.models import Model | |
import numpy as np | |
###################### SETUP ####################### | |
batch_size = 128 | |
seq_len = 30 | |
num_symbols = 50 | |
state_size = 256 | |
RNNLayer = GRU | |
##################################################### | |
def create_encoder(): | |
""" | |
EXAMPLE STATELESS ENCODER MODEL | |
""" | |
encoder_input = Input( | |
name="encoder-input", | |
batch_shape=(batch_size, seq_len, num_symbols)) | |
output = RNNLayer(state_size, | |
name="encoder-layer", | |
stateful=False, | |
go_backwards=True, | |
return_sequences=False)(encoder_input) | |
output = Activation('tanh')(output) | |
output = Dropout(0.2)(output) | |
encoder = Model(inputs=[encoder_input], | |
outputs=[output], | |
name='encoder-model') | |
return encoder | |
def _process_context(encoded_user_chat_context_input, | |
encoded_previous_response_context_input, | |
stateful=True): | |
""" | |
ALL PROCESSING FOR CONTEXT MODEL | |
""" | |
# Convert in to one-step sequences | |
encoded_previous_response_seq = RepeatVector(1)(encoded_previous_response_context_input) | |
encoded_user_chat_seq = RepeatVector(1)(encoded_user_chat_context_input) | |
# Create a sequence of two time steps | |
sequence = Concatenate(axis=1)([encoded_previous_response_seq, encoded_user_chat_seq]) | |
output = RNNLayer(state_size, | |
name="context-layer", | |
stateful=stateful, | |
return_sequences=False)(sequence) | |
output = Activation('tanh')(output) | |
output = Dropout(0.2)(output) | |
return output | |
def create_context_processor(stateful=True): | |
""" | |
EXAMPLE STATEFUL CONTEXT MODEL | |
""" | |
encoded_user_chat_context_input = Input( | |
name="encoded-user-chat-context-input", | |
batch_shape=(batch_size, state_size)) | |
encoded_previous_response_context_input = Input( | |
name="encoded-previous-response-context-input", | |
batch_shape=(batch_size, state_size)) | |
output = _process_context(encoded_user_chat_context_input, | |
encoded_previous_response_context_input, | |
stateful) | |
context_processor = Model(inputs=[encoded_user_chat_context_input, encoded_previous_response_context_input], | |
outputs=[output], | |
name='context-model') | |
return context_processor | |
def create_decoder(): | |
""" | |
EXAMPLE STATELESS DECODER MODEL | |
""" | |
encoded_user_chat_decoder_input = Input( | |
name="encoded-user-chat-decoder-input", | |
batch_shape=(batch_size, state_size)) | |
encoded_previous_response_decoder_input = Input( | |
name="encoded-previous-response-decoder-input", | |
batch_shape=(batch_size, state_size)) | |
current_context_decoder_input = Input( | |
name="current-context-decoder-input", | |
batch_shape=(batch_size, state_size)) | |
# concatenate encoded user chat, previous response and conversation context features | |
all_features = Concatenate(axis=1)([encoded_user_chat_decoder_input, | |
encoded_previous_response_decoder_input, | |
current_context_decoder_input]) | |
all_features_repeated = RepeatVector(seq_len, | |
name='all-features-repeater')(all_features) | |
output = RNNLayer(state_size, | |
name="decoder-layer", | |
stateful=False, | |
return_sequences=True)(all_features_repeated) | |
output = Activation('tanh')(output) | |
output = Dropout(0.2)(output) | |
output = TimeDistributed(Dense(num_symbols, | |
name="decoder-output", | |
activation='softmax'))(output) | |
decoder = Model(inputs=[encoded_user_chat_decoder_input, | |
encoded_previous_response_decoder_input, | |
current_context_decoder_input], | |
outputs=[output], | |
name='decoder-model') | |
return decoder | |
def create_train_model(context_processor): | |
""" | |
CREATE TRAINING MODEL BASED ON SUB-MODELS | |
""" | |
user_chat_input = Input( | |
name="user-chat-input", | |
batch_shape=(batch_size, seq_len, num_symbols)) | |
previous_response_input = Input( | |
name="previous-response-input", | |
batch_shape=(batch_size, seq_len, num_symbols)) | |
encoded_user_chat = encoder(user_chat_input) | |
encoded_previous_response = encoder(previous_response_input) | |
current_context = context_processor([encoded_user_chat, encoded_previous_response]) | |
response = decoder([encoded_user_chat, encoded_previous_response, current_context]) | |
train_model = Model(inputs=[user_chat_input, previous_response_input], | |
outputs=[response], | |
name='train-model') | |
return train_model | |
def create_train_model_embed_context_processing_without_submodel(): | |
""" | |
CREATE TRAINING MODEL BASED ON SUB-MODELS, BUT STATEFUL CONTEXT PROCESSING | |
EMBEDDED DIRECTLY WITHOUT SUB-MODEL | |
""" | |
user_chat_input = Input( | |
name="user-chat-input", | |
batch_shape=(batch_size, seq_len, num_symbols)) | |
previous_response_input = Input( | |
name="previous-response-input", | |
batch_shape=(batch_size, seq_len, num_symbols)) | |
encoded_user_chat = encoder(user_chat_input) | |
encoded_previous_response = encoder(previous_response_input) | |
# NO STATEFUL CONTEXT MODEL, BUT EMBED PROCESSING DIRECTLY | |
current_context = _process_context(encoded_user_chat, | |
encoded_previous_response, | |
stateful=True) | |
response = decoder([encoded_user_chat, encoded_previous_response, current_context]) | |
train_model = Model(inputs=[user_chat_input, previous_response_input], | |
outputs=[response], | |
name='train-model') | |
return train_model | |
################## CREATE MODELS ################### | |
encoder = create_encoder() | |
context_processor_stateful = create_context_processor() | |
context_processor_not_stateful = create_context_processor(False) | |
decoder = create_decoder() | |
train_model_stateful = create_train_model(context_processor_stateful) | |
train_model_not_stateful = create_train_model(context_processor_not_stateful) | |
train_model_embedded_context_processing = \ | |
create_train_model_embed_context_processing_without_submodel() | |
print('Encoder : ') | |
encoder.summary() | |
print('Context processor [STATEFUL] : ') | |
context_processor_stateful.summary() | |
print('Context processor [NOT STATEFUL] : ') | |
context_processor_not_stateful.summary() | |
print('Decoder : ') | |
decoder.summary() | |
print('Train model [STATEFUL, DESIRED IMPLEMENTATION] : ') | |
train_model_stateful.summary() | |
print('Train model [NOT STATEFUL] : ') | |
train_model_not_stateful.summary() | |
print('Train model [STATEFUL, EMBEDDED CONTEXT PROCESSING, WITHOUT SUB-MODEL] : ') | |
train_model_embedded_context_processing.summary() | |
#################################################### | |
################# TRAIN & PREDICT ################### | |
def create_fake_chats(bs, sl, ns): | |
chats = np.random.rand(bs, sl, ns) | |
chats = np.argmax(chats, axis=2) | |
# one-hot encoding of characters | |
batch = np.zeros((bs, sl, ns), dtype=int) | |
seq = np.arange(sl) | |
for i in range(bs): | |
batch[i, seq, chats[i, seq]] = 1 | |
return batch | |
train_model_stateful.compile(loss="categorical_crossentropy", | |
sample_weight_mode="temporal", | |
optimizer=Adam(), | |
weighted_metrics=["accuracy"]) | |
train_model_not_stateful.compile(loss="categorical_crossentropy", | |
sample_weight_mode="temporal", | |
optimizer=Adam(), | |
weighted_metrics=["accuracy"]) | |
train_model_embedded_context_processing.compile(loss="categorical_crossentropy", | |
sample_weight_mode="temporal", | |
optimizer=Adam(), | |
weighted_metrics=["accuracy"]) | |
# INPUTS | |
user_chat_batch = create_fake_chats(batch_size, seq_len, num_symbols) | |
previous_response_batch = create_fake_chats(batch_size, seq_len, num_symbols) | |
# EXPECTED OUTPUTS | |
expected_response_batch = create_fake_chats(batch_size, seq_len, num_symbols) | |
# TIME STEP WEIGHTS | |
sample_weights_batch = np.ones((batch_size, seq_len), dtype=float) | |
print("=== EXPERIMENT 1 : TRAIN AND PREDICT MODEL : NO STATEFULNESS ===") | |
print("TRY TRAINING ON A BATCH...") | |
metrics = train_model_not_stateful.train_on_batch([user_chat_batch, previous_response_batch], | |
expected_response_batch, | |
sample_weight=sample_weights_batch) | |
print("WORKS!") | |
print("TRY PREDICTING ON BATCH...") | |
predicted_response_batch = train_model_not_stateful.predict_on_batch([user_chat_batch, previous_response_batch]) | |
print("WORKS!") | |
print("=== EXPERIMENT 2 : TRAIN AND PREDICT MODEL : STATEFUL, CONTEXT PROCESSING EMBEDDED DIRECTLY, WITHOUT SUB-MODEL ===") | |
print("TRY TRAINING ON BATCH...") | |
metrics = train_model_embedded_context_processing.train_on_batch( | |
[user_chat_batch, previous_response_batch], | |
expected_response_batch, | |
sample_weight=sample_weights_batch) | |
print("WORKS!") | |
print("TRY PREDICTING ON BATCH...") | |
predicted_response_batch = train_model_embedded_context_processing.predict_on_batch([user_chat_batch, | |
previous_response_batch]) | |
print("WORKS!") | |
print("=== EXPERIMENT 3 : TRAIN AND PREDICT MODEL : STATEFUL, WITH SUB-MODEL (DESIRED IMPLEMENTATION) ===") | |
print("TRY TRAINING ON A BATCH...") | |
metrics = train_model_stateful.train_on_batch( | |
[user_chat_batch, previous_response_batch], | |
expected_response_batch, | |
sample_weight=sample_weights_batch) | |
print("WORKS!") | |
print("TRY PREDICTING ON BATCH...") | |
try: | |
# TRY PREDICTING ON BATCH : DOES NOT WORK :( :( :( | |
predicted_response_batch = train_model_stateful.predict_on_batch([user_chat_batch, previous_response_batch]) | |
print("WORKS!") | |
except Exception as e: | |
print("!!! DOES NOT WORK :( !!!") | |
print("Exception : ") | |
print(e) | |
##################################################### |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Implementation C) with stateful sub-model results in exception: