Created
December 2, 2024 07:43
-
-
Save myui/3d11f27a071e37b75db829a23d2b97fe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "27cddd06-2fc7-4a96-aa57-226109f57442", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sys\n", | |
"\n", | |
"sys.path.append(\"/home/td-user/rtrec\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "07765c5e-6af2-47d7-8320-38dc269a7986", | |
"metadata": {}, | |
"source": [ | |
"# Movielens 1M" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "827b4cc7-1f28-4e44-91e9-6d5cd8cde4dc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Using existing ratings.dat file.\n" | |
] | |
} | |
], | |
"source": [ | |
"from rtrec.experiments.datasets import load_dataset\n", | |
"\n", | |
"df = load_dataset(name='movielens_1m')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "c89ffe93-0b31-4cac-a004-8ada940add7b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"df['user'] = df['user'].astype(\"category\")\n", | |
"df['item'] = df['item'].astype(\"category\")\n", | |
"df['user_id'] = df['user'].cat.codes\n", | |
"df['item_id'] = df['item'].cat.codes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "499497fa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"iid2id = dict(enumerate(df['item'].cat.categories))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "52bcecc2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user</th>\n", | |
" <th>item</th>\n", | |
" <th>rating</th>\n", | |
" <th>tstamp</th>\n", | |
" <th>user_id</th>\n", | |
" <th>item_id</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>1193</td>\n", | |
" <td>5</td>\n", | |
" <td>978300760</td>\n", | |
" <td>0</td>\n", | |
" <td>1104</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>661</td>\n", | |
" <td>3</td>\n", | |
" <td>978302109</td>\n", | |
" <td>0</td>\n", | |
" <td>639</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1</td>\n", | |
" <td>914</td>\n", | |
" <td>3</td>\n", | |
" <td>978301968</td>\n", | |
" <td>0</td>\n", | |
" <td>853</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1</td>\n", | |
" <td>3408</td>\n", | |
" <td>4</td>\n", | |
" <td>978300275</td>\n", | |
" <td>0</td>\n", | |
" <td>3177</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1</td>\n", | |
" <td>2355</td>\n", | |
" <td>5</td>\n", | |
" <td>978824291</td>\n", | |
" <td>0</td>\n", | |
" <td>2162</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000204</th>\n", | |
" <td>6040</td>\n", | |
" <td>1091</td>\n", | |
" <td>1</td>\n", | |
" <td>956716541</td>\n", | |
" <td>6039</td>\n", | |
" <td>1019</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000205</th>\n", | |
" <td>6040</td>\n", | |
" <td>1094</td>\n", | |
" <td>5</td>\n", | |
" <td>956704887</td>\n", | |
" <td>6039</td>\n", | |
" <td>1022</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000206</th>\n", | |
" <td>6040</td>\n", | |
" <td>562</td>\n", | |
" <td>5</td>\n", | |
" <td>956704746</td>\n", | |
" <td>6039</td>\n", | |
" <td>548</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000207</th>\n", | |
" <td>6040</td>\n", | |
" <td>1096</td>\n", | |
" <td>4</td>\n", | |
" <td>956715648</td>\n", | |
" <td>6039</td>\n", | |
" <td>1024</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000208</th>\n", | |
" <td>6040</td>\n", | |
" <td>1097</td>\n", | |
" <td>4</td>\n", | |
" <td>956715569</td>\n", | |
" <td>6039</td>\n", | |
" <td>1025</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>1000209 rows × 6 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user item rating tstamp user_id item_id\n", | |
"0 1 1193 5 978300760 0 1104\n", | |
"1 1 661 3 978302109 0 639\n", | |
"2 1 914 3 978301968 0 853\n", | |
"3 1 3408 4 978300275 0 3177\n", | |
"4 1 2355 5 978824291 0 2162\n", | |
"... ... ... ... ... ... ...\n", | |
"1000204 6040 1091 1 956716541 6039 1019\n", | |
"1000205 6040 1094 5 956704887 6039 1022\n", | |
"1000206 6040 562 5 956704746 6039 548\n", | |
"1000207 6040 1096 4 956715648 6039 1024\n", | |
"1000208 6040 1097 4 956715569 6039 1025\n", | |
"\n", | |
"[1000209 rows x 6 columns]" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "74f06c6c-535c-45bd-8fe0-d59f8d369656", | |
"metadata": {}, | |
"source": [ | |
"# Temporal User Split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "7b2e53b8-b1e8-4dc2-90c9-40b00d9f9826", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/td-user/rtrec/rtrec/experiments/split.py:104: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.\n", | |
" for user, user_df in df.groupby('user'):\n" | |
] | |
} | |
], | |
"source": [ | |
"from rtrec.experiments.split import temporal_user_split\n", | |
"train_df, test_df = temporal_user_split(df)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8213db54", | |
"metadata": {}, | |
"source": [ | |
"# Irspack evaluation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "2023a533", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user</th>\n", | |
" <th>item</th>\n", | |
" <th>rating</th>\n", | |
" <th>tstamp</th>\n", | |
" <th>user_id</th>\n", | |
" <th>item_id</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>6040</td>\n", | |
" <td>858</td>\n", | |
" <td>4</td>\n", | |
" <td>956703932</td>\n", | |
" <td>6039</td>\n", | |
" <td>802</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>6040</td>\n", | |
" <td>593</td>\n", | |
" <td>5</td>\n", | |
" <td>956703954</td>\n", | |
" <td>6039</td>\n", | |
" <td>579</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>6040</td>\n", | |
" <td>2384</td>\n", | |
" <td>4</td>\n", | |
" <td>956703954</td>\n", | |
" <td>6039</td>\n", | |
" <td>2191</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>6040</td>\n", | |
" <td>1961</td>\n", | |
" <td>4</td>\n", | |
" <td>956703977</td>\n", | |
" <td>6039</td>\n", | |
" <td>1781</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>6040</td>\n", | |
" <td>2019</td>\n", | |
" <td>5</td>\n", | |
" <td>956703977</td>\n", | |
" <td>6039</td>\n", | |
" <td>1839</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797753</th>\n", | |
" <td>736</td>\n", | |
" <td>1278</td>\n", | |
" <td>4</td>\n", | |
" <td>1045711206</td>\n", | |
" <td>735</td>\n", | |
" <td>1186</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797754</th>\n", | |
" <td>736</td>\n", | |
" <td>3671</td>\n", | |
" <td>4</td>\n", | |
" <td>1045711217</td>\n", | |
" <td>735</td>\n", | |
" <td>3429</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797755</th>\n", | |
" <td>3840</td>\n", | |
" <td>1196</td>\n", | |
" <td>3</td>\n", | |
" <td>1046106127</td>\n", | |
" <td>3839</td>\n", | |
" <td>1106</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797756</th>\n", | |
" <td>3840</td>\n", | |
" <td>1225</td>\n", | |
" <td>3</td>\n", | |
" <td>1046106162</td>\n", | |
" <td>3839</td>\n", | |
" <td>1135</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797757</th>\n", | |
" <td>3840</td>\n", | |
" <td>2502</td>\n", | |
" <td>5</td>\n", | |
" <td>1046106198</td>\n", | |
" <td>3839</td>\n", | |
" <td>2308</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>797758 rows × 6 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user item rating tstamp user_id item_id\n", | |
"0 6040 858 4 956703932 6039 802\n", | |
"1 6040 593 5 956703954 6039 579\n", | |
"2 6040 2384 4 956703954 6039 2191\n", | |
"3 6040 1961 4 956703977 6039 1781\n", | |
"4 6040 2019 5 956703977 6039 1839\n", | |
"... ... ... ... ... ... ...\n", | |
"797753 736 1278 4 1045711206 735 1186\n", | |
"797754 736 3671 4 1045711217 735 3429\n", | |
"797755 3840 1196 3 1046106127 3839 1106\n", | |
"797756 3840 1225 3 1046106162 3839 1135\n", | |
"797757 3840 2502 5 1046106198 3839 2308\n", | |
"\n", | |
"[797758 rows x 6 columns]" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "e6cb85a0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from irspack import df_to_sparse\n", | |
"\n", | |
"X_train, _, _ = df_to_sparse(\n", | |
" train_df, 'user_id', 'item_id'\n", | |
")\n", | |
"X_test, _, _ = df_to_sparse(\n", | |
" test_df, 'user_id', 'item_id'\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "62bcdbe4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import scipy.sparse as sp\n", | |
"X_train = sp.csr_matrix((train_df['rating'].astype(float), (train_df['user_id'], train_df['item_id'])))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "8f5430f3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(6040, 3706)" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_train.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "5f6d1527", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import scipy.sparse as sp\n", | |
"X_test = sp.csr_matrix((test_df['rating'].astype(float), (test_df['user_id'], test_df['item_id'])))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "d1a4ccc5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(6040, 3706)" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_test.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"id": "ac225b2e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<irspack.recommenders.slim.SLIMRecommender at 0xffff4354d700>" | |
] | |
}, | |
"execution_count": 50, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from irspack import SLIMRecommender \n", | |
"\n", | |
"#recommender = SLIMRecommender(X_train, alpha=0.01, l1_ratio=0.001)\n", | |
"recommender = SLIMRecommender(X_train)\n", | |
"recommender.learn()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"id": "d6229884", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([6038, 6033, 6034, ..., 5483, 735, 3839], dtype=int16)" | |
] | |
}, | |
"execution_count": 51, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_users = test_df['user_id'].unique()\n", | |
"test_users" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"id": "dc7a1e7f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"scores = recommender.get_score(test_users)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"id": "8e61024e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"top_k=10\n", | |
"\n", | |
"recos = []\n", | |
"ground_truths = []\n", | |
"for row in test_df.groupby('user_id')['item'].apply(list).reset_index(name='ground_truth').itertuples():\n", | |
" user_scores = scores[row.user_id,:]\n", | |
" seen_items = X_train[row.user_id].indices\n", | |
" user_scores[seen_items] = -np.inf\n", | |
" top_items = np.argsort(user_scores)[-top_k:][::-1]\n", | |
" recos.append([iid2id[iid] for iid in top_items])\n", | |
" ground_truths.append(row.ground_truth)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 54, | |
"id": "a9b3623c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.06572847682119268,\n", | |
" 'recall': 0.024085925107561277,\n", | |
" 'f1': 0.02913174058081254,\n", | |
" 'ndcg': 0.07061680286431203,\n", | |
" 'hit_rate': 0.3423841059602649,\n", | |
" 'mrr': 0.1497353621360247,\n", | |
" 'map': 0.032553592676646245,\n", | |
" 'tp': 3970,\n", | |
" 'auc': 0.182195758435825}" | |
] | |
}, | |
"execution_count": 54, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(recos, ground_truths), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6581c1af", | |
"metadata": {}, | |
"source": [ | |
"NDCG is not good as SLIM elasticnet." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "fe4a6c6b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment