Last active
June 27, 2018 10:17
-
-
Save twolodzko/cd0ae4a06c23cb8cdc51ae46079bdb3f to your computer and use it in GitHub Desktop.
Matrix Factorization Using TensorFlow on The Modies Dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Matrix Factorization Using TensorFlow (ver 0.4)\n", | |
| "Tymoteusz Wolodzko\n", | |
| "\n", | |
| "This notebook shows results of my TensorFlow implementation of matrix factorization (it is equivalent to [Spark's ALS module](http://spark.apache.org/docs/2.2.0/api/python/pyspark.ml.html#module-pyspark.ml.recommendation) that I tested [in my previous notebook](http://nbviewer.jupyter.org/gist/twolodzko/7becd98ff256ef826b56945de297700d).) As in the previous notebook, I am using a subset of [The Movies Dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) hosted on Kaggle that contains metadata on over 45,000 movies and 26 million ratings from over 270,000 users." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pandas as pd\n", | |
| "import numpy as np\n", | |
| "import tensorflow as tf\n", | |
| "import random\n", | |
| "\n", | |
| "from tqdm import tqdm\n", | |
| "\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "%matplotlib inline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>id</th>\n", | |
| " <th>genres</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>862</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>8844</td>\n", | |
| " <td>[Adventure, Fantasy, Family]</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>15602</td>\n", | |
| " <td>[Romance, Comedy]</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>31357</td>\n", | |
| " <td>[Comedy, Drama, Romance]</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>11862</td>\n", | |
| " <td>[Comedy]</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " id genres\n", | |
| "0 862 [Animation, Comedy, Family]\n", | |
| "1 8844 [Adventure, Fantasy, Family]\n", | |
| "2 15602 [Romance, Comedy]\n", | |
| "3 31357 [Comedy, Drama, Romance]\n", | |
| "4 11862 [Comedy]" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "from ast import literal_eval\n", | |
| "\n", | |
| "movies = pd.read_csv('data/movies_metadata.csv', low_memory = False)\n", | |
| "movies = movies.loc[:, ['id', 'genres']]\n", | |
| "movies['genres'] = (\n", | |
| " movies['genres']\n", | |
| " .fillna('[]')\n", | |
| " .apply(literal_eval).apply(lambda x: [i['name'] for i in x] if isinstance(x, list) else [])\n", | |
| ")\n", | |
| "movies.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>id</th>\n", | |
| " <th>genres</th>\n", | |
| " <th>GoHands</th>\n", | |
| " <th>Sentai Filmworks</th>\n", | |
| " <th>BROSTA TV</th>\n", | |
| " <th>The Cartel</th>\n", | |
| " <th>Science Fiction</th>\n", | |
| " <th>Telescene Film Group Productions</th>\n", | |
| " <th>Drama</th>\n", | |
| " <th>Adventure</th>\n", | |
| " <th>...</th>\n", | |
| " <th>Western</th>\n", | |
| " <th>Rogue State</th>\n", | |
| " <th>Romance</th>\n", | |
| " <th>Action</th>\n", | |
| " <th>Thriller</th>\n", | |
| " <th>Fantasy</th>\n", | |
| " <th>Carousel Productions</th>\n", | |
| " <th>Pulser Productions</th>\n", | |
| " <th>War</th>\n", | |
| " <th>Music</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>862</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>8844</td>\n", | |
| " <td>[Adventure, Fantasy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>15602</td>\n", | |
| " <td>[Romance, Comedy]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>31357</td>\n", | |
| " <td>[Comedy, Drama, Romance]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>11862</td>\n", | |
| " <td>[Comedy]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>5 rows × 34 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " id genres GoHands Sentai Filmworks BROSTA TV \\\n", | |
| "0 862 [Animation, Comedy, Family] 0 0 0 \n", | |
| "1 8844 [Adventure, Fantasy, Family] 0 0 0 \n", | |
| "2 15602 [Romance, Comedy] 0 0 0 \n", | |
| "3 31357 [Comedy, Drama, Romance] 0 0 0 \n", | |
| "4 11862 [Comedy] 0 0 0 \n", | |
| "\n", | |
| " The Cartel Science Fiction Telescene Film Group Productions Drama \\\n", | |
| "0 0 0 0 0 \n", | |
| "1 0 0 0 0 \n", | |
| "2 0 0 0 0 \n", | |
| "3 0 0 0 1 \n", | |
| "4 0 0 0 0 \n", | |
| "\n", | |
| " Adventure ... Western Rogue State Romance Action Thriller Fantasy \\\n", | |
| "0 0 ... 0 0 0 0 0 0 \n", | |
| "1 1 ... 0 0 0 0 0 1 \n", | |
| "2 0 ... 0 0 1 0 0 0 \n", | |
| "3 0 ... 0 0 1 0 0 0 \n", | |
| "4 0 ... 0 0 0 0 0 0 \n", | |
| "\n", | |
| " Carousel Productions Pulser Productions War Music \n", | |
| "0 0 0 0 0 \n", | |
| "1 0 0 0 0 \n", | |
| "2 0 0 0 0 \n", | |
| "3 0 0 0 0 \n", | |
| "4 0 0 0 0 \n", | |
| "\n", | |
| "[5 rows x 34 columns]" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "genres = set([item for row in movies['genres'].values for item in row])\n", | |
| "\n", | |
| "for g in genres:\n", | |
| " movies[g] = 0\n", | |
| "\n", | |
| "for row in range(movies.shape[0]):\n", | |
| " for g in movies.loc[row, 'genres']:\n", | |
| " movies.at[row, g] = 1\n", | |
| " \n", | |
| "movies.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(359, 2)" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "movies['genres_bitmap'] = movies.iloc[:, 2:].apply(lambda row : ''.join([str(x) for x in row]), axis = 1)\n", | |
| "counts = movies['genres_bitmap'].value_counts()\n", | |
| "\n", | |
| "counts = pd.DataFrame({\n", | |
| " 'genres_bitmap' : counts.index,\n", | |
| " 'genres_counts' : counts\n", | |
| "})\n", | |
| "\n", | |
| "# leave only the popular combinations of genres\n", | |
| "counts = counts.loc[counts['genres_counts'] >= 10, :]\n", | |
| "counts.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>id</th>\n", | |
| " <th>genres</th>\n", | |
| " <th>GoHands</th>\n", | |
| " <th>Sentai Filmworks</th>\n", | |
| " <th>BROSTA TV</th>\n", | |
| " <th>The Cartel</th>\n", | |
| " <th>Science Fiction</th>\n", | |
| " <th>Telescene Film Group Productions</th>\n", | |
| " <th>Drama</th>\n", | |
| " <th>Adventure</th>\n", | |
| " <th>...</th>\n", | |
| " <th>Thriller</th>\n", | |
| " <th>Fantasy</th>\n", | |
| " <th>Carousel Productions</th>\n", | |
| " <th>Pulser Productions</th>\n", | |
| " <th>War</th>\n", | |
| " <th>Music</th>\n", | |
| " <th>genres_bitmap</th>\n", | |
| " <th>genres_counts</th>\n", | |
| " <th>genreId</th>\n", | |
| " <th>movieId</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>862</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>862</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>12233</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>12233</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>532</td>\n", | |
| " <td>[Family, Animation, Comedy]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>532</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>531</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>531</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>40688</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>40688</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>5 rows × 38 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " id genres GoHands Sentai Filmworks BROSTA TV \\\n", | |
| "0 862 [Animation, Comedy, Family] 0 0 0 \n", | |
| "1 12233 [Animation, Comedy, Family] 0 0 0 \n", | |
| "2 532 [Family, Animation, Comedy] 0 0 0 \n", | |
| "3 531 [Animation, Comedy, Family] 0 0 0 \n", | |
| "4 40688 [Animation, Comedy, Family] 0 0 0 \n", | |
| "\n", | |
| " The Cartel Science Fiction Telescene Film Group Productions Drama \\\n", | |
| "0 0 0 0 0 \n", | |
| "1 0 0 0 0 \n", | |
| "2 0 0 0 0 \n", | |
| "3 0 0 0 0 \n", | |
| "4 0 0 0 0 \n", | |
| "\n", | |
| " Adventure ... Thriller Fantasy Carousel Productions \\\n", | |
| "0 0 ... 0 0 0 \n", | |
| "1 0 ... 0 0 0 \n", | |
| "2 0 ... 0 0 0 \n", | |
| "3 0 ... 0 0 0 \n", | |
| "4 0 ... 0 0 0 \n", | |
| "\n", | |
| " Pulser Productions War Music genres_bitmap \\\n", | |
| "0 0 0 0 00000000011100000000000000000000 \n", | |
| "1 0 0 0 00000000011100000000000000000000 \n", | |
| "2 0 0 0 00000000011100000000000000000000 \n", | |
| "3 0 0 0 00000000011100000000000000000000 \n", | |
| "4 0 0 0 00000000011100000000000000000000 \n", | |
| "\n", | |
| " genres_counts genreId movieId \n", | |
| "0 112 122 862 \n", | |
| "1 112 122 12233 \n", | |
| "2 112 122 532 \n", | |
| "3 112 122 531 \n", | |
| "4 112 122 40688 \n", | |
| "\n", | |
| "[5 rows x 38 columns]" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "movies = movies.merge(counts, on = 'genres_bitmap', how = 'inner')\n", | |
| "movies.rename(columns = {'movie id' : 'movie_id'}, inplace = True)\n", | |
| "movies['genreId'] = movies['genres_bitmap'].astype(\"category\").cat.codes\n", | |
| "movies['movieId'] = movies['id'].astype('int64')\n", | |
| "\n", | |
| "movies.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(100004, 4)" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ratings = pd.read_csv('data/ratings_small.csv')\n", | |
| "\n", | |
| "ratings.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>userId</th>\n", | |
| " <th>movieId</th>\n", | |
| " <th>rating</th>\n", | |
| " <th>timestamp</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>31</td>\n", | |
| " <td>2.5</td>\n", | |
| " <td>1260759144</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1029</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1260759179</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1061</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1260759182</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1129</td>\n", | |
| " <td>2.0</td>\n", | |
| " <td>1260759185</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1172</td>\n", | |
| " <td>4.0</td>\n", | |
| " <td>1260759205</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " userId movieId rating timestamp\n", | |
| "0 1 31 2.5 1260759144\n", | |
| "1 1 1029 3.0 1260759179\n", | |
| "2 1 1061 3.0 1260759182\n", | |
| "3 1 1129 2.0 1260759185\n", | |
| "4 1 1172 4.0 1260759205" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ratings.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>userId</th>\n", | |
| " <th>movieId</th>\n", | |
| " <th>rating</th>\n", | |
| " <th>genreId</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>2.5</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>4</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>4.0</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>7</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>19</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>4.0</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>21</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " userId movieId rating genreId\n", | |
| "0 1 1371 2.5 173\n", | |
| "1 4 1371 4.0 173\n", | |
| "2 7 1371 3.0 173\n", | |
| "3 19 1371 4.0 173\n", | |
| "4 21 1371 3.0 173" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ratings_genres = (\n", | |
| " ratings.loc[:, ['userId', 'movieId', 'rating']]\n", | |
| " .merge(movies.loc[:, ['movieId', 'genreId']], on = 'movieId', how = 'inner')\n", | |
| ")\n", | |
| "\n", | |
| "ratings_genres.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>userId</th>\n", | |
| " <th>genreId</th>\n", | |
| " <th>counts</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>67</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>75</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>1</td>\n", | |
| " <td>173</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>1</td>\n", | |
| " <td>248</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>1</td>\n", | |
| " <td>257</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " userId genreId counts\n", | |
| "0 1 67 1\n", | |
| "1 1 75 1\n", | |
| "2 1 173 1\n", | |
| "3 1 248 1\n", | |
| "4 1 257 1" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ratings_counts = (\n", | |
| " ratings_genres.groupby(['userId', 'genreId'], as_index = False)['movieId'].count()\n", | |
| " .rename(columns = {'movieId' : 'counts'})\n", | |
| ")\n", | |
| "\n", | |
| "ratings_counts.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "((18424, 3), (4606, 3))" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "from sklearn.model_selection import train_test_split\n", | |
| "\n", | |
| "ratings_counts.dropna(axis = 0, inplace = True)\n", | |
| "ratings.dropna(axis = 0, inplace = True)\n", | |
| "\n", | |
| "ratings_train, ratings_test = train_test_split(ratings_counts, test_size=0.2, random_state=42)\n", | |
| "\n", | |
| "ratings_train.shape, ratings_test.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "import tensorflow as tf\n", | |
| "from sklearn.base import BaseEstimator, RegressorMixin\n", | |
| "from tqdm import tqdm, trange\n", | |
| "from scipy.sparse import coo_matrix, csr_matrix\n", | |
| "import json\n", | |
| "\n", | |
| "\n", | |
| "def sparse_matrix(rows, cols, values, shape=None, mode='dok'):\n", | |
| " if shape is None:\n", | |
| " n = np.max(rows) + 1\n", | |
| " k = np.max(cols) + 1\n", | |
| " if mode == 'dok':\n", | |
| " return coo_matrix((values, (rows, cols)), shape=(n, k)).todok()\n", | |
| " else:\n", | |
| " return csr_matrix((values, (rows, cols)), shape=(n, k))\n", | |
| "\n", | |
| "\n", | |
| "class MatrixFactorizer(BaseEstimator):\n", | |
| " \n", | |
| " \"\"\"Matrix Factorizer\n", | |
| " \n", | |
| " Factorize the matrix R (n, m) into P (n, n_components) and Q (n_components, m) weights\n", | |
| " matrices:\n", | |
| " \n", | |
| " R[i,j] = P[i,:] * Q[:,j]\n", | |
| " \n", | |
| " Additional intercepts mu, bi, bj can be included, leading to the following model:\n", | |
| " \n", | |
| " R[i,j] = mu + bi[i] + bj[j] + P[i,:] * Q[:,j]\n", | |
| " \n", | |
| " The model is commonly used for recommender systems, where the matrix R contains of\n", | |
| " ratings by n users of m products. When users rate products using some kind of rating\n", | |
| " system (e.g. \"likes\", 1 to 5 stars), we are talking about explicit ratings (Koren et al,\n", | |
| " 2009). When ratings are not available and instead we use indirect measures of preferences\n", | |
| " (e.g. clicks, purchases), we are talking about implicit ratings (Hu et al, 2008). For\n", | |
| " implicit ratings we use modified model, where we model the indicator variable:\n", | |
| " \n", | |
| " D[i,j] = 1 if R[i,j] > 0 else 0\n", | |
| " \n", | |
| " and define additional weights:\n", | |
| " \n", | |
| " C[i,j] = 1 + alpha * R[i, j]\n", | |
| " \n", | |
| " or log weights:\n", | |
| " \n", | |
| " C[i,j] = 1 + alpha * log(1 + R[i, j])\n", | |
| " \n", | |
| " The model is defined in terms of minimizing the loss function (squared, logistic) between\n", | |
| " D[i,j] indicators and the values predicted using matrix factorization, where the loss is weighted\n", | |
| " using the C[i,j] weights (see Hu et al, 2008 for details). When using logistic loss, the predictions\n", | |
| " are passed through the sigmoid function to squeze them into the (0, 1) range.\n", | |
| " \n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " \n", | |
| " n_components : int, default : 5\n", | |
| " Number of latent components to be estimated. The estimated latent matrices P and Q\n", | |
| " have (n, n_components) and (n_components, m) shapes subsequently.\n", | |
| " \n", | |
| " n_iter : int, default : 500\n", | |
| " Number of training epochs, the actual number of iterations is n_samples * n_epoch.\n", | |
| " \n", | |
| " batch_size : int, default : 500\n", | |
| " Size of the random batch to be used during training. The batch_size is the number of\n", | |
| " cells that are randomly sampled from the factorized matrix.\n", | |
| " \n", | |
| " learning_rate : float, default : 0.01\n", | |
| " Learning rate parameter.\n", | |
| " \n", | |
| " regularization_rate : float, default : 0.02\n", | |
| " Regularization parameter.\n", | |
| " \n", | |
| " alpha : float, default : 1.0\n", | |
| " Weighting parameter in matrix factorization with implicit ratings.\n", | |
| " \n", | |
| " implicit : bool, default : False\n", | |
| " Use matrix factorization with explicit (default) or implicit ratings. \n", | |
| " \n", | |
| " loss : 'squared', 'logistic', default: 'squared'\n", | |
| " Loss function to be used. For implicit=True 'logistic' loss may be preferable.\n", | |
| " \n", | |
| " log_weights : bool, default : False\n", | |
| " Only for implicit=True, use log weighting in the loss function instead of standard weights.\n", | |
| " \n", | |
| " fit_intercepts : bool, default : True\n", | |
| " When set to True, the mu, bi, bj intercepts are fitted, otherwise\n", | |
| " only the P and Q latent matrices are fitted.\n", | |
| " \n", | |
| " warm_start : bool, optional\n", | |
| " When set to True, reuse the solution of the previous call to fit as initialization,\n", | |
| " otherwise, just erase the previous solution.\n", | |
| " \n", | |
| " optimizer : 'Adam', 'Ftrl', default : 'Adam'\n", | |
| " Optimizer to be used, see TensorFlow documentation for more details.\n", | |
| " \n", | |
| " random_state : int, or None, default : None\n", | |
| " The seed of the pseudo random number generator to use when shuffling the data. If int,\n", | |
| " random_state is the seed used by the random number generator.\n", | |
| " \n", | |
| " show_progress : bool, default : False\n", | |
| " Show the progress bar.\n", | |
| " \n", | |
| " References\n", | |
| " ----------\n", | |
| " \n", | |
| " Koren, Y., Bell, R., & Volinsky, C. (2009).\n", | |
| " Matrix factorization techniques for recommender systems. Computer, 42(8).\n", | |
| " \n", | |
| " Yu, H. F., Hsieh, C. J., Si, S., & Dhillon, I. (2012, December).\n", | |
| " Scalable coordinate descent approaches to parallel matrix factorization for recommender systems.\n", | |
| " In Data Mining (ICDM), 2012 IEEE 12th International Conference on (pp. 765-774). IEEE.\n", | |
| " \n", | |
| " Hu, Y., Koren, Y., & Volinsky, C. (2008, December).\n", | |
| " Collaborative filtering for implicit feedback datasets.\n", | |
| " In Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on (pp. 263-272). IEEE.\n", | |
| " \n", | |
| " \"\"\" \n", | |
| " \n", | |
| " class TFModel(object):\n", | |
| " # Define and initialize the TensorFlow model, its weights, initialize session and saver\n", | |
| "\n", | |
| " def __init__(self, shape, learning_rate, alpha, regularization_rate,\n", | |
| " implicit, loss, log_weights, fit_intercepts, optimizer,\n", | |
| " random_state=None):\n", | |
| "\n", | |
| " self.shape = shape\n", | |
| " self.learning_rate = learning_rate\n", | |
| " self.implicit = implicit\n", | |
| " self.loss = loss\n", | |
| " self.log_weights = log_weights\n", | |
| " self.fit_intercepts = fit_intercepts\n", | |
| " self.optimizer = optimizer\n", | |
| " self.random_state = random_state\n", | |
| "\n", | |
| " n, k, d = self.shape\n", | |
| "\n", | |
| " self.graph = tf.Graph()\n", | |
| "\n", | |
| " with self.graph.as_default():\n", | |
| "\n", | |
| " tf.set_random_seed(self.random_state)\n", | |
| "\n", | |
| " with tf.name_scope('constants'):\n", | |
| " self.alpha = tf.constant(alpha, dtype=tf.float32)\n", | |
| " self.regularization_rate = tf.constant(regularization_rate, dtype=tf.float32,\n", | |
| " name='regularization_rate')\n", | |
| "\n", | |
| " with tf.name_scope('inputs'):\n", | |
| " self.row_ids = tf.placeholder(tf.int32, shape=[None], name='row_ids')\n", | |
| " self.col_ids = tf.placeholder(tf.int32, shape=[None], name='col_ids')\n", | |
| " self.values = tf.placeholder(tf.float32, shape=[None], name='values')\n", | |
| "\n", | |
| " if self.implicit:\n", | |
| " targets = tf.clip_by_value(self.values, 0, 1, name='targets')\n", | |
| " \n", | |
| " if self.log_weights:\n", | |
| " data_weights = tf.add(1.0, self.alpha * tf.log1p(self.values), name='data_weights')\n", | |
| " else:\n", | |
| " data_weights = tf.add(1.0, self.alpha * self.values, name='data_weights')\n", | |
| " else:\n", | |
| " targets = tf.identity(self.values, name='targets')\n", | |
| " data_weights = tf.constant(1.0, name='data_weights')\n", | |
| "\n", | |
| " with tf.name_scope('parameters'):\n", | |
| " \n", | |
| " if self.fit_intercepts:\n", | |
| " self.global_bias = tf.get_variable('global_bias', shape=[], dtype=tf.float32,\n", | |
| " initializer=tf.zeros_initializer())\n", | |
| " self.row_biases = tf.get_variable('row_biases', shape=[n], dtype=tf.float32,\n", | |
| " initializer=tf.zeros_initializer())\n", | |
| " self.col_biases = tf.get_variable('col_biases', shape=[k], dtype=tf.float32,\n", | |
| " initializer=tf.zeros_initializer())\n", | |
| "\n", | |
| " self.row_weights = tf.get_variable('row_weights', shape=[n, d], dtype=tf.float32,\n", | |
| " initializer = tf.random_normal_initializer(mean=0, stddev=0.01))\n", | |
| " self.col_weights = tf.get_variable('col_weights', shape=[k, d], dtype=tf.float32,\n", | |
| " initializer = tf.random_normal_initializer(mean=0, stddev=0.01))\n", | |
| "\n", | |
| " with tf.name_scope('prediction'):\n", | |
| " \n", | |
| " if self.fit_intercepts:\n", | |
| " batch_row_biases = tf.nn.embedding_lookup(self.row_biases, self.row_ids, name='row_bias')\n", | |
| " batch_col_biases = tf.nn.embedding_lookup(self.col_biases, self.col_ids, name='col_bias')\n", | |
| "\n", | |
| " batch_row_weights = tf.nn.embedding_lookup(self.row_weights, self.row_ids, name='row_weights')\n", | |
| " batch_col_weights = tf.nn.embedding_lookup(self.col_weights, self.col_ids, name='col_weights')\n", | |
| "\n", | |
| " weights = tf.reduce_sum(tf.multiply(batch_row_weights, batch_col_weights), axis=1, name='weights')\n", | |
| "\n", | |
| " if self.fit_intercepts:\n", | |
| " biases = tf.add(batch_row_biases, batch_col_biases)\n", | |
| " biases = tf.add(self.global_bias, biases, name='biases')\n", | |
| " linear_predictor = tf.add(biases, weights, name='linear_predictor')\n", | |
| " else:\n", | |
| " linear_predictor = tf.identity(weights, name='linear_predictor')\n", | |
| "\n", | |
| " if self.loss == 'logistic':\n", | |
| " self.pred = tf.sigmoid(linear_predictor, name='predictions')\n", | |
| " else:\n", | |
| " self.pred = tf.identity(linear_predictor, name='predictions')\n", | |
| "\n", | |
| " with tf.name_scope('loss'):\n", | |
| " \n", | |
| " l2_weights = tf.add(tf.nn.l2_loss(self.row_weights),\n", | |
| " tf.nn.l2_loss(self.col_weights), name='l2_weights')\n", | |
| " \n", | |
| " if self.fit_intercepts:\n", | |
| " l2_biases = tf.add(tf.nn.l2_loss(batch_row_biases),\n", | |
| " tf.nn.l2_loss(batch_col_biases), name='l2_biases')\n", | |
| " l2_term = tf.add(l2_weights, l2_biases)\n", | |
| " else:\n", | |
| " l2_term = l2_weights\n", | |
| " \n", | |
| " l2_term = tf.multiply(self.regularization_rate, l2_term, name='regularization')\n", | |
| "\n", | |
| " if self.loss == 'logistic':\n", | |
| " loss_raw = tf.losses.log_loss(predictions=self.pred, labels=targets,\n", | |
| " weights=data_weights)\n", | |
| " else:\n", | |
| " loss_raw = tf.losses.mean_squared_error(predictions=self.pred, labels=targets,\n", | |
| " weights=data_weights) \n", | |
| "\n", | |
| " self.cost = tf.add(loss_raw, l2_term, name='loss')\n", | |
| "\n", | |
| " if self.optimizer == 'Ftrl':\n", | |
| " self.train_step = tf.train.FtrlOptimizer(self.learning_rate).minimize(self.cost)\n", | |
| " else:\n", | |
| " self.train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(self.cost)\n", | |
| "\n", | |
| " self.saver = tf.train.Saver()\n", | |
| "\n", | |
| " init = tf.global_variables_initializer()\n", | |
| "\n", | |
| " # initialize TF session\n", | |
| " self.sess = tf.Session(graph=self.graph)\n", | |
| " self.sess.run(init)\n", | |
| " \n", | |
| " \n", | |
| " def train(self, rows, cols, values): \n", | |
| " batch = {\n", | |
| " self.row_ids : rows,\n", | |
| " self.col_ids : cols,\n", | |
| " self.values : values\n", | |
| " }\n", | |
| " _, loss_value = self.sess.run(fetches=[self.train_step, self.cost], feed_dict=batch)\n", | |
| " return loss_value\n", | |
| " \n", | |
| " \n", | |
| " def predict(self, rows, cols):\n", | |
| " batch = {\n", | |
| " self.row_ids : rows,\n", | |
| " self.col_ids : cols\n", | |
| " }\n", | |
| " return self.pred.eval(feed_dict=batch, session=self.sess)\n", | |
| " \n", | |
| " \n", | |
| " def coef(self):\n", | |
| " if self.fit_intercepts:\n", | |
| " return self.sess.run(fetches={\n", | |
| " 'global_bias' : self.global_bias,\n", | |
| " 'row_bias' : self.row_biases,\n", | |
| " 'col_bias' : self.col_biases,\n", | |
| " 'row_weights' : self.row_weights,\n", | |
| " 'col_weights' : self.col_weights\n", | |
| " })\n", | |
| " else:\n", | |
| " return self.sess.run(fetches={\n", | |
| " 'row_weights' : self.row_weights,\n", | |
| " 'col_weights' : self.col_weights\n", | |
| " })\n", | |
| " \n", | |
| " \n", | |
| " def save(self, path):\n", | |
| " self.saver.save(self.sess, path)\n", | |
| " \n", | |
| " \n", | |
| " def restore(self, path):\n", | |
| " self.saver.restore(self.sess, path)\n", | |
| " \n", | |
| " \n", | |
| " def __init__(self, n_components=5, n_iter=500, batch_size=500, learning_rate=0.01,\n", | |
| " regularization_rate=0.02, alpha=1.0, implicit=False, loss='squared',\n", | |
| " log_weights=False, fit_intercepts=True, warm_start=False, optimizer='Adam',\n", | |
| " random_state=None, show_progress=True):\n", | |
| " \n", | |
| " self.n_components = n_components\n", | |
| " self.shape = (None, None, self.n_components)\n", | |
| " self._data = None\n", | |
| " self.n_iter = n_iter\n", | |
| " self.batch_size = batch_size\n", | |
| " self.learning_rate = float(learning_rate)\n", | |
| " self.alpha = float(alpha)\n", | |
| " self.regularization_rate = float(regularization_rate)\n", | |
| " self.implicit = implicit\n", | |
| " self.loss = loss\n", | |
| " self.log_weights = log_weights\n", | |
| " self.fit_intercepts = fit_intercepts\n", | |
| " self.optimizer = optimizer\n", | |
| " self.random_state = random_state\n", | |
| " self.warm_start = warm_start\n", | |
| " self.show_progress = show_progress\n", | |
| " \n", | |
| " np.random.seed(self.random_state)\n", | |
| " self._fresh_session()\n", | |
| " \n", | |
| " \n", | |
| " def _fresh_session(self):\n", | |
| " # reset the session, to start from the scratch \n", | |
| " self._tf = None\n", | |
| " self.history = []\n", | |
| " \n", | |
| " \n", | |
| " def _tf_init(self, shape=None):\n", | |
| " # define the TensorFlow model and initialize variables, session, saver\n", | |
| " if shape is None:\n", | |
| " shape = self.shape\n", | |
| " self._tf = self.TFModel(shape=self.shape, learning_rate=self.learning_rate,\n", | |
| " alpha=self.alpha, regularization_rate=self.regularization_rate,\n", | |
| " implicit=self.implicit, loss=self.loss, log_weights=self.log_weights,\n", | |
| " fit_intercepts=self.fit_intercepts, optimizer=self.optimizer,\n", | |
| " random_state=self.random_state)\n", | |
| " \n", | |
| " \n", | |
| " def _get_batch(self, data, batch_size=1):\n", | |
| " # create single batch for training\n", | |
| " \n", | |
| " batch_rows = np.random.randint(self.shape[0], size=batch_size)\n", | |
| " batch_cols = np.random.randint(self.shape[1], size=batch_size)\n", | |
| " batch_vals = data[batch_rows, batch_cols].A.flatten()\n", | |
| " \n", | |
| " return batch_rows, batch_cols, batch_vals\n", | |
| " \n", | |
| " \n", | |
| " def set_shape(self, n, k):\n", | |
| " '''Manually set the shape parameters\n", | |
| " '''\n", | |
| " self.shape = (int(n), int(k), int(self.n_components))\n", | |
| " \n", | |
| " \n", | |
| " def fit(self, sparse_matrix):\n", | |
| " '''Fit the model\n", | |
| " \n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " \n", | |
| " sparse_matrix : sparse-matrix, shape (n_users, n_items)\n", | |
| " \n", | |
| " '''\n", | |
| " if not self.warm_start:\n", | |
| " self._fresh_session()\n", | |
| " return self.partial_fit(sparse_matrix)\n", | |
| " \n", | |
| " \n", | |
| " def partial_fit(self, sparse_matrix):\n", | |
| " '''Fit the model\n", | |
| " \n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " \n", | |
| " sparse_matrix : sparse-matrix, shape (n_users, n_items)\n", | |
| " \n", | |
| " '''\n", | |
| " \n", | |
| " if self._tf is None:\n", | |
| " self.set_shape(*sparse_matrix.shape)\n", | |
| " self._tf_init(self.shape)\n", | |
| " \n", | |
| " for _ in trange(self.n_iter, disable=not self.show_progress, desc='training'):\n", | |
| " # sample random batch, train, save batch loss\n", | |
| " batch_X0, batch_X1, batch_y = self._get_batch(sparse_matrix, self.batch_size)\n", | |
| " loss_value = self._tf.train(batch_X0, batch_X1, batch_y)\n", | |
| " self.history.append(loss_value)\n", | |
| " \n", | |
| " return self\n", | |
| " \n", | |
| " \n", | |
| " def predict(self, rows, cols):\n", | |
| " '''Predict using the model\n", | |
| " \n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " \n", | |
| " rows : array, shape (n_samples,)\n", | |
| " \n", | |
| " cols : array, shape (n_samples,)\n", | |
| " \n", | |
| " '''\n", | |
| " \n", | |
| " return self._tf.predict(rows, cols)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Preparing sparse matrix data representation." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "mtx = sparse_matrix(\n", | |
| " ratings_train.loc[:, 'userId'].values,\n", | |
| " ratings_train.loc[:, 'genreId'].values,\n", | |
| " ratings_train.loc[:, 'counts'].values\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We are using same settings as in the [Sparks ALS module notebook](http://nbviewer.jupyter.org/gist/twolodzko/7becd98ff256ef826b56945de297700d)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "training: 100%|██████████| 2500/2500 [00:07<00:00, 343.15it/s]\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[<matplotlib.lines.Line2D at 0x7ff9e9fcc160>]" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAAIABJREFUeJzt3Xl8VNX5+PHPk50lLCFhXxIgyC5LBFRAQAQEBK2ton4tfl3QX4tbWy24oKXVUluXb1utRWvr0orWFSuKoCBo2YJsArIjBGQx7FvW8/tj7iSTWe8kM5nJzPN+vfJi5sy5956bCc/cOfc554gxBqWUUvEhIdINUEopVXs06CulVBzRoK+UUnFEg75SSsURDfpKKRVHNOgrpVQc0aCvlFJxRIO+UkrFEQ36SikVR5Ii3QB3mZmZJjs7O9LNUEqpOmX16tXfG2OyAtWLuqCfnZ1Nfn5+pJuhlFJ1ioh8a6eedu8opVQc0aCvlFJxRIO+UkrFEQ36SikVRzToK6VUHNGgr5RScUSDvlJKxZGYCfqnikp5asFW1u49FummKKVU1IqZoF9SWs4fP93G2j1HI90UpZSKWjET9OulJAJwpqQswi1RSqnoFTNBPzXJcSpPfLyFY2eKI9wapZSKTjET9EWk4vFX2sWjlFJe2Qr6IjJGRLaIyHYRmebl9adFZK31s1VEjrm8Vuby2txQNl4ppVRwAs6yKSKJwLPAZUABsEpE5hpjNjnrGGPudal/J9DXZRdnjTF9QtdkpZRS1WXnSn8AsN0Ys9MYUwzMASb6qX8d8HooGlddggSupJRScchO0G8D7HV5XmCVeRCRDkAO8JlLcZqI5IvIchG5stotVUopVWOhXkRlEvCWMcY1b7KDMWafiHQEPhORDcaYHa4bicgUYApA+/bta9yIvUfP1HgfSikVi+xc6e8D2rk8b2uVeTMJt64dY8w+69+dwGKq9vc768w2xuQZY/KysgKu9hXQjPc31ngfSikVi+wE/VVArojkiEgKjsDukYUjIl2BpsAyl7KmIpJqPc4ELgY2uW8bKokJ2pevlFL+BAz6xphSYCowH9gMvGmM2SgiM0VkgkvVScAcY4xxKesG5IvIOmARMMs16yfU3rz9worHZ4pLw3UYpZSqs6RqjI68vLw8U92F0Y0x5EyfV/F896xxoWqWUkpFNRFZbYzJC1QvZkbkgmNUblZ6asXz0rLyCLZGKaWiT0wFfYBEl+kYVu4+UuW1/cfOMv5PS7n3jbW13SyllIoKMRf0U5MrT+n6F1Zw95w1GGN4b80+Lpr1GV/vO8G7a3wlHymlVGyLvaCfVPWU3l+7n/3Hz3lc9SulVDyKuaCfkuR5Ss8v3sHZYp1nXymlYi7oP3JFD4+yV5d/q106SilFDAb9C7IzIt0EpZSKWjEX9AE6ZjWIdBOUUioqxWTQ/+znwyLdBKWUikoxGfQBZozvHukmKKVU1InZoH95r5aRboJSSkWdmA36jdKS/b5+ukgnZFNKxZ+YDfoNUpPYPWscN12U7fX1I6eL+e74WYpLdX4epVT8CPXKWVHH22AtgCFPLALgyj6teWaSx7ouSikVk2L2St8pKcDCKh9vPEBZueHkuZJaapFSSkVOzAf95ET/p3iupJxH526k16OfcK5Ep2pQSsU2W0FfRMaIyBYR2S4i07y8/rSIrLV+torIMZfXJovINutncigbb0fbpvUC1nl1+bcAGvSVUjEv4MpZIpIIbAUuAwpwrJl7na9lD0XkTqCvMeZmEckA8oE8wACrgf7GmKO+jleTlbO8McbwyaaDdG7ekEuf/DxgfV1tSylVF4Vy5awBwHZjzE5jTDEwB5jop/51wOvW49HAAmPMESvQLwDG2DhmyIgIo3u0pFNWQ0Z0bR6w/vtrdWI2pVTsshP02wB7XZ4XWGUeRKQDkAN8Fuy2teGlmy4IWOfuOWu5Z84aTeVUSsWkUN/InQS8ZYwJqnNcRKaISL6I5B8+fDjETarqsat68vb/u8hvnffW7ufDDfvD2g6llIoEO0F/H9DO5Xlbq8ybSVR27dje1hgz2xiTZ4zJy8rKstGk6rthYAd6tmkUsN7rK/YGrKOUUnWNnaC/CsgVkRwRScER2Oe6VxKRrkBTYJlL8XxglIg0FZGmwCirLKJcF0/3RZdXVErFooBB3xhTCkzFEaw3A28aYzaKyEwRmeBSdRIwx7ikAxljjgC/xvHBsQqYaZVFVGKAAVtOh06co6hU0ziVUrEjYMpmbQt1yqYve4+c4bZX8vnmwEm/9UZ2a86LkwPfAFZKqUgKZcpmTGqXUZ/urQP37S/cfKgWWqOUUrUjboM+BJ5+2embAycoLStnxc7CirJ/fLmL387bHK6mKaVUWMT8LJv+lNvs2hrzzNKKx3+bnMel3Vrw6AeOAcnTx3YLS9uUUioc4vpKvzq3M255Ofz3G5RSKlziO+jjiPozJ/bgjks6Rbg1SikVfnEd9K/s45gRYnDnTK4f0D7CrVFKqfCL66Cfl53B7lnj6JjVkPbN6rNs+gjspPC/Zk3FrJRSdU1cB313rRrXo9xGP/9D731d8fjiWZ/xyrLdYWuTUkqFkgZ9N/+6bSC3Ds7hrzf2t1V/37GzzHh/Y5hbpZRSoaFB381FnTJ5aHx32zn8TjsOnwpTi5RSKnQ06PuQlGhvfh6nK//8ZZhaopRSoaNB34fm6alB1T9ZVBqmliilVOho0PehQ7MGNKkfXBePUkpFOw36fvRu2wSAP17XlztHdA5Y3xjDYx9u8lhn93RRKXNW7iHaZjRVSsWfuJ57J5A//LA3ryz7lvG9WvHKqaKA9d/M38sLS3cBMLJbCxqkOn69j87dyL9XF5Cd2YBBHZuFtc1KKeWPXun70bxRGr8YfR4JNhdd+eXbGyoel5YZth86ybmSMg6cOAdAkS62rpSKMFtBX0TGiMgWEdkuItN81LlGRDaJyEYR+ZdLeZmIrLV+PJZZrCvG9m5F+4z6tuv/4C9fMvKpJdw9Zw1l1oivJJsfHkopFS4Bg76IJALPApcD3YHrRKS7W51cYDpwsTGmB3CPy8tnjTF9rB/X5RXrlObpaSy5f7jt+jsOnwbgvzsKKbWCfoKNtXmVUiqc7FzpDwC2G2N2GmOKgTnARLc6twHPGmOOAhhjdLkpy8lzpazc5VgWONjcf6WUCjU7Qb8NsNfleYFV5qoL0EVEvhSR5SIyxuW1NBHJt8qvrGF7I27a5V0B6JTVIOhtz5WUcVrz+ZVSERSq7J0kIBcYBrQFlohIL2PMMaCDMWafiHQEPhORDcaYHa4bi8gUYApA+/bRPcXxHZd04o5LOvH1vuOM/9MXQW17499WkpQgbH98bJhap5RS/tm50t8HtHN53tYqc1UAzDXGlBhjdgFbcXwIYIzZZ/27E1gM9HU/gDFmtjEmzxiTl5WVFfRJRELrJvWqtV2pnWk8lVIqTOwE/VVArojkiEgKMAlwz8J5D8dVPiKSiaO7Z6eINBWRVJfyi4FNIWp7RGU0SGHZ9BG0bpwW6aYopZRtAYO+MaYUmArMBzYDbxpjNorITBFxZuPMBwpFZBOwCLjPGFMIdAPyRWSdVT7LGBMTQR8c8+/fdHF2tbc/W1xGSZnm7iulao9E29QAeXl5Jj+/7iw+vvXgSUY9vSSobXbPGgdA9rQPOb9tY96fOjgcTVNKxRERWW2MyQtUT6dhqKEuLdKD3ub2V/NZuNmR1bqu4HhFeXFpOSlJOkhaKRU+GmFCYPn0S5l/z1B2PD6WEV2bB6w/f+PBilG6Tu+uKaDLQx+x6/vT4WqmUkrplX4otGycRkvrhu6FHZvx2TfBjU276rnKBVi2HjxJTmbwYwCUUsoOvdIPsZHdWwS9zZo9x1iz5xgAOmZXKRVOGvRDLCezATddlF3t7cWan6dc8/mVUmGg3Tth4LriVoOURE4Xl9ne9rNvDlJuDLe/uponf3Q+i7YcomebxtxxSadwNFUpFWc06IeBMwv2rktzaVo/mV99YH9owusr93L4pGPBlp//ex0A/1n/nQZ9pVRIaPdOmGU0SAl6G9EpmJVSYaJBP8yu6N066G2W7SgMQ0uUUkqDflikJjt+ralJCSQkCHOmDApq+1M6/bJSKky0Tz8Mbr44hzNFZdwyOAeA8hBNdbFoyyFOnC3h2JkS+rRrQrdWjXQEr1IqKBr0wyAtOZFfjD6v4nlJWc2D/t+/3OVxQ7hf+ya885OLa7xvpVT80MvEWjAwJ4N7R3apeL7oF8OC3oe3DKCvrAFdSilllwb9WpCWnMidIzqTmCDMnNiDnMwG7J41jmev7xfppiml4owG/VqSkCDseHwsP74wu6JsXO9WkWuQUiouadCPsB/1bxvS/RWVlhFtayQopaKHraAvImNEZIuIbBeRaT7qXCMim0Rko4j8y6V8sohss34mh6rhseKxq3rVaPvsaR/yl8WOdebPFpdx3kMf8+QnW0PRNKVUDAoY9EUkEXgWuBzoDlwnIt3d6uQC04GLjTE9gHus8gzgEWAgMAB4RESahvQM6rhQpFz+7uNvADh+tgSAN/P31nifSqnYZCfiDAC2G2N2GmOKgTnARLc6twHPGmOOAhhjnBPKjwYWGGOOWK8tAMaEpumx4/P7hlWZpK26Sssd6+0mJeg0Dkop7+wE/TaA66VjgVXmqgvQRUS+FJHlIjImiG3jXodmDVj5wMga78e5GldioiPov/NVASfPldR4v0qp2BGqG7lJQC4wDLgOeEFEmtjdWESmiEi+iOQfPnw4RE2qW1KSElj54KUVzx8e391PbU/Z0z7kuUWOvv2khATWFxzjZ2+u44F3v/aoO+IPi7n2r8tq1mClVJ1kJ+jvA9q5PG9rlbkqAOYaY0qMMbuArTg+BOxsizFmtjEmzxiTl5WVFUz7Y0rz9DT+ddtALumSxY2DOgS9/RtWX35ignC6yDGH/8ET5zzq7fz+NCt2HalZY5VSdZKdoL8KyBWRHBFJASYBc93qvIfjKh8RycTR3bMTmA+MEpGm1g3cUVaZ8uGiTpm8fPMAUpISWPXgSK7JCz6lMylBMDi6erR7XynlKmDQN8aUAlNxBOvNwJvGmI0iMlNEJljV5gOFIrIJWATcZ4wpNMYcAX6N44NjFTDTKlM2ZKWn0rJxvaC3++bASa5/YQUAB08U8dzi7UHl7q/cdYRd358O+rhKqehna8I1Y8w8YJ5b2QyXxwb4mfXjvu1LwEs1a6aqrl3fn+aJj7cwqnsLOjdPt7XNNVZ//+5Z48LZNKVUBOiI3DhRVFrOQ+9tYLdewSsV13Rq5Tixcf8JXlu+h7V7K2fm3HfsLAeOn6V/hwyMMUz++yomnB/8Sl9KqbpDg36cuP+t9QDsP1aZzTP0iUWUlRt2zxrHhn3HWbL1MEu2xmfKrFLxQrt34syR08UVj52DuUCXaFQqXmjQj3LJmnOplAohDfpR7pYhOdx0UbZH+c8v6+JZOcTOFpfZrnuuxH5dpVTkaNCPcvVTknh0Qg/+37BOJCcKD4/vTv2URJo1TA37sS94bCGTZi+jpKzcb7331+6j68Mfs+3gybC3SSlVMxr064hfjunKtsfGcsvgHDbNHENacvjfulNFpSzfeYQDxz2ncnC1YNNBADZ9dyLsbVJK1YwG/ToqMdR9/TVYbMu5qYjef1Aq2mnQr6M6ZTUM2b6OumT0VIsV9TXkKxX9NOjXUT3bNObz+4aRGYK+/b6/XsD1L66o9vbOyd30Ql+p6KdBvw7r0KxBxePbL+kYsXaYiit931H/+NkSth86VUstUkr5okG/jvvNlT1pl1GPCzpkhO0YQ55YxNV/+S8DH1/IIWt+/uLScn72xlq+Laycy0fEkbrpLdvnque+ZORTn4etjUopezTo13FjerZk6f0jKpZIDJfV3x7l4Iki/rujkDPFpazcdYR31uzjgXc34Dprc9eHP+aGFzy7inYe1onelIoGGvRjzNAuWcywllq8fmB73v3JRSHd/8dfH6D7jPlsttIzE6RywRbnx87K3bpkglLRSidci2FpSYn0bd80pPtcvPUQAOsKHLN1Lt32PU3qJ1tlxyvqnThXQqO05JAeWylVc7aCvoiMAf4PSAReNMbMcnv9JuD3VK5/+2djzIvWa2XABqt8jzFmAiqsOjV3pHN2a2Vv0ZRgJCcmcK6kvEq//bEzJQA8//mOirLej37Cg2O78f3popC3QSlVfQGDvogkAs8Cl+FYAH2ViMw1xmxyq/qGMWaql12cNcb0qXlTlT89WjcC4MeDOnBJlyw+uXcouc1Dl8vvdPKcYzbObTYycR6btzlExyzhmwMnuSA7fDerlYoXdvr0BwDbjTE7jTHFwBxgYnibpYLVPD2N3bPGMbJ7CwC6tEivGCHr/EAIpdq8MXvRbz/jR88v47j1jcLdMwu36rw/StlkJ+i3Afa6PC+wytxdLSLrReQtEWnnUp4mIvkislxErqxJY1X1/PPWgZFuglfl5YYH393A9kO+A/b3p4o4ac31X1TmOZPnqaJSnlm4rWJdX6WUf6HK3vkAyDbG9AYWAC+7vNbBGJMHXA88IyKd3DcWkSnWB0P+4cO6clOoNamfwshuLSLdDAD+9Om2isfbDp3inyv2cO1fl9Nx+od8se17j/qBpmx2LgRTWlaDyYOUiiN2gv4+wPXKvS2VN2wBMMYUGmOcd+xeBPq7vLbP+ncnsBjo634AY8xsY0yeMSYvKysrqBNQdcuTC7ayy1qcffQzSwAoPF1MuYHnFm/3qJ8QaG4H52hgnQJCKVvsBP1VQK6I5IhICjAJmOtaQURauTydAGy2ypuKSKr1OBO4GHC/AaxqwT0jc2nbtF6kmwHA8D8s9tql4y1wVwn6xjFI7AfPfUlRqeMbQLlxzvsTXNTfcuAki7ccCmobpWJBwKBvjCkFpgLzcQTzN40xG0Vkpog40y/vEpGNIrIOuAu4ySrvBuRb5YuAWV6yflQt6NmmMV/8coRH+YNju0Wk6+fJT7Z6lH29z3M+ftcZpMuMYfo76/lqz7GKbwuV0zoHd/zRzyzhpr+vCm4jpWKArTx9Y8w8YJ5b2QyXx9OB6V62+y/Qq4ZtVGHQslEaB06c47ahHbltaEeyp31Yq8f/6OsDHmXHz5awcNNBcls0rDKZnNO+o2c5cdZxU9c5uZuxrvQDdgMppQAdkRu35t09hIMnKlfE6taqUcXUCpF06yv5JAjs/O04wDGy1+mHz1dm6DhjfLnO5a9UUHTunTiV0SCFbq0q8/ffuH1QBFtTVblLIs6dr6/1WzfQXP57j5yh8JSOClbKSYO+Aoi6eXJ6PjKfxVsOsfOw95G/f/18JwDl1mwQvm7kDnliEXmPLQxLG5WqizToKw9/uq4vV/ZpHdE2nCoq9Xuj9e2vCoDK7J3DJ4vInvah1w8Joyn8SlXQoK88XHF+axrVi44r/6JSzwVZXJWVV43oztk/lVLeadBXFVo1TuO+0ecBcO/ILh6vL7lvOF89fFltN8sv96v4M8VlLPpG8++V8kWzd+LMg2O70aWl9ymXl02/tOJx0wYpFY/fmDKIsnJD+2b1KS+Prr6Screo/+C7XwPwys0DGNolNKO7X1y6k8u6t/CaRupu2Y5CcjIb0LJxWkiOrVSoadCPM7cNDX4B9YEdm1U8jrZ0+DIfHfYHT5xj/7GzNd7/8bMl/ObDzfz9y918Oc1zcJu7615YTtP6yayZMarGx1YqHDToK58+vGtwxfz5Tr6yZBLEMTq2tm+aGj8HdG97TfZ/qsj+vo76mAJaqWigffrKpx6tGzPI5SrfnxaN0lj082HhbZAXh08Wey0v8TPrpr8PCrsWbTnE0CcWUXiqKOBMoEpFEw36KmSyMwP3eYfavW94H7z1wLsbKgZuufvnij0B91tWbvx+ODzy/kb2HDlD/98s5LKnP/d4PX/3EZZsPUz2tA8jNtL5+JkSXlv+bUg+5FTs0KCvqu3Tn19S8TgSXf3PLtpOoZ81eIt9pHt+sulgxePth05x1+tryJ72IcfOVH5r6PTAPHKmz/O2OUCVD5S9RzzvHXx/qoj5Gx3zC+XvPuL7JMLovrfW8dB7X7Nh3/HAlVXc0KCvqq1TVkOPm5sL7h1aa8f//fwtfrtxJvz5S6/lrh9QI5/6nLnr9gMw6LefetR1v0jeUHCcuev2V4wE9ifSN70LTzs+xAKNdVCRtXxnIaVltfce6Y1cVS03X5wDePaP57ZIJys9lcMno3e+G1/B+FxJOT/552oGd/ZM9RSBmR9s4qUvdwV9vEh1rkRZopXy4qs9R5k0ezl3XNKJaZd3rZVj6pW+CtruWeOYcUV3AJrWd+Tz33hhdsXrA7IzAFj5wKU8NK5brbcvEH/BcN6GAzzw7gavr/kL+MfPlvDI+1+7Hcc5/XPQTYxJxhjeXVNASS1e1Ua7762LI3/rRIeaXumrGmmQmsTuWeOqlP3+R725eXA2zRulRVXAe3dNAfe+sS6obb47fi5wJeDpBVt5edm3Qbcpe9qHjOzWnBcnXxCw7usr95CYIAzunEnrJvZXQYuW92Duuv3c+8Y69h09y9QRuZFuTq35+Ovv6NqykddEB2cKdG2+R7au9EVkjIhsEZHtIjLNy+s3ichhEVlr/dzq8tpkEdlm/UwOZeNVdKqfkkT/Do6r/dIoGsEbbMAHGPvHpQAcC5B7/4//7q7yPJj/xAs3B5424vDJIqa/s4H731rP5f+31NZ+nd1Y4czeufFvK5j6r69s1T1i3WOI5q6/cLjjta+8ZnhBZLrgAgZ9EUkEngUuB7oD14lIdy9V3zDG9LF+XrS2zQAeAQYCA4BHRKRpyFqvol6ZnTuetSASaYvOoPvkJ1tqvC/XieWOn42ewV9Lt33Pf9Z/F9Q2Ly/7lt99/E2YWhSd/CUcQO3e97FzpT8A2G6M2WmMKQbmABNt7n80sMAYc8QYcxRYAIypXlNVXeTsvm0d4blo/rUycG5+uJywRgaXlJUzZ+WeivmLvg4ildLXmAN/JIpv5f5l8Q6v5Qs2HWTFzsJabk1wHp27MWTLi0Yiw8tO0G8D7HV5XmCVubtaRNaLyFsi0i7IbVWMcl7pJyZGNgBtO+h9MZZwOXyqyGNQ1uwlO5n2zgbestYCmPnBpmrv/+jpYno9Op/V3x71eG3FzkKe/9x7UHVljPGYmtqugyfOeXzj+Hf+XgY9/mmNvlXd9ko+185eXu3ta4N7V14o1OY30VBl73wAZBtjeuO4mn85mI1FZIqI5ItI/uHDh0PUJBUNnPPyZ9RPCVAzvMLxH9WfGe9vZNXuyoBsjKHgqGMQ1+GTRby1uoCVNRi0tWr3EU6eK/V6xXzt7OXM+ihw98mUV1fT6QHfA9Bcvb5yD9nTPuTQSceN7YGPf8rg331Wpc4v317PgRPnCMVtnHBn+Bw7U8w/vtxV42D7/tp9LNtRyOpvj7Jub/BrOUTiSt9O9s4+oJ3L87ZWWQVjjOv3sReBJ1y2Hea27WL3AxhjZgOzAfLy8qLnzp+qsZsuyqZBquPPbF2B91RIgOf/pz93vLa6tppV615bsYfXrS6m388Pvo+/JrHJ16YLXEYmB/LvfMcX9j2FZ2ie7uiq8zWhnSOQekazYM7h9ldX89JNgTOaquu+t9azYNNBjp0t4ejpYn41sWe19nP3nKrTgKybMYrG9YNfgCja+vRXAbkikiMiKcAkYK5rBRFp5fJ0ArDZejwfGCUiTa0buKOsMhUnkhITuG5AexIDXNKM6dmy1ganRMJ/rFG//hw6cY4tB06ycX+Ipk2oyN4Jze7smvaO7w93uz4L80I4zik3nlm4rVqptgC7vz/tUbb36Jmgvj1E4r5LwKBvjCkFpuII1puBN40xG0VkpohMsKrdJSIbRWQdcBdwk7XtEeDXOD44VgEzrTIVZ9LTqn6pvCC7Kf+8dWCVsjsu6cRvrqzeFVe0CxQG3lpdwIDHP2X0M0sY98cvQnLMmoaTkrLyipvOzvb7+uzefuhURc75W6sLanjk4O09ElywrY4tB06yoaDyA3nYHxZ71Bn/py+44s/Bv3+1+cFsa3CWMWYeMM+tbIbL4+nAdB/bvgS8VIM2qhgwpmdLurZM55sDJzm/bWOeuqYP7TLqR7pZUeMX/646hiB72odkNEjxuTzllFcru8IWbjrIudIyxve2t5i9McZjkrjH521m9pKdVQba5T74EQBbfjOGNXuc/dXeo/7yCGXclJcbNn13gvF/+sLv7ysURj+zxFa9r/cFMauq89tYNdpTXToiV9UKEeGju4dQcPRstYP95/cN45LfLw5tw2rJyl3Bf8F1DmYK5NZX8gF8Bn33dM+3Vhdw31vrq5TNXrITgD99uo3TxWVVutpmf74zYBsSxLOjYu66/ZwrLuOaC9p53SYUOj4wj8yGjiSBI6eLMcZwqqiU9LSq/epHThdzuqi04m8vWtJZo3JwllKhIiIBA763K55P7h3K0vuH0ygt+Btk8ayiK8ZQZW3jtX6yTJ5csNUj3XPptu8DHivRSyS56/U13P925YeL+3v75fbA+/XH2Z3z/anKD8eX/7ubXo9+wt4jZ6rUvWjWpwx5YhHg6LaqSeZUONTFlE2lwqZLi3TaZdQnISE6rs5q0+wljgDsKyQs3FyZgbPr+9NeUzWvf3EFHa3UzKOni4PK2gFsXY4ePVMS9JQbN7y4gonV6P928hYn5290nNset6B/rqQyBfT1CA7Uc+dr+dFw0qCvIur9n17MG1MG2aqbGIdB//F5jiB+1EZXz68+2BhwUFbfXy/gUJBz37h2TfmKUXbGBXizruA498xZQ8fpwY9wLfdzdezvL8XX4jr+hOtKvOLLWLRNuKZUuJzfrgkDXdbh7dm6kc+6vv4j5zZvGOJWRZei0jLG/ynwFXHD1Kq36PYUnvFRMzB/q32dPBfc3D/HzhSz/5jn6mJO763dT7kJPhh7+2Lhev/CGMNfP9/Bd8erHntdQfApsbu8pGdWV2lZOc8u2s6Z4tKonYZBqVrTt73v+ficV/rNGlQd3RvOQTzR4LyHPrZVzz3o77c5LbQ3ldk6VQlw5bPeVyTz5aJZn3HRrM8C1nvxC+83jJduO0zBUc8PMH9SYtVGAAAWh0lEQVRX+gjsLjzDbz/6hlv+kV9RvP/YWT6wMWbiwPFzVaa4qO50Fd68t3Y/v5+/hacXbK0oq87cStWl2TuqzkhLTuS1WwbSqF5SlaUQNfXTIZT9w4/N2+zztR2Hg7vqPVNcBgTuIlm8xfsULDf+bSWpSQls+c3lVcq97c61zBmoN7nMgXTX62t8Hv+ZhVt5ZuE2ds8ax/A/LOZsiaPdH909JKRdi2eKHSOZz5aUsb4a3zpqSoO+ilr92jepmMLBaXBuJt8WVgadp645v7abFbU2fec/P/xxH4HcfWI4f8J549FfWqu3dX79XR37SsnM9zJBndMzC7c59mtMRcAHeG35t/xzRXA3fwtPFdGsYarX15wfRokiFVNyRN3gLKUi4Z2fXOy13PkfpEOz+vygX9tabFF0CzThlzMX393l/7eUpASxlX3jr2/e1a0v5weuVEPe+/Sr55xLkF+8teo3Djspq+6W7zzCuN6tvH67cQZ912y0/+6ovcFt2qev6px2GfX5Qd82PHdDP4/X7ht9ntdtfjq8U7ibVac1TLN3/ee3H92FayppdY1+egm//ch3N5O3tpwuKvX5mj+vLNtd8fh//76qyms1+XLjrwvKfT6quev2e53PJ9Q06Ks6JzFBeOraPvRo3djjtZ8O7+x1m9zm6X73+YcfxXc3UaDlIJ2SatC3HWwg3nLwJH/1MxrYeEn22bjf0VX1RJAzmfrLHKpJh5brGTvHXJRZv4cjZ6qm4d71+hqv8/mEmgZ9FXWG5GbSs43v1E1v/nJDP48J3ILxw/7aTWRP9UPgEx/XbNnIZxZurfLc34dIdea29yVU9zGcYy6c3TvvfLXPX/Ww0T59FXVevSX44H15r1aBK6kac00zDFawI3bdPbNwG/eM7FLxPNhvDtVVnZDvvMnsrU+/PITpn9WhV/oq5gw/L8ujzE4e9BNX9+a2ITnhaFLM2HLwZESP/1+X+XoCxc6g5rUPU1aSt/mGIhzzNeir2PP8jf1Z+cClVcqc///bZdSrKBvXu+q3g2suaMcF2Rlhb5+qvjmr9lJUai/v/7Kn7U2FHFA1Pg+cKaPuTbzhxRUVffqRokFfxZzUpESaN0qrUuYcwOUa1Eec19xj2+p2QTSymf2iambuuv3c+4ZjicJIXzFXV53o3hGRMSKyRUS2i8g0P/WuFhEjInnW82wROSsia62f50PVcKWC0b99Uz6/bxj3Wn3CmQ1TvY6y7NAs8OjeCed7zlt/wsd6sSr05m04wLEzxRw7a2+9gZraGeQIZHDp0/fSrRjpK/2Alycikgg8C1wGFACrRGSuMWaTW7104G5ghdsudhhj+oSovUoF7Y/X9SUhQejQrAEA/77jQnq1acwnXqYY9pYG6m1/c23M36LCp8/MBZFugl87Dp3mlWW7uSbPcwGZ2roB7Yud76QDgO3GmJ0AIjIHmAhscqv3a+B3wH0hbaFSIebs4unbromt+oM7Z7Jx/3GO2sxlV3VPKCdUA3jaSi99xcui63Whe6cNsNfleYFVVkFE+gHtjDHeJsXOEZE1IvK5iAypflOVCk6nLMeVff8O3mfudPbzX+vlaqymxvfWFNK65KkapKL6s/3QKY+yF5buCsux7Krx3ScRSQCeAm7y8vJ3QHtjTKGI9AfeE5EexpgTbvuYAkwBaN++fU2bpBQAn/58WMA6rguBO82c2IMZ72+seO4v3TM1KcHrZGCRWBFJKTvsXOnvA1wvhdpaZU7pQE9gsYjsBgYBc0UkzxhTZIwpBDDGrAZ2AF1wY4yZbYzJM8bkZWV55lgrVZsudFnUxckZxH9zZU8A8h8ayRe/HO5zXhYN+Spa2Qn6q4BcEckRkRRgEjDX+aIx5rgxJtMYk22MyQaWAxOMMfkikmXdCEZEOgK5gO/JNJSKcmN6tgQc2T9tm9ZnbE9HN457IpBe6KtoFTDoG2NKganAfGAz8KYxZqOIzBSRCQE2HwqsF5G1wFvAHcaY6FqGXikbfMXwx67qRcfMBvzrtkEe9edMGcTtQzvy4o/z+ON1fblhoKPrckhuJo3rJYe3wUr5YKtP3xgzD5jnVjbDR91hLo/fBt6uQfuUqnX1UhKrPHfNsHPPtquXkshnvxgGwF0jOvPHz7YDkCDCoI7NGOTSVZSWlMA/V+whNSmRdY+MInta8IuBK1VTOiJXKTdtm1Zv+cUWjStHAXdp6W8qZ8cnxwdTB/PcDf3I9jMgrG97z7TS5uneV2RSyg4N+kp54VxkvE2Teky7vKutPnrXJfomXeCZBuqe0dOrbWPG9mrlczqBcb1a8ebtF9Ixs0GV8paN07xvoJQNGvSV8uOje4bQu629QVyuMd3bGq2+Jgi71u0DIjUpgQX3DuXJa84nOTHBo7uptMx3CumvJvQgS78JKD806CvlxVV9HeMPU5Oq/hfxvxi3XVVr/mRY1aUc504dTG6LdNKSqwZ7J3+jR89z61YKdjEaFfs06CvlxaMTerD+0VGkJjkDb+CQXuUKO4iUTfduH9fpnx2vO/79kbW6V/1Uzw+Dri3TuTavHf07NK1ys9l9HVZ/huRm2q6r6i4N+kp5kZggNErzklbpZ9qUEV0rp2r2Fmuzrb75Czt5Dv5y5d41lGDtzHno7q0a8cKP86rU6dehKb/7YW+SExOqNlKElQ9WXVvAF2dKqYptGvSVsqFbK0e3SUqS7/8yIkJ6qu8s6C4t0lk2fQQ3X5zt8drS+4e77Mdtv251DXBZ9xYVz89v14T7Rp3n9ZgJAs3To+fG7y9GddEupwjToK+UDc/d0I/XbxtEk/optur7mj23VeN6XuflaZdR3+P+gVPn5o4PnHS3hVq6tGgIwKwf9KJpg8p2PTqhR8Xji/x8q3Bf+KU2Zvxt1bhe4EoqrHS5H6VsSE9LDtgtA4Rl0p3HrurJVX3bsLvQsZiHe3B2/wwZ37s143u35tvC037HHLgvIlMbE/6KVHZXqcjQK32lwiGEETQtOZHBuZmM7dWKri3TuX1oR8chrGN4Sw8F6NCsgdfVwZw8gn6Qbe5jcz0CVyI6GV2kadBXKoRqEtBusvr6k3wE6owGKXx8z9CKG8IVx7Rx0Bnju3uUtXH7FuAvHdWbW4fkMCAnuIXkfX1AqdqjQV+pMAg2gAJMG9OVnY+PJSnR3n/LYI7gXrdH60b8bXIe/89ljIC39P/7x5xHswYpDMjJ4KaLsqu8lpKYwO+u7l3x/O5LcwO2Q0TXGog0DfpKhVBNApqIkOCnO8adc4SvnS3cbwJPOL+1Y3F4l/Z6GzGcIMLqhy/jzdsvJDmxsu7NF+cwomtzcjIb8PWvRrP0/uE0cBs/MPnCDh77ExGddjrCNOgrVUdV9OnbiKJX92vrtdx1U299+q57vsvlSn7GFd0rvpE0TE2iXUZ9Stymh5hxRQ/c+ZqKwmnVgyM9yh4a183vNio4GvSVCoPaSH+c0Kc1AJkNA6eR+ruh61TupdGuHwrp3garuXBPOfV2xP+s/47z/cxl5G3eoFuHdPQos3POyjsN+kqFUG5zR+58YmL4+zDuGpHLxl+Ntj12wJ/G9ZK99ukHc+P1xxdmc/+YykFi3r6AFJWW88BY/1fuGS5jDv7wo/NtH1/ZYyvoi8gYEdkiIttFZJqfeleLiBGRPJey6dZ2W0RkdCgarVS0enFyHq/cPMD7FA4hlpAgNPAzAtgf58Rszrh888U5Aa/0A0lJSuAnwzq7bOu5cXm58TuqGeAf/3sBAL3aNOaH/b13S1V3zYPaFK0jjwMGfWuN22eBy4HuwHUi4pH/JSLpwN3ACpey7jjW1O0BjAGec66Zq1QsalI/haFdsiLdDL8eHt+dYec193zBJeY3axCe7hN/M4RWNMNG19jfJucFrLPm4cvsNKnC4M6hm3DuivNb06NV45DtL5TsXOkPALYbY3YaY4qBOcBEL/V+DfwOOOdSNhGYY4wpMsbsArZb+1NKRUiHDO9Xyc400z7tmnDF+Y77BaFOryyzEdG7tWrEiK7Nq6SDumvWMDXgt5CmQXxwje3VktduHcj/DKrepHO/vrKnR1l1fnWje7QIXKmG7AT9NsBel+cFVlkFEekHtDPGuC/6GXBbpVTt8haMDKaiT79ry3T+Z1AHmtRPZmyvliE9disbq36lJCXw0k0X0L219+6Rlo0c+/A1ncOVfVrzyBWeg9H8efJHfQC4rLu9833kiu782CUl1T1YV/ejsjamqKjx3DsikgA8BdxUg31MAaYAtG+v07sqVWtcgoyzT19E6Ny8IWtnjArZYTo0q8+3hWeYOrxzlfJ6yYmcLSmzvZ+l9w+nUT3H/ZIEAW9bDjuvOVf2De7a0n11MnAslfnBnYP5wXNfsrvwDACrHxqJATIbplJebnhl2bde9+cYhOZZ/p87B7N8ZyG/+XAzAL8c05XxvVsx5IlFQbW3Juxc6e8DXNdza2uVOaUDPYHFIrIbGATMtW7mBtoWAGPMbGNMnjEmLysruvtDlarrfF1M9m7jSKW8xM89iSX3Deeju4dU+9juo40X/WJYUHn47TLq09gK+j1aV/aZ33FJJ5+zlFZH43rJvD/1YjIapFRpc7OGqWQ2dKSV+rso99WL1bNN4yopqJ2bN6SdS3dbbaT62vktrQJyRSRHRFJw3Jid63zRGHPcGJNpjMk2xmQDy4EJxph8q94kEUkVkRwgF1gZ8rNQStVYr7aN2TRzNGN6+u7iaN+sPt1aBZ+Vcv0Axzd49/z6lo3TvObh2/Hy/w7gwo6OmU9TkxIY1cPRbrs9JM5uIm96t21cEdx98Xe/w30EtF3BfkOpjoBB3xhTCkwF5gObgTeNMRtFZKaITAiw7UbgTWAT8DHwU2OM/e9ySqmQc82972H1m3e3Ann9lPDMtn77JZ3YPWtcwAFewWhcP5n+HZoCjsFnrt1TdqQmO8Kf6/KUJsh9OLmPZ7hzROB5iBzbOTg/YPq1D37m0mDZeoeNMfOAeW5lM3zUHeb2/DHgsWq2TykVIq0bp7H/+LkqdxlH92jJkvuG075ZdOe9v3rLgCqDtpw6W4PhurRoyJYDJ4HAN1ETE4T3fnIxv3x7PQDPXNun4jVn70oQUyB55RiLEHgntbGGgTsdkatUnMhtke61PNQB39tEa4E4++l9GZKbVaUP32lin9b8587BjOnZqiLlNNBFeov0VHq1rdxXalLlTVxvk9gFmi/IG2NMUCmbky5w3PpsWM1uoWDoyllKxYnauqr81cSe/GqiZ966P0vuG86p4tKgjyUi9GxT9cOgJnP2O+O7t9RJOwPCnOwGb+dRfj6qC3de2rnKB1C46JW+UnGmNmc2fv22QSz82SUB6zWun0ybJjVbPzcUmS/OsQquffrOdQQu8LNgjOtnxH/uHBx08BaRWgn4oFf6SsWN6nRT1JStdYVDpPIqvfr7qLwZXFl244XZ3Hhhtu19+JrRNMXL4jiR6NPXoK9UnImGlavevP1Cvjt+NqT79BawvfF3/qH44HAud+m+C+ckdwBDcjNZuu376h+kBrR7R6k4E/mQDwNyMpjYJ7Q56dcNdIwFOL8aC7Y7Vd7Irf5vyXml7/7Z4rqE5pXWuXdp0bDax6kuvdJXSsWE4ec1Z/escT5fd15d/+yyLj7rVKRs1uByuJmXQV3jerfiLpfc/av7t2VCn9Yk21wPOZQ06CsVJ+olO24UJtU0Cb0OWnr/cLLSU0lLrrxZmmxN2+B6RV5ezSt95xoB57dt7DX99Nnr+3mURSLggwZ9peLGb3/Qi+6tGzGoY+3dXI0W7bxMJ/3n6/ry2opvK0YjA1zWvQU/6NeG+0d3DWr/jdKSeeXmAfRuG51z6LvSoK9UnGjWMJV7Rvru2og37TLqM/3yqpO9pSYl8tQ1fXxs4enh8d0ZaKVyui+eUy85OteL0qCvlIpZ7/zkIr757mTY9n/L4Byfr90zsguFp4u5dXD1JpQLF4lE7q4/eXl5Jj8/P9LNUEqpOkVEVhtjAg4b1pRNpZSKIxr0lVIqjmjQV0qpOKJBXyml4ogGfaWUiiO2gr6IjBGRLSKyXUSmeXn9DhHZICJrReQLEelulWeLyFmrfK2IPB/qE1BKKWVfwDx9EUkEngUuAwqAVSIy1xizyaXav4wxz1v1JwBPAWOs13YYY+yPdlBKKRU2dq70BwDbjTE7jTHFwBxgomsFY8wJl6cNiMw00UoppQKwMyK3DbDX5XkBMNC9koj8FPgZkAKMcHkpR0TWACeAh4wxS71sOwWYYj09JSJb7DXfq0wgMhNVR068nXO8nS/oOceLmpyzrcWJQzYNgzHmWeBZEbkeeAiYDHwHtDfGFIpIf+A9Eenh9s0AY8xsYHYo2iEi+XZGpcWSeDvneDtf0HOOF7Vxzna6d/YB7Vyet7XKfJkDXAlgjCkyxhRaj1cDOwCd8UkppSLETtBfBeSKSI6IpACTgLmuFUQk1+XpOGCbVZ5l3QhGRDoCucDOUDRcKaVU8AJ27xhjSkVkKjAfSAReMsZsFJGZQL4xZi4wVURGAiXAURxdOwBDgZkiUgKUA3cYY46E40RchKSbqI6Jt3OOt/MFPed4EfZzjrpZNpVSSoWPjshVSqk4EjNBP9Co4bpMRHa7jHjOt8oyRGSBiGyz/m1qlYuI/NH6PawXEc/FOaOQiLwkIodE5GuXsqDPUUQmW/W3ichkb8eKFj7O+VER2ecyin2sy2vTrXPeIiKjXcrrxN++iLQTkUUisklENorI3VZ5zL7Pfs45cu+zMabO/+C417AD6IhjnMA6oHuk2xXC89sNZLqVPQFMsx5PA35nPR4LfAQIMAhYEen22zzHoUA/4OvqniOQgSNRIANoaj1uGulzC/KcHwV+4aVud+vvOhXIsf7eE+vS3z7QCuhnPU4HtlrnFbPvs59zjtj7HCtX+gFHDcegicDL1uOXsdJkrfJXjMNyoImItIpEA4NhjFkCuN/kD/YcRwMLjDFHjDFHgQVUTgcSdXycsy8TgTnGkQa9C9iO4+++zvztG2O+M8Z8ZT0+CWzGMfgzZt9nP+fsS9jf51gJ+t5GDfv7xdY1BvhERFZbo5cBWhhjvrMeHwBaWI9j6XcR7DnGyrlPtbozXnJ2dRBj5ywi2UBfYAVx8j67nTNE6H2OlaAf6wYbY/oBlwM/FZGhri8ax/fCmE7DiodztPwF6AT0wTGi/cnINif0RKQh8DZwj/EcnR+T77OXc47Y+xwrQT/YUcN1ijFmn/XvIeBdHF/1Djq7bax/D1nVY+l3Eew51vlzN8YcNMaUGWPKgRdwvNcQI+csIsk4gt8/jTHvWMUx/T57O+dIvs+xEvQDjhquq0SkgYikOx8Do4CvcZyfM2thMvC+9Xgu8GMr82EQcNzlq3NdE+w5zgdGiUhT6+vyKKusznC7/3IVjvcaHOc8SURSRSQHx+j2ldShv30REeBvwGZjzFMuL8Xs++zrnCP6Pkf67naofnDc6d+K4w73g5FuTwjPqyOOO/XrgI3OcwOaAZ/imPJiIZBhlQuO9Q92ABuAvEifg83zfB3H19wSHP2Vt1TnHIGbcdz82g78b6TPqxrn/Kp1Tuut/9StXOo/aJ3zFuByl/I68bcPDMbRdbMeWGv9jI3l99nPOUfsfdYRuUopFUdipXtHKaWUDRr0lVIqjmjQV0qpOKJBXyml4ogGfaWUiiMa9JVSKo5o0FdKqTiiQV8ppeLI/wdMaaBzrVh3ewAAAABJRU5ErkJggg==\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "mf = MatrixFactorizer(n_components=5, learning_rate=0.001, regularization_rate=0.1,\n", | |
| " implicit=True, loss='logistic', log_weights=True,\n", | |
| " n_iter=2500, batch_size=1024, random_state=42)\n", | |
| "\n", | |
| "mf.fit(mtx)\n", | |
| "\n", | |
| "plt.plot(mf.history)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>userId</th>\n", | |
| " <th>genreId</th>\n", | |
| " <th>counts</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>1</td>\n", | |
| " <td>4</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>1</td>\n", | |
| " <td>5</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " userId genreId counts\n", | |
| "0 1 0 0.0\n", | |
| "1 1 1 0.0\n", | |
| "2 1 3 0.0\n", | |
| "3 1 4 0.0\n", | |
| "4 1 5 0.0" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "users = np.unique(ratings_test.loc[:, 'userId'].values)\n", | |
| "genres = np.unique(ratings_test.loc[:, 'genreId'].values)\n", | |
| "\n", | |
| "users_genres = np.array([(u, g) for u in users for g in genres])\n", | |
| "users_genres = pd.DataFrame(users_genres, columns = ['userId', 'genreId'])\n", | |
| "\n", | |
| "test_full = (\n", | |
| " users_genres\n", | |
| " .merge(ratings_test, on=['userId', 'genreId'], how='left')\n", | |
| " .fillna(0)\n", | |
| ")\n", | |
| "\n", | |
| "test_full.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "145692" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "test_full.shape[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 145692/145692 [02:10<00:00, 1117.18it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Anti-join, see:\n", | |
| "# https://stackoverflow.com/a/38516887/3986320\n", | |
| "\n", | |
| "train_keys = set(map(tuple, ratings_train.loc[:, ['userId', 'genreId']].values))\n", | |
| "test_keys = set(map(tuple, test_full.loc[:, ['userId', 'genreId']].values))\n", | |
| "\n", | |
| "test_only_keys = set(test_keys).difference(train_keys)\n", | |
| "\n", | |
| "test_users = []\n", | |
| "test_genres = []\n", | |
| "test_counts = []\n", | |
| "\n", | |
| "for i, row in tqdm(test_full.iterrows(), total=test_full.shape[0]):\n", | |
| " if tuple(row[['userId', 'genreId']]) in test_only_keys:\n", | |
| " test_users.append(row[0])\n", | |
| " test_genres.append(row[1])\n", | |
| " test_counts.append(row[2])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Notice: for some reason (presumably different versions of the one of the packages: sklearn, pandas, or numpy), the counts of rows differ between the Spark and Pandas notebooks." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.17300653, 0.17291747, 0.17295071, ..., 0.1735287 , 0.17319265,\n", | |
| " 0.17340496], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "predictions = mf.predict(test_users, test_genres)\n", | |
| "predictions" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "test_out = pd.DataFrame({\n", | |
| " 'userId' : test_users,\n", | |
| " 'genreId' : test_genres,\n", | |
| " 'counts' : test_counts,\n", | |
| " 'predictions' : predictions\n", | |
| "}).sort_values(['userId', 'predictions'], ascending=[True, False])\n", | |
| "\n", | |
| "test_out['rank'] = (\n", | |
| " test_out\n", | |
| " .groupby(['userId'])\n", | |
| " .cumcount() + 1\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(639, 228, 127680)" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "n = len(set(test_users))\n", | |
| "k = len(set(test_genres))\n", | |
| "\n", | |
| "n, k, test_out.shape[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(0.15968302214502822, 36.40772904906643)" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "## MPR (smaller better)\n", | |
| "MPR = (test_out.loc[test_out['counts'] > 0, 'rank'] / k).mean()\n", | |
| "\n", | |
| "MPR, MPR * k" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(0.6404786508541415, 1.5613322921324564)" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "## MRR (higher better)\n", | |
| "MRR = (1 / (\n", | |
| " test_out\n", | |
| " .loc[test_out['counts'] > 0, :]\n", | |
| " .groupby('userId')['rank'].min()\n", | |
| ")).mean()\n", | |
| "\n", | |
| "MRR, 1 / MRR" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Based on those two metrics, the performance is close as in results [obtained using Sparks ALS module](http://nbviewer.jupyter.org/gist/twolodzko/7becd98ff256ef826b56945de297700d)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Last updated: 2018-06-27 12:16:14.052402\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import datetime\n", | |
| "\n", | |
| "print('Last updated: ' + str(datetime.datetime.now()))" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment