Created
November 11, 2017 00:08
-
-
Save antishok/e565d19573a7fee68fa539f90ce4fc11 to your computer and use it in GitHub Desktop.
keras stateful lstm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import sys | |
from keras.models import Sequential | |
from keras.layers import Dense, Dropout, LSTM, Embedding, Lambda, Activation | |
from keras.layers.wrappers import TimeDistributed | |
from keras import backend as K | |
filepath = 'input.txt' | |
weights_file = 'model.best.h5' | |
raw_text = open(filepath, encoding='utf8').read() | |
chars = sorted(list(set(raw_text))) | |
char_to_idx = { c:i for i, c in enumerate(chars) } | |
idx_to_char = { i:c for i, c in enumerate(chars) } | |
n_chars = len(raw_text) | |
n_vocab = len(chars) | |
print("Total characters: ", n_chars) | |
print("Vocabulary size: ", n_vocab) | |
seq_len = 30 | |
embedding_size = 40 | |
#np.random.seed(123) | |
model = Sequential() | |
#model.add(Lambda(K.one_hot, arguments={'num_classes': n_vocab}, batch_input_shape=(1, 1), output_shape=(1, n_vocab), dtype=np.int32)) | |
model.add(Embedding(n_vocab, embedding_size, input_length=1, batch_input_shape=(1, 1))) | |
model.add(LSTM(256, batch_input_shape=(1, 1, embedding_size), stateful=True, return_sequences=True)) | |
model.add(LSTM(256, stateful=True, return_sequences=True)) | |
model.add(LSTM(128, stateful=True, return_sequences=True)) | |
model.add(TimeDistributed(Dense(n_vocab))) | |
model.add(Activation('softmax')) | |
model.load_weights(weights_file) | |
model.compile(loss='categorical_crossentropy', optimizer='adam') | |
def sample(preds, temperature=1.0): | |
# sample an index from a probability array | |
preds = np.asarray(preds).astype('float64') | |
preds[preds == 0.0] = 0.0000001 # otherwise np.log will warn when preds contains 0 | |
preds = np.log(preds) / temperature | |
exp_preds = np.exp(preds) | |
preds = exp_preds / np.sum(exp_preds) | |
probas = np.random.multinomial(1, preds, 1) | |
return np.argmax(probas) | |
if len(sys.argv) > 1: | |
init_text = ' '.join( sys.argv[1:] ) | |
else: | |
start = np.random.randint(0, n_chars - seq_len - 1) | |
init_text = raw_text[start:start + seq_len] | |
print('\n\nSeed:\n') | |
print(init_text + '{{SEED END}}') | |
for diversity in [0.6, 0.8, 1.0]: | |
print('\n\n--------------\nDiversity: ', diversity) | |
print('\nGenerated:\n') | |
model.reset_states() | |
for ch in init_text[:-1]: | |
idx = char_to_idx[ch] | |
x = np.array([idx], np.int32) | |
model.predict(x) | |
idx = char_to_idx[ init_text[-1] ] | |
for i in range(600): | |
x = np.array([idx], np.int32) | |
prediction = model.predict(x)[0,0] | |
if diversity > 0: | |
idx = sample(prediction, diversity) | |
else: | |
idx = np.argmax(prediction) | |
result = idx_to_char[idx] | |
sys.stdout.write(result) | |
sys.stdout.flush() | |
print('\n\nDone.') | |
# to prevent TF exception on exit: | |
from keras import backend as K | |
K.clear_session() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy as np | |
from keras.models import Sequential, load_model | |
from keras.layers import Dense, Activation, Dropout, LSTM, Embedding, Lambda | |
from keras.layers.wrappers import TimeDistributed | |
from keras import optimizers | |
from keras import backend as K | |
from keras.callbacks import ModelCheckpoint | |
filepath = 'input.txt' | |
seq_len = 30 | |
embedding_size = 40 | |
batch_size = 16 | |
num_epochs = 100 | |
dropout = 0.1 | |
validation_split = 0.05 | |
raw_text = open(filepath, encoding='utf8').read() | |
chars = sorted(list(set(raw_text))) | |
char_to_idx = { c:i for i, c in enumerate(chars) } | |
n_chars = len(raw_text) | |
n_vocab = len(chars) | |
num_patterns = (n_chars - 1) // seq_len | |
num_batches_total = num_patterns // batch_size | |
num_batches_val = int(num_batches_total * validation_split) | |
num_batches_train = num_batches_total - num_batches_val | |
print("Total characters: ", n_chars) | |
print("Vocabulary size: ", n_vocab) | |
print("Sequence length: ", seq_len) | |
print("Num batches: train(%s) validation(%s)" % (num_batches_train, num_batches_val)) | |
train_size = num_batches_train * batch_size * seq_len | |
train_data = raw_text[:train_size + 1] # extra char for the last sample's label | |
val_data = raw_text[train_size:] | |
def batch_generator(data, num_batches): | |
while True: | |
for batch_num in range(num_batches): | |
# generate one batch of samples: | |
X = np.zeros((batch_size, seq_len), np.int32) | |
Y = np.zeros((batch_size, seq_len, n_vocab), np.int32) | |
for i in range(batch_size): | |
offset = i * num_batches * seq_len + batch_num * seq_len | |
seq_in = data[offset:offset + seq_len] | |
seq_out = data[offset + 1:offset + 1 + seq_len] | |
X[i] = [char_to_idx[ch] for ch in seq_in] | |
for j in range(seq_len): | |
Y[i, j, char_to_idx[seq_out[j]] ] = 1 | |
yield (X, Y) | |
if len(sys.argv) > 1: | |
model = load_model(sys.argv[1]) | |
else: | |
model = Sequential() | |
#model.add(Lambda(K.one_hot, arguments={'num_classes': n_vocab}, batch_input_shape=(batch_size, seq_len), output_shape=(seq_len, n_vocab), dtype=np.int32)) | |
model.add(Embedding(n_vocab, embedding_size, input_length=seq_len, batch_input_shape=(batch_size, seq_len))) | |
model.add(LSTM(256, batch_input_shape=(batch_size, seq_len, embedding_size), dropout=dropout, stateful=True, return_sequences=True)) | |
model.add(LSTM(256, dropout=dropout, stateful=True, return_sequences=True)) | |
model.add(LSTM(128, dropout=dropout, stateful=True, return_sequences=True)) | |
model.add(TimeDistributed(Dense(n_vocab))) | |
model.add(Activation('softmax')) | |
optimizer = optimizers.Adam(clipnorm=1.5) | |
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) | |
output_file = 'model.best.h5' # model.E{epoch:02d}-L{loss:.2f}.h5' | |
checkpoint = ModelCheckpoint(output_file, verbose=1, save_best_only=True) | |
print("\nModel summary:\n") | |
model.summary() | |
model.fit_generator( | |
batch_generator(train_data, num_batches_train), | |
validation_data = batch_generator(val_data, num_batches_val), | |
steps_per_epoch = num_batches_train, | |
validation_steps = num_batches_val, | |
epochs = num_epochs, | |
callbacks = [checkpoint] | |
) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment