Last active
November 17, 2017 17:20
-
-
Save antishok/5dfbd89289a4707ba8e8c6d6a2dea3e4 to your computer and use it in GitHub Desktop.
keras stateful lstm - word-level rnn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import sys | |
from keras.models import Sequential | |
from keras.layers import Dense, LSTM, Embedding, Activation | |
from keras.layers.wrappers import TimeDistributed | |
from keras import backend as K | |
from utils import Data, weighted_sample | |
filepath = 'input.txt' | |
weights_file = 'model.best.h5' | |
seq_len = 30 | |
embedding_size = 20 | |
num_word_tokens = 200 | |
raw_text = open(filepath, encoding='utf8').read() | |
data = Data(raw_text, num_word_tokens=num_word_tokens) | |
print("Vocabulary size: ", data.n_vocab) | |
model = Sequential() | |
#model.add(Lambda(K.one_hot, arguments={'num_classes': data.n_vocab}, batch_input_shape=(1, 1), output_shape=(1, data.n_vocab), dtype=np.int32)) | |
model.add(Embedding(data.n_vocab, embedding_size, input_length=1, batch_input_shape=(1, 1))) | |
model.add(LSTM(128, batch_input_shape=(1, 1, data.n_vocab), stateful=True, return_sequences=True)) | |
#model.add(LSTM(128, stateful=True, return_sequences=True)) | |
model.add(LSTM(128, stateful=True, return_sequences=True)) | |
model.add(TimeDistributed(Dense(data.n_vocab))) | |
model.add(Activation('softmax')) | |
model.load_weights(weights_file) | |
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam') | |
if len(sys.argv) > 1: | |
init_text = ' '.join( sys.argv[1:] ) | |
else: | |
start = np.random.randint(0, len(raw_text) - seq_len - 1) | |
init_text = raw_text[start:start + seq_len] | |
print('\n\nSeed:\n') | |
print(init_text + '{{SEED END}}') | |
for diversity in [0.6, 0.8, 1.0]: | |
print('\n\n--------------\nDiversity: ', diversity) | |
print('\nGenerated:\n') | |
model.reset_states() | |
init_text_idxs = data.data_to_idxs(init_text) | |
for idx in init_text_idxs[:-1]: | |
x = np.array([idx], np.int32) | |
model.predict(x) | |
idx = init_text_idxs[-1] | |
for i in range(600): | |
x = np.array([idx], np.int32) | |
prediction = model.predict(x)[0,0] | |
idx = weighted_sample(prediction, diversity) | |
result = data.idx_to_chars(idx) | |
sys.stdout.write(result) | |
sys.stdout.flush() | |
print('\n\nDone.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy as np | |
from keras.models import Sequential, load_model | |
from keras.layers import Dense, Activation, LSTM, Embedding, Lambda | |
from keras.layers.normalization import BatchNormalization | |
from keras.layers.wrappers import TimeDistributed | |
from keras import optimizers | |
from keras.callbacks import ModelCheckpoint | |
from keras import backend as K | |
from utils import Data | |
filepath = 'input.txt' | |
seq_len = 30 | |
embedding_size = 20 | |
num_word_tokens = 200 | |
batch_size = 10 | |
num_epochs = 100 | |
dropout = 0.1 | |
validation_split = 0.1 | |
raw_text = open(filepath, encoding='utf8').read() | |
data = Data(raw_text, num_word_tokens=num_word_tokens, batch_size=batch_size, seq_len=seq_len) | |
data_idxs = data.data_to_idxs(raw_text) | |
num_patterns = (len(data_idxs) - 1) // seq_len | |
num_batches_total = num_patterns // batch_size | |
num_batches_val = int(num_batches_total * validation_split) | |
num_batches_train = num_batches_total - num_batches_val | |
print("Total characters: ", len(raw_text)) | |
print("Num word tokens: ", len(data.word_tokens)) | |
print("Vocabulary size: ", data.n_vocab) | |
print("Sequence length: ", seq_len) | |
print("Num batches: train(%s) validation(%s)" % (num_batches_train, num_batches_val)) | |
train_size = num_batches_train * batch_size * seq_len | |
train_data = data_idxs[:train_size + 1] # extra char for the last sample's label | |
val_data = data_idxs[train_size:] | |
if len(sys.argv) > 1: | |
model = load_model(sys.argv[1]) | |
else: | |
model = Sequential() | |
if embedding_size > 0: | |
model.add(Embedding(data.n_vocab, embedding_size, input_length=seq_len, batch_input_shape=(batch_size, seq_len))) | |
else: | |
model.add(Lambda(K.one_hot, arguments={'num_classes': data.n_vocab}, batch_input_shape=(batch_size, seq_len), output_shape=(seq_len, data.n_vocab), dtype=np.int32)) | |
model.add(BatchNormalization()) | |
model.add(LSTM(128, batch_input_shape=(batch_size, seq_len, embedding_size or data.n_vocab), dropout=dropout, stateful=True, return_sequences=True)) | |
#model.add(LSTM(128, dropout=dropout, stateful=True, return_sequences=True)) | |
model.add(LSTM(128, dropout=dropout, stateful=True, return_sequences=True)) | |
model.add(TimeDistributed(Dense(data.n_vocab))) | |
model.add(Activation('softmax')) | |
optimizer = optimizers.Adam(clipnorm=1.5) | |
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) | |
output_file = 'model.best.h5' # ''weights.E{epoch:02d}-L{loss:.2f}.h5' # 'weights.best.h5' | |
checkpoint = ModelCheckpoint(output_file, verbose=1, save_best_only=True) | |
print("\nModel summary:\n") | |
model.summary() | |
model.fit_generator( | |
data.batch_generator(train_data, num_batches_train), | |
validation_data = data.batch_generator(val_data, num_batches_val), | |
steps_per_epoch = num_batches_train, | |
validation_steps = num_batches_val, | |
epochs = num_epochs, | |
callbacks = [checkpoint] | |
) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import re | |
from collections import Counter | |
class Data: | |
def __init__(self, raw_text, num_word_tokens = 0, batch_size = 32, seq_len = 20): | |
self.batch_size = batch_size | |
self.seq_len = seq_len | |
self.chars = sorted(list(set(raw_text))) | |
self.word_tokens = [] | |
if num_word_tokens: | |
tokenized = re.findall(r'[a-zA-Zא-ת\']+|.', raw_text, re.DOTALL) | |
ctr = Counter(tokenized) | |
self.word_tokens = [word for word, freq | |
in ctr.most_common(num_word_tokens + len(self.chars)) | |
if len(word) > 1][:num_word_tokens] | |
self.char_to_idx = {c: i for i, c in enumerate(self.chars)} | |
self.token_to_idx = {tok: i + len(self.chars) for i, tok in enumerate(self.word_tokens)} | |
self.n_vocab = len(self.chars) + len(self.word_tokens) | |
def data_to_idxs(self, data): | |
if self.word_tokens: | |
data = re.findall(r'[a-zA-Zא-ת\']+|.', data, re.DOTALL) | |
idxs_list = [] | |
for token in data: | |
if token in self.token_to_idx: | |
idxs_list.append(self.token_to_idx[token]) | |
else: | |
for ch in token: | |
idxs_list.append(self.char_to_idx[ch]) | |
idxs = np.array(idxs_list, dtype=np.int32) | |
else: | |
idxs = np.zeros(len(data), dtype=np.int32) | |
for i, c in enumerate(data): | |
idxs[i] = self.char_to_idx[c] | |
return idxs | |
def idx_to_chars(self, idx): | |
if idx < len(self.chars): | |
return self.chars[idx] | |
else: | |
return self.word_tokens[idx - len(self.chars)] | |
def batch_generator(self, data, num_batches, labels='indices'): | |
batch_size, seq_len, n_vocab = self.batch_size, self.seq_len, self.n_vocab | |
while True: | |
for batch_num in range(num_batches): | |
# generate one batch of samples: | |
X = np.zeros((batch_size, seq_len), np.int32) | |
if labels == 'indices': # for use with 'sparse_categorical_crossentropy' loss | |
Y = np.zeros((batch_size, seq_len, 1), np.int32) | |
else: # for use with 'categorical_crossentropy' loss (labels are one-hot) | |
Y = np.zeros((batch_size, seq_len, n_vocab), np.int32) | |
for i in range(batch_size): | |
offset = i * num_batches * seq_len + batch_num * seq_len | |
seq_in = data[offset:offset + seq_len] | |
seq_out = data[offset + 1:offset + 1 + seq_len] | |
X[i] = seq_in | |
if labels == 'indices': | |
Y[i] = np.expand_dims(seq_out, -1) | |
else: | |
for j in range(seq_len): | |
Y[i, j, seq_out[j]] = 1 | |
yield (X, Y) | |
def weighted_sample(preds, temperature=1.0): | |
# sample an index from a probability array | |
if temperature <= 0: | |
return np.argmax(preds) | |
preds = np.asarray(preds).astype('float64') | |
preds[preds == 0.0] = 0.0000001 # otherwise np.log will warn when preds contains 0 | |
preds = np.log(preds) / temperature | |
exp_preds = np.exp(preds) | |
preds = exp_preds / np.sum(exp_preds) | |
probas = np.random.multinomial(1, preds, 1) | |
return np.argmax(probas) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment