-
-
Save alapini/349e7a83c8d3e4d65971b2bde3f2680d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch.autograd import Variable\n", | |
"import numpy as np\n", | |
"import torch.functional as F\n", | |
"import torch.nn.functional as F" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"corpus = [\n", | |
" 'he is a king',\n", | |
" 'she is a queen',\n", | |
" 'he is a man',\n", | |
" 'she is a woman',\n", | |
" 'warsaw is poland capital',\n", | |
" 'berlin is germany capital',\n", | |
" 'paris is france capital', \n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def tokenize_corpus(corpus):\n", | |
" tokens = [x.split() for x in corpus]\n", | |
" return tokens\n", | |
"\n", | |
"tokenized_corpus = tokenize_corpus(corpus)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocabulary = []\n", | |
"for sentence in tokenized_corpus:\n", | |
" for token in sentence:\n", | |
" if token not in vocabulary:\n", | |
" vocabulary.append(token)\n", | |
"\n", | |
"word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}\n", | |
"idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}\n", | |
"\n", | |
"vocabulary_size = len(vocabulary)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"window_size = 2\n", | |
"idx_pairs = []\n", | |
"# for each sentence\n", | |
"for sentence in tokenized_corpus:\n", | |
" indices = [word2idx[word] for word in sentence]\n", | |
" # for each word, threated as center word\n", | |
" for center_word_pos in range(len(indices)):\n", | |
" # for each window position\n", | |
" for w in range(-window_size, window_size + 1):\n", | |
" context_word_pos = center_word_pos + w\n", | |
" # make soure not jump out sentence\n", | |
" if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:\n", | |
" continue\n", | |
" context_word_idx = indices[context_word_pos]\n", | |
" idx_pairs.append((indices[center_word_pos], context_word_idx))\n", | |
"\n", | |
"idx_pairs = np.array(idx_pairs) # it will be useful to have this as numpy array" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_input_layer(word_idx):\n", | |
" x = torch.zeros(vocabulary_size).float()\n", | |
" x[word_idx] = 1.0\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss at epo 0: 4.3920796749847275\n", | |
"Loss at epo 10: 4.044160846940109\n", | |
"Loss at epo 20: 3.8000677623919077\n", | |
"Loss at epo 30: 3.6165322759321756\n", | |
"Loss at epo 40: 3.472428318858147\n", | |
"Loss at epo 50: 3.3555546615804945\n", | |
"Loss at epo 60: 3.2583227664232255\n", | |
"Loss at epo 70: 3.1757891476154327\n", | |
"Loss at epo 80: 3.104614196930613\n", | |
"Loss at epo 90: 3.042463874391147\n", | |
"Loss at epo 100: 2.9876448248113903\n" | |
] | |
} | |
], | |
"source": [ | |
"embedding_dims = 5\n", | |
"W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)\n", | |
"W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)\n", | |
"num_epochs = 101\n", | |
"learning_rate = 0.001\n", | |
"\n", | |
"for epo in range(num_epochs):\n", | |
" loss_val = 0\n", | |
" for data, target in idx_pairs:\n", | |
" x = Variable(get_input_layer(data)).float()\n", | |
" y_true = Variable(torch.from_numpy(np.array([target])).long())\n", | |
"\n", | |
" z1 = torch.matmul(W1, x)\n", | |
" z2 = torch.matmul(W2, z1)\n", | |
" \n", | |
" log_softmax = F.log_softmax(z2, dim=0)\n", | |
"\n", | |
" loss = F.nll_loss(log_softmax.view(1,-1), y_true)\n", | |
" loss_val += loss.data[0]\n", | |
" loss.backward()\n", | |
" W1.data -= learning_rate * W1.grad.data\n", | |
" W2.data -= learning_rate * W2.grad.data\n", | |
"\n", | |
" W1.grad.data.zero_()\n", | |
" W2.grad.data.zero_()\n", | |
" if epo % 10 == 0: \n", | |
" print(f'Loss at epo {epo}: {loss_val/len(idx_pairs)}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment