Created
June 24, 2018 10:05
-
-
Save urigoren/4e97b16157308cfdd2ce33e27ae0c534 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
categorical_data = [ | |
(0,1,2), | |
(0,1), | |
(0,1,3), | |
(0,1,3), | |
(0,1,3), | |
(0,1,2,3), | |
(2, 3), | |
(2, 3), | |
(2, 3), | |
] | |
import itertools | |
import numpy as np | |
from tqdm import tqdm | |
import torch | |
from torch.autograd import Variable | |
from torch.nn import functional as F | |
import torch.optim as optim | |
def one_hot(word_idx): | |
x = torch.zeros(vocabulary_size).float() | |
x[word_idx] = 1.0 | |
return x | |
def learning_rate(epoch): | |
if epoch == 0: | |
return 0.01 | |
if epoch <= 10: | |
return 0.005 | |
return 0.001 | |
vocabulary_size= 4 | |
embedding_dims = 2 | |
epochs = 10**5 | |
W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True) | |
W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True) | |
for epoch in tqdm(range(epochs)): | |
loss_val = 0 | |
for categorical in categorical_data: | |
for pair, flip in itertools.product(itertools.combinations(categorical, 2), [False, True]): | |
data, target = (pair[1], pair[0]) if flip else (pair[0], pair[1]) | |
x = Variable(one_hot(data)).float() | |
y_true = Variable(torch.from_numpy(np.array([target])).long()) | |
z1 = torch.matmul(W1, x) | |
z2 = torch.matmul(W2, z1) | |
log_softmax = F.log_softmax(z2, dim=0) | |
loss = F.nll_loss(log_softmax.view(1,-1), y_true) | |
loss_val += loss.data[0] | |
loss.backward() | |
W1.data -= learning_rate(epoch) * W1.grad.data | |
W2.data -= learning_rate(epoch) * W2.grad.data | |
W1.grad.data.zero_() | |
W2.grad.data.zero_() | |
if epoch % 10**4 == 0: | |
print(f'Loss at epo {epoch}: {2*loss_val/(len(categorical)*(len(categorical)-1))}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment