Created
May 5, 2021 08:51
-
-
Save lispandfound/ec331e51aa1560ee76bc40a26e77df3b to your computer and use it in GitHub Desktop.
Bengio et al model.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
import tensorflow as tf | |
import tensorflow.keras as keras | |
import tensorflow.keras.layers as layers | |
from nltk import word_tokenize | |
class BengioModel(keras.Model): | |
''' Model that replicates the architecture of Bengio et al. ''' | |
def __init__(self, window_size: int, vocabulary_size: int, embedding_size: int=60, hidden_units: int=50, regulariser_l=0.1): | |
''' Initialise model. | |
Args: | |
- window_size :: Number of words used for context. | |
- vocabulary_size :: Size of the vocabulary in the corpus. | |
- embedding_size :: Size of the embedding layer. | |
- hidden_units :: Number of hidden units in the hidden layer. | |
- regulariser_l :: How strong regularisation is (As l -> inf, regularisation gets arbitrarily strong and smooths parameters). | |
NOTE: The default value of 0.1 is *just* a placeholder, as the paper didn't specify strength. | |
''' | |
super().__init__() | |
self.window_size = window_size | |
self.vocabulary_size = vocabulary_size | |
self.embedding_size = embedding_size | |
# Takes the place of tanh(d + Hx) | |
# You could easily chuck a few more layers here if you wanted to experiment with depth. | |
# Not sure why the original paper uses the tanh function (legacy????). I would recommend substituting this with a relu. | |
self.non_linear = layers.Dense(hidden_units, activation=tf.nn.tanh) | |
# NOTE: Paper didn't specify if the embedding is regularised???? | |
self.embedding = layers.Embedding(vocabulary_size, embedding_size) | |
self.W = layers.Dense(vocabulary_size, use_bias=False, kernel_regularizer=keras.regularizers.l2(l=regulariser_l)) | |
self.U = layers.Dense(vocabulary_size, use_bias=False, kernel_regularizer=keras.regularizers.l2(l=regulariser_l)) | |
self.b = tf.Variable(tf.random.uniform((vocabulary_size,), minval=-1, maxval=1)) | |
def call(self, inputs): | |
embed = self.embedding(inputs) | |
# The embedding output will be a tensor of shape (batch_size, self.window_size, self.embedding_size), i.e one embedding per word in the window | |
# This reshape call concatenates all of the embeddings together. | |
embed = tf.reshape(embed, (-1, self.embedding_size * self.window_size)) | |
act = self.non_linear(embed) | |
non_linear = self.U(act) | |
linear = self.W(embed) | |
logit = linear + non_linear + self.b | |
return logit | |
def window(list, n: int, pad): | |
''' Produce a rolling window over a list of length n (using pad when if we run out of elements). ''' | |
for i in range(len(list) - n - 1): | |
yield list[i : i + n], (pad if i == n - 1 else list[i + n + 1]) | |
def load_data(filename: str, window_size: int): | |
''' This code is almost identical to what Ben had ''' | |
with open(filename, 'r') as f: | |
counts = Counter() | |
lines = [] | |
for line in f.readlines(): | |
tokenized = word_tokenize(line.lower()) | |
counts.update(tokenized) | |
lines.append(tokenized) | |
vocab = [w for w in counts if counts[w] >= 2] | |
vocab.append('<UNK>') | |
vocab.append('#') | |
vocab_map = dict(map(reversed, enumerate(vocab))) | |
windows = [] | |
labels = [] | |
for line in lines: | |
line = [vocab_map.get(word, vocab_map['<UNK>']) for word in line] | |
for win, label in window(line, window_size, vocab_map['#']): | |
windows.append(win) | |
labels.append(label) | |
return vocab_map, tf.constant(windows), tf.constant(labels) | |
def perplexity(y_true, y_pred): | |
''' Compute the perplexity of the model. ''' | |
ce = tf.reduce_mean(tf.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)) | |
return tf.exp(ce) | |
WINDOW_SIZE = 6 | |
HIDDEN_SIZE = 60 | |
NUM_EPOCHS = 50 | |
EMBED_DIM = 60 | |
HIDDEN_DIM = 60 | |
BATCH_SIZE = 100 | |
SEED = 31415 | |
vocab_map, windows, labels = load_data('brown.txt', WINDOW_SIZE) | |
# Shuffle the window and label tensors. | |
# The initial shuffling will determine the train/val/test split. | |
# The variable SEED controls what shuffle is produced. | |
tf.random.set_seed(SEED) | |
indices = tf.range(0, windows.shape[0], dtype=tf.int32) | |
shuffle = tf.random.shuffle(indices) | |
windows = tf.gather(windows, shuffle) | |
labels = tf.gather(labels, shuffle) | |
# This code splits the dataset into train/validation/test. | |
# The way it's split is as follows: | |
# train (64%) val (16%) test (20%) | |
# <-------------------------><-------><-----------> | |
# [.................................................] (dataset) | |
# | |
# Tweak TRAIN_VAL_SPLIT and VAL_SPLIT to change the proportion. | |
n = windows.shape[0] | |
split = int(0.8 * n) | |
val_split = int(0.8 * split) | |
train_windows = windows[:val_split] | |
train_labels = labels[:val_split] | |
val_windows = windows[val_split:split] | |
val_labels = labels[val_split:split] | |
test_windows = windows[split:] | |
test_labels = labels[split:] | |
# Checkpointing is super useful for making sure your progress isn't lost over a few hours. | |
# Basically it'll save your weights to disk and then can load them in case one epoch looks interesting or your computer dies. | |
# Check out https://keras.io/api/callbacks/model_checkpoint/ for details on the flags you can configure for this. | |
checkpoint = keras.callbacks.ModelCheckpoint('model.{epoch:02d}-{val_loss:.2f}.h5', | |
# This will only save the model if it beats the best validation accuracy. | |
# Disable this to save more but use more disk space. | |
save_best_only=True) | |
vocab_size = len(vocab_map) + 1 | |
model = BengioModel(WINDOW_SIZE, vocab_size, embedding_size=EMBED_DIM, hidden_units=HIDDEN_DIM) | |
# Because BengioModel subclasses the keras Model class you can do all sorts of interesting things with it. | |
# Check out https://keras.io/api/models/model/ for a list of supported methods and properties. | |
model.compile(optimizer='adam', | |
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=[perplexity] | |
) | |
model.fit(train_windows, train_labels, | |
# Comment this line to disable checkpointing | |
callbacks=[checkpoint], | |
batch_size=BATCH_SIZE, | |
validation_data=(val_windows, val_labels), | |
epochs=NUM_EPOCHS) | |
model.evaluate(test_windows, test_labels) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment