Created
November 28, 2024 05:15
-
-
Save myui/6062a0f08d8382f07b83dc6e85b886a4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 59, | |
"id": "27cddd06-2fc7-4a96-aa57-226109f57442", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sys\n", | |
"\n", | |
"sys.path.append(\"/home/td-user/rtrec\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"id": "15d4770d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ[\"OPENBLAS_NUM_THREADS\"] = \"1\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "07765c5e-6af2-47d7-8320-38dc269a7986", | |
"metadata": {}, | |
"source": [ | |
"# Movielens 1M" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"id": "827b4cc7-1f28-4e44-91e9-6d5cd8cde4dc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Using existing ratings.dat file.\n" | |
] | |
} | |
], | |
"source": [ | |
"from rtrec.experiments.datasets import load_dataset\n", | |
"\n", | |
"df = load_dataset(name='movielens_1m')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"id": "c89ffe93-0b31-4cac-a004-8ada940add7b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user</th>\n", | |
" <th>item</th>\n", | |
" <th>tstamp</th>\n", | |
" <th>rating</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>1193</td>\n", | |
" <td>978300760</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>661</td>\n", | |
" <td>978302109</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1</td>\n", | |
" <td>914</td>\n", | |
" <td>978301968</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1</td>\n", | |
" <td>3408</td>\n", | |
" <td>978300275</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1</td>\n", | |
" <td>2355</td>\n", | |
" <td>978824291</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000204</th>\n", | |
" <td>6040</td>\n", | |
" <td>1091</td>\n", | |
" <td>956716541</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000205</th>\n", | |
" <td>6040</td>\n", | |
" <td>1094</td>\n", | |
" <td>956704887</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000206</th>\n", | |
" <td>6040</td>\n", | |
" <td>562</td>\n", | |
" <td>956704746</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000207</th>\n", | |
" <td>6040</td>\n", | |
" <td>1096</td>\n", | |
" <td>956715648</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000208</th>\n", | |
" <td>6040</td>\n", | |
" <td>1097</td>\n", | |
" <td>956715569</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>1000209 rows × 4 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user item tstamp rating\n", | |
"0 1 1193 978300760 5\n", | |
"1 1 661 978302109 3\n", | |
"2 1 914 978301968 3\n", | |
"3 1 3408 978300275 4\n", | |
"4 1 2355 978824291 5\n", | |
"... ... ... ... ...\n", | |
"1000204 6040 1091 956716541 1\n", | |
"1000205 6040 1094 956704887 5\n", | |
"1000206 6040 562 956704746 5\n", | |
"1000207 6040 1096 956715648 4\n", | |
"1000208 6040 1097 956715569 4\n", | |
"\n", | |
"[1000209 rows x 4 columns]" | |
] | |
}, | |
"execution_count": 62, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df[['user', 'item', 'tstamp', 'rating']]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "74f06c6c-535c-45bd-8fe0-d59f8d369656", | |
"metadata": {}, | |
"source": [ | |
"# Temporal User Split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"id": "7b2e53b8-b1e8-4dc2-90c9-40b00d9f9826", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rtrec.experiments.split import temporal_user_split\n", | |
"train_df, test_df = temporal_user_split(df)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 64, | |
"id": "4c702533-70ce-4e31-9801-bd90c2520ebf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"num train samples=797758, num test samples=202451\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"num train samples={len(train_df)}, num test samples={len(test_df)}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"id": "01eaa43f-823d-4542-bb9c-ce2285c94aa4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"unique number of users: 6040, test: 6040\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"unique number of users: {train_df['user'].nunique()}, test: {test_df['user'].nunique()}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"id": "fa215929-9ef2-4b06-896a-c973ef77876c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"unique number of items: 3667, test: 3535\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"unique number of items: {train_df['item'].nunique()}, test: {test_df['item'].nunique()}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 67, | |
"id": "07a58003-46c2-44c3-922f-328a993014cf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>count</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" count\n", | |
"min 1.0\n", | |
"max 1.0\n", | |
"mean 1.0" | |
] | |
}, | |
"execution_count": 67, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 68, | |
"id": "82fa2779", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>count</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" count\n", | |
"min 1.0\n", | |
"max 1.0\n", | |
"mean 1.0" | |
] | |
}, | |
"execution_count": 68, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"id": "63a0ff91-aa99-49ef-94b8-e1cd7a10af8e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>count</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" count\n", | |
"min 1.0\n", | |
"max 1.0\n", | |
"mean 1.0" | |
] | |
}, | |
"execution_count": 69, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"id": "283a8b63-5f9d-42ce-90a2-859fdd4fa612", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rectools import Columns\n", | |
"from rectools.dataset import Dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 71, | |
"id": "26bedafe-a249-4935-9c2f-ecc3d5962d2a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_df_rectools = train_df.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
"})\n", | |
"\n", | |
"train_dataset = Dataset.construct(train_df_rectools)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"id": "97a9eb39-d38d-4bd0-97da-74a20f7e3b18", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_df_rectools = test_df.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "ef59872b-8f8e-4590-9cd7-7cfef21b998c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/td-user/.local/lib/python3.12/site-packages/implicit/cpu/als.py:95: RuntimeWarning: OpenBLAS is configured to use 6 threads. It is highly recommended to disable its internal threadpool by setting the environment variable 'OPENBLAS_NUM_THREADS=1' or by calling 'threadpoolctl.threadpool_limits(1, \"blas\")'. Having OpenBLAS use a threadpool can lead to severe performance issues here.\n", | |
" check_blas_config()\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 30.9 s, sys: 145 ms, total: 31.1 s\n", | |
"Wall time: 5.99 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<rectools.models.implicit_als.ImplicitALSWrapperModel at 0xffff20a97380>" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"from implicit.als import AlternatingLeastSquares\n", | |
"from rectools.models import ImplicitALSWrapperModel\n", | |
"\n", | |
"model = ImplicitALSWrapperModel(AlternatingLeastSquares(factors=64, iterations=10))\n", | |
"model.fit(train_dataset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "1913fb59-8ef7-456b-a09d-73c563cc66db", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.91 s, sys: 49 µs, total: 1.91 s\n", | |
"Wall time: 326 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recos = model.recommend(\n", | |
" users=test_df_rectools[Columns.User].unique(),\n", | |
" dataset=train_dataset,\n", | |
" k=10,\n", | |
" filter_viewed=True,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "63bfe5b6-739b-4d19-bab7-0dd44a64db91", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
"ground_truth = test_df.groupby('user')['item'].apply(list).to_list()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "e413590f-53e6-4b10-9959-7f35d82310d5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.11690397350993817,\n", | |
" 'recall': 0.06727192248925505,\n", | |
" 'f1': 0.06936100613910212,\n", | |
" 'ndcg': 0.12878704147413883,\n", | |
" 'hit_rate': 0.5922185430463576,\n", | |
" 'mrr': 0.2553735020498256,\n", | |
" 'map': 0.0602965775008802,\n", | |
" 'tp': 7061,\n", | |
" 'auc': 0.30738672737306877}" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(ranked_list, ground_truth), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e037e048-b283-415e-825c-fbb91cedd936", | |
"metadata": {}, | |
"source": [ | |
"## Test filter_viewed=False" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"id": "30c0ad0e-469a-48bf-8ccd-5d2dbbdd2d44", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.07 s, sys: 0 ns, total: 1.07 s\n", | |
"Wall time: 204 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recos = model.recommend(\n", | |
" users=test_df_rectools[Columns.User].unique(),\n", | |
" dataset=train_dataset,\n", | |
" k=10,\n", | |
" filter_viewed=False,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 81, | |
"id": "59c9daf0-1deb-43d6-b208-df645e221dd1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
"ground_truth = test_df.groupby('user')['item'].apply(list).to_list()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"id": "1dd78d14-f71d-4bfe-8bdf-ebaff49e32b5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.01846026490066204,\n", | |
" 'recall': 0.016539939413491602,\n", | |
" 'f1': 0.014715708764502608,\n", | |
" 'ndcg': 0.01938098006289434,\n", | |
" 'hit_rate': 0.15380794701986755,\n", | |
" 'mrr': 0.04177664774519063,\n", | |
" 'map': 0.00634852182226831,\n", | |
" 'tp': 1115,\n", | |
" 'auc': 0.06675355434668337}" | |
] | |
}, | |
"execution_count": 82, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(ranked_list, ground_truth), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e6095abe-ef29-450a-ae55-43a006c010a1", | |
"metadata": {}, | |
"source": [ | |
"Very bad NDCG when filter_viewed is False for Movielens dataset because Movielens dataset has single rating per user/item pair." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b9b58755-effe-4791-8620-e9836250fa20", | |
"metadata": {}, | |
"source": [ | |
"# Temporal Global Split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"id": "36342e92-553b-4c43-abf5-bd891506323a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rtrec.experiments.split import temporal_split\n", | |
"train_df, test_df = temporal_split(df)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"id": "e47fd10c-5e71-4d97-bd7c-47a5dfd0c382", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"num train samples=800167, num test samples=200042\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"num train samples={len(train_df)}, num test samples={len(test_df)}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"id": "178adb72-e6e6-48f6-a184-8ea5a47c017a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"unique number of items: 3662, test: 3511\n" | |
] | |
} | |
], | |
"source": [ | |
"print(f\"unique number of items: {train_df['item'].nunique()}, test: {test_df['item'].nunique()}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"id": "d1f054f5-d1e1-4915-9d66-2e5a5419e5ae", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>count</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" count\n", | |
"min 1.0\n", | |
"max 1.0\n", | |
"mean 1.0" | |
] | |
}, | |
"execution_count": 86, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 87, | |
"id": "8855698e-64e7-4ec9-be55-6120ba1a6342", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>count</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" count\n", | |
"min 1.0\n", | |
"max 1.0\n", | |
"mean 1.0" | |
] | |
}, | |
"execution_count": 87, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_df.groupby(['user', 'item']).size().reset_index(name='count').agg({\n", | |
" 'count': ['min', 'max', 'mean']\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 88, | |
"id": "54afa1f0-fd25-40f2-b415-ffcb11858f42", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_df_rectools = train_df.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
"})\n", | |
"\n", | |
"train_dataset = Dataset.construct(train_df_rectools)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"id": "38a0e2e7-b44d-4a8c-80cb-8097e620b024", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_df_rectools = test_df.rename(columns={\n", | |
" 'user': Columns.User,\n", | |
" 'item': Columns.Item,\n", | |
" 'rating': Columns.Weight,\n", | |
" 'tstamp': Columns.Datetime,\n", | |
"})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"id": "4273b10d-7453-45f1-8a5b-f414ebf22641", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 25.5 s, sys: 69.8 ms, total: 25.5 s\n", | |
"Wall time: 4.38 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<rectools.models.implicit_als.ImplicitALSWrapperModel at 0xffff1b384350>" | |
] | |
}, | |
"execution_count": 90, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"from implicit.als import AlternatingLeastSquares\n", | |
"from rectools.models import ImplicitALSWrapperModel\n", | |
"\n", | |
"model = ImplicitALSWrapperModel(AlternatingLeastSquares(factors=64, iterations=10))\n", | |
"model.fit(train_dataset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 91, | |
"id": "3aab2315-f8d6-4b5a-a2ae-cb27f6c2a0f3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 829 ms, sys: 0 ns, total: 829 ms\n", | |
"Wall time: 148 ms\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/td-user/.local/lib/python3.12/site-packages/rectools/models/base.py:406: UserWarning: \n", | |
" Model `<class 'rectools.models.implicit_als.ImplicitALSWrapperModel'>` doesn't support recommendations for cold users,\n", | |
" but some of given users are cold: they are not in the `dataset.user_id_map`\n", | |
" \n", | |
" warnings.warn(explanation)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recos = model.recommend(\n", | |
" users=test_df_rectools[Columns.User].unique(),\n", | |
" dataset=train_dataset, # note same dataset to train should be used.\n", | |
" k=10,\n", | |
" filter_viewed=True,\n", | |
" on_unsupported_targets='warn'\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 92, | |
"id": "3511808b-4df1-4c74-ba63-efd544f7f267", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
"ground_truth = test_df.groupby('user')['item'].apply(list).to_list()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 93, | |
"id": "d4167013-81d9-46e2-9328-1154084e804c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.1737532808398942,\n", | |
" 'recall': 0.015518869329745933,\n", | |
" 'f1': 0.02565785882256373,\n", | |
" 'ndcg': 0.1780897874583865,\n", | |
" 'hit_rate': 0.6045494313210849,\n", | |
" 'mrr': 0.3092484967156882,\n", | |
" 'map': 0.10165062700495765,\n", | |
" 'tp': 1986,\n", | |
" 'auc': 0.3208006985237956}" | |
] | |
}, | |
"execution_count": 93, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(ranked_list, ground_truth), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "50910a72", | |
"metadata": {}, | |
"source": [ | |
"# LightFM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"id": "be4bd465", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<rectools.models.lightfm.LightFMWrapperModel at 0xffff46131910>" | |
] | |
}, | |
"execution_count": 73, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from rectools.models import LightFMWrapperModel\n", | |
"from lightfm import LightFM\n", | |
"\n", | |
"#model = LightFMWrapperModel(LightFM(no_components=10, loss=\"bpr\"), epochs=20)\n", | |
"model = LightFMWrapperModel(LightFM(no_components=10, loss=\"warp\"), epochs=20)\n", | |
"model.fit(train_dataset)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"id": "5a0fbffb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 173 ms, sys: 37 µs, total: 173 ms\n", | |
"Wall time: 172 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recos = model.recommend(\n", | |
" users=test_df_rectools[Columns.User].unique(),\n", | |
" dataset=train_dataset,\n", | |
" k=10,\n", | |
" filter_viewed=True,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 75, | |
"id": "74094716", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranked_list = recos.groupby('user_id')['item_id'].apply(list).to_list()\n", | |
"ground_truth = test_df.groupby('user')['item'].apply(list).to_list()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"id": "c6728691", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.12958609271523552,\n", | |
" 'recall': 0.060708003278539355,\n", | |
" 'f1': 0.06777845004152794,\n", | |
" 'ndcg': 0.1410367943661583,\n", | |
" 'hit_rate': 0.5809602649006622,\n", | |
" 'mrr': 0.2735684063912527,\n", | |
" 'map': 0.07049961539680885,\n", | |
" 'tp': 7827,\n", | |
" 'auc': 0.30966887417218547}" | |
] | |
}, | |
"execution_count": 76, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from rtrec.utils.metrics import compute_scores\n", | |
"\n", | |
"compute_scores(zip(ranked_list, ground_truth), recommend_size=10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "86bd64d5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment