Last active
June 27, 2018 10:17
-
-
Save twolodzko/cd0ae4a06c23cb8cdc51ae46079bdb3f to your computer and use it in GitHub Desktop.
Matrix Factorization Using TensorFlow on The Modies Dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Matrix Factorization Using TensorFlow (ver 0.4)\n", | |
| "Tymoteusz Wolodzko\n", | |
| "\n", | |
| "This notebook shows results of my TensorFlow implementation of matrix factorization (it is equivalent to [Spark's ALS module](http://spark.apache.org/docs/2.2.0/api/python/pyspark.ml.html#module-pyspark.ml.recommendation) that I tested [in my previous notebook](http://nbviewer.jupyter.org/gist/twolodzko/7becd98ff256ef826b56945de297700d).) As in the previous notebook, I am using a subset of [The Movies Dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) hosted on Kaggle that contains metadata on over 45,000 movies and 26 million ratings from over 270,000 users." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pandas as pd\n", | |
| "import numpy as np\n", | |
| "import tensorflow as tf\n", | |
| "import random\n", | |
| "\n", | |
| "from tqdm import tqdm\n", | |
| "\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "%matplotlib inline" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>id</th>\n", | |
| " <th>genres</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>862</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>8844</td>\n", | |
| " <td>[Adventure, Fantasy, Family]</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>15602</td>\n", | |
| " <td>[Romance, Comedy]</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>31357</td>\n", | |
| " <td>[Comedy, Drama, Romance]</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>11862</td>\n", | |
| " <td>[Comedy]</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " id genres\n", | |
| "0 862 [Animation, Comedy, Family]\n", | |
| "1 8844 [Adventure, Fantasy, Family]\n", | |
| "2 15602 [Romance, Comedy]\n", | |
| "3 31357 [Comedy, Drama, Romance]\n", | |
| "4 11862 [Comedy]" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "from ast import literal_eval\n", | |
| "\n", | |
| "movies = pd.read_csv('data/movies_metadata.csv', low_memory = False)\n", | |
| "movies = movies.loc[:, ['id', 'genres']]\n", | |
| "movies['genres'] = (\n", | |
| " movies['genres']\n", | |
| " .fillna('[]')\n", | |
| " .apply(literal_eval).apply(lambda x: [i['name'] for i in x] if isinstance(x, list) else [])\n", | |
| ")\n", | |
| "movies.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>id</th>\n", | |
| " <th>genres</th>\n", | |
| " <th>GoHands</th>\n", | |
| " <th>Sentai Filmworks</th>\n", | |
| " <th>BROSTA TV</th>\n", | |
| " <th>The Cartel</th>\n", | |
| " <th>Science Fiction</th>\n", | |
| " <th>Telescene Film Group Productions</th>\n", | |
| " <th>Drama</th>\n", | |
| " <th>Adventure</th>\n", | |
| " <th>...</th>\n", | |
| " <th>Western</th>\n", | |
| " <th>Rogue State</th>\n", | |
| " <th>Romance</th>\n", | |
| " <th>Action</th>\n", | |
| " <th>Thriller</th>\n", | |
| " <th>Fantasy</th>\n", | |
| " <th>Carousel Productions</th>\n", | |
| " <th>Pulser Productions</th>\n", | |
| " <th>War</th>\n", | |
| " <th>Music</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>862</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>8844</td>\n", | |
| " <td>[Adventure, Fantasy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>15602</td>\n", | |
| " <td>[Romance, Comedy]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>31357</td>\n", | |
| " <td>[Comedy, Drama, Romance]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>11862</td>\n", | |
| " <td>[Comedy]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>5 rows × 34 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " id genres GoHands Sentai Filmworks BROSTA TV \\\n", | |
| "0 862 [Animation, Comedy, Family] 0 0 0 \n", | |
| "1 8844 [Adventure, Fantasy, Family] 0 0 0 \n", | |
| "2 15602 [Romance, Comedy] 0 0 0 \n", | |
| "3 31357 [Comedy, Drama, Romance] 0 0 0 \n", | |
| "4 11862 [Comedy] 0 0 0 \n", | |
| "\n", | |
| " The Cartel Science Fiction Telescene Film Group Productions Drama \\\n", | |
| "0 0 0 0 0 \n", | |
| "1 0 0 0 0 \n", | |
| "2 0 0 0 0 \n", | |
| "3 0 0 0 1 \n", | |
| "4 0 0 0 0 \n", | |
| "\n", | |
| " Adventure ... Western Rogue State Romance Action Thriller Fantasy \\\n", | |
| "0 0 ... 0 0 0 0 0 0 \n", | |
| "1 1 ... 0 0 0 0 0 1 \n", | |
| "2 0 ... 0 0 1 0 0 0 \n", | |
| "3 0 ... 0 0 1 0 0 0 \n", | |
| "4 0 ... 0 0 0 0 0 0 \n", | |
| "\n", | |
| " Carousel Productions Pulser Productions War Music \n", | |
| "0 0 0 0 0 \n", | |
| "1 0 0 0 0 \n", | |
| "2 0 0 0 0 \n", | |
| "3 0 0 0 0 \n", | |
| "4 0 0 0 0 \n", | |
| "\n", | |
| "[5 rows x 34 columns]" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "genres = set([item for row in movies['genres'].values for item in row])\n", | |
| "\n", | |
| "for g in genres:\n", | |
| " movies[g] = 0\n", | |
| "\n", | |
| "for row in range(movies.shape[0]):\n", | |
| " for g in movies.loc[row, 'genres']:\n", | |
| " movies.at[row, g] = 1\n", | |
| " \n", | |
| "movies.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(359, 2)" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "movies['genres_bitmap'] = movies.iloc[:, 2:].apply(lambda row : ''.join([str(x) for x in row]), axis = 1)\n", | |
| "counts = movies['genres_bitmap'].value_counts()\n", | |
| "\n", | |
| "counts = pd.DataFrame({\n", | |
| " 'genres_bitmap' : counts.index,\n", | |
| " 'genres_counts' : counts\n", | |
| "})\n", | |
| "\n", | |
| "# leave only the popular combinations of genres\n", | |
| "counts = counts.loc[counts['genres_counts'] >= 10, :]\n", | |
| "counts.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>id</th>\n", | |
| " <th>genres</th>\n", | |
| " <th>GoHands</th>\n", | |
| " <th>Sentai Filmworks</th>\n", | |
| " <th>BROSTA TV</th>\n", | |
| " <th>The Cartel</th>\n", | |
| " <th>Science Fiction</th>\n", | |
| " <th>Telescene Film Group Productions</th>\n", | |
| " <th>Drama</th>\n", | |
| " <th>Adventure</th>\n", | |
| " <th>...</th>\n", | |
| " <th>Thriller</th>\n", | |
| " <th>Fantasy</th>\n", | |
| " <th>Carousel Productions</th>\n", | |
| " <th>Pulser Productions</th>\n", | |
| " <th>War</th>\n", | |
| " <th>Music</th>\n", | |
| " <th>genres_bitmap</th>\n", | |
| " <th>genres_counts</th>\n", | |
| " <th>genreId</th>\n", | |
| " <th>movieId</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>862</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>862</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>12233</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>12233</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>532</td>\n", | |
| " <td>[Family, Animation, Comedy]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>532</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>531</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>531</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>40688</td>\n", | |
| " <td>[Animation, Comedy, Family]</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>...</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0</td>\n", | |
| " <td>00000000011100000000000000000000</td>\n", | |
| " <td>112</td>\n", | |
| " <td>122</td>\n", | |
| " <td>40688</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "<p>5 rows × 38 columns</p>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " id genres GoHands Sentai Filmworks BROSTA TV \\\n", | |
| "0 862 [Animation, Comedy, Family] 0 0 0 \n", | |
| "1 12233 [Animation, Comedy, Family] 0 0 0 \n", | |
| "2 532 [Family, Animation, Comedy] 0 0 0 \n", | |
| "3 531 [Animation, Comedy, Family] 0 0 0 \n", | |
| "4 40688 [Animation, Comedy, Family] 0 0 0 \n", | |
| "\n", | |
| " The Cartel Science Fiction Telescene Film Group Productions Drama \\\n", | |
| "0 0 0 0 0 \n", | |
| "1 0 0 0 0 \n", | |
| "2 0 0 0 0 \n", | |
| "3 0 0 0 0 \n", | |
| "4 0 0 0 0 \n", | |
| "\n", | |
| " Adventure ... Thriller Fantasy Carousel Productions \\\n", | |
| "0 0 ... 0 0 0 \n", | |
| "1 0 ... 0 0 0 \n", | |
| "2 0 ... 0 0 0 \n", | |
| "3 0 ... 0 0 0 \n", | |
| "4 0 ... 0 0 0 \n", | |
| "\n", | |
| " Pulser Productions War Music genres_bitmap \\\n", | |
| "0 0 0 0 00000000011100000000000000000000 \n", | |
| "1 0 0 0 00000000011100000000000000000000 \n", | |
| "2 0 0 0 00000000011100000000000000000000 \n", | |
| "3 0 0 0 00000000011100000000000000000000 \n", | |
| "4 0 0 0 00000000011100000000000000000000 \n", | |
| "\n", | |
| " genres_counts genreId movieId \n", | |
| "0 112 122 862 \n", | |
| "1 112 122 12233 \n", | |
| "2 112 122 532 \n", | |
| "3 112 122 531 \n", | |
| "4 112 122 40688 \n", | |
| "\n", | |
| "[5 rows x 38 columns]" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "movies = movies.merge(counts, on = 'genres_bitmap', how = 'inner')\n", | |
| "movies.rename(columns = {'movie id' : 'movie_id'}, inplace = True)\n", | |
| "movies['genreId'] = movies['genres_bitmap'].astype(\"category\").cat.codes\n", | |
| "movies['movieId'] = movies['id'].astype('int64')\n", | |
| "\n", | |
| "movies.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(100004, 4)" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ratings = pd.read_csv('data/ratings_small.csv')\n", | |
| "\n", | |
| "ratings.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>userId</th>\n", | |
| " <th>movieId</th>\n", | |
| " <th>rating</th>\n", | |
| " <th>timestamp</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>31</td>\n", | |
| " <td>2.5</td>\n", | |
| " <td>1260759144</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1029</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1260759179</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1061</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>1260759182</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1129</td>\n", | |
| " <td>2.0</td>\n", | |
| " <td>1260759185</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1172</td>\n", | |
| " <td>4.0</td>\n", | |
| " <td>1260759205</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " userId movieId rating timestamp\n", | |
| "0 1 31 2.5 1260759144\n", | |
| "1 1 1029 3.0 1260759179\n", | |
| "2 1 1061 3.0 1260759182\n", | |
| "3 1 1129 2.0 1260759185\n", | |
| "4 1 1172 4.0 1260759205" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ratings.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>userId</th>\n", | |
| " <th>movieId</th>\n", | |
| " <th>rating</th>\n", | |
| " <th>genreId</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>2.5</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>4</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>4.0</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>7</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>19</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>4.0</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>21</td>\n", | |
| " <td>1371</td>\n", | |
| " <td>3.0</td>\n", | |
| " <td>173</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " userId movieId rating genreId\n", | |
| "0 1 1371 2.5 173\n", | |
| "1 4 1371 4.0 173\n", | |
| "2 7 1371 3.0 173\n", | |
| "3 19 1371 4.0 173\n", | |
| "4 21 1371 3.0 173" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ratings_genres = (\n", | |
| " ratings.loc[:, ['userId', 'movieId', 'rating']]\n", | |
| " .merge(movies.loc[:, ['movieId', 'genreId']], on = 'movieId', how = 'inner')\n", | |
| ")\n", | |
| "\n", | |
| "ratings_genres.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>userId</th>\n", | |
| " <th>genreId</th>\n", | |
| " <th>counts</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>67</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>75</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>1</td>\n", | |
| " <td>173</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>1</td>\n", | |
| " <td>248</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>1</td>\n", | |
| " <td>257</td>\n", | |
| " <td>1</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " userId genreId counts\n", | |
| "0 1 67 1\n", | |
| "1 1 75 1\n", | |
| "2 1 173 1\n", | |
| "3 1 248 1\n", | |
| "4 1 257 1" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ratings_counts = (\n", | |
| " ratings_genres.groupby(['userId', 'genreId'], as_index = False)['movieId'].count()\n", | |
| " .rename(columns = {'movieId' : 'counts'})\n", | |
| ")\n", | |
| "\n", | |
| "ratings_counts.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "((18424, 3), (4606, 3))" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "from sklearn.model_selection import train_test_split\n", | |
| "\n", | |
| "ratings_counts.dropna(axis = 0, inplace = True)\n", | |
| "ratings.dropna(axis = 0, inplace = True)\n", | |
| "\n", | |
| "ratings_train, ratings_test = train_test_split(ratings_counts, test_size=0.2, random_state=42)\n", | |
| "\n", | |
| "ratings_train.shape, ratings_test.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "import tensorflow as tf\n", | |
| "from sklearn.base import BaseEstimator, RegressorMixin\n", | |
| "from tqdm import tqdm, trange\n", | |
| "from scipy.sparse import coo_matrix, csr_matrix\n", | |
| "import json\n", | |
| "\n", | |
| "\n", | |
| "def sparse_matrix(rows, cols, values, shape=None, mode='dok'):\n", | |
| " if shape is None:\n", | |
| " n = np.max(rows) + 1\n", | |
| " k = np.max(cols) + 1\n", | |
| " if mode == 'dok':\n", | |
| " return coo_matrix((values, (rows, cols)), shape=(n, k)).todok()\n", | |
| " else:\n", | |
| " return csr_matrix((values, (rows, cols)), shape=(n, k))\n", | |
| "\n", | |
| "\n", | |
| "class MatrixFactorizer(BaseEstimator):\n", | |
| " \n", | |
| " \"\"\"Matrix Factorizer\n", | |
| " \n", | |
| " Factorize the matrix R (n, m) into P (n, n_components) and Q (n_components, m) weights\n", | |
| " matrices:\n", | |
| " \n", | |
| " R[i,j] = P[i,:] * Q[:,j]\n", | |
| " \n", | |
| " Additional intercepts mu, bi, bj can be included, leading to the following model:\n", | |
| " \n", | |
| " R[i,j] = mu + bi[i] + bj[j] + P[i,:] * Q[:,j]\n", | |
| " \n", | |
| " The model is commonly used for recommender systems, where the matrix R contains of\n", | |
| " ratings by n users of m products. When users rate products using some kind of rating\n", | |
| " system (e.g. \"likes\", 1 to 5 stars), we are talking about explicit ratings (Koren et al,\n", | |
| " 2009). When ratings are not available and instead we use indirect measures of preferences\n", | |
| " (e.g. clicks, purchases), we are talking about implicit ratings (Hu et al, 2008). For\n", | |
| " implicit ratings we use modified model, where we model the indicator variable:\n", | |
| " \n", | |
| " D[i,j] = 1 if R[i,j] > 0 else 0\n", | |
| " \n", | |
| " and define additional weights:\n", | |
| " \n", | |
| " C[i,j] = 1 + alpha * R[i, j]\n", | |
| " \n", | |
| " or log weights:\n", | |
| " \n", | |
| " C[i,j] = 1 + alpha * log(1 + R[i, j])\n", | |
| " \n", | |
| " The model is defined in terms of minimizing the loss function (squared, logistic) between\n", | |
| " D[i,j] indicators and the values predicted using matrix factorization, where the loss is weighted\n", | |
| " using the C[i,j] weights (see Hu et al, 2008 for details). When using logistic loss, the predictions\n", | |
| " are passed through the sigmoid function to squeze them into the (0, 1) range.\n", | |
| " \n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " \n", | |
| " n_components : int, default : 5\n", | |
| " Number of latent components to be estimated. The estimated latent matrices P and Q\n", | |
| " have (n, n_components) and (n_components, m) shapes subsequently.\n", | |
| " \n", | |
| " n_iter : int, default : 500\n", | |
| " Number of training epochs, the actual number of iterations is n_samples * n_epoch.\n", | |
| " \n", | |
| " batch_size : int, default : 500\n", | |
| " Size of the random batch to be used during training. The batch_size is the number of\n", | |
| " cells that are randomly sampled from the factorized matrix.\n", | |
| " \n", | |
| " learning_rate : float, default : 0.01\n", | |
| " Learning rate parameter.\n", | |
| " \n", | |
| " regularization_rate : float, default : 0.02\n", | |
| " Regularization parameter.\n", | |
| " \n", | |
| " alpha : float, default : 1.0\n", | |
| " Weighting parameter in matrix factorization with implicit ratings.\n", | |
| " \n", | |
| " implicit : bool, default : False\n", | |
| " Use matrix factorization with explicit (default) or implicit ratings. \n", | |
| " \n", | |
| " loss : 'squared', 'logistic', default: 'squared'\n", | |
| " Loss function to be used. For implicit=True 'logistic' loss may be preferable.\n", | |
| " \n", | |
| " log_weights : bool, default : False\n", | |
| " Only for implicit=True, use log weighting in the loss function instead of standard weights.\n", | |
| " \n", | |
| " fit_intercepts : bool, default : True\n", | |
| " When set to True, the mu, bi, bj intercepts are fitted, otherwise\n", | |
| " only the P and Q latent matrices are fitted.\n", | |
| " \n", | |
| " warm_start : bool, optional\n", | |
| " When set to True, reuse the solution of the previous call to fit as initialization,\n", | |
| " otherwise, just erase the previous solution.\n", | |
| " \n", | |
| " optimizer : 'Adam', 'Ftrl', default : 'Adam'\n", | |
| " Optimizer to be used, see TensorFlow documentation for more details.\n", | |
| " \n", | |
| " random_state : int, or None, default : None\n", | |
| " The seed of the pseudo random number generator to use when shuffling the data. If int,\n", | |
| " random_state is the seed used by the random number generator.\n", | |
| " \n", | |
| " show_progress : bool, default : False\n", | |
| " Show the progress bar.\n", | |
| " \n", | |
| " References\n", | |
| " ----------\n", | |
| " \n", | |
| " Koren, Y., Bell, R., & Volinsky, C. (2009).\n", | |
| " Matrix factorization techniques for recommender systems. Computer, 42(8).\n", | |
| " \n", | |
| " Yu, H. F., Hsieh, C. J., Si, S., & Dhillon, I. (2012, December).\n", | |
| " Scalable coordinate descent approaches to parallel matrix factorization for recommender systems.\n", | |
| " In Data Mining (ICDM), 2012 IEEE 12th International Conference on (pp. 765-774). IEEE.\n", | |
| " \n", | |
| " Hu, Y., Koren, Y., & Volinsky, C. (2008, December).\n", | |
| " Collaborative filtering for implicit feedback datasets.\n", | |
| " In Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on (pp. 263-272). IEEE.\n", | |
| " \n", | |
| " \"\"\" \n", | |
| " \n", | |
| " class TFModel(object):\n", | |
| " # Define and initialize the TensorFlow model, its weights, initialize session and saver\n", | |
| "\n", | |
| " def __init__(self, shape, learning_rate, alpha, regularization_rate,\n", | |
| " implicit, loss, log_weights, fit_intercepts, optimizer,\n", | |
| " random_state=None):\n", | |
| "\n", | |
| " self.shape = shape\n", | |
| " self.learning_rate = learning_rate\n", | |
| " self.implicit = implicit\n", | |
| " self.loss = loss\n", | |
| " self.log_weights = log_weights\n", | |
| " self.fit_intercepts = fit_intercepts\n", | |
| " self.optimizer = optimizer\n", | |
| " self.random_state = random_state\n", | |
| "\n", | |
| " n, k, d = self.shape\n", | |
| "\n", | |
| " self.graph = tf.Graph()\n", | |
| "\n", | |
| " with self.graph.as_default():\n", | |
| "\n", | |
| " tf.set_random_seed(self.random_state)\n", | |
| "\n", | |
| " with tf.name_scope('constants'):\n", | |
| " self.alpha = tf.constant(alpha, dtype=tf.float32)\n", | |
| " self.regularization_rate = tf.constant(regularization_rate, dtype=tf.float32,\n", | |
| " name='regularization_rate')\n", | |
| "\n", | |
| " with tf.name_scope('inputs'):\n", | |
| " self.row_ids = tf.placeholder(tf.int32, shape=[None], name='row_ids')\n", | |
| " self.col_ids = tf.placeholder(tf.int32, shape=[None], name='col_ids')\n", | |
| " self.values = tf.placeholder(tf.float32, shape=[None], name='values')\n", | |
| "\n", | |
| " if self.implicit:\n", | |
| " targets = tf.clip_by_value(self.values, 0, 1, name='targets')\n", | |
| " \n", | |
| " if self.log_weights:\n", | |
| " data_weights = tf.add(1.0, self.alpha * tf.log1p(self.values), name='data_weights')\n", | |
| " else:\n", | |
| " data_weights = tf.add(1.0, self.alpha * self.values, name='data_weights')\n", | |
| " else:\n", | |
| " targets = tf.identity(self.values, name='targets')\n", | |
| " data_weights = tf.constant(1.0, name='data_weights')\n", | |
| "\n", | |
| " with tf.name_scope('parameters'):\n", | |
| " \n", | |
| " if self.fit_intercepts:\n", | |
| " self.global_bias = tf.get_variable('global_bias', shape=[], dtype=tf.float32,\n", | |
| " initializer=tf.zeros_initializer())\n", | |
| " self.row_biases = tf.get_variable('row_biases', shape=[n], dtype=tf.float32,\n", | |
| " initializer=tf.zeros_initializer())\n", | |
| " self.col_biases = tf.get_variable('col_biases', shape=[k], dtype=tf.float32,\n", | |
| " initializer=tf.zeros_initializer())\n", | |
| "\n", | |
| " self.row_weights = tf.get_variable('row_weights', shape=[n, d], dtype=tf.float32,\n", | |
| " initializer = tf.random_normal_initializer(mean=0, stddev=0.01))\n", | |
| " self.col_weights = tf.get_variable('col_weights', shape=[k, d], dtype=tf.float32,\n", | |
| " initializer = tf.random_normal_initializer(mean=0, stddev=0.01))\n", | |
| "\n", | |
| " with tf.name_scope('prediction'):\n", | |
| " \n", | |
| " if self.fit_intercepts:\n", | |
| " batch_row_biases = tf.nn.embedding_lookup(self.row_biases, self.row_ids, name='row_bias')\n", | |
| " batch_col_biases = tf.nn.embedding_lookup(self.col_biases, self.col_ids, name='col_bias')\n", | |
| "\n", | |
| " batch_row_weights = tf.nn.embedding_lookup(self.row_weights, self.row_ids, name='row_weights')\n", | |
| " batch_col_weights = tf.nn.embedding_lookup(self.col_weights, self.col_ids, name='col_weights')\n", | |
| "\n", | |
| " weights = tf.reduce_sum(tf.multiply(batch_row_weights, batch_col_weights), axis=1, name='weights')\n", | |
| "\n", | |
| " if self.fit_intercepts:\n", | |
| " biases = tf.add(batch_row_biases, batch_col_biases)\n", | |
| " biases = tf.add(self.global_bias, biases, name='biases')\n", | |
| " linear_predictor = tf.add(biases, weights, name='linear_predictor')\n", | |
| " else:\n", | |
| " linear_predictor = tf.identity(weights, name='linear_predictor')\n", | |
| "\n", | |
| " if self.loss == 'logistic':\n", | |
| " self.pred = tf.sigmoid(linear_predictor, name='predictions')\n", | |
| " else:\n", | |
| " self.pred = tf.identity(linear_predictor, name='predictions')\n", | |
| "\n", | |
| " with tf.name_scope('loss'):\n", | |
| " \n", | |
| " l2_weights = tf.add(tf.nn.l2_loss(self.row_weights),\n", | |
| " tf.nn.l2_loss(self.col_weights), name='l2_weights')\n", | |
| " \n", | |
| " if self.fit_intercepts:\n", | |
| " l2_biases = tf.add(tf.nn.l2_loss(batch_row_biases),\n", | |
| " tf.nn.l2_loss(batch_col_biases), name='l2_biases')\n", | |
| " l2_term = tf.add(l2_weights, l2_biases)\n", | |
| " else:\n", | |
| " l2_term = l2_weights\n", | |
| " \n", | |
| " l2_term = tf.multiply(self.regularization_rate, l2_term, name='regularization')\n", | |
| "\n", | |
| " if self.loss == 'logistic':\n", | |
| " loss_raw = tf.losses.log_loss(predictions=self.pred, labels=targets,\n", | |
| " weights=data_weights)\n", | |
| " else:\n", | |
| " loss_raw = tf.losses.mean_squared_error(predictions=self.pred, labels=targets,\n", | |
| " weights=data_weights) \n", | |
| "\n", | |
| " self.cost = tf.add(loss_raw, l2_term, name='loss')\n", | |
| "\n", | |
| " if self.optimizer == 'Ftrl':\n", | |
| " self.train_step = tf.train.FtrlOptimizer(self.learning_rate).minimize(self.cost)\n", | |
| " else:\n", | |
| " self.train_step = tf.train.AdamOptimizer(self.learning_rate).minimize(self.cost)\n", | |
| "\n", | |
| " self.saver = tf.train.Saver()\n", | |
| "\n", | |
| " init = tf.global_variables_initializer()\n", | |
| "\n", | |
| " # initialize TF session\n", | |
| " self.sess = tf.Session(graph=self.graph)\n", | |
| " self.sess.run(init)\n", | |
| " \n", | |
| " \n", | |
| " def train(self, rows, cols, values): \n", | |
| " batch = {\n", | |
| " self.row_ids : rows,\n", | |
| " self.col_ids : cols,\n", | |
| " self.values : values\n", | |
| " }\n", | |
| " _, loss_value = self.sess.run(fetches=[self.train_step, self.cost], feed_dict=batch)\n", | |
| " return loss_value\n", | |
| " \n", | |
| " \n", | |
| " def predict(self, rows, cols):\n", | |
| " batch = {\n", | |
| " self.row_ids : rows,\n", | |
| " self.col_ids : cols\n", | |
| " }\n", | |
| " return self.pred.eval(feed_dict=batch, session=self.sess)\n", | |
| " \n", | |
| " \n", | |
| " def coef(self):\n", | |
| " if self.fit_intercepts:\n", | |
| " return self.sess.run(fetches={\n", | |
| " 'global_bias' : self.global_bias,\n", | |
| " 'row_bias' : self.row_biases,\n", | |
| " 'col_bias' : self.col_biases,\n", | |
| " 'row_weights' : self.row_weights,\n", | |
| " 'col_weights' : self.col_weights\n", | |
| " })\n", | |
| " else:\n", | |
| " return self.sess.run(fetches={\n", | |
| " 'row_weights' : self.row_weights,\n", | |
| " 'col_weights' : self.col_weights\n", | |
| " })\n", | |
| " \n", | |
| " \n", | |
| " def save(self, path):\n", | |
| " self.saver.save(self.sess, path)\n", | |
| " \n", | |
| " \n", | |
| " def restore(self, path):\n", | |
| " self.saver.restore(self.sess, path)\n", | |
| " \n", | |
| " \n", | |
| " def __init__(self, n_components=5, n_iter=500, batch_size=500, learning_rate=0.01,\n", | |
| " regularization_rate=0.02, alpha=1.0, implicit=False, loss='squared',\n", | |
| " log_weights=False, fit_intercepts=True, warm_start=False, optimizer='Adam',\n", | |
| " random_state=None, show_progress=True):\n", | |
| " \n", | |
| " self.n_components = n_components\n", | |
| " self.shape = (None, None, self.n_components)\n", | |
| " self._data = None\n", | |
| " self.n_iter = n_iter\n", | |
| " self.batch_size = batch_size\n", | |
| " self.learning_rate = float(learning_rate)\n", | |
| " self.alpha = float(alpha)\n", | |
| " self.regularization_rate = float(regularization_rate)\n", | |
| " self.implicit = implicit\n", | |
| " self.loss = loss\n", | |
| " self.log_weights = log_weights\n", | |
| " self.fit_intercepts = fit_intercepts\n", | |
| " self.optimizer = optimizer\n", | |
| " self.random_state = random_state\n", | |
| " self.warm_start = warm_start\n", | |
| " self.show_progress = show_progress\n", | |
| " \n", | |
| " np.random.seed(self.random_state)\n", | |
| " self._fresh_session()\n", | |
| " \n", | |
| " \n", | |
| " def _fresh_session(self):\n", | |
| " # reset the session, to start from the scratch \n", | |
| " self._tf = None\n", | |
| " self.history = []\n", | |
| " \n", | |
| " \n", | |
| " def _tf_init(self, shape=None):\n", | |
| " # define the TensorFlow model and initialize variables, session, saver\n", | |
| " if shape is None:\n", | |
| " shape = self.shape\n", | |
| " self._tf = self.TFModel(shape=self.shape, learning_rate=self.learning_rate,\n", | |
| " alpha=self.alpha, regularization_rate=self.regularization_rate,\n", | |
| " implicit=self.implicit, loss=self.loss, log_weights=self.log_weights,\n", | |
| " fit_intercepts=self.fit_intercepts, optimizer=self.optimizer,\n", | |
| " random_state=self.random_state)\n", | |
| " \n", | |
| " \n", | |
| " def _get_batch(self, data, batch_size=1):\n", | |
| " # create single batch for training\n", | |
| " \n", | |
| " batch_rows = np.random.randint(self.shape[0], size=batch_size)\n", | |
| " batch_cols = np.random.randint(self.shape[1], size=batch_size)\n", | |
| " batch_vals = data[batch_rows, batch_cols].A.flatten()\n", | |
| " \n", | |
| " return batch_rows, batch_cols, batch_vals\n", | |
| " \n", | |
| " \n", | |
| " def set_shape(self, n, k):\n", | |
| " '''Manually set the shape parameters\n", | |
| " '''\n", | |
| " self.shape = (int(n), int(k), int(self.n_components))\n", | |
| " \n", | |
| " \n", | |
| " def fit(self, sparse_matrix):\n", | |
| " '''Fit the model\n", | |
| " \n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " \n", | |
| " sparse_matrix : sparse-matrix, shape (n_users, n_items)\n", | |
| " \n", | |
| " '''\n", | |
| " if not self.warm_start:\n", | |
| " self._fresh_session()\n", | |
| " return self.partial_fit(sparse_matrix)\n", | |
| " \n", | |
| " \n", | |
| " def partial_fit(self, sparse_matrix):\n", | |
| " '''Fit the model\n", | |
| " \n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " \n", | |
| " sparse_matrix : sparse-matrix, shape (n_users, n_items)\n", | |
| " \n", | |
| " '''\n", | |
| " \n", | |
| " if self._tf is None:\n", | |
| " self.set_shape(*sparse_matrix.shape)\n", | |
| " self._tf_init(self.shape)\n", | |
| " \n", | |
| " for _ in trange(self.n_iter, disable=not self.show_progress, desc='training'):\n", | |
| " # sample random batch, train, save batch loss\n", | |
| " batch_X0, batch_X1, batch_y = self._get_batch(sparse_matrix, self.batch_size)\n", | |
| " loss_value = self._tf.train(batch_X0, batch_X1, batch_y)\n", | |
| " self.history.append(loss_value)\n", | |
| " \n", | |
| " return self\n", | |
| " \n", | |
| " \n", | |
| " def predict(self, rows, cols):\n", | |
| " '''Predict using the model\n", | |
| " \n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " \n", | |
| " rows : array, shape (n_samples,)\n", | |
| " \n", | |
| " cols : array, shape (n_samples,)\n", | |
| " \n", | |
| " '''\n", | |
| " \n", | |
| " return self._tf.predict(rows, cols)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Preparing sparse matrix data representation." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "mtx = sparse_matrix(\n", | |
| " ratings_train.loc[:, 'userId'].values,\n", | |
| " ratings_train.loc[:, 'genreId'].values,\n", | |
| " ratings_train.loc[:, 'counts'].values\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We are using same settings as in the [Sparks ALS module notebook](http://nbviewer.jupyter.org/gist/twolodzko/7becd98ff256ef826b56945de297700d)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "training: 100%|██████████| 2500/2500 [00:07<00:00, 343.15it/s]\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[<matplotlib.lines.Line2D at 0x7ff9e9fcc160>]" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "mf = MatrixFactorizer(n_components=5, learning_rate=0.001, regularization_rate=0.1,\n", | |
| " implicit=True, loss='logistic', log_weights=True,\n", | |
| " n_iter=2500, batch_size=1024, random_state=42)\n", | |
| "\n", | |
| "mf.fit(mtx)\n", | |
| "\n", | |
| "plt.plot(mf.history)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>userId</th>\n", | |
| " <th>genreId</th>\n", | |
| " <th>counts</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>0</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>1</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>1</td>\n", | |
| " <td>4</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>1</td>\n", | |
| " <td>5</td>\n", | |
| " <td>0.0</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " userId genreId counts\n", | |
| "0 1 0 0.0\n", | |
| "1 1 1 0.0\n", | |
| "2 1 3 0.0\n", | |
| "3 1 4 0.0\n", | |
| "4 1 5 0.0" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "users = np.unique(ratings_test.loc[:, 'userId'].values)\n", | |
| "genres = np.unique(ratings_test.loc[:, 'genreId'].values)\n", | |
| "\n", | |
| "users_genres = np.array([(u, g) for u in users for g in genres])\n", | |
| "users_genres = pd.DataFrame(users_genres, columns = ['userId', 'genreId'])\n", | |
| "\n", | |
| "test_full = (\n", | |
| " users_genres\n", | |
| " .merge(ratings_test, on=['userId', 'genreId'], how='left')\n", | |
| " .fillna(0)\n", | |
| ")\n", | |
| "\n", | |
| "test_full.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "145692" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "test_full.shape[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 145692/145692 [02:10<00:00, 1117.18it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Anti-join, see:\n", | |
| "# https://stackoverflow.com/a/38516887/3986320\n", | |
| "\n", | |
| "train_keys = set(map(tuple, ratings_train.loc[:, ['userId', 'genreId']].values))\n", | |
| "test_keys = set(map(tuple, test_full.loc[:, ['userId', 'genreId']].values))\n", | |
| "\n", | |
| "test_only_keys = set(test_keys).difference(train_keys)\n", | |
| "\n", | |
| "test_users = []\n", | |
| "test_genres = []\n", | |
| "test_counts = []\n", | |
| "\n", | |
| "for i, row in tqdm(test_full.iterrows(), total=test_full.shape[0]):\n", | |
| " if tuple(row[['userId', 'genreId']]) in test_only_keys:\n", | |
| " test_users.append(row[0])\n", | |
| " test_genres.append(row[1])\n", | |
| " test_counts.append(row[2])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Notice: for some reason (presumably different versions of the one of the packages: sklearn, pandas, or numpy), the counts of rows differ between the Spark and Pandas notebooks." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0.17300653, 0.17291747, 0.17295071, ..., 0.1735287 , 0.17319265,\n", | |
| " 0.17340496], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "predictions = mf.predict(test_users, test_genres)\n", | |
| "predictions" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "test_out = pd.DataFrame({\n", | |
| " 'userId' : test_users,\n", | |
| " 'genreId' : test_genres,\n", | |
| " 'counts' : test_counts,\n", | |
| " 'predictions' : predictions\n", | |
| "}).sort_values(['userId', 'predictions'], ascending=[True, False])\n", | |
| "\n", | |
| "test_out['rank'] = (\n", | |
| " test_out\n", | |
| " .groupby(['userId'])\n", | |
| " .cumcount() + 1\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(639, 228, 127680)" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "n = len(set(test_users))\n", | |
| "k = len(set(test_genres))\n", | |
| "\n", | |
| "n, k, test_out.shape[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(0.15968302214502822, 36.40772904906643)" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "## MPR (smaller better)\n", | |
| "MPR = (test_out.loc[test_out['counts'] > 0, 'rank'] / k).mean()\n", | |
| "\n", | |
| "MPR, MPR * k" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(0.6404786508541415, 1.5613322921324564)" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "## MRR (higher better)\n", | |
| "MRR = (1 / (\n", | |
| " test_out\n", | |
| " .loc[test_out['counts'] > 0, :]\n", | |
| " .groupby('userId')['rank'].min()\n", | |
| ")).mean()\n", | |
| "\n", | |
| "MRR, 1 / MRR" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Based on those two metrics, the performance is close as in results [obtained using Sparks ALS module](http://nbviewer.jupyter.org/gist/twolodzko/7becd98ff256ef826b56945de297700d)." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Last updated: 2018-06-27 12:16:14.052402\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import datetime\n", | |
| "\n", | |
| "print('Last updated: ' + str(datetime.datetime.now()))" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.5.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment