Created
August 17, 2015 11:32
-
-
Save Smerity/418a4e7f9e719ff02bf3 to your computer and use it in GitHub Desktop.
Epoch tuning through early stopping for bAbi RNN in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import print_function | |
from functools import reduce | |
import re | |
import tarfile | |
import numpy as np | |
np.random.seed(1337) # for reproducibility | |
bAs such, I agree strongly with you that this won't make a good test dataset for testing various RNN architectures.from keras.callbacks import EarlyStopping | |
from keras.datasets.data_utils import get_file | |
from keras.initializations import normal, identity | |
from keras.layers.embeddings import Embedding | |
from keras.layers.core import Dense, Dropout, Merge | |
from keras.layers import recurrent | |
from keras.models import Sequential | |
from keras.preprocessing.sequence import pad_sequences | |
''' | |
Trains two recurrent neural networks based upon a story and a question. | |
The resulting merged vector is then queried to answer a range of bAbI tasks. | |
The results are comparable to those for an LSTM model provided in Weston et al.: | |
"Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks" | |
http://arxiv.org/abs/1502.05698 | |
Task Number | FB-LSTM | LSTM | GRU | IRNN | RNN | | |
----------- | ------- | ---- | --- | ---- | --- | | |
QA1 - Single Supporting Fact | 50 | 51.2 | 52.1 | 47.7 | 52.7 | | |
QA2 - Two Supporting Facts | 20 | 21.8 | 37.0 | 19.7 | 27.8 | | |
QA3 - Three Supporting Facts | 20 | 20.1 | 20.5 | 21.3 | 22.4 | | |
QA4 - Two Arg. Relations | 61 | 56.2 | 62.9 | 69.0 | 20.0 | | |
QA5 - Three Arg. Relations | 70 | 46.8 | 61.9 | 32.7 | 38.8 | | |
QA6 - Yes/No Questions | 48 | 49.1 | 50.7 | 49.3 | 44.8 | | |
QA7 - Counting | 49 | 76.1 | 78.9 | 75.4 | 63.2 | | |
QA8 - Lists/Sets | 45 | 72.1 | 77.2 | 73.7 | 41.0 | | |
QA9 - Simple Negation | 64 | 63.5 | 64.0 | 58.6 | 63.8 | | |
QA10 - Indefinite Knowledge | 44 | 47.6 | 47.7 | 47.7 | 42.8 | | |
QA11 - Basic Coreference | 72 | 71.9 | 74.9 | 74.0 | 75.1 | | |
QA12 - Conjunction | 74 | 73.2 | 76.4 | 71.0 | 77.2 | | |
QA13 - Compound Coreference | 94 | 94.0 | 94.4 | 94.0 | 94.4 | | |
QA14 - Time Reasoning | 27 | 23.7 | 34.8 | 30.5 | 19.9 | | |
QA15 - Basic Deduction | 21 | 21.7 | 32.4 | 54.0 | 23.9 | | |
QA16 - Basic Induction | 23 | 44.4 | 50.6 | 49.4 | 41.8 | | |
QA17 - Positional Reasoning | 51 | 52.1 | 49.1 | 48.9 | 52.4 | | |
QA18 - Size Reasoning | 52 | 91.0 | 90.8 | 58.4 | 54.8 | | |
QA19 - Path Finding | 8 | 9.5 | 9.0 | 11.5 | 7.1 | | |
QA20 - Agent's Motivations | 91 | 93.5 | 90.7 | 97.6 | 92.2 | | |
For the resources related to the bAbI project, refer to: | |
https://research.facebook.com/researchers/1543934539189348 | |
Notes: | |
- The task does not traditionally parse the question separately. This likely | |
improves accuracy and is a good example of merging two RNNs. | |
- The word vector embeddings are not shared between the story and question RNNs. | |
- See how the accuracy changes given 10,000 training samples (en-10k) instead | |
of only 1000. 1000 was used in order to be comparable to the original paper. | |
- Experiment with SimpleRNN, IRNN, GRU, and JZS1-3. | |
- The length and noise (i.e. 'useless' story components) impact the ability for | |
LSTMs / GRUs to provide the correct answer. Given only the supporting facts, | |
these RNNs can achieve 100% accuracy on many tasks. Memory networks and neural | |
networks that use attentional processes can efficiently search through this | |
noise to find the relevant statements, improving performance substantially. | |
This becomes especially obvious on QA2 and QA3, both far longer than QA1. | |
''' | |
def tokenize(sent): | |
'''Return the tokens of a sentence including punctuation. | |
>>> tokenize('Bob dropped the apple. Where is the apple?') | |
['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?'] | |
''' | |
return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()] | |
def parse_stories(lines, only_supporting=False): | |
'''Parse stories provided in the bAbi tasks format | |
If only_supporting is true, only the sentences that support the answer are kept. | |
''' | |
data = [] | |
story = [] | |
for line in lines: | |
line = line.decode('utf-8').strip() | |
nid, line = line.split(' ', 1) | |
nid = int(nid) | |
if nid == 1: | |
story = [] | |
if '\t' in line: | |
q, a, supporting = line.split('\t') | |
q = tokenize(q) | |
substory = None | |
if only_supporting: | |
# Only select the related substory | |
supporting = map(int, supporting.split()) | |
substory = [story[i - 1] for i in supporting] | |
else: | |
# Provide all the substories | |
substory = [x for x in story if x] | |
data.append((substory, q, a)) | |
story.append('') | |
else: | |
sent = tokenize(line) | |
story.append(sent) | |
return data | |
def get_stories(f, only_supporting=False, max_length=None): | |
'''Given a file name, read the file, retrieve the stories, and then convert the sentences into a single story. | |
If max_length is supplied, any stories longer than max_length tokens will be discarded. | |
''' | |
data = parse_stories(f.readlines(), only_supporting=only_supporting) | |
flatten = lambda data: reduce(lambda x, y: x + y, data) | |
data = [(flatten(story), q, answer) for story, q, answer in data if not max_length or len(flatten(story)) < max_length] | |
return data | |
def vectorize_stories(data): | |
X = [] | |
Xq = [] | |
Y = [] | |
for story, query, answer in data: | |
x = [word_idx[w] for w in story] | |
xq = [word_idx[w] for w in query] | |
y = np.zeros(vocab_size) | |
y[word_idx[answer]] = 1 | |
X.append(x) | |
Xq.append(xq) | |
Y.append(y) | |
return pad_sequences(X, maxlen=story_maxlen), pad_sequences(Xq, maxlen=query_maxlen), np.array(Y) | |
IRNN = lambda *args, **kwargs: \ | |
recurrent.SimpleRNN(*args, init=lambda shape: normal(shape, scale=0.001), | |
inner_init=lambda shape: identity(shape, scale=1.0), | |
activation='relu', **kwargs) | |
#RNN = recurrent.SimpleRNN | |
#RNN = IRNN | |
#RNN = recurrent.GRU | |
RNN = recurrent.LSTM | |
EMBED_HIDDEN_SIZE = 50 | |
SENT_HIDDEN_SIZE = 100 | |
QUERY_HIDDEN_SIZE = 100 | |
BATCH_SIZE = 32 | |
MAX_EPOCHS = 100 | |
print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN, EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, QUERY_HIDDEN_SIZE)) | |
path = get_file('babi-tasks-v1-2.tar.gz', origin='http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz') | |
tar = tarfile.open(path) | |
# Default QA1 with 1000 samples | |
challenge = 'tasks_1-20_v1-2/en/qa1_single-supporting-fact_{}.txt' | |
# QA1 with 10,000 samples | |
# challenge = 'tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_{}.txt' | |
# QA2 with 1000 samples | |
#challenge = 'tasks_1-20_v1-2/en/qa2_two-supporting-facts_{}.txt' | |
# QA2 with 10,000 samples | |
# challenge = 'tasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_{}.txt' | |
challenge = 'tasks_1-20_v1-2/en/qa11_basic-coreference_{}.txt' | |
#import sys | |
#print('Running on', sys.argv[1]) | |
#challenge = 'tasks_1-20_v1-2/en/' + sys.argv[1].replace('train', '{}') | |
train = get_stories(tar.extractfile(challenge.format('train'))) | |
test = get_stories(tar.extractfile(challenge.format('test'))) | |
vocab = sorted(reduce(lambda x, y: x | y, (set(story + q + [answer]) for story, q, answer in train + test))) | |
# Reserve 0 for masking via pad_sequences | |
vocab_size = len(vocab) + 1 | |
word_idx = dict((c, i + 1) for i, c in enumerate(vocab)) | |
story_maxlen = max(map(len, (x for x, _, _ in train + test))) | |
query_maxlen = max(map(len, (x for _, x, _ in train + test))) | |
X, Xq, Y = vectorize_stories(train) | |
tX, tXq, tY = vectorize_stories(test) | |
print('vocab = {}'.format(vocab)) | |
print('X.shape = {}'.format(X.shape)) | |
print('Xq.shape = {}'.format(Xq.shape)) | |
print('Y.shape = {}'.format(Y.shape)) | |
print('story_maxlen, query_maxlen = {}, {}'.format(story_maxlen, query_maxlen)) | |
def construct_model(): | |
sentrnn = Sequential() | |
sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True)) | |
#sentrnn.add(Dropout(0.1)) | |
sentrnn.add(RNN(EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, return_sequences=False)) | |
sentrnn.add(Dropout(0.3)) | |
As such, I agree strongly with you that this won't make a good test dataset for testing various RNN architectures. | |
qrnn = Sequential() | |
qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE)) | |
#qrnn.add(Dropout(0.1)) | |
qrnn.add(RNN(EMBED_HIDDEN_SIZE, QUERY_HIDDEN_SIZE, return_sequences=False)) | |
qrnn.add(Dropout(0.3)) | |
model = Sequential() | |
model.add(Merge([sentrnn, qrnn], mode='concat')) | |
model.add(Dense(SENT_HIDDEN_SIZE + QUERY_HIDDEN_SIZE, vocab_size, activation='softmax')) | |
model.compile(optimizer='adam', loss='categorical_crossentropy', class_mode='categorical') | |
return model | |
print('Finding best number of epochs...') | |
model = construct_model() | |
early_stop = EarlyStopping(monitor='val_loss', patience=20, verbose=1) | |
model.fit([X, Xq], Y, batch_size=BATCH_SIZE, nb_epoch=MAX_EPOCHS, validation_split=0.05, show_accuracy=True, callbacks=[early_stop]) | |
print('Training using {} epochs...'.format(early_stop.best_epoch + 1)) | |
model = construct_model() | |
model.fit([X, Xq], Y, batch_size=BATCH_SIZE, nb_epoch=early_stop.best_epoch + 1, show_accuracy=True) | |
loss, acc = model.evaluate([tX, tXq], tY, batch_size=BATCH_SIZE, show_accuracy=True) | |
print('Test loss / test accuracy = {:.4f} / {:.4f}'.format(loss, acc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@Smerity could you please say why you said "As such, I agree strongly with you that this won't make a good test dataset for testing various RNN architectures" ? Thank you. I am currently doing my thesis on this dataset and I am wondering