Skip to content

Instantly share code, notes, and snippets.

@myui
Created December 2, 2024 07:43
Show Gist options
  • Save myui/3d11f27a071e37b75db829a23d2b97fe to your computer and use it in GitHub Desktop.
Save myui/3d11f27a071e37b75db829a23d2b97fe to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "27cddd06-2fc7-4a96-aa57-226109f57442",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"/home/td-user/rtrec\")"
]
},
{
"cell_type": "markdown",
"id": "07765c5e-6af2-47d7-8320-38dc269a7986",
"metadata": {},
"source": [
"# Movielens 1M"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "827b4cc7-1f28-4e44-91e9-6d5cd8cde4dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using existing ratings.dat file.\n"
]
}
],
"source": [
"from rtrec.experiments.datasets import load_dataset\n",
"\n",
"df = load_dataset(name='movielens_1m')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c89ffe93-0b31-4cac-a004-8ada940add7b",
"metadata": {},
"outputs": [],
"source": [
"df['user'] = df['user'].astype(\"category\")\n",
"df['item'] = df['item'].astype(\"category\")\n",
"df['user_id'] = df['user'].cat.codes\n",
"df['item_id'] = df['item'].cat.codes"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "499497fa",
"metadata": {},
"outputs": [],
"source": [
"iid2id = dict(enumerate(df['item'].cat.categories))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "52bcecc2",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>item</th>\n",
" <th>rating</th>\n",
" <th>tstamp</th>\n",
" <th>user_id</th>\n",
" <th>item_id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>1193</td>\n",
" <td>5</td>\n",
" <td>978300760</td>\n",
" <td>0</td>\n",
" <td>1104</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>661</td>\n",
" <td>3</td>\n",
" <td>978302109</td>\n",
" <td>0</td>\n",
" <td>639</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>914</td>\n",
" <td>3</td>\n",
" <td>978301968</td>\n",
" <td>0</td>\n",
" <td>853</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>3408</td>\n",
" <td>4</td>\n",
" <td>978300275</td>\n",
" <td>0</td>\n",
" <td>3177</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1</td>\n",
" <td>2355</td>\n",
" <td>5</td>\n",
" <td>978824291</td>\n",
" <td>0</td>\n",
" <td>2162</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000204</th>\n",
" <td>6040</td>\n",
" <td>1091</td>\n",
" <td>1</td>\n",
" <td>956716541</td>\n",
" <td>6039</td>\n",
" <td>1019</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000205</th>\n",
" <td>6040</td>\n",
" <td>1094</td>\n",
" <td>5</td>\n",
" <td>956704887</td>\n",
" <td>6039</td>\n",
" <td>1022</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000206</th>\n",
" <td>6040</td>\n",
" <td>562</td>\n",
" <td>5</td>\n",
" <td>956704746</td>\n",
" <td>6039</td>\n",
" <td>548</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000207</th>\n",
" <td>6040</td>\n",
" <td>1096</td>\n",
" <td>4</td>\n",
" <td>956715648</td>\n",
" <td>6039</td>\n",
" <td>1024</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1000208</th>\n",
" <td>6040</td>\n",
" <td>1097</td>\n",
" <td>4</td>\n",
" <td>956715569</td>\n",
" <td>6039</td>\n",
" <td>1025</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1000209 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" user item rating tstamp user_id item_id\n",
"0 1 1193 5 978300760 0 1104\n",
"1 1 661 3 978302109 0 639\n",
"2 1 914 3 978301968 0 853\n",
"3 1 3408 4 978300275 0 3177\n",
"4 1 2355 5 978824291 0 2162\n",
"... ... ... ... ... ... ...\n",
"1000204 6040 1091 1 956716541 6039 1019\n",
"1000205 6040 1094 5 956704887 6039 1022\n",
"1000206 6040 562 5 956704746 6039 548\n",
"1000207 6040 1096 4 956715648 6039 1024\n",
"1000208 6040 1097 4 956715569 6039 1025\n",
"\n",
"[1000209 rows x 6 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "markdown",
"id": "74f06c6c-535c-45bd-8fe0-d59f8d369656",
"metadata": {},
"source": [
"# Temporal User Split"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7b2e53b8-b1e8-4dc2-90c9-40b00d9f9826",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/td-user/rtrec/rtrec/experiments/split.py:104: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.\n",
" for user, user_df in df.groupby('user'):\n"
]
}
],
"source": [
"from rtrec.experiments.split import temporal_user_split\n",
"train_df, test_df = temporal_user_split(df)"
]
},
{
"cell_type": "markdown",
"id": "8213db54",
"metadata": {},
"source": [
"# Irspack evaluation"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2023a533",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>item</th>\n",
" <th>rating</th>\n",
" <th>tstamp</th>\n",
" <th>user_id</th>\n",
" <th>item_id</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6040</td>\n",
" <td>858</td>\n",
" <td>4</td>\n",
" <td>956703932</td>\n",
" <td>6039</td>\n",
" <td>802</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>6040</td>\n",
" <td>593</td>\n",
" <td>5</td>\n",
" <td>956703954</td>\n",
" <td>6039</td>\n",
" <td>579</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>6040</td>\n",
" <td>2384</td>\n",
" <td>4</td>\n",
" <td>956703954</td>\n",
" <td>6039</td>\n",
" <td>2191</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>6040</td>\n",
" <td>1961</td>\n",
" <td>4</td>\n",
" <td>956703977</td>\n",
" <td>6039</td>\n",
" <td>1781</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>6040</td>\n",
" <td>2019</td>\n",
" <td>5</td>\n",
" <td>956703977</td>\n",
" <td>6039</td>\n",
" <td>1839</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797753</th>\n",
" <td>736</td>\n",
" <td>1278</td>\n",
" <td>4</td>\n",
" <td>1045711206</td>\n",
" <td>735</td>\n",
" <td>1186</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797754</th>\n",
" <td>736</td>\n",
" <td>3671</td>\n",
" <td>4</td>\n",
" <td>1045711217</td>\n",
" <td>735</td>\n",
" <td>3429</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797755</th>\n",
" <td>3840</td>\n",
" <td>1196</td>\n",
" <td>3</td>\n",
" <td>1046106127</td>\n",
" <td>3839</td>\n",
" <td>1106</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797756</th>\n",
" <td>3840</td>\n",
" <td>1225</td>\n",
" <td>3</td>\n",
" <td>1046106162</td>\n",
" <td>3839</td>\n",
" <td>1135</td>\n",
" </tr>\n",
" <tr>\n",
" <th>797757</th>\n",
" <td>3840</td>\n",
" <td>2502</td>\n",
" <td>5</td>\n",
" <td>1046106198</td>\n",
" <td>3839</td>\n",
" <td>2308</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>797758 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" user item rating tstamp user_id item_id\n",
"0 6040 858 4 956703932 6039 802\n",
"1 6040 593 5 956703954 6039 579\n",
"2 6040 2384 4 956703954 6039 2191\n",
"3 6040 1961 4 956703977 6039 1781\n",
"4 6040 2019 5 956703977 6039 1839\n",
"... ... ... ... ... ... ...\n",
"797753 736 1278 4 1045711206 735 1186\n",
"797754 736 3671 4 1045711217 735 3429\n",
"797755 3840 1196 3 1046106127 3839 1106\n",
"797756 3840 1225 3 1046106162 3839 1135\n",
"797757 3840 2502 5 1046106198 3839 2308\n",
"\n",
"[797758 rows x 6 columns]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e6cb85a0",
"metadata": {},
"outputs": [],
"source": [
"from irspack import df_to_sparse\n",
"\n",
"X_train, _, _ = df_to_sparse(\n",
" train_df, 'user_id', 'item_id'\n",
")\n",
"X_test, _, _ = df_to_sparse(\n",
" test_df, 'user_id', 'item_id'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "62bcdbe4",
"metadata": {},
"outputs": [],
"source": [
"import scipy.sparse as sp\n",
"X_train = sp.csr_matrix((train_df['rating'].astype(float), (train_df['user_id'], train_df['item_id'])))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "8f5430f3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(6040, 3706)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "5f6d1527",
"metadata": {},
"outputs": [],
"source": [
"import scipy.sparse as sp\n",
"X_test = sp.csr_matrix((test_df['rating'].astype(float), (test_df['user_id'], test_df['item_id'])))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "d1a4ccc5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(6040, 3706)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_test.shape"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "ac225b2e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<irspack.recommenders.slim.SLIMRecommender at 0xffff4354d700>"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from irspack import SLIMRecommender \n",
"\n",
"#recommender = SLIMRecommender(X_train, alpha=0.01, l1_ratio=0.001)\n",
"recommender = SLIMRecommender(X_train)\n",
"recommender.learn()"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "d6229884",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([6038, 6033, 6034, ..., 5483, 735, 3839], dtype=int16)"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_users = test_df['user_id'].unique()\n",
"test_users"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "dc7a1e7f",
"metadata": {},
"outputs": [],
"source": [
"scores = recommender.get_score(test_users)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "8e61024e",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"top_k=10\n",
"\n",
"recos = []\n",
"ground_truths = []\n",
"for row in test_df.groupby('user_id')['item'].apply(list).reset_index(name='ground_truth').itertuples():\n",
" user_scores = scores[row.user_id,:]\n",
" seen_items = X_train[row.user_id].indices\n",
" user_scores[seen_items] = -np.inf\n",
" top_items = np.argsort(user_scores)[-top_k:][::-1]\n",
" recos.append([iid2id[iid] for iid in top_items])\n",
" ground_truths.append(row.ground_truth)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "a9b3623c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'precision': 0.06572847682119268,\n",
" 'recall': 0.024085925107561277,\n",
" 'f1': 0.02913174058081254,\n",
" 'ndcg': 0.07061680286431203,\n",
" 'hit_rate': 0.3423841059602649,\n",
" 'mrr': 0.1497353621360247,\n",
" 'map': 0.032553592676646245,\n",
" 'tp': 3970,\n",
" 'auc': 0.182195758435825}"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from rtrec.utils.metrics import compute_scores\n",
"\n",
"compute_scores(zip(recos, ground_truths), recommend_size=10)"
]
},
{
"cell_type": "markdown",
"id": "6581c1af",
"metadata": {},
"source": [
"NDCG is not good as SLIM elasticnet."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe4a6c6b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment