Skip to content

Instantly share code, notes, and snippets.

@alapini
Forked from mbednarski/w2v.ipynb
Created February 25, 2019 10:59
Show Gist options
  • Save alapini/349e7a83c8d3e4d65971b2bde3f2680d to your computer and use it in GitHub Desktop.
Save alapini/349e7a83c8d3e4d65971b2bde3f2680d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.autograd import Variable\n",
"import numpy as np\n",
"import torch.functional as F\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"corpus = [\n",
" 'he is a king',\n",
" 'she is a queen',\n",
" 'he is a man',\n",
" 'she is a woman',\n",
" 'warsaw is poland capital',\n",
" 'berlin is germany capital',\n",
" 'paris is france capital', \n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_corpus(corpus):\n",
" tokens = [x.split() for x in corpus]\n",
" return tokens\n",
"\n",
"tokenized_corpus = tokenize_corpus(corpus)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"vocabulary = []\n",
"for sentence in tokenized_corpus:\n",
" for token in sentence:\n",
" if token not in vocabulary:\n",
" vocabulary.append(token)\n",
"\n",
"word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}\n",
"idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}\n",
"\n",
"vocabulary_size = len(vocabulary)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"window_size = 2\n",
"idx_pairs = []\n",
"# for each sentence\n",
"for sentence in tokenized_corpus:\n",
" indices = [word2idx[word] for word in sentence]\n",
" # for each word, threated as center word\n",
" for center_word_pos in range(len(indices)):\n",
" # for each window position\n",
" for w in range(-window_size, window_size + 1):\n",
" context_word_pos = center_word_pos + w\n",
" # make soure not jump out sentence\n",
" if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:\n",
" continue\n",
" context_word_idx = indices[context_word_pos]\n",
" idx_pairs.append((indices[center_word_pos], context_word_idx))\n",
"\n",
"idx_pairs = np.array(idx_pairs) # it will be useful to have this as numpy array"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def get_input_layer(word_idx):\n",
" x = torch.zeros(vocabulary_size).float()\n",
" x[word_idx] = 1.0\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss at epo 0: 4.3920796749847275\n",
"Loss at epo 10: 4.044160846940109\n",
"Loss at epo 20: 3.8000677623919077\n",
"Loss at epo 30: 3.6165322759321756\n",
"Loss at epo 40: 3.472428318858147\n",
"Loss at epo 50: 3.3555546615804945\n",
"Loss at epo 60: 3.2583227664232255\n",
"Loss at epo 70: 3.1757891476154327\n",
"Loss at epo 80: 3.104614196930613\n",
"Loss at epo 90: 3.042463874391147\n",
"Loss at epo 100: 2.9876448248113903\n"
]
}
],
"source": [
"embedding_dims = 5\n",
"W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)\n",
"W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)\n",
"num_epochs = 101\n",
"learning_rate = 0.001\n",
"\n",
"for epo in range(num_epochs):\n",
" loss_val = 0\n",
" for data, target in idx_pairs:\n",
" x = Variable(get_input_layer(data)).float()\n",
" y_true = Variable(torch.from_numpy(np.array([target])).long())\n",
"\n",
" z1 = torch.matmul(W1, x)\n",
" z2 = torch.matmul(W2, z1)\n",
" \n",
" log_softmax = F.log_softmax(z2, dim=0)\n",
"\n",
" loss = F.nll_loss(log_softmax.view(1,-1), y_true)\n",
" loss_val += loss.data[0]\n",
" loss.backward()\n",
" W1.data -= learning_rate * W1.grad.data\n",
" W2.data -= learning_rate * W2.grad.data\n",
"\n",
" W1.grad.data.zero_()\n",
" W2.grad.data.zero_()\n",
" if epo % 10 == 0: \n",
" print(f'Loss at epo {epo}: {loss_val/len(idx_pairs)}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment