Last active
December 11, 2017 09:06
-
-
Save antishok/69d1a640f7113e255b471c973b57bf1c to your computer and use it in GitHub Desktop.
keras char-rnn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy as np | |
from keras.models import Sequential, load_model | |
from keras.layers import Dense, Activation, Dropout, LSTM, Embedding, Lambda | |
from keras import optimizers | |
from keras import backend as K | |
from keras.callbacks import ModelCheckpoint | |
filepath = 'input.txt' | |
seq_len = 80 | |
#embedding_size = 40 | |
batch_size = 64 | |
num_epochs = 100 | |
dropout = 0.2 | |
validation_split = 0.05 | |
raw_text = open(filepath, encoding='utf8').read() | |
chars = sorted(list(set(raw_text))) | |
char_to_idx = { c:i for i, c in enumerate(chars) } | |
n_chars = len(raw_text) | |
n_vocab = len(chars) | |
num_patterns = (n_chars - 1) // seq_len | |
num_batches_total = num_patterns // batch_size | |
num_batches_val = int(num_batches_total * validation_split) | |
num_batches_train = num_batches_total - num_batches_val | |
print("Total characters: ", n_chars) | |
print("Vocabulary size: ", n_vocab) | |
print("Sequence length: ", seq_len) | |
print("Num batches: train(%s) validation(%s)" % (num_batches_train, num_batches_val)) | |
train_size = num_batches_train * batch_size * seq_len | |
train_data = raw_text[:train_size + 1] # extra char for the last sample's label | |
val_data = raw_text[train_size:] | |
def batch_generator(data, num_batches): | |
while True: | |
for batch_num in range(num_batches): | |
# generate one batch of samples: | |
X = np.zeros((batch_size, seq_len), np.int32) | |
Y = np.zeros((batch_size, n_vocab), np.int32) | |
for i in range(batch_size): | |
offset = i * num_batches * seq_len + batch_num * seq_len | |
seq_in = data[offset:offset + seq_len] | |
seq_out = data[offset + seq_len] | |
X[i] = [char_to_idx[ch] for ch in seq_in] | |
Y[i][ char_to_idx[seq_out] ] = 1 | |
yield (X, Y) | |
if len(sys.argv) > 1: | |
model = load_model(sys.argv[1]) | |
else: | |
model = Sequential() | |
model.add(Lambda(K.one_hot, arguments={'num_classes': n_vocab}, batch_input_shape=(batch_size, seq_len), output_shape=(seq_len, n_vocab), dtype=np.int32)) | |
#model.add(Embedding(n_vocab, embedding_size, input_length=seq_len, batch_input_shape=(batch_size, seq_len))) | |
model.add(LSTM(512, batch_input_shape=(batch_size, seq_len, n_vocab), dropout=dropout, recurrent_dropout=dropout, stateful=True, return_sequences=True)) | |
model.add(LSTM(512, dropout=dropout, recurrent_dropout=dropout, stateful=True, return_sequences=True)) | |
model.add(LSTM(512, dropout=dropout, recurrent_dropout=dropout, stateful=True)) | |
model.add(Dense(n_vocab, activation='softmax')) | |
optimizer = optimizers.Adam(clipnorm=1.5) | |
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) | |
output_file = 'model.E{epoch:02d}-L{loss:.2f}.h5' # 'weights.best.h5' | |
checkpoint = ModelCheckpoint(output_file, verbose=1, save_best_only=True) | |
model.summary() | |
model.fit_generator( | |
batch_generator(train_data, num_batches_train), | |
validation_data = batch_generator(val_data, num_batches_val), | |
steps_per_epoch = num_batches_train, | |
validation_steps = num_batches_val, | |
epochs = num_epochs, | |
callbacks = [checkpoint] | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import sys | |
from keras.models import Sequential | |
from keras.layers import Dense, Dropout, LSTM, Embedding, Lambda | |
from keras import backend as K | |
filepath = 'input.txt' | |
weights_file = 'model.best.h5' | |
raw_text = open(filepath, encoding='utf8').read() | |
chars = sorted(list(set(raw_text))) | |
char_to_idx = { c:i for i, c in enumerate(chars) } | |
idx_to_char = { i:c for i, c in enumerate(chars) } | |
n_chars = len(raw_text) | |
n_vocab = len(chars) | |
print("Total characters: ", n_chars) | |
print("Vocabulary size: ", n_vocab) | |
seq_len = 80 | |
#embedding_size = 40 | |
model = Sequential() | |
model.add(Lambda(K.one_hot, arguments={'num_classes': n_vocab}, batch_input_shape=(1, 1), output_shape=(1, n_vocab), dtype=np.int32)) | |
#model.add(Embedding(n_vocab, embedding_size, input_length=1, batch_input_shape=(1, 1))) | |
model.add(LSTM(512, batch_input_shape=(1, 1, n_vocab), stateful=True, return_sequences=True)) | |
model.add(LSTM(512, stateful=True, return_sequences=True)) | |
model.add(LSTM(512, stateful=True)) | |
model.add(Dense(n_vocab, activation='softmax')) | |
model.load_weights(weights_file) | |
model.compile(loss='categorical_crossentropy', optimizer='adam') | |
def sample(preds, temperature=1.0): | |
# sample an index from a probability array | |
preds = np.asarray(preds).astype('float64') | |
preds[preds == 0.0] = 0.0000001 # otherwise np.log will warn when preds contains 0 | |
preds = np.log(preds) / temperature | |
exp_preds = np.exp(preds) | |
preds = exp_preds / np.sum(exp_preds) | |
probas = np.random.multinomial(1, preds, 1) | |
return np.argmax(probas) | |
if len(sys.argv) > 1: | |
init_text = ' '.join( sys.argv[1:] ) | |
else: | |
start = np.random.randint(0, n_chars - seq_len - 1) | |
init_text = raw_text[start:start + seq_len] | |
print('\n\nSeed:\n') | |
print(init_text + '{{SEED END}}') | |
for diversity in [0.6, 0.8, 1.0]: | |
print('\n\n--------------\nDiversity: ', diversity) | |
print('\nGenerated:\n') | |
model.reset_states() | |
for ch in init_text[:-1]: | |
idx = char_to_idx[ch] | |
x = np.array([idx], np.int32) | |
model.predict(x) | |
idx = char_to_idx[ init_text[-1] ] | |
for i in range(600): | |
x = np.array([idx], np.int32) | |
prediction = model.predict(x)[0] | |
if diversity > 0: | |
idx = sample(prediction, diversity) | |
else: | |
idx = np.argmax(prediction) | |
result = idx_to_char[idx] | |
sys.stdout.write(result) | |
sys.stdout.flush() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment