Last active
June 4, 2019 09:20
-
-
Save ita9naiwa/c4b9adcd3707d98ead72f81bcad0a3cd to your computer and use it in GitHub Desktop.
logistic matrix factorization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd \n", | |
"import os \n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"r = pd.read_csv(\"movielens/u.data\", sep='\\t')\n", | |
"r.columns =['user', 'item', 'rating','ts']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user</th>\n", | |
" <th>item</th>\n", | |
" <th>rating</th>\n", | |
" <th>ts</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>186</td>\n", | |
" <td>302</td>\n", | |
" <td>3</td>\n", | |
" <td>891717742</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>22</td>\n", | |
" <td>377</td>\n", | |
" <td>1</td>\n", | |
" <td>878887116</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>244</td>\n", | |
" <td>51</td>\n", | |
" <td>2</td>\n", | |
" <td>880606923</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>166</td>\n", | |
" <td>346</td>\n", | |
" <td>1</td>\n", | |
" <td>886397596</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>298</td>\n", | |
" <td>474</td>\n", | |
" <td>4</td>\n", | |
" <td>884182806</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user item rating ts\n", | |
"0 186 302 3 891717742\n", | |
"1 22 377 1 878887116\n", | |
"2 244 51 2 880606923\n", | |
"3 166 346 1 886397596\n", | |
"4 298 474 4 884182806" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"r.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"uid_to_idx = {y:x for x, y in enumerate(r.user.unique().tolist())}\n", | |
"iid_to_idx = {y:x for x, y in enumerate(r.item.unique().tolist())}\n", | |
"idx_to_uid = {x:y for x, y in enumerate(r.user.unique().tolist())}\n", | |
"idx_to_iid = {x:y for x, y in enumerate(r.item.unique().tolist())}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"r['uid'] = r.user.map(lambda x: uid_to_idx[x])\n", | |
"r['iid'] = r.item.map(lambda x: iid_to_idx[x])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(943, 1682)" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"U = len(uid_to_idx)\n", | |
"M = len(iid_to_idx)\n", | |
"(U, M)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"u_items = dict(r.groupby('uid').iid.apply(np.array))\n", | |
"i_users = dict(r.groupby('iid').uid.apply(np.array))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def calc(u, v, ub=None, vb=None):\n", | |
" ret = np.dot(u,v.T)\n", | |
" if ub is not None and vb is not None:\n", | |
" b = ub + vb\n", | |
" ret += b\n", | |
" return np.exp(ret)\n", | |
" \n", | |
"def update(u, u_vec, i_vec, pos_views, U, M, user=True, lr=0.003, lamb=0.01):\n", | |
" uv = u_vec[u, :-2]\n", | |
" if user:\n", | |
" ub = uv[-1]\n", | |
" else:\n", | |
" ub = uv[-2]\n", | |
" \n", | |
" try:\n", | |
" p_v = pos_views[u]\n", | |
" #p_scores = 1 + np.log(1 + pos_views[u][:, 1])\n", | |
" p_scores = np.ones_like(p_v)\n", | |
" except:\n", | |
" p_v = []\n", | |
" p_scores = []\n", | |
" \n", | |
" n_v = np.random.choice(M, np.min([M // 2, len(p_v) * 10]), replace=False)\n", | |
" p_v = np.hstack([p_v, n_v])\n", | |
" p_scores = np.hstack([p_scores, np.zeros_like(n_v)])\n", | |
" i_v = i_vec[p_v, :-2]\n", | |
" \n", | |
" if user:\n", | |
" ib = i_v[:, -2]\n", | |
" else:\n", | |
" ib = i_v[:, -1] \n", | |
" exp = calc(uv, i_v, ub, ib) \n", | |
" B = (1 + p_scores) * exp\n", | |
" C = np.divide(B, 1 + exp)\n", | |
"\n", | |
" A = np.dot(p_scores, i_v)\n", | |
" D = np.dot(C, i_v)\n", | |
" du = (A - D ) - lamb * uv\n", | |
" dv = (np.sum(p_scores) - np.sum(C)) - lamb * ub\n", | |
" ret = np.ones_like(u_vec[u])\n", | |
" ret[:-2] = (uv + lr * du) \n", | |
" if user:\n", | |
" ret[-1] = (ub + lr * dv)\n", | |
" else:\n", | |
" ret[-2] = (ub + lr * dv)\n", | |
" return ret" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"us = np.random.normal(0, 0.01, size=(U, 20))\n", | |
"vs = np.random.normal(0, 0.01, size=(M, 20))\n", | |
"us[:, -2] = 1.0\n", | |
"vs[:, -1] = 1.0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 67, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"18.45311222259154 6.606965804260298\n", | |
"0.0007069635913750441\n", | |
"9.76411791693294 3.774571727624158\n", | |
"0.0007069635913750441\n", | |
"6.892110674011095 2.8077871545472375\n", | |
"0.00035348179568752205\n", | |
"5.5052403815692665 2.28960475840313\n", | |
"0.0010604453870625664\n", | |
"4.79944692348174 1.9639428104523216\n", | |
"0.0007069635913750441\n", | |
"4.59048677812209 1.7937426291608674\n", | |
"0.0010604453870625664\n", | |
"4.6062514549870155 1.7530676488786767\n", | |
"0.5330505478967833\n", | |
"4.495584754615364 1.699913457535704\n", | |
"0.5991516436903499\n", | |
"4.568251496663015 1.498026963787225\n", | |
"0.3092965712265818\n", | |
"4.626760772762859 1.321821346978116\n", | |
"0.416401555319901\n", | |
"4.683963123532821 1.2238422815609613\n", | |
"0.5722870272180982\n", | |
"4.749517847836844 1.1723391925712852\n", | |
"0.5638034641215977\n", | |
"4.804234120991001 1.1423121225481827\n", | |
"0.558501237186285\n", | |
"4.8564117416545685 1.1208772777918459\n", | |
"0.5652173913043478\n", | |
"4.90401400731277 1.102984114240213\n", | |
"0.5652173913043478\n", | |
"4.962556749806057 1.0918275514967524\n", | |
"0.5641569459172853\n", | |
"5.018016815109204 1.0809830455094898\n", | |
"0.5539059738423471\n", | |
"5.078056516083909 1.072948230232887\n", | |
"0.5662778366914104\n", | |
"5.1344621115701745 1.065829642540267\n", | |
"0.5680452456698479\n", | |
"5.188347239423277 1.059868401759393\n", | |
"0.5698126546482856\n", | |
"5.244511003839926 1.053877044812167\n", | |
"0.5662778366914104\n", | |
"5.322076176656392 1.0497995231928663\n", | |
"0.5659243548957228\n", | |
"5.382694238855233 1.0455908343524027\n", | |
"0.5694591728525981\n", | |
"5.438910969798368 1.0419320423753113\n", | |
"0.5563803464121597\n", | |
"5.511768074618116 1.0399761063760051\n", | |
"0.5652173913043478\n", | |
"5.567225293720422 1.0370519100806548\n", | |
"0.5676917638741604\n", | |
"5.627514696554446 1.0339099272087968\n", | |
"0.5627430187345351\n", | |
"5.687694915310855 1.0325661132701782\n", | |
"0.5652173913043478\n", | |
"5.753047687460456 1.0302156441159953\n", | |
"0.5645104277129727\n", | |
"5.80741499007184 1.0284133269487519\n", | |
"0.5652173913043478\n", | |
"5.866707220754546 1.0262342582214674\n", | |
"0.5630965005302226\n", | |
"5.915964358758941 1.0254664716738577\n", | |
"0.5645104277129727\n", | |
"5.973055210029575 1.0242211982094316\n", | |
"0.5553199010250972\n", | |
"6.027916067315448 1.0219485114146472\n", | |
"0.5652173913043478\n", | |
"6.074863071455239 1.0208621962433566\n", | |
"0.5581477553905974\n", | |
"6.130656324766482 1.020135384814659\n", | |
"0.5634499823259101\n", | |
"6.185412187877836 1.019224569793171\n", | |
"0.5659243548957228\n", | |
"6.229137304952745 1.0188626421962705\n", | |
"0.5705196182396607\n", | |
"6.28340675799909 1.0184204816128182\n", | |
"0.5659243548957228\n", | |
"6.326114415423463 1.0174666603073765\n", | |
"0.546482856132909\n", | |
"6.3740138178674695 1.0171226860789873\n", | |
"0.5655708731000353\n", | |
"6.42458640585839 1.016829027481995\n", | |
"0.5528455284552846\n", | |
"6.467422640364974 1.0165559684002063\n", | |
"0.5652173913043478\n", | |
"6.52060838230271 1.0160755116423004\n", | |
"0.5556733828207847\n", | |
"6.563432141907908 1.0164130119427142\n", | |
"0.5648639095086603\n", | |
"6.60696314536215 1.0161590850338782\n", | |
"0.5574407917992223\n", | |
"6.657217106499905 1.0157642592435365\n", | |
"0.5648639095086603\n", | |
"6.70232948902955 1.0160164989135207\n", | |
"0.5574407917992225\n", | |
"6.740703431844464 1.0162361366950343\n", | |
"0.55920820077766\n", | |
"6.784919119669208 1.0162324610383728\n", | |
"0.5655708731000353\n", | |
"6.8277866256369215 1.016018558890994\n", | |
"0.5648639095086603\n", | |
"6.874355150160551 1.0163753279847116\n", | |
"0.5592082007776599\n", | |
"6.912767653332144 1.0170418286165273\n", | |
"0.5683987274655355\n", | |
"6.950781013970135 1.0181478721106176\n", | |
"0.5609756097560976\n", | |
"6.994799065653214 1.0185419075189077\n", | |
"0.5669848002827855\n", | |
"7.036546737220167 1.0193205485674008\n", | |
"0.5733474726051608\n", | |
"7.08650726449552 1.0200182896334\n", | |
"0.574054436196536\n", | |
"7.125491524508782 1.020800321006764\n", | |
"0.5956168257334749\n", | |
"7.165805918193476 1.021925544538118\n", | |
"0.5765288087663485\n", | |
"7.199342372138001 1.0230095166623712\n", | |
"0.5804171085189113\n", | |
"7.238828977932087 1.0249046304836948\n", | |
"0.5683987274655355\n", | |
"7.274748374796994 1.026875165523167\n", | |
"0.5973842347119123\n", | |
"7.306434753308298 1.0283718220391616\n", | |
"0.5871332626369743\n", | |
"7.344057225651415 1.030389686549755\n", | |
"0.5995051254860374\n", | |
"7.379214637596634 1.0327720049965796\n", | |
"0.6051608342170378\n", | |
"7.416560604789014 1.0350162628461552\n", | |
"0.6058677978084128\n", | |
"7.4519428799100265 1.0372572388898782\n", | |
"0.5860728172499116\n", | |
"7.490489270819507 1.0401367817685745\n", | |
"0.6246023329798516\n", | |
"7.5178022807536315 1.0432004485580937\n", | |
"0.634499823259102\n", | |
"7.558610620054508 1.0468211281910573\n", | |
"0.630258041710852\n", | |
"7.594056322331437 1.0503815651097952\n", | |
"0.6327324142806645\n", | |
"7.619939210804441 1.055033580123976\n", | |
"0.6560622127960409\n", | |
"7.660712793173402 1.0592961611168363\n", | |
"0.6341463414634146\n", | |
"7.694495099130147 1.0638908056581033\n", | |
"0.6557087310003534\n", | |
"7.720294625758447 1.0688068752790627\n", | |
"0.6334393778720396\n", | |
"7.754548172169112 1.073848192522783\n", | |
"0.6546482856132908\n", | |
"7.7835841037189235 1.0798652830889857\n", | |
"0.6472251679038529\n", | |
"7.813332999251804 1.0860221793068552\n", | |
"0.6585365853658537\n", | |
"7.841287614198128 1.09276365671608\n", | |
"0.6578296217744786\n", | |
"7.874772479927948 1.0997133324111505\n", | |
"0.6553552492046659\n", | |
"7.905298688711339 1.1074603986757043\n", | |
"0.664192294096854\n", | |
"7.92881123515752 1.1151154692303549\n", | |
"0.6670201484623541\n", | |
"7.956179577823236 1.1239593583109877\n", | |
"0.6751502297631673\n", | |
"7.9873016833837065 1.1311207930154308\n", | |
"0.6808059384941676\n", | |
"8.013985948469601 1.1390309756417045\n", | |
"0.6815129020855426\n", | |
"8.04583041278601 1.1470198652242665\n", | |
"0.6864616472251679\n", | |
"8.077031280288058 1.1561225883381838\n", | |
"0.7002474372569811\n", | |
"8.097823430400517 1.1633767036715974\n", | |
"0.6896429833863555\n", | |
"8.121751038254072 1.1714501937120494\n", | |
"0.6949452103216683\n", | |
"8.14295631988417 1.1802750706877245\n", | |
"0.6931778013432308\n", | |
"8.173377256448603 1.1882865250134902\n", | |
"0.6988335100742312\n", | |
"8.193755548585838 1.1982506733343627\n", | |
"0.6988335100742312\n", | |
"8.220839142648206 1.2051574876049813\n", | |
"0.7030752916224814\n", | |
"8.239019539627257 1.2125233398118578\n", | |
"0.6998939554612937\n", | |
"8.260645020711133 1.2207418254981088\n", | |
"0.7034287734181689\n", | |
"8.28629250808858 1.2246431539911202\n", | |
"0.6988335100742312\n", | |
"8.305564237841512 1.2321320291739561\n", | |
"0.7006009190526687\n", | |
"8.329160362956163 1.2388141847849186\n", | |
"0.7030752916224814\n", | |
"8.351474420302138 1.247967791993751\n", | |
"0.7037822552138563\n", | |
"8.37142298532775 1.2532609164105555\n", | |
"0.6995404736656062\n" | |
] | |
} | |
], | |
"source": [ | |
"for _ in range(100):\n", | |
" for u in range(U):\n", | |
" us[u] = update(u, us, vs, u_items, U, M, user=True, lr=0.01 / np.sqrt(1+_), lamb=1.0)\n", | |
" for m in range(M):\n", | |
" vs[m] = update(m, vs, us, i_users, M, U, user=False, lr=0.01 / np.sqrt(1+_), lamb=1.0)\n", | |
" print(np.dot(us[0], us[0]), np.dot(vs[0], vs[0]))\n", | |
" x = np.dot(us, vs.T).argsort()[:, ::-1][:, :3]\n", | |
" ret = []\n", | |
" for u, seen_items in u_items.items():\n", | |
" q = set(x[u]).intersection(seen_items.tolist())\n", | |
" ret.append(len(q) / 3)\n", | |
" print(np.mean(ret))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment