Created
March 22, 2018 14:37
-
-
Save jaidevd/1412d99caf1bcd1c46a26d18ff9301cc to your computer and use it in GitHub Desktop.
Sample W2V script in keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# vim:fenc=utf-8 | |
""" | |
Training Word embeddings on the MTsamples dataset. | |
""" | |
from keras.models import Model | |
from keras.layers import Dense, Embedding, Input, Reshape, Dot | |
from keras.preprocessing.sequence import make_sampling_table, skipgrams | |
from keras.preprocessing.text import Tokenizer | |
import numpy as np | |
import json | |
# Constants | |
V = 24000 # vocab size | |
vector_dim = 256 # dimensions of the output vectors | |
window_size = 2 # If overall window size is `n`, this parameter should be (n - 1) / 2 | |
print('Building model...') | |
input_target = Input((1,)) | |
input_context = Input((1,)) | |
embedding = Embedding(V, vector_dim, input_length=1) | |
target = Reshape((vector_dim, 1))(embedding(input_target)) | |
context = Reshape((vector_dim, 1))(embedding(input_context)) | |
op = Dot(axes=1, normalize=True)([target, context]) | |
op = Reshape((1,))(op) | |
op = Dense(1, activation='sigmoid')(op) | |
model = Model(input=[input_target, input_context], output=op) | |
model.compile(loss='binary_crossentropy', optimizer='rmsprop') | |
print('Building corpus...') | |
sents = json.load(open('mt_sents.json')) # Such that `sents` is a list of strings, each string a sentence | |
vect = Tokenizer(num_words=V) | |
vect.fit_on_texts(sents) | |
revdict = {v: k for k, v in vect.word_index.items()} | |
st = make_sampling_table(V) | |
t2s = vect.texts_to_sequences(sents) | |
X = [] | |
y = [] | |
for sent_seq in t2s: | |
_x, _y = skipgrams(sent_seq, V, window_size=window_size, negative_samples=0.5, sampling_table=st) | |
X.extend(_x) | |
y.extend(_y) | |
x_tar, x_con = zip(*X) | |
x_tar = np.array(x_tar) | |
x_con = np.array(x_con) | |
y = np.array(y) | |
print('Training...') | |
model.fit([x_tar, x_con], y, batch_size=128, epochs=5) | |
model.save('mtsamples.h5') | |
with open('mtsamples_dict.json', 'w') as fout: | |
json.dump(vect.word_index, fout) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment