Skip to content

Instantly share code, notes, and snippets.

@chris838
Created July 22, 2016 16:17
Show Gist options
  • Save chris838/cea1987c38e0f29a2a514ad229454c0e to your computer and use it in GitHub Desktop.
Save chris838/cea1987c38e0f29a2a514ad229454c0e to your computer and use it in GitHub Desktop.
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
class MyLSTM(Chain):
def __init__(self, vocab_size, hidden_size):
super(MyLSTM, self).__init__(
mid=L.LSTM(vocab_size, hidden_size),
out=L.Linear(hidden_size, vocab_size),
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.W = np.identity(vocab_size).astype(np.float32)
def reset_state(self):
self.mid.reset_state()
def __call__(self, x):
x_1hot = F.embed_id(x, self.W)
h = self.mid(x_1hot)
y = self.out(h)
return y
np.random.seed(10)
# hyper parameters
hidden_size = 50
# data I/O
text = open('data/sherlock-300.txt', 'r').read()
chars = list(set(text))
data_size, vocab_size = len(text), len(chars)
print 'data has %d characters, %d unique.' % (data_size, vocab_size)
char_to_idx = {ch: idx for idx, ch in enumerate(chars)}
idx_to_char = {idx: ch for idx, ch in enumerate(chars)}
text_as_idxs = np.array([char_to_idx[ch] for ch in text]).astype(np.int32)
data = zip(text_as_idxs[1:], text_as_idxs[:-1])
test_size = 2000
test, train = data[:test_size], data[test_size:]
train_iter = iterators.SerialIterator(train, batch_size=100)
test_iter = iterators.SerialIterator(
test, batch_size=100, repeat=False, shuffle=False)
# create net
net = MyLSTM(vocab_size, hidden_size)
model = L.Classifier(net)
optimizer = optimizers.SGD()
optimizer.setup(model)
updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (20, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss']))
trainer.extend(extensions.ProgressBar(update_interval=10))
# start training
trainer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment