Created
September 5, 2018 23:24
-
-
Save malnakli/44d0064381fc59074beec0b706ec415b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Robert Guthrie | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import numpy as np | |
from sklearn.metrics.pairwise import euclidean_distances | |
torch.manual_seed(1) | |
CONTEXT_SIZE = 2 # 2 words to the left, 2 to the right | |
raw_text = """We are about to study the idea of a computational process. | |
Computational processes are abstract beings that inhabit computers. | |
As they evolve, processes manipulate other abstract things called data. | |
The evolution of a process is directed by a pattern of rules | |
called a program. People create programs to direct processes. In effect, | |
we conjure the spirits of the computer with our spells.""".split() | |
# By deriving a set from `raw_text`, we deduplicate the array | |
vocab = set(raw_text) | |
vocab_size = len(vocab) | |
EMBEDDING_DIM = 10 | |
word_to_ix = {word: i for i, word in enumerate(vocab)} | |
data = [] | |
for i in range(2, len(raw_text) - 2): | |
context = [raw_text[i - 2], raw_text[i - 1], | |
raw_text[i + 1], raw_text[i + 2]] | |
target = raw_text[i] | |
data.append((context, target)) | |
def make_context_vector(context, word_to_ix): | |
idxs = [word_to_ix[w] for w in context] | |
return torch.tensor(idxs, dtype=torch.long) | |
class CBOW(nn.Module): | |
def __init__(self, vocab_size, embedding_dim, context_size): | |
super(CBOW, self).__init__() | |
self.embeddings = nn.Embedding(vocab_size, embedding_dim) | |
self.linear1 = nn.Linear(context_size * 2 * embedding_dim, 128) | |
self.linear2 = nn.Linear(128, vocab_size) | |
def forward(self, inputs): | |
#import pdb; pdb.set_trace() | |
embeds = self.embeddings(inputs).view((1, -1)) | |
out = F.relu(self.linear1(embeds)) | |
out = self.linear2(out) | |
log_probs = F.log_softmax(out, dim=1) | |
return log_probs | |
losses = [] | |
loss_function = nn.NLLLoss() | |
model = CBOW(vocab_size, EMBEDDING_DIM, CONTEXT_SIZE) | |
optimizer = optim.SGD(model.parameters(), lr=0.001) | |
for epoch in range(100): | |
total_loss = 0 | |
for context, target in data: | |
# Step 1. Prepare the inputs to be passed to the model (i.e, turn the words | |
# into integer indices and wrap them in tensors) | |
context_idxs = make_context_vector(context, word_to_ix) | |
# Step 2. Recall that torch *accumulates* gradients. Before passing in a | |
# new instance, you need to zero out the gradients from the old | |
# instance | |
model.zero_grad() | |
# Step 3. Run the forward pass, getting log probabilities over next | |
# words | |
log_probs = model(context_idxs) | |
# Step 4. Compute your loss function. (Again, Torch wants the target | |
# word wrapped in a tensor) | |
loss = loss_function(log_probs, torch.tensor([word_to_ix[target]], dtype=torch.long)) | |
# Step 5. Do the backward pass and update the gradient | |
loss.backward() | |
optimizer.step() | |
# Get the Python number from a 1-element Tensor by calling tensor.item() | |
total_loss += loss.item() | |
losses.append(total_loss) | |
print(total_loss) # The loss decreased every iteration over the training data! | |
#embeding_words = [] | |
ix_to_word = [k for k,v in word_to_ix.items()] | |
with torch.no_grad(): | |
for context in [["other", "abstract", "called", "data."]]: | |
bow_vec = make_context_vector(context, word_to_ix) | |
log_probs = model(bow_vec) | |
index = np.argmax(log_probs) | |
#import pdb; pdb.set_trace() | |
print(ix_to_word[index]) | |
# embeding_words.append((context[0],log_probs)) | |
# for i in range(len(embeding_words)): | |
# for j in range(len(embeding_words)): | |
# print(embeding_words[i][0],embeding_words[j][0]) | |
# print(euclidean_distances(embeding_words[i][1],embeding_words[j][1])) | |
# import pdb; pdb.set_trace() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment