Skip to content

Instantly share code, notes, and snippets.

@jaidevd
Created March 22, 2018 14:37
Show Gist options
  • Save jaidevd/1412d99caf1bcd1c46a26d18ff9301cc to your computer and use it in GitHub Desktop.
Save jaidevd/1412d99caf1bcd1c46a26d18ff9301cc to your computer and use it in GitHub Desktop.
Sample W2V script in keras
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
"""
Training Word embeddings on the MTsamples dataset.
"""
from keras.models import Model
from keras.layers import Dense, Embedding, Input, Reshape, Dot
from keras.preprocessing.sequence import make_sampling_table, skipgrams
from keras.preprocessing.text import Tokenizer
import numpy as np
import json
# Constants
V = 24000 # vocab size
vector_dim = 256 # dimensions of the output vectors
window_size = 2 # If overall window size is `n`, this parameter should be (n - 1) / 2
print('Building model...')
input_target = Input((1,))
input_context = Input((1,))
embedding = Embedding(V, vector_dim, input_length=1)
target = Reshape((vector_dim, 1))(embedding(input_target))
context = Reshape((vector_dim, 1))(embedding(input_context))
op = Dot(axes=1, normalize=True)([target, context])
op = Reshape((1,))(op)
op = Dense(1, activation='sigmoid')(op)
model = Model(input=[input_target, input_context], output=op)
model.compile(loss='binary_crossentropy', optimizer='rmsprop')
print('Building corpus...')
sents = json.load(open('mt_sents.json')) # Such that `sents` is a list of strings, each string a sentence
vect = Tokenizer(num_words=V)
vect.fit_on_texts(sents)
revdict = {v: k for k, v in vect.word_index.items()}
st = make_sampling_table(V)
t2s = vect.texts_to_sequences(sents)
X = []
y = []
for sent_seq in t2s:
_x, _y = skipgrams(sent_seq, V, window_size=window_size, negative_samples=0.5, sampling_table=st)
X.extend(_x)
y.extend(_y)
x_tar, x_con = zip(*X)
x_tar = np.array(x_tar)
x_con = np.array(x_con)
y = np.array(y)
print('Training...')
model.fit([x_tar, x_con], y, batch_size=128, epochs=5)
model.save('mtsamples.h5')
with open('mtsamples_dict.json', 'w') as fout:
json.dump(vect.word_index, fout)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment