Skip to content

Instantly share code, notes, and snippets.

Created December 2, 2024 08:38
Show Gist options
  • Save myui/dd260fdef600f4971b74ba74c45c1380 to your computer and use it in GitHub Desktop.
Save myui/dd260fdef600f4971b74ba74c45c1380 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
"cells": [
"cell_type": "code",
"execution_count": 1,
"id": "27cddd06-2fc7-4a96-aa57-226109f57442",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"cell_type": "code",
"execution_count": 2,
"id": "ef77cc0a-833f-4d0c-adf7-6c507f93dfe9",
"metadata": {},
"outputs": [],
"source": [
"#from rtrec._lowlevel import set_notebook_mode\n",
"cell_type": "markdown",
"id": "07765c5e-6af2-47d7-8320-38dc269a7986",
"metadata": {},
"source": [
"# Movielens 1M"
"cell_type": "code",
"execution_count": 3,
"id": "827b4cc7-1f28-4e44-91e9-6d5cd8cde4dc",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Using existing ratings.dat file.\n"
"source": [
"from rtrec.experiments.datasets import load_dataset\n",
"df = load_dataset(name='movielens_1m')"
"cell_type": "code",
"execution_count": 4,
"id": "c89ffe93-0b31-4cac-a004-8ada940add7b",
"metadata": {},
"outputs": [
"data": {
"text/html": [
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>item</th>\n",
" <th>tstamp</th>\n",
" <th>rating</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>1193</td>\n",
" <td>978300760</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>661</td>\n",
" <td>978302109</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>914</td>\n",
" <td>978301968</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>3408</td>\n",
" <td>978300275</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>2355</td>\n",
" <td>978824291</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000204</th>\n",
" <td>6040</td>\n",
" <td>1091</td>\n",
" <td>956716541</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000205</th>\n",
" <td>6040</td>\n",
" <td>1094</td>\n",
" <td>956704887</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000206</th>\n",
" <td>6040</td>\n",
" <td>562</td>\n",
" <td>956704746</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000207</th>\n",
" <td>6040</td>\n",
" <td>1096</td>\n",
" <td>956715648</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000208</th>\n",
" <td>6040</td>\n",
" <td>1097</td>\n",
" <td>956715569</td>\n",
" <td>4</td>\n",
" </tr>\n",
" </tbody>\n",
"<p>1000209 rows × 4 columns</p>\n",
"text/plain": [
" user item tstamp rating\n",
"0 1 1193 978300760 5\n",
"1 1 661 978302109 3\n",
"2 1 914 978301968 3\n",
"3 1 3408 978300275 4\n",
"4 1 2355 978824291 5\n",
"... ... ... ... ...\n",
"1000204 6040 1091 956716541 1\n",
"1000205 6040 1094 956704887 5\n",
"1000206 6040 562 956704746 5\n",
"1000207 6040 1096 956715648 4\n",
"1000208 6040 1097 956715569 4\n",
"[1000209 rows x 4 columns]"
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
"source": [
"df[['user', 'item', 'tstamp', 'rating']]"
"cell_type": "code",
"execution_count": 5,
"id": "2a6f1883-1ebb-435f-be35-5b183a8812ac",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 6,
"id": "eb5d4826-2a26-47dd-b695-603099f4d7fb",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 7,
"id": "aa0bfa94-4007-4295-81aa-6b55489139e3",
"metadata": {},
"outputs": [],
"source": [
"# from rtrec.experiments.split import leave_one_last_item\n",
"# train_df, test_df = leave_one_last_item(df)\n",
"from rtrec.experiments.split import temporal_user_split\n",
"train_df, test_df = temporal_user_split(df)\n",
"#from rtrec.experiments.split import random_split\n",
"#train_df, test_df = random_split(df)"
"cell_type": "code",
"execution_count": 8,
"id": "6a6f03d6-b204-4f71-adae-d190de5e37f8",
"metadata": {},
"outputs": [
"data": {
"text/html": [
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>item</th>\n",
" <th>rating</th>\n",
" <th>tstamp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6040</td>\n",
" <td>858</td>\n",
" <td>4</td>\n",
" <td>956703932</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>6040</td>\n",
" <td>593</td>\n",
" <td>5</td>\n",
" <td>956703954</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6040</td>\n",
" <td>2384</td>\n",
" <td>4</td>\n",
" <td>956703954</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>6040</td>\n",
" <td>1961</td>\n",
" <td>4</td>\n",
" <td>956703977</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>6040</td>\n",
" <td>2019</td>\n",
" <td>5</td>\n",
" <td>956703977</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797753</th>\n",
" <td>736</td>\n",
" <td>1278</td>\n",
" <td>4</td>\n",
" <td>1045711206</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797754</th>\n",
" <td>736</td>\n",
" <td>3671</td>\n",
" <td>4</td>\n",
" <td>1045711217</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797755</th>\n",
" <td>3840</td>\n",
" <td>1196</td>\n",
" <td>3</td>\n",
" <td>1046106127</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797756</th>\n",
" <td>3840</td>\n",
" <td>1225</td>\n",
" <td>3</td>\n",
" <td>1046106162</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797757</th>\n",
" <td>3840</td>\n",
" <td>2502</td>\n",
" <td>5</td>\n",
" <td>1046106198</td>\n",
" </tr>\n",
" </tbody>\n",
"<p>797758 rows × 4 columns</p>\n",
"text/plain": [
" user item rating tstamp\n",
"0 6040 858 4 956703932\n",
"1 6040 593 5 956703954\n",
"2 6040 2384 4 956703954\n",
"3 6040 1961 4 956703977\n",
"4 6040 2019 5 956703977\n",
"... ... ... ... ...\n",
"797753 736 1278 4 1045711206\n",
"797754 736 3671 4 1045711217\n",
"797755 3840 1196 3 1046106127\n",
"797756 3840 1225 3 1046106162\n",
"797757 3840 2502 5 1046106198\n",
"[797758 rows x 4 columns]"
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 9,
"id": "2fce7c8c-ba85-46aa-8af0-95abed71d265",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[90m[\u001b[0m2024-12-02T08:22:08Z \u001b[32mINFO \u001b[0m rtrec::slim\u001b[90m]\u001b[0m Creating SlimMSE model with optimizer: adagrad_rda, alpha: 0.001, lambda1: 0.0002, lambda2: 0.0001, rating_range: (-5.0, 10.0), decay_in_days: None, n_recent: 50\n"
"source": [
"from rtrec.recommender import Recommender\n",
"#from rtrec.models import SLIM_MSE as SlimMSE\n",
"from rtrec.models import Fast_SLIM_MSE as SlimMSE\n",
"#model = SlimMSE()\n",
"model = SlimMSE(decay_in_days=None, n_recent=50)\n",
"#model = SlimMSE(optimizer='sgd', alpha=0.0001, decay_in_days=None, n_recent=None)\n",
"#model = SlimMSE(alpha=0.001)\n",
"#model = SlimMSE(optimizer=\"adagrad_rda\", alpha=0.001, n_recent=50, lambda1 = 0.0002, lambda2 = 0.0001)\n",
"recommender = Recommender(model)"
"cell_type": "code",
"execution_count": 10,
"id": "90eac400-bb74-4b21-a86f-bb63bece783c",
"metadata": {},
"outputs": [],
"source": [
"from scipy.sparse import csr_matrix"
"cell_type": "code",
"execution_count": 11,
"id": "79f22c5e-c083-4b24-a6c5-f2fd0a3d47e7",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/10 [00:00<?, ?it/s]"
"name": "stdout",
"output_type": "stream",
"text": [
"Starting epoch 1/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 10%|████▍ | 1/10 [00:28<04:16, 28.45s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 completed in 28.42 seconds\n",
"Throughput: 28070.33 samples/sec\n",
"Empirical loss after epoch 1: 3.4905147552490234\n",
"Starting epoch 2/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 20%|████████▊ | 2/10 [01:06<04:33, 34.13s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2 completed in 38.10 seconds\n",
"Throughput: 20939.14 samples/sec\n",
"Empirical loss after epoch 2: 3.4076671600341797\n",
"Starting epoch 3/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 30%|█████████████▏ | 3/10 [01:43<04:09, 35.62s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3 completed in 37.36 seconds\n",
"Throughput: 21352.26 samples/sec\n",
"Empirical loss after epoch 3: 3.34891939163208\n",
"Starting epoch 4/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 40%|█████████████████▌ | 4/10 [02:21<03:39, 36.57s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4 completed in 38.02 seconds\n",
"Throughput: 20984.56 samples/sec\n",
"Empirical loss after epoch 4: 3.3069534301757812\n",
"Starting epoch 5/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 50%|██████████████████████ | 5/10 [02:59<03:05, 37.06s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5 completed in 37.88 seconds\n",
"Throughput: 21059.36 samples/sec\n",
"Empirical loss after epoch 5: 3.27057147026062\n",
"Starting epoch 6/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 60%|██████████████████████████▍ | 6/10 [03:39<02:31, 37.77s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6 completed in 39.13 seconds\n",
"Throughput: 20389.93 samples/sec\n",
"Empirical loss after epoch 6: 3.2380311489105225\n",
"Starting epoch 7/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 70%|██████████████████████████████▊ | 7/10 [04:16<01:53, 37.71s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7 completed in 37.56 seconds\n",
"Throughput: 21238.49 samples/sec\n",
"Empirical loss after epoch 7: 3.197824478149414\n",
"Starting epoch 8/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|███████████████████████████████████▏ | 8/10 [04:54<01:15, 37.78s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8 completed in 37.82 seconds\n",
"Throughput: 21091.29 samples/sec\n",
"Empirical loss after epoch 8: 3.1531848907470703\n",
"Starting epoch 9/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 90%|███████████████████████████████████████▌ | 9/10 [05:33<00:38, 38.00s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9 completed in 38.47 seconds\n",
"Throughput: 20736.90 samples/sec\n",
"Empirical loss after epoch 9: 3.1153697967529297\n",
"Starting epoch 10/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████| 10/10 [06:11<00:00, 37.10s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10 completed in 37.95 seconds\n",
"Throughput: 21021.12 samples/sec\n",
"Empirical loss after epoch 10: 3.0825929641723633\n",
"CPU times: user 6min 8s, sys: 5min 2s, total: 11min 11s\n",
"Wall time: 6min 13s\n"
"name": "stderr",
"output_type": "stream",
"text": [
"source": [
", epochs=10, bulk_identify=True)"
"cell_type": "code",
"execution_count": 12,
"id": "34ae8015-aadd-42ce-9e8c-661c8a1faa09",
"metadata": {},
"outputs": [],
"source": [
", epochs=10, bulk_identify=True)"
"cell_type": "code",
"execution_count": 13,
"id": "6fcf1cf3-be44-434b-84af-2209e288cd67",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████| 61/61 [00:19<00:00, 3.08it/s]\n"
"data": {
"text/plain": [
"{'precision': 0.11735099337748679,\n",
" 'recall': 0.048998880307564885,\n",
" 'f1': 0.056674428155052727,\n",
" 'ndcg': 0.12906595112419084,\n",
" 'hit_rate': 0.5190397350993378,\n",
" 'mrr': 0.2556857064017651,\n",
" 'map': 0.06567731500252794,\n",
" 'tp': 7088,\n",
" 'auc': 0.287015353989278}"
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
"source": [
"recommender.evaluate(test_df, recommend_size=10, filter_interacted=True)"
"cell_type": "code",
"execution_count": 14,
"id": "34cc6956-e5c4-4c0a-99bc-f4c8ae20f34b",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████| 61/61 [00:20<00:00, 2.91it/s]\n"
"data": {
"text/plain": [
"{'precision': 0.044453642384105,\n",
" 'recall': 0.030181948847140663,\n",
" 'f1': 0.03054455166525799,\n",
" 'ndcg': 0.048561648452451256,\n",
" 'hit_rate': 0.31440397350993377,\n",
" 'mrr': 0.11065062283191433,\n",
" 'map': 0.018361158247944716,\n",
" 'tp': 2685,\n",
" 'auc': 0.16280730973404822}"
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
"source": [
"recommender.evaluate(test_df, recommend_size=10, filter_interacted=False)"
"cell_type": "code",
"execution_count": 15,
"id": "58458ff4-1414-4d6e-ac07-2257248d9e51",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"[2858, 1198, 1196, 593, 1210, 318, 2396, 858, 356, 1]"
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.recommend(user=1, top_k=10, filter_interacted=True)"
"cell_type": "code",
"execution_count": 16,
"id": "ede6a584-7db4-4690-8d71-78816d24665f",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"[2858, 260, 1198, 1196, 593, 1210, 318, 2396, 608, 858]"
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.recommend(user=1, top_k=10, filter_interacted=False)"
"cell_type": "code",
"execution_count": 17,
"id": "3838c017-4674-4bce-91ff-11f491a98b7b",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"[527, 260, 608, 858, 2762, 296, 1617, 1270, 1197, 1097]"
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.recommend(user=2, top_k=10, filter_interacted=True)"
"cell_type": "code",
"execution_count": 18,
"id": "1fc146b4-873a-468b-8a6f-5bc2b56756bb",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.predict_rating(user=1, item=260)"
"cell_type": "code",
"execution_count": 19,
"id": "6cdac121-0190-4719-9b0d-90587aa1db67",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.predict_rating(user=1, item=527)"
"cell_type": "code",
"execution_count": 20,
"id": "f21aa0ad-27de-4735-8bf5-5cfce234042f",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
" 260,\n",
" 527,\n",
" 531,\n",
" 594,\n",
" 608,\n",
" 661,\n",
" 720,\n",
" 914,\n",
" 919,\n",
" 938,\n",
" 1022,\n",
" 1028,\n",
" 1029,\n",
" 1035,\n",
" 1097,\n",
" 1193,\n",
" 1197,\n",
" 1207,\n",
" 1246,\n",
" 1270,\n",
" 1287,\n",
" 1545,\n",
" 1721,\n",
" 1836,\n",
" 1961,\n",
" 1962,\n",
" 2018,\n",
" 2028,\n",
" 2321,\n",
" 2340,\n",
" 2398,\n",
" 2692,\n",
" 2762,\n",
" 2791,\n",
" 2797,\n",
" 2804,\n",
" 2918,\n",
" 3105,\n",
" 3114,\n",
" 3186,\n",
" 3408]"
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
"source": [
"user1_items = train_df[train_df['user']==1]['item'].tolist()\n",
"cell_type": "code",
"execution_count": 21,
"id": "8951e54a-d9fa-4462-b6b9-f5b069dec912",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"[1, 48, 588, 595, 745, 783, 1566, 1907, 2294, 2355, 2687]"
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
"source": [
"user1_items_test = test_df[test_df['user']==1]['item'].tolist()\n",
"cell_type": "code",
"execution_count": 22,
"id": "7b7e4310-9eeb-43f7-aa5f-55b1155fc3b7",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.predict_rating(user=1, item=1)"
"cell_type": "code",
"execution_count": 23,
"id": "88724b56-1e66-4957-b592-f1a9d5cb8ac3",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.predict_rating(user=1, item=260)"
"cell_type": "code",
"execution_count": 24,
"id": "20c61d11-a225-4810-97cb-5b8c26a6b726",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.predict_rating(user=1, item=5)"
"cell_type": "code",
"execution_count": 25,
"id": "62fa2d09-d0e9-4f5d-b42f-edac08387dde",
"metadata": {},
"outputs": [
"data": {
"text/html": [
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>item</th>\n",
" <th>rating</th>\n",
" <th>tstamp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>40</th>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>5</td>\n",
" <td>978824268</td>\n",
" </tr>\n",
" </tbody>\n",
"text/plain": [
" user item rating tstamp\n",
"40 1 1 5 978824268"
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
"source": [
"df.query('user==1 & item in (1, 2, 3, 5, 1137, 253)')"
"cell_type": "code",
"execution_count": 26,
"id": "62af9456-8e66-405b-9d7e-64e03a93e274",
"metadata": {},
"outputs": [
"data": {
"text/html": [
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>item</th>\n",
" <th>rating</th>\n",
" <th>tstamp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>633</th>\n",
" <td>8</td>\n",
" <td>253</td>\n",
" <td>5</td>\n",
" <td>978230943</td>\n",
" </tr>\n",
" <tr>\n",
" <th>914</th>\n",
" <td>10</td>\n",
" <td>253</td>\n",
" <td>5</td>\n",
" <td>978228886</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7787</th>\n",
" <td>53</td>\n",
" <td>253</td>\n",
" <td>3</td>\n",
" <td>977980504</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10276</th>\n",
" <td>73</td>\n",
" <td>253</td>\n",
" <td>3</td>\n",
" <td>981315509</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11640</th>\n",
" <td>90</td>\n",
" <td>253</td>\n",
" <td>3</td>\n",
" <td>993875867</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>992870</th>\n",
" <td>5998</td>\n",
" <td>253</td>\n",
" <td>3</td>\n",
" <td>1001832240</td>\n",
" </tr>\n",
" <tr>\n",
" <th>994430</th>\n",
" <td>6005</td>\n",
" <td>253</td>\n",
" <td>5</td>\n",
" <td>956794779</td>\n",
" </tr>\n",
" <tr>\n",
" <th>994687</th>\n",
" <td>6007</td>\n",
" <td>253</td>\n",
" <td>2</td>\n",
" <td>956790310</td>\n",
" </tr>\n",
" <tr>\n",
" <th>997581</th>\n",
" <td>6025</td>\n",
" <td>253</td>\n",
" <td>4</td>\n",
" <td>956730684</td>\n",
" </tr>\n",
" <tr>\n",
" <th>998401</th>\n",
" <td>6035</td>\n",
" <td>253</td>\n",
" <td>4</td>\n",
" <td>956712491</td>\n",
" </tr>\n",
" </tbody>\n",
"<p>738 rows × 4 columns</p>\n",
"text/plain": [
" user item rating tstamp\n",
"633 8 253 5 978230943\n",
"914 10 253 5 978228886\n",
"7787 53 253 3 977980504\n",
"10276 73 253 3 981315509\n",
"11640 90 253 3 993875867\n",
"... ... ... ... ...\n",
"992870 5998 253 3 1001832240\n",
"994430 6005 253 5 956794779\n",
"994687 6007 253 2 956790310\n",
"997581 6025 253 4 956730684\n",
"998401 6035 253 4 956712491\n",
"[738 rows x 4 columns]"
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
"source": [
"df.query('item == 253')"
"cell_type": "code",
"execution_count": 27,
"id": "8808e32d-318e-4f0d-a91b-8640f6c549ad",
"metadata": {},
"outputs": [
"data": {
"text/html": [
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>item</th>\n",
" <th>co_occurrence_items</th>\n",
" <th>co_occurrence_counts</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1193</td>\n",
" <td>[2858, 1196, 260, 608]</td>\n",
" <td>[1171, 1127, 1125, 1107]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>661</td>\n",
" <td>[480, 1580, 1196, 1]</td>\n",
" <td>[397, 394, 384, 383]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>914</td>\n",
" <td>[260, 1196, 919, 2396]</td>\n",
" <td>[472, 469, 461, 446]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3408</td>\n",
" <td>[2858, 2762, 2396, 3578]</td>\n",
" <td>[1014, 838, 805, 790]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2355</td>\n",
" <td>[2858, 1, 1196, 1210]</td>\n",
" <td>[1169, 1081, 1029, 1028]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3701</th>\n",
" <td>2198</td>\n",
" <td>[1193, 1653, 39, 1089, 1215]</td>\n",
" <td>[2, 2, 2, 2, 2]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3702</th>\n",
" <td>2703</td>\n",
" <td>[2600, 2023, 1589, 3250]</td>\n",
" <td>[1, 1, 1, 1]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3703</th>\n",
" <td>2845</td>\n",
" <td>[1840, 58, 1621, 4]</td>\n",
" <td>[1, 1, 1, 1]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3704</th>\n",
" <td>3607</td>\n",
" <td>[2683, 1517, 908, 3386]</td>\n",
" <td>[1, 1, 1, 1]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3705</th>\n",
" <td>2909</td>\n",
" <td>[357, 3786, 1966, 2908]</td>\n",
" <td>[1, 1, 1, 1]</td>\n",
" </tr>\n",
" </tbody>\n",
"<p>3706 rows × 3 columns</p>\n",
"text/plain": [
" item co_occurrence_items co_occurrence_counts\n",
"0 1193 [2858, 1196, 260, 608] [1171, 1127, 1125, 1107]\n",
"1 661 [480, 1580, 1196, 1] [397, 394, 384, 383]\n",
"2 914 [260, 1196, 919, 2396] [472, 469, 461, 446]\n",
"3 3408 [2858, 2762, 2396, 3578] [1014, 838, 805, 790]\n",
"4 2355 [2858, 1, 1196, 1210] [1169, 1081, 1029, 1028]\n",
"... ... ... ...\n",
"3701 2198 [1193, 1653, 39, 1089, 1215] [2, 2, 2, 2, 2]\n",
"3702 2703 [2600, 2023, 1589, 3250] [1, 1, 1, 1]\n",
"3703 2845 [1840, 58, 1621, 4] [1, 1, 1, 1]\n",
"3704 3607 [2683, 1517, 908, 3386] [1, 1, 1, 1]\n",
"3705 2909 [357, 3786, 1966, 2908] [1, 1, 1, 1]\n",
"[3706 rows x 3 columns]"
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
"source": [
"import pandas as pd\n",
"from scipy.sparse import coo_matrix\n",
"import numpy as np\n",
"# Map item IDs to contiguous indices for sparse matrix construction\n",
"item_mapping = {item: idx for idx, item in enumerate(df['item'].unique())}\n",
"reverse_mapping = {idx: item for item, idx in item_mapping.items()}\n",
"# Map user and item to their indices\n",
"df['item_idx'] = df['item'].map(item_mapping)\n",
"df['user_idx'] = pd.factorize(df['user'])[0]\n",
"# Create a user-item sparse matrix\n",
"row = df['user_idx']\n",
"col = df['item_idx']\n",
"data = [1] * len(df)\n",
"user_item_matrix = coo_matrix((data, (row, col)))\n",
"# Compute co-occurrence matrix (item-item)\n",
"co_occurrence_matrix = user_item_matrix.T @ user_item_matrix\n",
"# Build top-5 co-occurrence lists\n",
"result = []\n",
"for item in range(co_occurrence_matrix.shape[0]):\n",
" # Get the row of the co-occurrence matrix for the current item\n",
" start_idx = co_occurrence_matrix.indptr[item]\n",
" end_idx = co_occurrence_matrix.indptr[item + 1]\n",
" \n",
" neighbors = co_occurrence_matrix.indices[start_idx:end_idx]\n",
" counts =[start_idx:end_idx]\n",
" \n",
" # Sort by count and select top-5\n",
" top_indices = np.argsort(-counts)[:5] # Sort in descending order\n",
" top_neighbors = [reverse_mapping[idx] for idx in neighbors[top_indices] if idx != item] # Exclude self-co-occurrence\n",
" top_counts = [counts[idx] for idx in top_indices if neighbors[idx] != item]\n",
" \n",
" result.append({\n",
" 'item': reverse_mapping[item],\n",
" 'co_occurrence_items': top_neighbors,\n",
" 'co_occurrence_counts': top_counts\n",
" })\n",
"# Convert result to a DataFrame\n",
"co_occurrence_df = pd.DataFrame(result)\n",
"cell_type": "code",
"execution_count": 28,
"id": "b18b84ad-684e-477b-a044-96b84e069936",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"[[549, 1683, 3190, 1303, 2290],\n",
" [1833, 1840, 1148, 950, 2696],\n",
" [1307, 1196, 260, 1097, 1036]]"
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.similar_items(query_items=[242, 302, 1674], top_k=5, filter_query_items=True)"
"cell_type": "code",
"execution_count": 29,
"id": "43fb9402-8b2a-41a7-bed6-fd023f646110",
"metadata": {},
"outputs": [
"data": {
"text/html": [
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>item</th>\n",
" <th>co_occurrence_items</th>\n",
" <th>co_occurrence_counts</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>124</th>\n",
" <td>1196</td>\n",
" <td>[260, 1210, 1198, 2571]</td>\n",
" <td>[2355, 2228, 1999, 1920]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>162</th>\n",
" <td>2728</td>\n",
" <td>[1196, 260, 858, 608]</td>\n",
" <td>[374, 356, 331, 320]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>167</th>\n",
" <td>318</td>\n",
" <td>[2858, 593, 608, 2028]</td>\n",
" <td>[1684, 1653, 1616, 1531]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>328</th>\n",
" <td>16</td>\n",
" <td>[593, 608, 2858, 296]</td>\n",
" <td>[592, 591, 555, 544]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>382</th>\n",
" <td>357</td>\n",
" <td>[1265, 2858, 2396, 356]</td>\n",
" <td>[995, 917, 916, 889]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>428</th>\n",
" <td>1674</td>\n",
" <td>[1196, 1210, 1198, 1097]</td>\n",
" <td>[837, 762, 757, 754]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>547</th>\n",
" <td>1307</td>\n",
" <td>[1270, 1196, 1265, 1197]</td>\n",
" <td>[1214, 1153, 1119, 1110]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1032</th>\n",
" <td>302</td>\n",
" <td>[1094, 2858, 2291, 265]</td>\n",
" <td>[64, 64, 64, 60]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1944</th>\n",
" <td>242</td>\n",
" <td>[2858, 608, 589, 2396]</td>\n",
" <td>[64, 58, 55, 54]</td>\n",
" </tr>\n",
" </tbody>\n",
"text/plain": [
" item co_occurrence_items co_occurrence_counts\n",
"124 1196 [260, 1210, 1198, 2571] [2355, 2228, 1999, 1920]\n",
"162 2728 [1196, 260, 858, 608] [374, 356, 331, 320]\n",
"167 318 [2858, 593, 608, 2028] [1684, 1653, 1616, 1531]\n",
"328 16 [593, 608, 2858, 296] [592, 591, 555, 544]\n",
"382 357 [1265, 2858, 2396, 356] [995, 917, 916, 889]\n",
"428 1674 [1196, 1210, 1198, 1097] [837, 762, 757, 754]\n",
"547 1307 [1270, 1196, 1265, 1197] [1214, 1153, 1119, 1110]\n",
"1032 302 [1094, 2858, 2291, 265] [64, 64, 64, 60]\n",
"1944 242 [2858, 608, 589, 2396] [64, 58, 55, 54]"
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
"source": [
"co_occurrence_df.query('item in (242, 302, 1674, 357, 318, 16, 2728, 1307, 1196)')"
"cell_type": "code",
"execution_count": 30,
"id": "b2d3b189-b967-4657-b8d0-94b3a47c7793",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"[2858, 260, 1198, 1196, 593, 1210, 318, 2396, 608, 858]"
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
"source": [
"model.recommend(user=1, top_k=10, filter_interacted=False)"
"cell_type": "markdown",
"id": "bc448d55-b9d0-4123-b5b6-fe82aa3c168b",
"metadata": {},
"source": [
"# Factorization Machines"
"cell_type": "code",
"execution_count": 31,
"id": "305a8409-c1ea-4fdc-b0d7-34bd59fa4865",
"metadata": {},
"outputs": [],
"source": [
"from rtrec.recommender import Recommender\n",
"from rtrec.models import FactorizationMachines\n",
"#model = FactorizationMachines(n_factors=10, alpha=0.001, decay_in_days=None)\n",
"model = FactorizationMachines(n_factors=10, alpha=0.001, decay_in_days=None)\n",
"recommender = Recommender(model)"
"cell_type": "code",
"execution_count": 32,
"id": "9c853742-680d-4bd2-9525-3f4380683185",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/10 [00:00<?, ?it/s]"
"name": "stdout",
"output_type": "stream",
"text": [
"Starting epoch 1/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 10%|████▍ | 1/10 [00:13<02:04, 13.81s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 completed in 13.79 seconds\n",
"Throughput: 57867.36 samples/sec\n",
"Empirical loss after epoch 1: 0.9040387595036169\n",
"Starting epoch 2/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 20%|████████▊ | 2/10 [00:28<01:54, 14.29s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2 completed in 14.61 seconds\n",
"Throughput: 54586.31 samples/sec\n",
"Empirical loss after epoch 2: 0.8848439280853644\n",
"Starting epoch 3/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 30%|█████████████▏ | 3/10 [00:43<01:41, 14.45s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3 completed in 14.63 seconds\n",
"Throughput: 54535.84 samples/sec\n",
"Empirical loss after epoch 3: 0.8708736921358707\n",
"Starting epoch 4/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 40%|█████████████████▌ | 4/10 [00:57<01:27, 14.58s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4 completed in 14.75 seconds\n",
"Throughput: 54094.24 samples/sec\n",
"Empirical loss after epoch 4: 0.8596646970054016\n",
"Starting epoch 5/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 50%|██████████████████████ | 5/10 [01:12<01:13, 14.64s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5 completed in 14.74 seconds\n",
"Throughput: 54124.66 samples/sec\n",
"Empirical loss after epoch 5: 0.8503504694891765\n",
"Starting epoch 6/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 60%|██████████████████████████▍ | 6/10 [01:27<00:58, 14.75s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6 completed in 14.94 seconds\n",
"Throughput: 53392.38 samples/sec\n",
"Empirical loss after epoch 6: 0.8424220223559435\n",
"Starting epoch 7/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 70%|██████████████████████████████▊ | 7/10 [01:42<00:44, 14.74s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7 completed in 14.72 seconds\n",
"Throughput: 54177.82 samples/sec\n",
"Empirical loss after epoch 7: 0.8355521437773976\n",
"Starting epoch 8/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 80%|███████████████████████████████████▏ | 8/10 [01:56<00:29, 14.72s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8 completed in 14.64 seconds\n",
"Throughput: 54500.23 samples/sec\n",
"Empirical loss after epoch 8: 0.8295195833481861\n",
"Starting epoch 9/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
" 90%|███████████████████████████████████████▌ | 9/10 [02:11<00:14, 14.70s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9 completed in 14.65 seconds\n",
"Throughput: 54467.15 samples/sec\n",
"Empirical loss after epoch 9: 0.8241643619978707\n",
"Starting epoch 10/10\n"
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████| 10/10 [02:26<00:00, 14.62s/it]"
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10 completed in 14.62 seconds\n",
"Throughput: 54575.20 samples/sec\n",
"Empirical loss after epoch 10: 0.8193697607758701\n",
"CPU times: user 2min 26s, sys: 0 ns, total: 2min 26s\n",
"Wall time: 2min 26s\n"
"name": "stderr",
"output_type": "stream",
"text": [
"source": [
"#%pdb on\n",
", epochs=10, bulk_identify=False)"
"cell_type": "code",
"execution_count": 33,
"id": "f18b2390-d753-4852-9b94-dd4add348b81",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████| 61/61 [01:57<00:00, 1.93s/it]\n"
"data": {
"text/plain": [
"{'precision': 0.0759933774834448,\n",
" 'recall': 0.029059448527411654,\n",
" 'f1': 0.034670981879104466,\n",
" 'ndcg': 0.08084015799931689,\n",
" 'hit_rate': 0.378476821192053,\n",
" 'mrr': 0.16370637548617717,\n",
" 'map': 0.03806029982050462,\n",
" 'tp': 4590,\n",
" 'auc': 0.19877746241984673}"
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
"source": [
"recommender.evaluate(test_df, recommend_size=10, filter_interacted=True)"
"cell_type": "code",
"execution_count": null,
"id": "9c65d696-0ee6-40db-8d9e-c8cfd6cb2b9a",
"metadata": {},
"outputs": [],
"source": []
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"nbformat": 4,
"nbformat_minor": 5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment