Last active
August 29, 2016 10:26
-
-
Save rth/3af30c60bece7db4207821a6dddc5e8d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"# Latent Semantic Indexing (LSI) example\n", | |
"\n", | |
"Reproducing the LSI example from *Information Retrieval, Algorithms and Heuristics* (2004) by Grossman and Frieder [[1]](http://www1.se.cuhk.edu.hk/~seem5680/lecture/LSI-Eg.pdf) using scikit-learn." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"== Document-term matrix == \n", | |
"[[1 0 1 0 1 1 1 1 1 0 0]\n", | |
" [1 1 0 1 0 0 1 1 0 2 1]\n", | |
" [1 1 0 0 0 1 1 1 1 0 1]]\n", | |
"== Query vector == \n", | |
"[[0 0 0 0 0 1 0 0 0 1 1]]\n", | |
"== Dictionary ==\n", | |
"['aa', 'arrived', 'damaged', 'delivery', 'fire', 'gold', 'in', 'of', 'shipment', 'silver', 'truck']\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"from numpy.testing import assert_allclose\n", | |
"import scipy.linalg\n", | |
"\n", | |
"from sklearn.feature_extraction.text import CountVectorizer\n", | |
"from sklearn.metrics.pairwise import cosine_similarity\n", | |
"from sklearn.decomposition import TruncatedSVD\n", | |
"from sklearn.preprocessing import StandardScaler, normalize, Normalizer\n", | |
"from sklearn.pipeline import make_pipeline\n", | |
"\n", | |
"documents = [\"Shipment of gold damaged in aa fire.\",\n", | |
" \"Delivery of silver arrived in aa silver truck.\",\n", | |
" \"Shipment of gold arrived in aa truck.\",\n", | |
" ]\n", | |
"query = \"gold silver truck\"\n", | |
"\n", | |
"# values taken from the example in the book\n", | |
"sim_scores_ref = [-0.0541, 0.9910, 0.4478]\n", | |
"q_proj_ref = [[0.2140, 0.1821]]\n", | |
"\n", | |
"dm_vec = CountVectorizer()\n", | |
"dm_vec.fit(documents)\n", | |
"X = dm_vec.transform(documents)\n", | |
"\n", | |
"\n", | |
"q = dm_vec.transform([query])\n", | |
"print(\"== Document-term matrix == \")\n", | |
"print(X.todense())\n", | |
"print(\"== Query vector == \")\n", | |
"print(q.todense())\n", | |
"print(\"== Dictionary ==\")\n", | |
"print(dm_vec.get_feature_names())\n", | |
"assert X.shape[1] == 11 # check matrix size w/ example in the book\n", | |
"assert X.sum() == 22 # additional check w/ example in the book" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Manually computing LSI using `scipy.linalg.svd`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"OK\n" | |
] | |
} | |
], | |
"source": [ | |
"n_components = 2\n", | |
"\n", | |
"\n", | |
"U_sp, S_sp, Vh_sp = scipy.linalg.svd(X.todense().T, full_matrices=False)\n", | |
"\n", | |
"q_proj_sp = q.dot(U_sp[:,:n_components]).dot(np.diag(1./S_sp[:n_components]))\n", | |
"X_proj_sp = X.dot(U_sp[:,:n_components]).dot(np.diag(1./S_sp[:n_components]))\n", | |
"\n", | |
"\n", | |
"sim_scores_sp = cosine_similarity(X_proj_sp, q_proj_sp)[:,0]\n", | |
"\n", | |
"assert_allclose(sim_scores_sp, sim_scores_ref, atol=1e-2)\n", | |
"assert_allclose(np.abs(q_proj_sp), np.abs(q_proj_ref), atol=1e-2)\n", | |
"print('OK')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Computing LSI with `TruncatedSVD`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"ename": "AssertionError", | |
"evalue": "\nNot equal to tolerance rtol=1e-07, atol=0.01\n\n(mismatch 100.0%)\n x: array([[ 0.877169, 0.429941]])\n y: array([[ 0.214 , 0.1821]])", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[1;31mAssertionError\u001b[0m Traceback (most recent call last)", | |
"\u001b[1;32m<ipython-input-33-0edc31e2d469>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[0mq_proj_sk\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlsi\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mq\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m \u001b[0massert_allclose\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mq_proj_sk\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mq_proj_ref\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0matol\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1e-2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[1;31m# compute the cosine similarity\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m/usr/lib64/python3.4/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_allclose\u001b[1;34m(actual, desired, rtol, atol, equal_nan, err_msg, verbose)\u001b[0m\n\u001b[0;32m 1357\u001b[0m \u001b[0mheader\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'Not equal to tolerance rtol=%g, atol=%g'\u001b[0m \u001b[1;33m%\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mrtol\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0matol\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1358\u001b[0m assert_array_compare(compare, actual, desired, err_msg=str(err_msg),\n\u001b[1;32m-> 1359\u001b[1;33m verbose=verbose, header=header)\n\u001b[0m\u001b[0;32m 1360\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1361\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0massert_array_almost_equal_nulp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnulp\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;32m/usr/lib64/python3.4/site-packages/numpy/testing/utils.py\u001b[0m in \u001b[0;36massert_array_compare\u001b[1;34m(comparison, x, y, err_msg, verbose, header, precision)\u001b[0m\n\u001b[0;32m 711\u001b[0m names=('x', 'y'), precision=precision)\n\u001b[0;32m 712\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mcond\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 713\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 714\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 715\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", | |
"\u001b[1;31mAssertionError\u001b[0m: \nNot equal to tolerance rtol=1e-07, atol=0.01\n\n(mismatch 100.0%)\n x: array([[ 0.877169, 0.429941]])\n y: array([[ 0.214 , 0.1821]])" | |
] | |
} | |
], | |
"source": [ | |
"lsi = make_pipeline(TruncatedSVD(n_components=n_components),\n", | |
" #Normalizer(copy=True)\n", | |
" #StandardScaler(copy=True, with_mean=False, with_std=True)\n", | |
" )\n", | |
"\n", | |
"lsi.fit(X)\n", | |
"X_proj_sk = lsi.transform(X)\n", | |
"q_proj_sk = lsi.transform(q)\n", | |
"\n", | |
"assert_allclose(np.abs(q_proj_sk), np.abs(q_proj_ref), atol=1e-2)\n", | |
"\n", | |
"# compute the cosine similarity\n", | |
"X_proj_sk2 = normalize(X_proj_sk2)\n", | |
"q_proj_sk2 = normalize(q_proj_sk2)\n", | |
"sim_scores_sk = X_proj_sk.dot(q_proj_sk.T).T[0]\n", | |
"\n", | |
"assert_allclose(sim_scores_sk, sim_scores_ref, atol=1e-2)\n", | |
"print('OK')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`TruncatedSVD.transform` does not produce the expected results.\n", | |
"\n", | |
"### Computing LSI with `TruncatedSVD` + normalization by the singular values fixes the problem" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"OK\n" | |
] | |
} | |
], | |
"source": [ | |
"X_proj_sk2 = lsi.transform(X).dot(np.diag(1./S_sp[:n_components]))\n", | |
"q_proj_sk2 = lsi.transform(q).dot(np.diag(1./S_sp[:n_components]))\n", | |
"\n", | |
"assert_allclose(np.abs(q_proj_sk2), np.abs(q_proj_ref), atol=1e-2)\n", | |
"\n", | |
"# compute the cosine similarity\n", | |
"X_proj_sk2 = normalize(X_proj_sk2)\n", | |
"q_proj_sk2 = normalize(q_proj_sk2)\n", | |
"sim_scores_sk2 = X_proj_sk2.dot(q_proj_sk2.T).T[0]\n", | |
"\n", | |
"\n", | |
"assert_allclose(sim_scores_sk2, sim_scores_ref, atol=1e-2)\n", | |
"\n", | |
"print('OK')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"but the problem is that the `TruncatedSVD` class does not currently store the singular values." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment