Created
December 2, 2024 08:38
-
-
Save myui/dd260fdef600f4971b74ba74c45c1380 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "27cddd06-2fc7-4a96-aa57-226109f57442", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import sys\n", | |
"\n", | |
"sys.path.append(\"/home/td-user/rtrec\")\n", | |
"#sys.path.append(\"/Users/myui/workspace/myui/rtrec/\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "ef77cc0a-833f-4d0c-adf7-6c507f93dfe9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#from rtrec._lowlevel import set_notebook_mode\n", | |
"#set_notebook_mode(True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "07765c5e-6af2-47d7-8320-38dc269a7986", | |
"metadata": {}, | |
"source": [ | |
"# Movielens 1M" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "827b4cc7-1f28-4e44-91e9-6d5cd8cde4dc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Using existing ratings.dat file.\n" | |
] | |
} | |
], | |
"source": [ | |
"from rtrec.experiments.datasets import load_dataset\n", | |
"\n", | |
"df = load_dataset(name='movielens_1m')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "c89ffe93-0b31-4cac-a004-8ada940add7b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user</th>\n", | |
" <th>item</th>\n", | |
" <th>tstamp</th>\n", | |
" <th>rating</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>1193</td>\n", | |
" <td>978300760</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>661</td>\n", | |
" <td>978302109</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>1</td>\n", | |
" <td>914</td>\n", | |
" <td>978301968</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>1</td>\n", | |
" <td>3408</td>\n", | |
" <td>978300275</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>1</td>\n", | |
" <td>2355</td>\n", | |
" <td>978824291</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000204</th>\n", | |
" <td>6040</td>\n", | |
" <td>1091</td>\n", | |
" <td>956716541</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000205</th>\n", | |
" <td>6040</td>\n", | |
" <td>1094</td>\n", | |
" <td>956704887</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000206</th>\n", | |
" <td>6040</td>\n", | |
" <td>562</td>\n", | |
" <td>956704746</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000207</th>\n", | |
" <td>6040</td>\n", | |
" <td>1096</td>\n", | |
" <td>956715648</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1000208</th>\n", | |
" <td>6040</td>\n", | |
" <td>1097</td>\n", | |
" <td>956715569</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>1000209 rows × 4 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user item tstamp rating\n", | |
"0 1 1193 978300760 5\n", | |
"1 1 661 978302109 3\n", | |
"2 1 914 978301968 3\n", | |
"3 1 3408 978300275 4\n", | |
"4 1 2355 978824291 5\n", | |
"... ... ... ... ...\n", | |
"1000204 6040 1091 956716541 1\n", | |
"1000205 6040 1094 956704887 5\n", | |
"1000206 6040 562 956704746 5\n", | |
"1000207 6040 1096 956715648 4\n", | |
"1000208 6040 1097 956715569 4\n", | |
"\n", | |
"[1000209 rows x 4 columns]" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df[['user', 'item', 'tstamp', 'rating']]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "2a6f1883-1ebb-435f-be35-5b183a8812ac", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3706" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df['item'].nunique()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "eb5d4826-2a26-47dd-b695-603099f4d7fb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"6040" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df['user'].nunique()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "aa0bfa94-4007-4295-81aa-6b55489139e3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# from rtrec.experiments.split import leave_one_last_item\n", | |
"# train_df, test_df = leave_one_last_item(df)\n", | |
"\n", | |
"from rtrec.experiments.split import temporal_user_split\n", | |
"train_df, test_df = temporal_user_split(df)\n", | |
"\n", | |
"#from rtrec.experiments.split import random_split\n", | |
"#train_df, test_df = random_split(df)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "6a6f03d6-b204-4f71-adae-d190de5e37f8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user</th>\n", | |
" <th>item</th>\n", | |
" <th>rating</th>\n", | |
" <th>tstamp</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>6040</td>\n", | |
" <td>858</td>\n", | |
" <td>4</td>\n", | |
" <td>956703932</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>6040</td>\n", | |
" <td>593</td>\n", | |
" <td>5</td>\n", | |
" <td>956703954</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>6040</td>\n", | |
" <td>2384</td>\n", | |
" <td>4</td>\n", | |
" <td>956703954</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>6040</td>\n", | |
" <td>1961</td>\n", | |
" <td>4</td>\n", | |
" <td>956703977</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>6040</td>\n", | |
" <td>2019</td>\n", | |
" <td>5</td>\n", | |
" <td>956703977</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797753</th>\n", | |
" <td>736</td>\n", | |
" <td>1278</td>\n", | |
" <td>4</td>\n", | |
" <td>1045711206</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797754</th>\n", | |
" <td>736</td>\n", | |
" <td>3671</td>\n", | |
" <td>4</td>\n", | |
" <td>1045711217</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797755</th>\n", | |
" <td>3840</td>\n", | |
" <td>1196</td>\n", | |
" <td>3</td>\n", | |
" <td>1046106127</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797756</th>\n", | |
" <td>3840</td>\n", | |
" <td>1225</td>\n", | |
" <td>3</td>\n", | |
" <td>1046106162</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>797757</th>\n", | |
" <td>3840</td>\n", | |
" <td>2502</td>\n", | |
" <td>5</td>\n", | |
" <td>1046106198</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>797758 rows × 4 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user item rating tstamp\n", | |
"0 6040 858 4 956703932\n", | |
"1 6040 593 5 956703954\n", | |
"2 6040 2384 4 956703954\n", | |
"3 6040 1961 4 956703977\n", | |
"4 6040 2019 5 956703977\n", | |
"... ... ... ... ...\n", | |
"797753 736 1278 4 1045711206\n", | |
"797754 736 3671 4 1045711217\n", | |
"797755 3840 1196 3 1046106127\n", | |
"797756 3840 1225 3 1046106162\n", | |
"797757 3840 2502 5 1046106198\n", | |
"\n", | |
"[797758 rows x 4 columns]" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "2fce7c8c-ba85-46aa-8af0-95abed71d265", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[90m[\u001b[0m2024-12-02T08:22:08Z \u001b[32mINFO \u001b[0m rtrec::slim\u001b[90m]\u001b[0m Creating SlimMSE model with optimizer: adagrad_rda, alpha: 0.001, lambda1: 0.0002, lambda2: 0.0001, rating_range: (-5.0, 10.0), decay_in_days: None, n_recent: 50\n" | |
] | |
} | |
], | |
"source": [ | |
"from rtrec.recommender import Recommender\n", | |
"#from rtrec.models import SLIM_MSE as SlimMSE\n", | |
"from rtrec.models import Fast_SLIM_MSE as SlimMSE\n", | |
"\n", | |
"#model = SlimMSE()\n", | |
"model = SlimMSE(decay_in_days=None, n_recent=50)\n", | |
"#model = SlimMSE(optimizer='sgd', alpha=0.0001, decay_in_days=None, n_recent=None)\n", | |
"#model = SlimMSE(alpha=0.001)\n", | |
"#model = SlimMSE(optimizer=\"adagrad_rda\", alpha=0.001, n_recent=50, lambda1 = 0.0002, lambda2 = 0.0001)\n", | |
"recommender = Recommender(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "90eac400-bb74-4b21-a86f-bb63bece783c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from scipy.sparse import csr_matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "79f22c5e-c083-4b24-a6c5-f2fd0a3d47e7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 0%| | 0/10 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Starting epoch 1/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 10%|████▍ | 1/10 [00:28<04:16, 28.45s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1 completed in 28.42 seconds\n", | |
"Throughput: 28070.33 samples/sec\n", | |
"Empirical loss after epoch 1: 3.4905147552490234\n", | |
"Starting epoch 2/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 20%|████████▊ | 2/10 [01:06<04:33, 34.13s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 2 completed in 38.10 seconds\n", | |
"Throughput: 20939.14 samples/sec\n", | |
"Empirical loss after epoch 2: 3.4076671600341797\n", | |
"Starting epoch 3/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 30%|█████████████▏ | 3/10 [01:43<04:09, 35.62s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 3 completed in 37.36 seconds\n", | |
"Throughput: 21352.26 samples/sec\n", | |
"Empirical loss after epoch 3: 3.34891939163208\n", | |
"Starting epoch 4/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 40%|█████████████████▌ | 4/10 [02:21<03:39, 36.57s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 4 completed in 38.02 seconds\n", | |
"Throughput: 20984.56 samples/sec\n", | |
"Empirical loss after epoch 4: 3.3069534301757812\n", | |
"Starting epoch 5/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 50%|██████████████████████ | 5/10 [02:59<03:05, 37.06s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 5 completed in 37.88 seconds\n", | |
"Throughput: 21059.36 samples/sec\n", | |
"Empirical loss after epoch 5: 3.27057147026062\n", | |
"Starting epoch 6/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 60%|██████████████████████████▍ | 6/10 [03:39<02:31, 37.77s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 6 completed in 39.13 seconds\n", | |
"Throughput: 20389.93 samples/sec\n", | |
"Empirical loss after epoch 6: 3.2380311489105225\n", | |
"Starting epoch 7/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 70%|██████████████████████████████▊ | 7/10 [04:16<01:53, 37.71s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 7 completed in 37.56 seconds\n", | |
"Throughput: 21238.49 samples/sec\n", | |
"Empirical loss after epoch 7: 3.197824478149414\n", | |
"Starting epoch 8/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 80%|███████████████████████████████████▏ | 8/10 [04:54<01:15, 37.78s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 8 completed in 37.82 seconds\n", | |
"Throughput: 21091.29 samples/sec\n", | |
"Empirical loss after epoch 8: 3.1531848907470703\n", | |
"Starting epoch 9/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 90%|███████████████████████████████████████▌ | 9/10 [05:33<00:38, 38.00s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 9 completed in 38.47 seconds\n", | |
"Throughput: 20736.90 samples/sec\n", | |
"Empirical loss after epoch 9: 3.1153697967529297\n", | |
"Starting epoch 10/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|███████████████████████████████████████████| 10/10 [06:11<00:00, 37.10s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 10 completed in 37.95 seconds\n", | |
"Throughput: 21021.12 samples/sec\n", | |
"Empirical loss after epoch 10: 3.0825929641723633\n", | |
"CPU times: user 6min 8s, sys: 5min 2s, total: 11min 11s\n", | |
"Wall time: 6min 13s\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"recommender.fit(train_df, epochs=10, bulk_identify=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "34ae8015-aadd-42ce-9e8c-661c8a1faa09", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#%%time\n", | |
"#recommender.fit(test_df, epochs=10, bulk_identify=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "6fcf1cf3-be44-434b-84af-2209e288cd67", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|███████████████████████████████████████████| 61/61 [00:19<00:00, 3.08it/s]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.11735099337748679,\n", | |
" 'recall': 0.048998880307564885,\n", | |
" 'f1': 0.056674428155052727,\n", | |
" 'ndcg': 0.12906595112419084,\n", | |
" 'hit_rate': 0.5190397350993378,\n", | |
" 'mrr': 0.2556857064017651,\n", | |
" 'map': 0.06567731500252794,\n", | |
" 'tp': 7088,\n", | |
" 'auc': 0.287015353989278}" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"recommender.evaluate(test_df, recommend_size=10, filter_interacted=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "34cc6956-e5c4-4c0a-99bc-f4c8ae20f34b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|███████████████████████████████████████████| 61/61 [00:20<00:00, 2.91it/s]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.044453642384105,\n", | |
" 'recall': 0.030181948847140663,\n", | |
" 'f1': 0.03054455166525799,\n", | |
" 'ndcg': 0.048561648452451256,\n", | |
" 'hit_rate': 0.31440397350993377,\n", | |
" 'mrr': 0.11065062283191433,\n", | |
" 'map': 0.018361158247944716,\n", | |
" 'tp': 2685,\n", | |
" 'auc': 0.16280730973404822}" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"recommender.evaluate(test_df, recommend_size=10, filter_interacted=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "58458ff4-1414-4d6e-ac07-2257248d9e51", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[2858, 1198, 1196, 593, 1210, 318, 2396, 858, 356, 1]" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.recommend(user=1, top_k=10, filter_interacted=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "ede6a584-7db4-4690-8d71-78816d24665f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[2858, 260, 1198, 1196, 593, 1210, 318, 2396, 608, 858]" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.recommend(user=1, top_k=10, filter_interacted=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "3838c017-4674-4bce-91ff-11f491a98b7b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[527, 260, 608, 858, 2762, 296, 1617, 1270, 1197, 1097]" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.recommend(user=2, top_k=10, filter_interacted=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "1fc146b4-873a-468b-8a6f-5bc2b56756bb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"4.026906967163086" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.predict_rating(user=1, item=260)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "6cdac121-0190-4719-9b0d-90587aa1db67", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3.486403226852417" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.predict_rating(user=1, item=527)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "f21aa0ad-27de-4735-8bf5-5cfce234042f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[150,\n", | |
" 260,\n", | |
" 527,\n", | |
" 531,\n", | |
" 594,\n", | |
" 608,\n", | |
" 661,\n", | |
" 720,\n", | |
" 914,\n", | |
" 919,\n", | |
" 938,\n", | |
" 1022,\n", | |
" 1028,\n", | |
" 1029,\n", | |
" 1035,\n", | |
" 1097,\n", | |
" 1193,\n", | |
" 1197,\n", | |
" 1207,\n", | |
" 1246,\n", | |
" 1270,\n", | |
" 1287,\n", | |
" 1545,\n", | |
" 1721,\n", | |
" 1836,\n", | |
" 1961,\n", | |
" 1962,\n", | |
" 2018,\n", | |
" 2028,\n", | |
" 2321,\n", | |
" 2340,\n", | |
" 2398,\n", | |
" 2692,\n", | |
" 2762,\n", | |
" 2791,\n", | |
" 2797,\n", | |
" 2804,\n", | |
" 2918,\n", | |
" 3105,\n", | |
" 3114,\n", | |
" 3186,\n", | |
" 3408]" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"user1_items = train_df[train_df['user']==1]['item'].tolist()\n", | |
"user1_items.sort()\n", | |
"\n", | |
"user1_items" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "8951e54a-d9fa-4462-b6b9-f5b069dec912", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[1, 48, 588, 595, 745, 783, 1566, 1907, 2294, 2355, 2687]" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"user1_items_test = test_df[test_df['user']==1]['item'].tolist()\n", | |
"user1_items_test.sort()\n", | |
"user1_items_test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "7b7e4310-9eeb-43f7-aa5f-55b1155fc3b7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3.3724117279052734" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.predict_rating(user=1, item=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "88724b56-1e66-4957-b592-f1a9d5cb8ac3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"4.026906967163086" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.predict_rating(user=1, item=260)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "20c61d11-a225-4810-97cb-5b8c26a6b726", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.0" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.predict_rating(user=1, item=5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "62fa2d09-d0e9-4f5d-b42f-edac08387dde", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user</th>\n", | |
" <th>item</th>\n", | |
" <th>rating</th>\n", | |
" <th>tstamp</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>40</th>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>5</td>\n", | |
" <td>978824268</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user item rating tstamp\n", | |
"40 1 1 5 978824268" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.query('user==1 & item in (1, 2, 3, 5, 1137, 253)')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"id": "62af9456-8e66-405b-9d7e-64e03a93e274", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>user</th>\n", | |
" <th>item</th>\n", | |
" <th>rating</th>\n", | |
" <th>tstamp</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>633</th>\n", | |
" <td>8</td>\n", | |
" <td>253</td>\n", | |
" <td>5</td>\n", | |
" <td>978230943</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>914</th>\n", | |
" <td>10</td>\n", | |
" <td>253</td>\n", | |
" <td>5</td>\n", | |
" <td>978228886</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7787</th>\n", | |
" <td>53</td>\n", | |
" <td>253</td>\n", | |
" <td>3</td>\n", | |
" <td>977980504</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>10276</th>\n", | |
" <td>73</td>\n", | |
" <td>253</td>\n", | |
" <td>3</td>\n", | |
" <td>981315509</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>11640</th>\n", | |
" <td>90</td>\n", | |
" <td>253</td>\n", | |
" <td>3</td>\n", | |
" <td>993875867</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>992870</th>\n", | |
" <td>5998</td>\n", | |
" <td>253</td>\n", | |
" <td>3</td>\n", | |
" <td>1001832240</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>994430</th>\n", | |
" <td>6005</td>\n", | |
" <td>253</td>\n", | |
" <td>5</td>\n", | |
" <td>956794779</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>994687</th>\n", | |
" <td>6007</td>\n", | |
" <td>253</td>\n", | |
" <td>2</td>\n", | |
" <td>956790310</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>997581</th>\n", | |
" <td>6025</td>\n", | |
" <td>253</td>\n", | |
" <td>4</td>\n", | |
" <td>956730684</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>998401</th>\n", | |
" <td>6035</td>\n", | |
" <td>253</td>\n", | |
" <td>4</td>\n", | |
" <td>956712491</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>738 rows × 4 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" user item rating tstamp\n", | |
"633 8 253 5 978230943\n", | |
"914 10 253 5 978228886\n", | |
"7787 53 253 3 977980504\n", | |
"10276 73 253 3 981315509\n", | |
"11640 90 253 3 993875867\n", | |
"... ... ... ... ...\n", | |
"992870 5998 253 3 1001832240\n", | |
"994430 6005 253 5 956794779\n", | |
"994687 6007 253 2 956790310\n", | |
"997581 6025 253 4 956730684\n", | |
"998401 6035 253 4 956712491\n", | |
"\n", | |
"[738 rows x 4 columns]" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df.query('item == 253')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "8808e32d-318e-4f0d-a91b-8640f6c549ad", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>item</th>\n", | |
" <th>co_occurrence_items</th>\n", | |
" <th>co_occurrence_counts</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1193</td>\n", | |
" <td>[2858, 1196, 260, 608]</td>\n", | |
" <td>[1171, 1127, 1125, 1107]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>661</td>\n", | |
" <td>[480, 1580, 1196, 1]</td>\n", | |
" <td>[397, 394, 384, 383]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>914</td>\n", | |
" <td>[260, 1196, 919, 2396]</td>\n", | |
" <td>[472, 469, 461, 446]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>3408</td>\n", | |
" <td>[2858, 2762, 2396, 3578]</td>\n", | |
" <td>[1014, 838, 805, 790]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>2355</td>\n", | |
" <td>[2858, 1, 1196, 1210]</td>\n", | |
" <td>[1169, 1081, 1029, 1028]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3701</th>\n", | |
" <td>2198</td>\n", | |
" <td>[1193, 1653, 39, 1089, 1215]</td>\n", | |
" <td>[2, 2, 2, 2, 2]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3702</th>\n", | |
" <td>2703</td>\n", | |
" <td>[2600, 2023, 1589, 3250]</td>\n", | |
" <td>[1, 1, 1, 1]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3703</th>\n", | |
" <td>2845</td>\n", | |
" <td>[1840, 58, 1621, 4]</td>\n", | |
" <td>[1, 1, 1, 1]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3704</th>\n", | |
" <td>3607</td>\n", | |
" <td>[2683, 1517, 908, 3386]</td>\n", | |
" <td>[1, 1, 1, 1]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3705</th>\n", | |
" <td>2909</td>\n", | |
" <td>[357, 3786, 1966, 2908]</td>\n", | |
" <td>[1, 1, 1, 1]</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>3706 rows × 3 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" item co_occurrence_items co_occurrence_counts\n", | |
"0 1193 [2858, 1196, 260, 608] [1171, 1127, 1125, 1107]\n", | |
"1 661 [480, 1580, 1196, 1] [397, 394, 384, 383]\n", | |
"2 914 [260, 1196, 919, 2396] [472, 469, 461, 446]\n", | |
"3 3408 [2858, 2762, 2396, 3578] [1014, 838, 805, 790]\n", | |
"4 2355 [2858, 1, 1196, 1210] [1169, 1081, 1029, 1028]\n", | |
"... ... ... ...\n", | |
"3701 2198 [1193, 1653, 39, 1089, 1215] [2, 2, 2, 2, 2]\n", | |
"3702 2703 [2600, 2023, 1589, 3250] [1, 1, 1, 1]\n", | |
"3703 2845 [1840, 58, 1621, 4] [1, 1, 1, 1]\n", | |
"3704 3607 [2683, 1517, 908, 3386] [1, 1, 1, 1]\n", | |
"3705 2909 [357, 3786, 1966, 2908] [1, 1, 1, 1]\n", | |
"\n", | |
"[3706 rows x 3 columns]" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import pandas as pd\n", | |
"from scipy.sparse import coo_matrix\n", | |
"import numpy as np\n", | |
"\n", | |
"# Map item IDs to contiguous indices for sparse matrix construction\n", | |
"item_mapping = {item: idx for idx, item in enumerate(df['item'].unique())}\n", | |
"reverse_mapping = {idx: item for item, idx in item_mapping.items()}\n", | |
"\n", | |
"# Map user and item to their indices\n", | |
"df['item_idx'] = df['item'].map(item_mapping)\n", | |
"df['user_idx'] = pd.factorize(df['user'])[0]\n", | |
"\n", | |
"# Create a user-item sparse matrix\n", | |
"row = df['user_idx']\n", | |
"col = df['item_idx']\n", | |
"data = [1] * len(df)\n", | |
"user_item_matrix = coo_matrix((data, (row, col)))\n", | |
"\n", | |
"# Compute co-occurrence matrix (item-item)\n", | |
"co_occurrence_matrix = user_item_matrix.T @ user_item_matrix\n", | |
"\n", | |
"# Build top-5 co-occurrence lists\n", | |
"result = []\n", | |
"for item in range(co_occurrence_matrix.shape[0]):\n", | |
" # Get the row of the co-occurrence matrix for the current item\n", | |
" start_idx = co_occurrence_matrix.indptr[item]\n", | |
" end_idx = co_occurrence_matrix.indptr[item + 1]\n", | |
" \n", | |
" neighbors = co_occurrence_matrix.indices[start_idx:end_idx]\n", | |
" counts = co_occurrence_matrix.data[start_idx:end_idx]\n", | |
" \n", | |
" # Sort by count and select top-5\n", | |
" top_indices = np.argsort(-counts)[:5] # Sort in descending order\n", | |
" top_neighbors = [reverse_mapping[idx] for idx in neighbors[top_indices] if idx != item] # Exclude self-co-occurrence\n", | |
" top_counts = [counts[idx] for idx in top_indices if neighbors[idx] != item]\n", | |
" \n", | |
" result.append({\n", | |
" 'item': reverse_mapping[item],\n", | |
" 'co_occurrence_items': top_neighbors,\n", | |
" 'co_occurrence_counts': top_counts\n", | |
" })\n", | |
"\n", | |
"# Convert result to a DataFrame\n", | |
"co_occurrence_df = pd.DataFrame(result)\n", | |
"co_occurrence_df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "b18b84ad-684e-477b-a044-96b84e069936", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[549, 1683, 3190, 1303, 2290],\n", | |
" [1833, 1840, 1148, 950, 2696],\n", | |
" [1307, 1196, 260, 1097, 1036]]" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.similar_items(query_items=[242, 302, 1674], top_k=5, filter_query_items=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "43fb9402-8b2a-41a7-bed6-fd023f646110", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>item</th>\n", | |
" <th>co_occurrence_items</th>\n", | |
" <th>co_occurrence_counts</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>124</th>\n", | |
" <td>1196</td>\n", | |
" <td>[260, 1210, 1198, 2571]</td>\n", | |
" <td>[2355, 2228, 1999, 1920]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>162</th>\n", | |
" <td>2728</td>\n", | |
" <td>[1196, 260, 858, 608]</td>\n", | |
" <td>[374, 356, 331, 320]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>167</th>\n", | |
" <td>318</td>\n", | |
" <td>[2858, 593, 608, 2028]</td>\n", | |
" <td>[1684, 1653, 1616, 1531]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>328</th>\n", | |
" <td>16</td>\n", | |
" <td>[593, 608, 2858, 296]</td>\n", | |
" <td>[592, 591, 555, 544]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>382</th>\n", | |
" <td>357</td>\n", | |
" <td>[1265, 2858, 2396, 356]</td>\n", | |
" <td>[995, 917, 916, 889]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>428</th>\n", | |
" <td>1674</td>\n", | |
" <td>[1196, 1210, 1198, 1097]</td>\n", | |
" <td>[837, 762, 757, 754]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>547</th>\n", | |
" <td>1307</td>\n", | |
" <td>[1270, 1196, 1265, 1197]</td>\n", | |
" <td>[1214, 1153, 1119, 1110]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1032</th>\n", | |
" <td>302</td>\n", | |
" <td>[1094, 2858, 2291, 265]</td>\n", | |
" <td>[64, 64, 64, 60]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1944</th>\n", | |
" <td>242</td>\n", | |
" <td>[2858, 608, 589, 2396]</td>\n", | |
" <td>[64, 58, 55, 54]</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" item co_occurrence_items co_occurrence_counts\n", | |
"124 1196 [260, 1210, 1198, 2571] [2355, 2228, 1999, 1920]\n", | |
"162 2728 [1196, 260, 858, 608] [374, 356, 331, 320]\n", | |
"167 318 [2858, 593, 608, 2028] [1684, 1653, 1616, 1531]\n", | |
"328 16 [593, 608, 2858, 296] [592, 591, 555, 544]\n", | |
"382 357 [1265, 2858, 2396, 356] [995, 917, 916, 889]\n", | |
"428 1674 [1196, 1210, 1198, 1097] [837, 762, 757, 754]\n", | |
"547 1307 [1270, 1196, 1265, 1197] [1214, 1153, 1119, 1110]\n", | |
"1032 302 [1094, 2858, 2291, 265] [64, 64, 64, 60]\n", | |
"1944 242 [2858, 608, 589, 2396] [64, 58, 55, 54]" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"co_occurrence_df.query('item in (242, 302, 1674, 357, 318, 16, 2728, 1307, 1196)')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "b2d3b189-b967-4657-b8d0-94b3a47c7793", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[2858, 260, 1198, 1196, 593, 1210, 318, 2396, 608, 858]" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.recommend(user=1, top_k=10, filter_interacted=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bc448d55-b9d0-4123-b5b6-fe82aa3c168b", | |
"metadata": {}, | |
"source": [ | |
"# Factorization Machines" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "305a8409-c1ea-4fdc-b0d7-34bd59fa4865", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from rtrec.recommender import Recommender\n", | |
"from rtrec.models import FactorizationMachines\n", | |
"\n", | |
"#model = FactorizationMachines(n_factors=10, alpha=0.001, decay_in_days=None)\n", | |
"model = FactorizationMachines(n_factors=10, alpha=0.001, decay_in_days=None)\n", | |
"recommender = Recommender(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "9c853742-680d-4bd2-9525-3f4380683185", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 0%| | 0/10 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Starting epoch 1/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 10%|████▍ | 1/10 [00:13<02:04, 13.81s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1 completed in 13.79 seconds\n", | |
"Throughput: 57867.36 samples/sec\n", | |
"Empirical loss after epoch 1: 0.9040387595036169\n", | |
"Starting epoch 2/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 20%|████████▊ | 2/10 [00:28<01:54, 14.29s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 2 completed in 14.61 seconds\n", | |
"Throughput: 54586.31 samples/sec\n", | |
"Empirical loss after epoch 2: 0.8848439280853644\n", | |
"Starting epoch 3/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 30%|█████████████▏ | 3/10 [00:43<01:41, 14.45s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 3 completed in 14.63 seconds\n", | |
"Throughput: 54535.84 samples/sec\n", | |
"Empirical loss after epoch 3: 0.8708736921358707\n", | |
"Starting epoch 4/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 40%|█████████████████▌ | 4/10 [00:57<01:27, 14.58s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 4 completed in 14.75 seconds\n", | |
"Throughput: 54094.24 samples/sec\n", | |
"Empirical loss after epoch 4: 0.8596646970054016\n", | |
"Starting epoch 5/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 50%|██████████████████████ | 5/10 [01:12<01:13, 14.64s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 5 completed in 14.74 seconds\n", | |
"Throughput: 54124.66 samples/sec\n", | |
"Empirical loss after epoch 5: 0.8503504694891765\n", | |
"Starting epoch 6/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 60%|██████████████████████████▍ | 6/10 [01:27<00:58, 14.75s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 6 completed in 14.94 seconds\n", | |
"Throughput: 53392.38 samples/sec\n", | |
"Empirical loss after epoch 6: 0.8424220223559435\n", | |
"Starting epoch 7/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 70%|██████████████████████████████▊ | 7/10 [01:42<00:44, 14.74s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 7 completed in 14.72 seconds\n", | |
"Throughput: 54177.82 samples/sec\n", | |
"Empirical loss after epoch 7: 0.8355521437773976\n", | |
"Starting epoch 8/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 80%|███████████████████████████████████▏ | 8/10 [01:56<00:29, 14.72s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 8 completed in 14.64 seconds\n", | |
"Throughput: 54500.23 samples/sec\n", | |
"Empirical loss after epoch 8: 0.8295195833481861\n", | |
"Starting epoch 9/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 90%|███████████████████████████████████████▌ | 9/10 [02:11<00:14, 14.70s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 9 completed in 14.65 seconds\n", | |
"Throughput: 54467.15 samples/sec\n", | |
"Empirical loss after epoch 9: 0.8241643619978707\n", | |
"Starting epoch 10/10\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|███████████████████████████████████████████| 10/10 [02:26<00:00, 14.62s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 10 completed in 14.62 seconds\n", | |
"Throughput: 54575.20 samples/sec\n", | |
"Empirical loss after epoch 10: 0.8193697607758701\n", | |
"CPU times: user 2min 26s, sys: 0 ns, total: 2min 26s\n", | |
"Wall time: 2min 26s\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"#%pdb on\n", | |
"\n", | |
"recommender.fit(train_df, epochs=10, bulk_identify=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "f18b2390-d753-4852-9b94-dd4add348b81", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|███████████████████████████████████████████| 61/61 [01:57<00:00, 1.93s/it]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': 0.0759933774834448,\n", | |
" 'recall': 0.029059448527411654,\n", | |
" 'f1': 0.034670981879104466,\n", | |
" 'ndcg': 0.08084015799931689,\n", | |
" 'hit_rate': 0.378476821192053,\n", | |
" 'mrr': 0.16370637548617717,\n", | |
" 'map': 0.03806029982050462,\n", | |
" 'tp': 4590,\n", | |
" 'auc': 0.19877746241984673}" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"recommender.evaluate(test_df, recommend_size=10, filter_interacted=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9c65d696-0ee6-40db-8d9e-c8cfd6cb2b9a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment