Created
January 26, 2018 16:24
-
-
Save emaadmanzoor/1d06e0751a3f7d39bc6814941b37531d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Word Embeddings via PMI Matrix Factorization\n", | |
"\n", | |
"*Contact TA: emaad[at]cmu.edu, [eyeshalfclosed.com/teaching/](http://www.eyeshalfclosed.com/teaching/)*\n", | |
"\n", | |
" * Based on [Neural Word Embedding as Implicit Matrix Factorization](https://papers.nips.cc/paper/5477-neural-word-embedding-as-implicit-matrix-factorization), by Omar Levy and Yoav Goldberg, NIPS 2014.\n", | |
" * Dataset: https://www.kaggle.com/hacker-news/hacker-news-posts/downloads/HN_posts_year_to_Sep_26_2016.csv\n", | |
" * Notes: http://www.eyeshalfclosed.com/teaching/95865-recitation-word2vec_as_PMI.pdf\n", | |
" * Source material: https://multithreaded.stitchfix.com/blog/2017/10/18/stop-using-word2vec/.\n", | |
" * Source material: https://www.kaggle.com/alexklibisz/simple-word-vectors-with-co-occurrence-pmi-and-svd" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from collections import Counter\n", | |
"from itertools import combinations\n", | |
"from math import log\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"from pprint import pformat\n", | |
"from scipy.sparse import csc_matrix\n", | |
"from scipy.sparse.linalg import svds, norm\n", | |
"from string import punctuation" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 0. Load the data." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>title</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>You have two days to comment if you want stem ...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>SQLAR the SQLite Archiver</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>What if we just printed a flatscreen televisio...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>algorithmic music</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>How the Data Vault Enables the Next-Gen Data W...</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" title\n", | |
"0 You have two days to comment if you want stem ...\n", | |
"1 SQLAR the SQLite Archiver\n", | |
"2 What if we just printed a flatscreen televisio...\n", | |
"3 algorithmic music\n", | |
"4 How the Data Vault Enables the Next-Gen Data W..." | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df = pd.read_csv('HN_posts_year_to_Sep_26_2016.csv', usecols=['title'])\n", | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 1. Read and preprocess titles from HN posts." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.11 s, sys: 69.8 ms, total: 1.18 s\n", | |
"Wall time: 1.19 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"punctrans = str.maketrans(dict.fromkeys(punctuation))\n", | |
"def tokenize(title):\n", | |
" x = title.lower() # Lowercase\n", | |
" x = x.encode('ascii', 'ignore').decode() # Keep only ascii chars.\n", | |
" x = x.translate(punctrans) # Remove punctuation\n", | |
" return x.split() # Return tokenized.\n", | |
"\n", | |
"texts_tokenized = df['title'].apply(tokenize)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 2a. Compute unigram and bigram counts.\n", | |
"\n", | |
"A unigram is a single word (x). A bigram is a pair of words (x,y).\n", | |
"Bigrams are counted for any two terms occurring in the same title.\n", | |
"For example, the title \"Foo bar baz\" has unigrams [foo, bar, baz]\n", | |
"and bigrams [(bar, foo), (bar, baz), (baz, foo)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 13.2 s, sys: 317 ms, total: 13.5 s\n", | |
"Wall time: 14 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"cx = Counter()\n", | |
"cxy = Counter()\n", | |
"for text in texts_tokenized:\n", | |
" for x in text:\n", | |
" cx[x] += 1\n", | |
" for x, y in map(sorted, combinations(text, 2)):\n", | |
" cxy[(x, y)] += 1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 2b. Remove frequent and infrequent unigrams.\n", | |
"\n", | |
"Pick arbitrary occurrence count thresholds to eliminate unigrams occurring\n", | |
"very frequently or infrequently. This decreases the vocab size substantially." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"99181 tokens before\n", | |
"1045 tokens after\n", | |
"Most common: [('i', 5577), ('google', 5532), ('be', 5320), ('app', 5124), ('as', 5121), ('its', 5077), ('about', 4927), ('can', 4801), ('using', 4613), ('do', 4534), ('not', 4330), ('us', 4189), ('web', 4134), ('will', 4125), ('we', 4113), ('startup', 3849), ('open', 3828), ('first', 3730), ('code', 3705), ('apple', 3695), ('pdf', 3659), ('more', 3652), ('software', 3558), ('my', 3515), ('this', 3477)]\n", | |
"CPU times: user 114 ms, sys: 4.37 ms, total: 118 ms\n", | |
"Wall time: 118 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"print('%d tokens before' % len(cx))\n", | |
"min_count = (1 / 1000) * len(df)\n", | |
"max_count = (1 / 50) * len(df)\n", | |
"for x in list(cx.keys()):\n", | |
" if cx[x] < min_count or cx[x] > max_count:\n", | |
" del cx[x]\n", | |
"print('%d tokens after' % len(cx))\n", | |
"print('Most common:', cx.most_common()[:25])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 2c. Remove frequent and infrequent bigrams.\n", | |
"\n", | |
"Any bigram containing a unigram that was removed must now be removed." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 3.42 s, sys: 34.5 ms, total: 3.45 s\n", | |
"Wall time: 3.48 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"for x, y in list(cxy.keys()):\n", | |
" if x not in cx or y not in cx:\n", | |
" del cxy[(x, y)]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 3. Build unigram <-> index lookup." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 558 µs, sys: 37 µs, total: 595 µs\n", | |
"Wall time: 602 µs\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"x2i, i2x = {}, {}\n", | |
"for i, x in enumerate(cx.keys()):\n", | |
" x2i[x] = i\n", | |
" i2x[i] = x" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 4. Sum unigram and bigram counts for computing probabilities.\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sx = sum(cx.values())\n", | |
"sxy = sum(cxy.values())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 5. Accumulate data, rows, and cols to build sparse PMI matrix\n", | |
"\n", | |
"The PMI value for a bigram with tokens (x, y) is:\n", | |
"$$ \\textrm{PMI}(x,y) = \\frac{\\textrm{log}(p(x,y))}{p(x)p(y)} $$\n", | |
"\n", | |
"The probabilities are computed on the fly using the sums from above." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"297427 non-zero elements\n", | |
"Sample PMI values\n", | |
" [(('elon', 'musk'), 6.8839034669891745),\n", | |
" (('pi', 'raspberry'), 6.764144313961322),\n", | |
" (('street', 'wall'), 6.6818495161706615),\n", | |
" (('francisco', 'san'), 6.497633763166996),\n", | |
" (('capital', 'venture'), 6.444633948553916),\n", | |
" (('basic', 'income'), 6.329754881712322),\n", | |
" (('card', 'credit'), 6.257880347803981),\n", | |
" (('studio', 'visual'), 6.241828638793474),\n", | |
" (('star', 'wars'), 6.176785106834259),\n", | |
" (('command', 'line'), 6.125005172051617)]\n", | |
"CPU times: user 901 ms, sys: 38.5 ms, total: 940 ms\n", | |
"Wall time: 951 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"pmi_samples = Counter()\n", | |
"data, rows, cols = [], [], []\n", | |
"for (x, y), n in cxy.items():\n", | |
" rows.append(x2i[x])\n", | |
" cols.append(x2i[y])\n", | |
" data.append(log((n / sxy) / (cx[x] / sx) / (cx[y] / sx)))\n", | |
" pmi_samples[(x, y)] = data[-1]\n", | |
"PMI = csc_matrix((data, (rows, cols)))\n", | |
"print('%d non-zero elements' % PMI.count_nonzero())\n", | |
"print('Sample PMI values\\n', pformat(pmi_samples.most_common()[:10]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 6. Factorize the PMI matrix using sparse SVD aka \"learn the unigram/word vectors\".\n", | |
"\n", | |
"This part replaces the stochastic gradient descent used by Word2vec\n", | |
"and other related neural network formulations. We pick an arbitrary vector size k=20." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 248 ms, sys: 46.7 ms, total: 295 ms\n", | |
"Wall time: 389 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"U, _, _ = svds(PMI, k=20)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Normalize the vectors to compute cosine similarity." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"norms = np.sqrt(np.sum(np.square(U), axis=1, keepdims=True))\n", | |
"U /= np.maximum(norms, 1e-7)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 8. Show some nearest neighbor samples as a sanity-check.\n", | |
"\n", | |
"The format is `<unigram> <count>: (<neighbor unigram>, <similarity>), ...`\n", | |
" \n", | |
"From this we can see that the relationships make sense." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"facebook, 2853\n", | |
" (ads, 0.804) (bot, 0.753) (google, 0.747) (instagram, 0.735) (app, 0.706) \n", | |
"----------\n", | |
"twitter, 1641\n", | |
" (tv, 0.822) (time, 0.810) (traffic, 0.782) (tracking, 0.772) (type, 0.765) \n", | |
"----------\n", | |
"instagram, 391\n", | |
" (facebook, 0.735) (links, 0.765) (news, 0.695) (images, 0.691) (search, 0.686) \n", | |
"----------\n", | |
"messenger, 374\n", | |
" (messaging, 0.764) (chat, 0.693) (lets, 0.633) (facebook, 0.612) (platform, 0.608) \n", | |
"----------\n", | |
"hack, 881\n", | |
" (hacked, 0.790) (hackers, 0.922) (hacking, 0.769) (malware, 0.695) (emails, 0.681) \n", | |
"----------\n", | |
"security, 2425\n", | |
" (software, 0.640) (secure, 0.718) (remote, 0.791) (push, 0.651) (process, 0.649) \n", | |
"----------\n", | |
"deep, 1375\n", | |
" (learning, 0.866) (neural, 0.852) (networks, 0.822) (algorithms, 0.820) (machine, 0.779) \n", | |
"----------\n", | |
"encryption, 968\n", | |
" (fbi, 0.849) (government, 0.836) (nsa, 0.739) (attacks, 0.739) (crypto, 0.911) \n", | |
"----------\n", | |
"cli, 311\n", | |
" (command, 0.877) (custom, 0.817) (easy, 0.786) (client, 0.764) (browser, 0.740) \n", | |
"----------\n", | |
"venture, 393\n", | |
" (where, 0.896) (university, 0.939) (valley, 0.936) (vc, 0.951) (were, 0.892) \n", | |
"----------\n", | |
"paris, 295\n", | |
" (police, 0.897) (nsa, 0.894) (national, 0.807) (obama, 0.796) (russian, 0.737) \n", | |
"----------\n" | |
] | |
} | |
], | |
"source": [ | |
"k = 5\n", | |
"for x in ['facebook', 'twitter', 'instagram', 'messenger', 'hack', 'security', \n", | |
" 'deep', 'encryption', 'cli', 'venture', 'paris']:\n", | |
" dd = np.dot(U, U[x2i[x]]) # Cosine similarity for this unigram against all others.\n", | |
" s = ''\n", | |
" # Compile the list of nearest neighbor descriptions.\n", | |
" # Argpartition is faster than argsort and meets our needs.\n", | |
" for i in np.argpartition(-1 * dd, k + 1)[:k + 1]:\n", | |
" if i2x[i] == x: continue\n", | |
" s += '(%s, %.3lf) ' % (i2x[i], dd[i])\n", | |
" print('%s, %d\\n %s' % (x, cx[x], s))\n", | |
" print('-' * 10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 9. Word-vector compositions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(images, 1.460) (facebook, 1.460) (instagram, 1.426) \n" | |
] | |
} | |
], | |
"source": [ | |
"composition = U[x2i[\"facebook\"]] - U[x2i[\"ads\"]]\n", | |
"composition /= np.linalg.norm(composition)\n", | |
"\n", | |
"k = 2\n", | |
"composition = U[x2i[\"facebook\"]] + U[x2i[\"images\"]]\n", | |
"dd = np.dot(U, composition) # Cosine similarity for this unigram against all others.\n", | |
"s = ''\n", | |
"for i in np.argpartition(-1 * dd, k + 1)[:k + 1]:\n", | |
" s += '(%s, %.3lf) ' % (i2x[i], dd[i])\n", | |
"print(s)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda env:96865]", | |
"language": "python", | |
"name": "conda-env-96865-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment