Created
December 3, 2024 07:11
-
-
Save myui/893e6b8d3e308ad7b19306cd5e1fe844 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "27cddd06-2fc7-4a96-aa57-226109f57442", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sys\n", | |
"\n", | |
"sys.path.append(\"/home/td-user/rtrec\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "15d4770d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "07765c5e-6af2-47d7-8320-38dc269a7986", | |
"metadata": {}, | |
"source": [ | |
"# Movielens 1M" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "827b4cc7-1f28-4e44-91e9-6d5cd8cde4dc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Using existing ratings.dat file.\n" | |
] | |
} | |
], | |
"source": [ | |
"from rtrec.experiments.datasets import load_dataset\n", | |
"\n", | |
"df = load_dataset(name='movielens_1m')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "c89ffe93-0b31-4cac-a004-8ada940add7b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user</th>\n", | |
" <th>item</th>\n", | |
" <th>tstamp</th>\n", | |
" <th>rating</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>1193</td>\n", | |
" <td>978300760</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>661</td>\n", | |
" <td>978302109</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1</td>\n", | |
" <td>914</td>\n", | |
" <td>978301968</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1</td>\n", | |
" <td>3408</td>\n", | |
" <td>978300275</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1</td>\n", | |
" <td>2355</td>\n", | |
" <td>978824291</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000204</th>\n", | |
" <td>6040</td>\n", | |
" <td>1091</td>\n", | |
" <td>956716541</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000205</th>\n", | |
" <td>6040</td>\n", | |
" <td>1094</td>\n", | |
" <td>956704887</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000206</th>\n", | |
" <td>6040</td>\n", | |
" <td>562</td>\n", | |
" <td>956704746</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000207</th>\n", | |
" <td>6040</td>\n", | |
" <td>1096</td>\n", | |
" <td>956715648</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000208</th>\n", | |
" <td>6040</td>\n", | |
" <td>1097</td>\n", | |
" <td>956715569</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>1000209 rows × 4 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user item tstamp rating\n", | |
"0 1 1193 978300760 5\n", | |
"1 1 661 978302109 3\n", | |
"2 1 914 978301968 3\n", | |
"3 1 3408 978300275 4\n", | |
"4 1 2355 978824291 5\n", | |
"... ... ... ... ...\n", | |
"1000204 6040 1091 956716541 1\n", | |
"1000205 6040 1094 956704887 5\n", | |
"1000206 6040 562 956704746 5\n", | |
"1000207 6040 1096 956715648 4\n", | |
"1000208 6040 1097 956715569 4\n", | |
"\n", | |
"[1000209 rows x 4 columns]" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df[['user', 'item', 'tstamp', 'rating']]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "74f06c6c-535c-45bd-8fe0-d59f8d369656", | |
"metadata": {}, | |
"source": [ | |
"# Temporal User Split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "7b2e53b8-b1e8-4dc2-90c9-40b00d9f9826", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rtrec.experiments.split import temporal_user_split\n", | |
"train_df, test_df = temporal_user_split(df)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "4c702533-70ce-4e31-9801-bd90c2520ebf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"num train samples=797758, num test samples=202451\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"num train samples={len(train_df)}, num test samples={len(test_df)}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "01eaa43f-823d-4542-bb9c-ce2285c94aa4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"unique number of users: 6040, test: 6040\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"unique number of users: {train_df['user'].nunique()}, test: {test_df['user'].nunique()}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "fa215929-9ef2-4b06-896a-c973ef77876c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"unique number of items: 3667, test: 3535\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"unique number of items: {train_df['item'].nunique()}, test: {test_df['item'].nunique()}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "07a58003-46c2-44c3-922f-328a993014cf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>count</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" count\n", | |
"min 1.0\n", | |
"max 1.0\n", | |
"mean 1.0" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "82fa2779", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>count</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" count\n", | |
"min 1.0\n", | |
"max 1.0\n", | |
"mean 1.0" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "63a0ff91-aa99-49ef-94b8-e1cd7a10af8e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>count</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" count\n", | |
"min 1.0\n", | |
"max 1.0\n", | |
"mean 1.0" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "283a8b63-5f9d-42ce-90a2-859fdd4fa612", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rectools import Columns\n", | |
"from rectools.dataset import Dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "26bedafe-a249-4935-9c2f-ecc3d5962d2a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_df_rectools = train_df.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
"})\n", | |
"\n", | |
"train_dataset = Dataset.construct(train_df_rectools)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "97a9eb39-d38d-4bd0-97da-74a20f7e3b18", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_df_rectools = test_df.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "ef59872b-8f8e-4590-9cd7-7cfef21b998c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/td-user/.local/lib/python3.12/site-packages/implicit/cpu/als.py:95: RuntimeWarning: OpenBLAS is configured to use 6 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n", | |
" check_blas_config()\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.61 s, sys: 161 ms, total: 1.77 s\n", | |
"Wall time: 795 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<rectools.models.implicit_als.ImplicitALSWrapperModel at 0xffff808d02c0>" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"from implicit.als import AlternatingLeastSquares\n", | |
"from rectools.models import ImplicitALSWrapperModel\n", | |
"\n", | |
"model = ImplicitALSWrapperModel(AlternatingLeastSquares(factors=10, iterations=10))\n", | |
"model.fit(train_dataset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "1913fb59-8ef7-456b-a09d-73c563cc66db", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 358 ms, sys: 97.9 ms, total: 456 ms\n", | |
"Wall time: 371 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recos = model.recommend(\n", | |
" users=test_df_rectools[Columns.User].unique(),\n", | |
" dataset=train_dataset,\n", | |
" k=10,\n", | |
" filter_viewed=True,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "63bfe5b6-739b-4d19-bab7-0dd44a64db91", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
"ground_truth = test_df.groupby('user')['item'].apply(list).to_list()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "e413590f-53e6-4b10-9959-7f35d82310d5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.12337748344371237,\n", | |
" 'recall': 0.057162073497439896,\n", | |
" 'f1': 0.06389984643404804,\n", | |
" 'ndcg': 0.13396903052648818,\n", | |
" 'hit_rate': 0.5586092715231789,\n", | |
" 'mrr': 0.25996274834437033,\n", | |
" 'map': 0.06665208918372337,\n", | |
" 'tp': 7452,\n", | |
" 'auc': 0.29608130978660835}" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(ranked_list, ground_truth), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "78e31c02", | |
"metadata": {}, | |
"source": [ | |
"## w/ Hyperparameter optimization" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "cc477733", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import optuna\n", | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"def objective(trial):\n", | |
" factors= trial.suggest_int(\"factors\", 8, 128)\n", | |
" regularization = trial.suggest_float(\"regularization\", 1e-4, 1.0, log=True)\n", | |
" alpha = trial.suggest_float(\"alpha\", 1e-3, 1.0, log=True)\n", | |
" iterations = trial.suggest_int(\"iterations\", 5, 20)\n", | |
" \n", | |
" train_df1, train_df2 = temporal_user_split(train_df, test_frac=0.3)\n", | |
" train_df1 = train_df1.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
" })\n", | |
" train_df2 = train_df2.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
" })\n", | |
" train_ds1 = Dataset.construct(train_df1)\n", | |
" \n", | |
" model = ImplicitALSWrapperModel(AlternatingLeastSquares(factors=factors, iterations=iterations, alpha=alpha, regularization=regularization))\n", | |
" model.fit(train_ds1)\n", | |
"\n", | |
" recos = model.recommend(\n", | |
" users=train_df2[Columns.User].unique(),\n", | |
" dataset=train_ds1,\n", | |
" k=10,\n", | |
" filter_viewed=False,\n", | |
" ) \n", | |
" \n", | |
" ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
" ground_truth = train_df2.groupby(Columns.User)[Columns.Item].apply(list).to_list()\n", | |
"\n", | |
" scores = compute_scores(zip(ranked_list, ground_truth), recommend_size=10)\n", | |
" ndcg = scores['ndcg']\n", | |
" #from IPython.core.debugger import Pdb; Pdb().set_trace()\n", | |
" return ndcg" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "023fda4c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[I 2024-12-03 07:09:51,178] A new study created in memory with name: no-name-12b090d3-f6b9-4710-9dba-60f6eca94647\n", | |
"[I 2024-12-03 07:09:52,190] Trial 0 finished with value: 0.03784099050802478 and parameters: {'factors': 17, 'regularization': 0.00041325159442384685, 'alpha': 0.12488745589212062, 'iterations': 13}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:09:53,791] Trial 1 finished with value: 0.013647232779899779 and parameters: {'factors': 104, 'regularization': 0.00019606233876557524, 'alpha': 0.8309073244181959, 'iterations': 19}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:09:54,610] Trial 2 finished with value: 0.010886846451079195 and parameters: {'factors': 85, 'regularization': 0.00021491462628679163, 'alpha': 0.24525636535158987, 'iterations': 5}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:09:55,670] Trial 3 finished with value: 0.021052719213400215 and parameters: {'factors': 21, 'regularization': 0.02155392639508381, 'alpha': 0.013851510857715266, 'iterations': 20}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:09:56,882] Trial 4 finished with value: 0.007583283489394196 and parameters: {'factors': 83, 'regularization': 0.011327641814034037, 'alpha': 0.004213263012401815, 'iterations': 13}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:09:57,661] Trial 5 finished with value: 0.014279328452246429 and parameters: {'factors': 37, 'regularization': 0.0037601502405578613, 'alpha': 0.0022954361494292586, 'iterations': 6}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:09:59,037] Trial 6 finished with value: 0.011442133099663793 and parameters: {'factors': 57, 'regularization': 0.13122966055034088, 'alpha': 0.025285010234324074, 'iterations': 20}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:10:00,061] Trial 7 finished with value: 0.012268138679778532 and parameters: {'factors': 57, 'regularization': 0.001016333167988013, 'alpha': 0.04849576869711544, 'iterations': 14}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:10:00,949] Trial 8 finished with value: 0.00636779520222651 and parameters: {'factors': 122, 'regularization': 0.12611649283301232, 'alpha': 0.03492411714189907, 'iterations': 5}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:10:01,955] Trial 9 finished with value: 0.03228487433432796 and parameters: {'factors': 16, 'regularization': 0.053915325128136224, 'alpha': 0.054776783704446, 'iterations': 18}. Best is trial 0 with value: 0.03784099050802478.\n", | |
"[I 2024-12-03 07:10:02,789] Trial 10 finished with value: 0.04670429133792384 and parameters: {'factors': 8, 'regularization': 0.8014171720931502, 'alpha': 0.144799436205759, 'iterations': 9}. Best is trial 10 with value: 0.04670429133792384.\n", | |
"[I 2024-12-03 07:10:03,558] Trial 11 finished with value: 0.047288808310857064 and parameters: {'factors': 8, 'regularization': 0.8382184023269297, 'alpha': 0.1896659041445441, 'iterations': 9}. Best is trial 11 with value: 0.047288808310857064.\n", | |
"[I 2024-12-03 07:10:04,422] Trial 12 finished with value: 0.028312070003995625 and parameters: {'factors': 37, 'regularization': 0.39798618190794277, 'alpha': 0.430714174363913, 'iterations': 9}. Best is trial 11 with value: 0.047288808310857064.\n", | |
"[I 2024-12-03 07:10:05,210] Trial 13 finished with value: 0.04611991182133565 and parameters: {'factors': 8, 'regularization': 0.8837904966480261, 'alpha': 0.14253500606443725, 'iterations': 9}. Best is trial 11 with value: 0.047288808310857064.\n", | |
"[I 2024-12-03 07:10:06,056] Trial 14 finished with value: 0.034608210931558234 and parameters: {'factors': 35, 'regularization': 0.9695695756841981, 'alpha': 0.9610649778108706, 'iterations': 9}. Best is trial 11 with value: 0.047288808310857064.\n", | |
"[I 2024-12-03 07:10:06,953] Trial 15 finished with value: 0.025822878602092286 and parameters: {'factors': 32, 'regularization': 0.2709204554841261, 'alpha': 0.1516527216541496, 'iterations': 11}. Best is trial 11 with value: 0.047288808310857064.\n", | |
"[I 2024-12-03 07:10:08,048] Trial 16 finished with value: 0.021446578017768773 and parameters: {'factors': 51, 'regularization': 0.045743048934349335, 'alpha': 0.36347970604691954, 'iterations': 16}. Best is trial 11 with value: 0.047288808310857064.\n", | |
"[I 2024-12-03 07:10:08,780] Trial 17 finished with value: 0.037950635690016554 and parameters: {'factors': 8, 'regularization': 0.002060825136162672, 'alpha': 0.010018420564525867, 'iterations': 7}. Best is trial 11 with value: 0.047288808310857064.\n", | |
"[I 2024-12-03 07:10:09,660] Trial 18 finished with value: 0.02718838560064843 and parameters: {'factors': 26, 'regularization': 0.25504876432065393, 'alpha': 0.08453177778205802, 'iterations': 11}. Best is trial 11 with value: 0.047288808310857064.\n", | |
"[I 2024-12-03 07:10:10,609] Trial 19 finished with value: 0.015226816496589779 and parameters: {'factors': 74, 'regularization': 0.09072347562880652, 'alpha': 0.42867332345606274, 'iterations': 8}. Best is trial 11 with value: 0.047288808310857064.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 53.7 s, sys: 3.58 s, total: 57.3 s\n", | |
"Wall time: 19.4 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"{'factors': 8,\n", | |
" 'regularization': 0.8382184023269297,\n", | |
" 'alpha': 0.1896659041445441,\n", | |
" 'iterations': 9}" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"study = optuna.create_study(direction=\"maximize\")\n", | |
"study.optimize(objective, n_trials=20)\n", | |
"study.best_params # E.g. {'x': 2.002108042}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "586d7a0a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 883 ms, sys: 145 ms, total: 1.03 s\n", | |
"Wall time: 214 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<rectools.models.implicit_als.ImplicitALSWrapperModel at 0xffff5ffab440>" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"model = ImplicitALSWrapperModel(AlternatingLeastSquares(**study.best_params))\n", | |
"model.fit(train_dataset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "e5f1f2e0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 346 ms, sys: 113 ms, total: 458 ms\n", | |
"Wall time: 376 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recos = model.recommend(\n", | |
" users=test_df_rectools[Columns.User].unique(),\n", | |
" dataset=train_dataset,\n", | |
" k=10,\n", | |
" filter_viewed=True,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "1dcb0593", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
"ground_truth = test_df.groupby('user')['item'].apply(list).to_list()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "a20bc382", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.129552980132454,\n", | |
" 'recall': 0.0594329503140339,\n", | |
" 'f1': 0.06694009618043677,\n", | |
" 'ndcg': 0.14152567953843076,\n", | |
" 'hit_rate': 0.5665562913907285,\n", | |
" 'mrr': 0.2734034347734669,\n", | |
" 'map': 0.07212871358992476,\n", | |
" 'tp': 7825,\n", | |
" 'auc': 0.3066560890360568}" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(ranked_list, ground_truth), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e037e048-b283-415e-825c-fbb91cedd936", | |
"metadata": {}, | |
"source": [ | |
"## Test filter_viewed=False" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "30c0ad0e-469a-48bf-8ccd-5d2dbbdd2d44", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recos = model.recommend(\n", | |
" users=test_df_rectools[Columns.User].unique(),\n", | |
" dataset=train_dataset,\n", | |
" k=10,\n", | |
" filter_viewed=False,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "59c9daf0-1deb-43d6-b208-df645e221dd1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
"ground_truth = test_df.groupby('user')['item'].apply(list).to_list()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "1dd78d14-f71d-4bfe-8bdf-ebaff49e32b5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(ranked_list, ground_truth), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e6095abe-ef29-450a-ae55-43a006c010a1", | |
"metadata": {}, | |
"source": [ | |
"Very bad NDCG when filter_viewed is False for Movielens dataset because Movielens dataset has single rating per user/item pair." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b9b58755-effe-4791-8620-e9836250fa20", | |
"metadata": {}, | |
"source": [ | |
"# Temporal Global Split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "36342e92-553b-4c43-abf5-bd891506323a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rtrec.experiments.split import temporal_split\n", | |
"train_df, test_df = temporal_split(df)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e47fd10c-5e71-4d97-bd7c-47a5dfd0c382", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"print(f\"num train samples={len(train_df)}, num test samples={len(test_df)}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "178adb72-e6e6-48f6-a184-8ea5a47c017a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"print(f\"unique number of items: {train_df['item'].nunique()}, test: {test_df['item'].nunique()}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d1f054f5-d1e1-4915-9d66-2e5a5419e5ae", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "8855698e-64e7-4ec9-be55-6120ba1a6342", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "54afa1f0-fd25-40f2-b415-ffcb11858f42", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_df_rectools = train_df.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
"})\n", | |
"\n", | |
"train_dataset = Dataset.construct(train_df_rectools)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "38a0e2e7-b44d-4a8c-80cb-8097e620b024", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_df_rectools = test_df.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4273b10d-7453-45f1-8a5b-f414ebf22641", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"from implicit.als import AlternatingLeastSquares\n", | |
"from rectools.models import ImplicitALSWrapperModel\n", | |
"\n", | |
"model = ImplicitALSWrapperModel(AlternatingLeastSquares(factors=64, iterations=10))\n", | |
"model.fit(train_dataset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3aab2315-f8d6-4b5a-a2ae-cb27f6c2a0f3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recos = model.recommend(\n", | |
" users=test_df_rectools[Columns.User].unique(),\n", | |
" dataset=train_dataset, # note same dataset to train should be used.\n", | |
" k=10,\n", | |
" filter_viewed=True,\n", | |
" on_unsupported_targets='warn'\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3511808b-4df1-4c74-ba63-efd544f7f267", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
"ground_truth = test_df.groupby('user')['item'].apply(list).to_list()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d4167013-81d9-46e2-9328-1154084e804c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(ranked_list, ground_truth), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "50910a72", | |
"metadata": {}, | |
"source": [ | |
"# LightFM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "be4bd465", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rectools.models import LightFMWrapperModel\n", | |
"from lightfm import LightFM\n", | |
"\n", | |
"#model = LightFMWrapperModel(LightFM(no_components=10, loss=\"bpr\"), epochs=20)\n", | |
"model = LightFMWrapperModel(LightFM(no_components=10, loss=\"warp\"), epochs=20)\n", | |
"model.fit(train_dataset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5a0fbffb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recos = model.recommend(\n", | |
" users=test_df_rectools[Columns.User].unique(),\n", | |
" dataset=train_dataset,\n", | |
" k=10,\n", | |
" filter_viewed=True,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "74094716", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
"ground_truth = test_df.groupby('user')['item'].apply(list).to_list()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c6728691", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(ranked_list, ground_truth), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "86bd64d5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment