Created
September 1, 2017 10:04
-
-
Save zygm0nt/473770c2e19f1a61e8658a095427fdd3 to your computer and use it in GitHub Desktop.
ALS + XGBoost tutorial
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "raw", | |
"metadata": {}, | |
"source": [ | |
"## This iPython Notebook is a slightly modified version of the tutorial by Jesse Steinweg-Woods available at https://jessesw.com/Rec-System/ " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"import scipy.sparse as sparse\n", | |
"import numpy as np\n", | |
"from scipy.sparse.linalg import spsolve\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#website_url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/00352/Online%20Retail.xlsx'\n", | |
"retail_data = pd.read_excel('Online Retail.xlsx') # This may take a couple minutes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>InvoiceNo</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Description</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>InvoiceDate</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>Country</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>536365</td>\n", | |
" <td>85123A</td>\n", | |
" <td>WHITE HANGING HEART T-LIGHT HOLDER</td>\n", | |
" <td>6</td>\n", | |
" <td>2010-12-01 08:26:00</td>\n", | |
" <td>2.55</td>\n", | |
" <td>17850.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>536365</td>\n", | |
" <td>71053</td>\n", | |
" <td>WHITE METAL LANTERN</td>\n", | |
" <td>6</td>\n", | |
" <td>2010-12-01 08:26:00</td>\n", | |
" <td>3.39</td>\n", | |
" <td>17850.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>536365</td>\n", | |
" <td>84406B</td>\n", | |
" <td>CREAM CUPID HEARTS COAT HANGER</td>\n", | |
" <td>8</td>\n", | |
" <td>2010-12-01 08:26:00</td>\n", | |
" <td>2.75</td>\n", | |
" <td>17850.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>536365</td>\n", | |
" <td>84029G</td>\n", | |
" <td>KNITTED UNION FLAG HOT WATER BOTTLE</td>\n", | |
" <td>6</td>\n", | |
" <td>2010-12-01 08:26:00</td>\n", | |
" <td>3.39</td>\n", | |
" <td>17850.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>536365</td>\n", | |
" <td>84029E</td>\n", | |
" <td>RED WOOLLY HOTTIE WHITE HEART.</td>\n", | |
" <td>6</td>\n", | |
" <td>2010-12-01 08:26:00</td>\n", | |
" <td>3.39</td>\n", | |
" <td>17850.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" InvoiceNo StockCode Description Quantity \\\n", | |
"0 536365 85123A WHITE HANGING HEART T-LIGHT HOLDER 6 \n", | |
"1 536365 71053 WHITE METAL LANTERN 6 \n", | |
"2 536365 84406B CREAM CUPID HEARTS COAT HANGER 8 \n", | |
"3 536365 84029G KNITTED UNION FLAG HOT WATER BOTTLE 6 \n", | |
"4 536365 84029E RED WOOLLY HOTTIE WHITE HEART. 6 \n", | |
"\n", | |
" InvoiceDate UnitPrice CustomerID Country \n", | |
"0 2010-12-01 08:26:00 2.55 17850.0 United Kingdom \n", | |
"1 2010-12-01 08:26:00 3.39 17850.0 United Kingdom \n", | |
"2 2010-12-01 08:26:00 2.75 17850.0 United Kingdom \n", | |
"3 2010-12-01 08:26:00 3.39 17850.0 United Kingdom \n", | |
"4 2010-12-01 08:26:00 3.39 17850.0 United Kingdom " | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"retail_data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'pandas.core.frame.DataFrame'>\n", | |
"RangeIndex: 541909 entries, 0 to 541908\n", | |
"Data columns (total 8 columns):\n", | |
"InvoiceNo 541909 non-null object\n", | |
"StockCode 541909 non-null object\n", | |
"Description 540455 non-null object\n", | |
"Quantity 541909 non-null int64\n", | |
"InvoiceDate 541909 non-null datetime64[ns]\n", | |
"UnitPrice 541909 non-null float64\n", | |
"CustomerID 406829 non-null float64\n", | |
"Country 541909 non-null object\n", | |
"dtypes: datetime64[ns](1), float64(2), int64(1), object(4)\n", | |
"memory usage: 33.1+ MB\n" | |
] | |
} | |
], | |
"source": [ | |
"retail_data.info()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"cleaned_retail = retail_data.loc[pd.isnull(retail_data.CustomerID) == False]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'pandas.core.frame.DataFrame'>\n", | |
"Int64Index: 406829 entries, 0 to 541908\n", | |
"Data columns (total 8 columns):\n", | |
"InvoiceNo 406829 non-null object\n", | |
"StockCode 406829 non-null object\n", | |
"Description 406829 non-null object\n", | |
"Quantity 406829 non-null int64\n", | |
"InvoiceDate 406829 non-null datetime64[ns]\n", | |
"UnitPrice 406829 non-null float64\n", | |
"CustomerID 406829 non-null float64\n", | |
"Country 406829 non-null object\n", | |
"dtypes: datetime64[ns](1), float64(2), int64(1), object(4)\n", | |
"memory usage: 27.9+ MB\n" | |
] | |
} | |
], | |
"source": [ | |
"cleaned_retail.info()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"item_lookup = cleaned_retail[['StockCode', 'Description']].drop_duplicates() # Only get unique item/description pairs\n", | |
"item_lookup['StockCode'] = item_lookup.StockCode.astype(str) # Encode as strings for future lookup ease" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Description</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>85123A</td>\n", | |
" <td>WHITE HANGING HEART T-LIGHT HOLDER</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>71053</td>\n", | |
" <td>WHITE METAL LANTERN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>84406B</td>\n", | |
" <td>CREAM CUPID HEARTS COAT HANGER</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>84029G</td>\n", | |
" <td>KNITTED UNION FLAG HOT WATER BOTTLE</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>84029E</td>\n", | |
" <td>RED WOOLLY HOTTIE WHITE HEART.</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" StockCode Description\n", | |
"0 85123A WHITE HANGING HEART T-LIGHT HOLDER\n", | |
"1 71053 WHITE METAL LANTERN\n", | |
"2 84406B CREAM CUPID HEARTS COAT HANGER\n", | |
"3 84029G KNITTED UNION FLAG HOT WATER BOTTLE\n", | |
"4 84029E RED WOOLLY HOTTIE WHITE HEART." | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"item_lookup.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/ipykernel_launcher.py:1: SettingWithCopyWarning: \n", | |
"A value is trying to be set on a copy of a slice from a DataFrame\n", | |
"\n", | |
"See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", | |
" \"\"\"Entry point for launching an IPython kernel.\n" | |
] | |
} | |
], | |
"source": [ | |
"cleaned_retail.loc['CustomerID'] = cleaned_retail.CustomerID.astype(int) # Convert to int for customer ID" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>InvoiceNo</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Description</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>InvoiceDate</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>Country</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>51024</th>\n", | |
" <td>C540634</td>\n", | |
" <td>85232B</td>\n", | |
" <td>SET OF 3 BABUSHKA STACKING TINS</td>\n", | |
" <td>-3.0</td>\n", | |
" <td>2011-01-10 12:02:00</td>\n", | |
" <td>4.95</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>51025</th>\n", | |
" <td>C540634</td>\n", | |
" <td>21655</td>\n", | |
" <td>HANGING RIDGE GLASS T-LIGHT HOLDER</td>\n", | |
" <td>-12.0</td>\n", | |
" <td>2011-01-10 12:02:00</td>\n", | |
" <td>1.69</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>51026</th>\n", | |
" <td>C540634</td>\n", | |
" <td>79323P</td>\n", | |
" <td>PINK CHERRY LIGHTS</td>\n", | |
" <td>-4.0</td>\n", | |
" <td>2011-01-10 12:02:00</td>\n", | |
" <td>6.75</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>53203</th>\n", | |
" <td>540825</td>\n", | |
" <td>79323P</td>\n", | |
" <td>PINK CHERRY LIGHTS</td>\n", | |
" <td>4.0</td>\n", | |
" <td>2011-01-11 13:54:00</td>\n", | |
" <td>6.75</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>53204</th>\n", | |
" <td>540825</td>\n", | |
" <td>21655</td>\n", | |
" <td>HANGING RIDGE GLASS T-LIGHT HOLDER</td>\n", | |
" <td>12.0</td>\n", | |
" <td>2011-01-11 13:54:00</td>\n", | |
" <td>1.69</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>53205</th>\n", | |
" <td>540825</td>\n", | |
" <td>85232B</td>\n", | |
" <td>SET OF 3 BABUSHKA STACKING TINS</td>\n", | |
" <td>3.0</td>\n", | |
" <td>2011-01-11 13:54:00</td>\n", | |
" <td>4.95</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88001</th>\n", | |
" <td>C543744</td>\n", | |
" <td>85232B</td>\n", | |
" <td>SET OF 3 BABUSHKA STACKING TINS</td>\n", | |
" <td>-3.0</td>\n", | |
" <td>2011-02-11 13:43:00</td>\n", | |
" <td>4.95</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88002</th>\n", | |
" <td>C543744</td>\n", | |
" <td>21655</td>\n", | |
" <td>HANGING RIDGE GLASS T-LIGHT HOLDER</td>\n", | |
" <td>-12.0</td>\n", | |
" <td>2011-02-11 13:43:00</td>\n", | |
" <td>1.69</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88003</th>\n", | |
" <td>C543744</td>\n", | |
" <td>79323P</td>\n", | |
" <td>PINK CHERRY LIGHTS</td>\n", | |
" <td>-4.0</td>\n", | |
" <td>2011-02-11 13:43:00</td>\n", | |
" <td>6.75</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88004</th>\n", | |
" <td>C543745</td>\n", | |
" <td>79323B</td>\n", | |
" <td>BLACK CHERRY LIGHTS</td>\n", | |
" <td>-4.0</td>\n", | |
" <td>2011-02-11 13:46:00</td>\n", | |
" <td>6.75</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88005</th>\n", | |
" <td>C543745</td>\n", | |
" <td>85232B</td>\n", | |
" <td>SET OF 3 BABUSHKA STACKING TINS</td>\n", | |
" <td>-3.0</td>\n", | |
" <td>2011-02-11 13:46:00</td>\n", | |
" <td>4.95</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88006</th>\n", | |
" <td>C543745</td>\n", | |
" <td>21655</td>\n", | |
" <td>HANGING RIDGE GLASS T-LIGHT HOLDER</td>\n", | |
" <td>-12.0</td>\n", | |
" <td>2011-02-11 13:46:00</td>\n", | |
" <td>1.69</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88018</th>\n", | |
" <td>543747</td>\n", | |
" <td>79323B</td>\n", | |
" <td>BLACK CHERRY LIGHTS</td>\n", | |
" <td>8.0</td>\n", | |
" <td>2011-02-11 13:53:00</td>\n", | |
" <td>6.75</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88019</th>\n", | |
" <td>543747</td>\n", | |
" <td>85232B</td>\n", | |
" <td>SET OF 3 BABUSHKA STACKING TINS</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2011-02-11 13:53:00</td>\n", | |
" <td>4.95</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88020</th>\n", | |
" <td>543747</td>\n", | |
" <td>21655</td>\n", | |
" <td>HANGING RIDGE GLASS T-LIGHT HOLDER</td>\n", | |
" <td>24.0</td>\n", | |
" <td>2011-02-11 13:53:00</td>\n", | |
" <td>1.69</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88028</th>\n", | |
" <td>C543749</td>\n", | |
" <td>M</td>\n", | |
" <td>Manual</td>\n", | |
" <td>-1.0</td>\n", | |
" <td>2011-02-11 13:55:00</td>\n", | |
" <td>71.46</td>\n", | |
" <td>13672.0</td>\n", | |
" <td>United Kingdom</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" InvoiceNo StockCode Description Quantity \\\n", | |
"51024 C540634 85232B SET OF 3 BABUSHKA STACKING TINS -3.0 \n", | |
"51025 C540634 21655 HANGING RIDGE GLASS T-LIGHT HOLDER -12.0 \n", | |
"51026 C540634 79323P PINK CHERRY LIGHTS -4.0 \n", | |
"53203 540825 79323P PINK CHERRY LIGHTS 4.0 \n", | |
"53204 540825 21655 HANGING RIDGE GLASS T-LIGHT HOLDER 12.0 \n", | |
"53205 540825 85232B SET OF 3 BABUSHKA STACKING TINS 3.0 \n", | |
"88001 C543744 85232B SET OF 3 BABUSHKA STACKING TINS -3.0 \n", | |
"88002 C543744 21655 HANGING RIDGE GLASS T-LIGHT HOLDER -12.0 \n", | |
"88003 C543744 79323P PINK CHERRY LIGHTS -4.0 \n", | |
"88004 C543745 79323B BLACK CHERRY LIGHTS -4.0 \n", | |
"88005 C543745 85232B SET OF 3 BABUSHKA STACKING TINS -3.0 \n", | |
"88006 C543745 21655 HANGING RIDGE GLASS T-LIGHT HOLDER -12.0 \n", | |
"88018 543747 79323B BLACK CHERRY LIGHTS 8.0 \n", | |
"88019 543747 85232B SET OF 3 BABUSHKA STACKING TINS 6.0 \n", | |
"88020 543747 21655 HANGING RIDGE GLASS T-LIGHT HOLDER 24.0 \n", | |
"88028 C543749 M Manual -1.0 \n", | |
"\n", | |
" InvoiceDate UnitPrice CustomerID Country \n", | |
"51024 2011-01-10 12:02:00 4.95 13672.0 United Kingdom \n", | |
"51025 2011-01-10 12:02:00 1.69 13672.0 United Kingdom \n", | |
"51026 2011-01-10 12:02:00 6.75 13672.0 United Kingdom \n", | |
"53203 2011-01-11 13:54:00 6.75 13672.0 United Kingdom \n", | |
"53204 2011-01-11 13:54:00 1.69 13672.0 United Kingdom \n", | |
"53205 2011-01-11 13:54:00 4.95 13672.0 United Kingdom \n", | |
"88001 2011-02-11 13:43:00 4.95 13672.0 United Kingdom \n", | |
"88002 2011-02-11 13:43:00 1.69 13672.0 United Kingdom \n", | |
"88003 2011-02-11 13:43:00 6.75 13672.0 United Kingdom \n", | |
"88004 2011-02-11 13:46:00 6.75 13672.0 United Kingdom \n", | |
"88005 2011-02-11 13:46:00 4.95 13672.0 United Kingdom \n", | |
"88006 2011-02-11 13:46:00 1.69 13672.0 United Kingdom \n", | |
"88018 2011-02-11 13:53:00 6.75 13672.0 United Kingdom \n", | |
"88019 2011-02-11 13:53:00 4.95 13672.0 United Kingdom \n", | |
"88020 2011-02-11 13:53:00 1.69 13672.0 United Kingdom \n", | |
"88028 2011-02-11 13:55:00 71.46 13672.0 United Kingdom " | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"cleaned_retail.query('CustomerID==13672')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"grouped_cleaned = cleaned_retail.groupby(['CustomerID', 'StockCode']).sum().reset_index() # Group together" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#we will use this later\n", | |
"xgboost_data = cleaned_retail.groupby(['CustomerID', 'StockCode','Country']).sum().reset_index() # Group together\n", | |
"xgboost_data.loc[xgboost_data.Quantity==0,'Quantity'] = 1 # Replace a sum of zero purchases with a one to\n", | |
"# indicate purchased\n", | |
"xgboost_data = xgboost_data.query('Quantity > 0') # Only get customers where purchase totals were positive" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>12346.0</td>\n", | |
" <td>23166</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.08</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>16008</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.25</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>17021</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>36.0</td>\n", | |
" <td>0.30</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20665</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.95</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20719</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>40.0</td>\n", | |
" <td>3.40</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice\n", | |
"0 12346.0 23166 United Kingdom 1.0 2.08\n", | |
"1 12347.0 16008 Iceland 24.0 0.25\n", | |
"2 12347.0 17021 Iceland 36.0 0.30\n", | |
"3 12347.0 20665 Iceland 6.0 2.95\n", | |
"4 12347.0 20719 Iceland 40.0 3.40" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" CustomerID StockCode Country Quantity UnitPrice\n", | |
"59631 13672.0 21655 United Kingdom 1.0 8.45\n", | |
"59632 13672.0 79323B United Kingdom 4.0 13.50\n", | |
"59634 13672.0 85232B United Kingdom 1.0 24.75\n" | |
] | |
} | |
], | |
"source": [ | |
"print(xgboost_data.query('CustomerID==13672'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>12346.0</td>\n", | |
" <td>23166</td>\n", | |
" <td>0.0</td>\n", | |
" <td>2.08</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>16008</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.25</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>17021</td>\n", | |
" <td>36.0</td>\n", | |
" <td>0.30</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20665</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.95</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20719</td>\n", | |
" <td>40.0</td>\n", | |
" <td>3.40</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Quantity UnitPrice\n", | |
"0 12346.0 23166 0.0 2.08\n", | |
"1 12347.0 16008 24.0 0.25\n", | |
"2 12347.0 17021 36.0 0.30\n", | |
"3 12347.0 20665 6.0 2.95\n", | |
"4 12347.0 20719 40.0 3.40" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grouped_cleaned.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"grouped_cleaned.loc[grouped_cleaned.Quantity==0,'Quantity'] = 1 # Replace a sum of zero purchases with a one to\n", | |
"# indicate purchased" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"grouped_purchased = grouped_cleaned.query('Quantity > 0') # Only get customers where purchase totals were positive" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>12346.0</td>\n", | |
" <td>23166</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.08</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>16008</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.25</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>17021</td>\n", | |
" <td>36.0</td>\n", | |
" <td>0.30</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20665</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.95</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20719</td>\n", | |
" <td>40.0</td>\n", | |
" <td>3.40</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Quantity UnitPrice\n", | |
"0 12346.0 23166 1.0 2.08\n", | |
"1 12347.0 16008 24.0 0.25\n", | |
"2 12347.0 17021 36.0 0.30\n", | |
"3 12347.0 20665 6.0 2.95\n", | |
"4 12347.0 20719 40.0 3.40" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grouped_purchased.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>12346.0</td>\n", | |
" <td>23166</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.08</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>16008</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.25</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>17021</td>\n", | |
" <td>36.0</td>\n", | |
" <td>0.30</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20665</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.95</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20719</td>\n", | |
" <td>40.0</td>\n", | |
" <td>3.40</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Quantity UnitPrice\n", | |
"0 12346.0 23166 1.0 2.08\n", | |
"1 12347.0 16008 24.0 0.25\n", | |
"2 12347.0 17021 36.0 0.30\n", | |
"3 12347.0 20665 6.0 2.95\n", | |
"4 12347.0 20719 40.0 3.40" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grouped_cleaned.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"grouped_purchased = grouped_purchased[['StockCode', 'Quantity', 'CustomerID']] # Get rid of unnecessary info" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>CustomerID</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>59621</th>\n", | |
" <td>21655</td>\n", | |
" <td>1.0</td>\n", | |
" <td>13672.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>59622</th>\n", | |
" <td>79323B</td>\n", | |
" <td>4.0</td>\n", | |
" <td>13672.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>59624</th>\n", | |
" <td>85232B</td>\n", | |
" <td>1.0</td>\n", | |
" <td>13672.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" StockCode Quantity CustomerID\n", | |
"59621 21655 1.0 13672.0\n", | |
"59622 79323B 4.0 13672.0\n", | |
"59624 85232B 1.0 13672.0" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grouped_purchased.query('CustomerID==13672')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"customers = list(np.sort(grouped_purchased.CustomerID.unique())) # Get our unique customers\n", | |
"products = list(grouped_purchased.StockCode.unique()) # Get our unique products that were purchased\n", | |
"quantity = list(grouped_purchased.Quantity) # All of our purchases\n", | |
"\n", | |
"rows = grouped_purchased.CustomerID.astype('category', categories = customers).cat.codes \n", | |
"# Get the associated row indices\n", | |
"cols = grouped_purchased.StockCode.astype('category', categories = products).cat.codes \n", | |
"# Get the associated column indices\n", | |
"purchases_sparse = sparse.csr_matrix((quantity, (rows, cols)), shape=(len(customers), len(products)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<4338x3664 sparse matrix of type '<class 'numpy.float64'>'\n", | |
"\twith 266723 stored elements in Compressed Sparse Row format>" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"purchases_sparse" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"98.32190920694744" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"matrix_size = purchases_sparse.shape[0]*purchases_sparse.shape[1] # Number of possible interactions in the matrix\n", | |
"num_purchases = len(purchases_sparse.nonzero()[0]) # Number of items interacted with\n", | |
"sparsity = 100*(1 - (num_purchases/matrix_size))\n", | |
"sparsity" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def make_train(purchases, pct_test = 0.2):\n", | |
" '''\n", | |
" This function will take in the original user-item matrix and \"mask\" a percentage of the original ratings where a\n", | |
" user-item interaction has taken place for use as a test set. The test set will contain all of the original ratings, \n", | |
" while the training set replaces the specified percentage of them with a zero in the original ratings matrix. \n", | |
" \n", | |
" parameters: \n", | |
" \n", | |
" purchases - the original purchases matrix from which you want to generate a train/test set. Test is just a complete\n", | |
" copy of the original set. This is in the form of a sparse csr_matrix. \n", | |
" \n", | |
" pct_test - The percentage of user-item interactions where an interaction took place that you want to mask in the \n", | |
" training set for later comparison to the test set, which contains all of the original ratings. \n", | |
" \n", | |
" returns:\n", | |
" \n", | |
" training_set - The altered version of the original data with a certain percentage of the user-item pairs \n", | |
" that originally had interaction set back to zero.\n", | |
" \n", | |
" test_set - A copy of the original ratings matrix, unaltered, so it can be used to see how the rank order \n", | |
" compares with the actual interactions.\n", | |
" \n", | |
" user_inds - From the randomly selected user-item indices, which user rows were altered in the training data.\n", | |
" This will be necessary later when evaluating the performance via AUC.\n", | |
" '''\n", | |
" test_set = purchases.copy() # Make a copy of the original set to be the test set. \n", | |
" test_set[test_set != 0] = 1 # Store the test set as a binary preference matrix\n", | |
" training_set = purchases.copy() # Make a copy of the original data we can alter as our training set. \n", | |
" nonzero_inds = training_set.nonzero() # Find the indices in the ratings data where an interaction exists\n", | |
" nonzero_pairs = list(zip(nonzero_inds[0], nonzero_inds[1])) # Zip these pairs together of user,item index into list\n", | |
" random.seed(0) # Set the random seed to zero for reproducibility\n", | |
" num_samples = int(np.ceil(pct_test*len(nonzero_pairs))) # Round the number of samples needed to the nearest integer\n", | |
" samples = random.sample(nonzero_pairs, num_samples) # Sample a random number of user-item pairs without replacement\n", | |
" user_inds = [index[0] for index in samples] # Get the user row indices\n", | |
" item_inds = [index[1] for index in samples] # Get the item column indices\n", | |
" training_set[user_inds, item_inds] = 0 # Assign all of the randomly chosen user-item pairs to zero\n", | |
" training_set.eliminate_zeros() # Get rid of zeros in sparse array storage after update to save space\n", | |
" return training_set, test_set, list(set(user_inds)) # Output the unique list of user rows that were altered " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"product_train, product_test, product_users_altered = make_train(purchases_sparse, pct_test = 0.2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(4042,)" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.shape(product_users_altered)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import implicit\n", | |
"from implicit import alternating_least_squares" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"env: OPENBLAS_NUM_THREADS=1\n", | |
"item_user_data\n", | |
"(3664, 4338)\n" | |
] | |
} | |
], | |
"source": [ | |
"%env OPENBLAS_NUM_THREADS=1\n", | |
"\n", | |
"alpha = 15\n", | |
"#user_vecs, item_vecs = alternating_least_squares((product_train*alpha).astype('double'), \n", | |
" #factors=20, \n", | |
" #regularization = 0.1, \n", | |
" #iterations = 50)\n", | |
"\n", | |
"#self.user_factors[userid]\n", | |
" \n", | |
"# initialize a model\n", | |
"model = implicit.als.AlternatingLeastSquares(factors=20)\n", | |
"\n", | |
"#sparse matrix of item, user with alpha multiplication\n", | |
"item_user_data = (product_train*alpha).astype('double').T\n", | |
"print(\"item_user_data\")\n", | |
"print(np.shape(item_user_data))\n", | |
"\n", | |
"# train the model on a sparse matrix of item/user/confidence weights\n", | |
"model.fit(item_user_data)\n", | |
"\n", | |
"\n", | |
"# recommend items for a user\n", | |
"userid = 1\n", | |
"user_items = item_user_data.T.tocsr()\n", | |
"recommendations = model.recommend(userid, user_items)\n", | |
"user_vecs = model.user_factors\n", | |
"item_vecs = model.item_factors" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn import metrics" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def auc_score(predictions, test):\n", | |
" '''\n", | |
" This simple function will output the area under the curve using sklearn's metrics. \n", | |
" \n", | |
" parameters:\n", | |
" \n", | |
" - predictions: your prediction output\n", | |
" \n", | |
" - test: the actual target result you are comparing to\n", | |
" \n", | |
" returns:\n", | |
" \n", | |
" - AUC (area under the Receiver Operating Characterisic curve)\n", | |
" '''\n", | |
" fpr, tpr, thresholds = metrics.roc_curve(test, predictions)\n", | |
" return metrics.auc(fpr, tpr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def calc_mean_auc(training_set, altered_users, predictions, test_set):\n", | |
" '''\n", | |
" This function will calculate the mean AUC by user for any user that had their user-item matrix altered. \n", | |
" \n", | |
" parameters:\n", | |
" \n", | |
" training_set - The training set resulting from make_train, where a certain percentage of the original\n", | |
" user/item interactions are reset to zero to hide them from the model \n", | |
" \n", | |
" predictions - The matrix of your predicted ratings for each user/item pair as output from the implicit MF.\n", | |
" These should be stored in a list, with user vectors as item zero and item vectors as item one. \n", | |
" \n", | |
" altered_users - The indices of the users where at least one user/item pair was altered from make_train function\n", | |
" \n", | |
" test_set - The test set constucted earlier from make_train function\n", | |
" \n", | |
" \n", | |
" \n", | |
" returns:\n", | |
" \n", | |
" The mean AUC (area under the Receiver Operator Characteristic curve) of the test set only on user-item interactions\n", | |
" there were originally zero to test ranking ability in addition to the most popular items as a benchmark.\n", | |
" '''\n", | |
" \n", | |
" \n", | |
" store_auc = [] # An empty list to store the AUC for each user that had an item removed from the training set\n", | |
" popularity_auc = [] # To store popular AUC scores\n", | |
" pop_items = np.array(test_set.sum(axis = 0)).reshape(-1) # Get sum of item iteractions to find most popular\n", | |
" item_vecs = predictions[1]\n", | |
" for user in altered_users: # Iterate through each user that had an item altered\n", | |
" training_row = training_set[user,:].toarray().reshape(-1) # Get the training set row\n", | |
" zero_inds = np.where(training_row == 0) # Find where the interaction had not yet occurred\n", | |
" # Get the predicted values based on our user/item vectors\n", | |
" user_vec = predictions[0][user,:]\n", | |
" pred = user_vec.dot(item_vecs).toarray()[0,zero_inds].reshape(-1)\n", | |
" # Get only the items that were originally zero\n", | |
" # Select all ratings from the MF prediction for this user that originally had no iteraction\n", | |
" actual = test_set[user,:].toarray()[0,zero_inds].reshape(-1) \n", | |
" # Select the binarized yes/no interaction pairs from the original full data\n", | |
" # that align with the same pairs in training \n", | |
" pop = pop_items[zero_inds] # Get the item popularity for our chosen items\n", | |
" store_auc.append(auc_score(pred, actual)) # Calculate AUC for the given user and store\n", | |
" popularity_auc.append(auc_score(pop, actual)) # Calculate AUC using most popular and score\n", | |
" # End users iteration\n", | |
" \n", | |
" return float('%.3f'%np.mean(store_auc)), float('%.3f'%np.mean(popularity_auc)) \n", | |
" # Return the mean AUC rounded to three decimal places for both test and popularity benchmark" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(4042,)" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.shape(product_users_altered)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.868, 0.814)" | |
] | |
}, | |
"execution_count": 36, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"calc_mean_auc(product_train, product_users_altered, \n", | |
" [sparse.csr_matrix(user_vecs), sparse.csr_matrix(item_vecs.T)], product_test)\n", | |
"# AUC for our recommender system" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"customers_arr = np.array(customers) # Array of customer IDs from the ratings matrix\n", | |
"products_arr = np.array(products) # Array of product IDs from the ratings matrix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(['23166', '16008', '17021', ..., '84613C', '84206B', '21414'],\n", | |
" dtype='<U21')" | |
] | |
}, | |
"execution_count": 38, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"products_arr" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def get_items_purchased(customer_id, mf_train, customers_list, products_list, item_lookup):\n", | |
" '''\n", | |
" This just tells me which items have been already purchased by a specific user in the training set. \n", | |
" \n", | |
" parameters: \n", | |
" \n", | |
" customer_id - Input the customer's id number that you want to see prior purchases of at least once\n", | |
" \n", | |
" mf_train - The initial ratings training set used (without weights applied)\n", | |
" \n", | |
" customers_list - The array of customers used in the ratings matrix\n", | |
" \n", | |
" products_list - The array of products used in the ratings matrix\n", | |
" \n", | |
" item_lookup - A simple pandas dataframe of the unique product ID/product descriptions available\n", | |
" \n", | |
" returns:\n", | |
" \n", | |
" A list of item IDs and item descriptions for a particular customer that were already purchased in the training set\n", | |
" '''\n", | |
" cust_ind = np.where(customers_list == customer_id)[0][0] # Returns the index row of our customer id\n", | |
" purchased_ind = mf_train[cust_ind,:].nonzero()[1] # Get column indices of purchased items\n", | |
" prod_codes = products_list[purchased_ind] # Get the stock codes for our purchased items\n", | |
" return item_lookup.loc[item_lookup.StockCode.isin(prod_codes)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 12346., 12347., 12348., 12349., 12350.])" | |
] | |
}, | |
"execution_count": 40, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"customers_arr[:5]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Description</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>61619</th>\n", | |
" <td>23166</td>\n", | |
" <td>MEDIUM CERAMIC TOP STORAGE JAR</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" StockCode Description\n", | |
"61619 23166 MEDIUM CERAMIC TOP STORAGE JAR" | |
] | |
}, | |
"execution_count": 41, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_items_purchased(12346, product_train, customers_arr, products_arr, item_lookup)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.preprocessing import MinMaxScaler" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def rec_items(customer_id, mf_train, user_vecs, item_vecs, customer_list, item_list, item_lookup, num_items = 10):\n", | |
" '''\n", | |
" This function will return the top recommended items to our users \n", | |
" \n", | |
" parameters:\n", | |
" \n", | |
" customer_id - Input the customer's id number that you want to get recommendations for\n", | |
" \n", | |
" mf_train - The training matrix you used for matrix factorization fitting\n", | |
" \n", | |
" user_vecs - the user vectors from your fitted matrix factorization\n", | |
" \n", | |
" item_vecs - the item vectors from your fitted matrix factorization\n", | |
" \n", | |
" customer_list - an array of the customer's ID numbers that make up the rows of your ratings matrix \n", | |
" (in order of matrix)\n", | |
" \n", | |
" item_list - an array of the products that make up the columns of your ratings matrix\n", | |
" (in order of matrix)\n", | |
" \n", | |
" item_lookup - A simple pandas dataframe of the unique product ID/product descriptions available\n", | |
" \n", | |
" num_items - The number of items you want to recommend in order of best recommendations. Default is 10. \n", | |
" \n", | |
" returns:\n", | |
" \n", | |
" - The top n recommendations chosen based on the user/item vectors for items never interacted with/purchased\n", | |
" '''\n", | |
" \n", | |
" cust_ind = np.where(customer_list == customer_id)[0][0] # Returns the index row of our customer id\n", | |
" pref_vec = mf_train[cust_ind,:].toarray() # Get the ratings from the training set ratings matrix\n", | |
" pref_vec = pref_vec.reshape(-1) + 1 # Add 1 to everything, so that items not purchased yet become equal to 1\n", | |
" pref_vec[pref_vec > 1] = 0 # Make everything already purchased zero\n", | |
" rec_vector = user_vecs[cust_ind,:].dot(item_vecs.T) # Get dot product of user vector and all item vectors\n", | |
" # Scale this recommendation vector between 0 and 1\n", | |
" min_max = MinMaxScaler()\n", | |
" rec_vector_scaled = min_max.fit_transform(rec_vector.reshape(-1,1))[:,0] \n", | |
" recommend_vector = pref_vec*rec_vector_scaled \n", | |
" # Items already purchased have their recommendation multiplied by zero\n", | |
" product_idx = np.argsort(recommend_vector)[::-1][:num_items] # Sort the indices of the items into order \n", | |
" # of best recommendations\n", | |
" rec_list = [] # start empty list to store items\n", | |
" for index in product_idx:\n", | |
" code = item_list[index]\n", | |
" rec_list.append([code, item_lookup.Description.loc[item_lookup.StockCode == code].iloc[0]]) \n", | |
" # Append our descriptions to the list\n", | |
" codes = [item[0] for item in rec_list]\n", | |
" descriptions = [item[1] for item in rec_list]\n", | |
" final_frame = pd.DataFrame({'StockCode': codes, 'Description': descriptions}) # Create a dataframe \n", | |
" return final_frame[['StockCode', 'Description']] # Switch order of columns around\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Description</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>23167</td>\n", | |
" <td>SMALL CERAMIC TOP STORAGE JAR</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>23165</td>\n", | |
" <td>LARGE CERAMIC TOP STORAGE JAR</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>22963</td>\n", | |
" <td>JAM JAR WITH GREEN LID</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>22950</td>\n", | |
" <td>36 DOILIES VINTAGE CHRISTMAS</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>23340</td>\n", | |
" <td>VINTAGE CHRISTMAS CAKE FRILL</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>23247</td>\n", | |
" <td>BISCUIT TIN 50'S CHRISTMAS</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>22978</td>\n", | |
" <td>PANTRY ROLLING PIN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>22962</td>\n", | |
" <td>JAM JAR WITH PINK LID</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>22918</td>\n", | |
" <td>HERB MARKER PARSLEY</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>23294</td>\n", | |
" <td>SET OF 6 SNACK LOAF BAKING CASES</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" StockCode Description\n", | |
"0 23167 SMALL CERAMIC TOP STORAGE JAR \n", | |
"1 23165 LARGE CERAMIC TOP STORAGE JAR\n", | |
"2 22963 JAM JAR WITH GREEN LID\n", | |
"3 22950 36 DOILIES VINTAGE CHRISTMAS\n", | |
"4 23340 VINTAGE CHRISTMAS CAKE FRILL\n", | |
"5 23247 BISCUIT TIN 50'S CHRISTMAS\n", | |
"6 22978 PANTRY ROLLING PIN\n", | |
"7 22962 JAM JAR WITH PINK LID\n", | |
"8 22918 HERB MARKER PARSLEY\n", | |
"9 23294 SET OF 6 SNACK LOAF BAKING CASES" | |
] | |
}, | |
"execution_count": 44, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"rec_items(12346, product_train, user_vecs, item_vecs, customers_arr, products_arr, item_lookup,\n", | |
" num_items = 10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Description</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>2148</th>\n", | |
" <td>37446</td>\n", | |
" <td>MINI CAKE STAND WITH HANGING CAKES</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2149</th>\n", | |
" <td>37449</td>\n", | |
" <td>CERAMIC CAKE STAND + HANGING CAKES</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4859</th>\n", | |
" <td>37450</td>\n", | |
" <td>CERAMIC CAKE BOWL + HANGING CAKES</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5108</th>\n", | |
" <td>22890</td>\n", | |
" <td>NOVELTY BISCUITS CAKE STAND 3 TIER</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" StockCode Description\n", | |
"2148 37446 MINI CAKE STAND WITH HANGING CAKES\n", | |
"2149 37449 CERAMIC CAKE STAND + HANGING CAKES\n", | |
"4859 37450 CERAMIC CAKE BOWL + HANGING CAKES\n", | |
"5108 22890 NOVELTY BISCUITS CAKE STAND 3 TIER" | |
] | |
}, | |
"execution_count": 45, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_items_purchased(12353, product_train, customers_arr, products_arr, item_lookup)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Description</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>22055</td>\n", | |
" <td>MINI CAKE STAND HANGING STRAWBERY</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>37447</td>\n", | |
" <td>CERAMIC CAKE DESIGN SPOTTED PLATE</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>22063</td>\n", | |
" <td>CERAMIC BOWL WITH STRAWBERRY DESIGN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>22057</td>\n", | |
" <td>CERAMIC PLATE STRAWBERRY DESIGN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>37448</td>\n", | |
" <td>CERAMIC CAKE DESIGN SPOTTED MUG</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>22059</td>\n", | |
" <td>CERAMIC STRAWBERRY DESIGN MUG</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>22649</td>\n", | |
" <td>STRAWBERRY FAIRY CAKE TEAPOT</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>22061</td>\n", | |
" <td>LARGE CAKE STAND HANGING STRAWBERY</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>22644</td>\n", | |
" <td>CERAMIC CHERRY CAKE MONEY BANK</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>22645</td>\n", | |
" <td>CERAMIC HEART FAIRY CAKE MONEY BANK</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" StockCode Description\n", | |
"0 22055 MINI CAKE STAND HANGING STRAWBERY\n", | |
"1 37447 CERAMIC CAKE DESIGN SPOTTED PLATE\n", | |
"2 22063 CERAMIC BOWL WITH STRAWBERRY DESIGN\n", | |
"3 22057 CERAMIC PLATE STRAWBERRY DESIGN\n", | |
"4 37448 CERAMIC CAKE DESIGN SPOTTED MUG\n", | |
"5 22059 CERAMIC STRAWBERRY DESIGN MUG\n", | |
"6 22649 STRAWBERRY FAIRY CAKE TEAPOT\n", | |
"7 22061 LARGE CAKE STAND HANGING STRAWBERY\n", | |
"8 22644 CERAMIC CHERRY CAKE MONEY BANK\n", | |
"9 22645 CERAMIC HEART FAIRY CAKE MONEY BANK" | |
] | |
}, | |
"execution_count": 46, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"rec_items(12353, product_train, user_vecs, item_vecs, customers_arr, products_arr, item_lookup,\n", | |
" num_items = 10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Description</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>34</th>\n", | |
" <td>22326</td>\n", | |
" <td>ROUND SNACK BOXES SET OF4 WOODLAND</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>35</th>\n", | |
" <td>22629</td>\n", | |
" <td>SPACEBOY LUNCH BOX</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>37</th>\n", | |
" <td>22631</td>\n", | |
" <td>CIRCUS PARADE LUNCH BOX</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>93</th>\n", | |
" <td>20725</td>\n", | |
" <td>LUNCH BAG RED RETROSPOT</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>369</th>\n", | |
" <td>22382</td>\n", | |
" <td>LUNCH BAG SPACEBOY DESIGN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>547</th>\n", | |
" <td>22328</td>\n", | |
" <td>ROUND SNACK BOXES SET OF 4 FRUITS</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>549</th>\n", | |
" <td>22630</td>\n", | |
" <td>DOLLY GIRL LUNCH BOX</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1241</th>\n", | |
" <td>22555</td>\n", | |
" <td>PLASTERS IN TIN STRONGMAN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>58132</th>\n", | |
" <td>20725</td>\n", | |
" <td>LUNCH BAG RED SPOTTY</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" StockCode Description\n", | |
"34 22326 ROUND SNACK BOXES SET OF4 WOODLAND \n", | |
"35 22629 SPACEBOY LUNCH BOX \n", | |
"37 22631 CIRCUS PARADE LUNCH BOX \n", | |
"93 20725 LUNCH BAG RED RETROSPOT\n", | |
"369 22382 LUNCH BAG SPACEBOY DESIGN \n", | |
"547 22328 ROUND SNACK BOXES SET OF 4 FRUITS \n", | |
"549 22630 DOLLY GIRL LUNCH BOX\n", | |
"1241 22555 PLASTERS IN TIN STRONGMAN\n", | |
"58132 20725 LUNCH BAG RED SPOTTY" | |
] | |
}, | |
"execution_count": 47, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_items_purchased(12361, product_train, customers_arr, products_arr, item_lookup)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Description</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>23254</td>\n", | |
" <td>CHILDRENS CUTLERY DOLLY GIRL</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>22899</td>\n", | |
" <td>CHILDREN'S APRON DOLLY GIRL</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>23256</td>\n", | |
" <td>CHILDRENS CUTLERY SPACEBOY</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>22551</td>\n", | |
" <td>PLASTERS IN TIN SPACEBOY</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>20726</td>\n", | |
" <td>LUNCH BAG WOODLAND</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>22662</td>\n", | |
" <td>LUNCH BAG DOLLY GIRL DESIGN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>21559</td>\n", | |
" <td>STRAWBERRY LUNCH BOX WITH CUTLERY</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>22367</td>\n", | |
" <td>CHILDRENS APRON SPACEBOY DESIGN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>22383</td>\n", | |
" <td>LUNCH BAG SUKI DESIGN</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>20727</td>\n", | |
" <td>LUNCH BAG BLACK SKULL.</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" StockCode Description\n", | |
"0 23254 CHILDRENS CUTLERY DOLLY GIRL \n", | |
"1 22899 CHILDREN'S APRON DOLLY GIRL \n", | |
"2 23256 CHILDRENS CUTLERY SPACEBOY \n", | |
"3 22551 PLASTERS IN TIN SPACEBOY\n", | |
"4 20726 LUNCH BAG WOODLAND\n", | |
"5 22662 LUNCH BAG DOLLY GIRL DESIGN\n", | |
"6 21559 STRAWBERRY LUNCH BOX WITH CUTLERY\n", | |
"7 22367 CHILDRENS APRON SPACEBOY DESIGN\n", | |
"8 22383 LUNCH BAG SUKI DESIGN \n", | |
"9 20727 LUNCH BAG BLACK SKULL." | |
] | |
}, | |
"execution_count": 48, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"rec_items(12361, product_train, user_vecs, item_vecs, customers_arr, products_arr, item_lookup,\n", | |
" num_items = 10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def prediction_matrix(mf_train, user_vecs, item_vecs, customer_list, item_list, item_lookup):\n", | |
" '''\n", | |
" This function will return the normalized prediction for each item and each user:\n", | |
" \n", | |
" customer_id - Input the customer's id number that you want to get recommendations for\n", | |
" \n", | |
" mf_train - The training matrix you used for matrix factorization fitting\n", | |
" \n", | |
" user_vecs - the user vectors from your fitted matrix factorization\n", | |
" \n", | |
" item_vecs - the item vectors from your fitted matrix factorization\n", | |
" \n", | |
" customer_list - an array of the customer's ID numbers that make up the rows of your ratings matrix \n", | |
" (in order of matrix)\n", | |
" \n", | |
" item_list - an array of the products that make up the columns of your ratings matrix\n", | |
" (in order of matrix)\n", | |
" \n", | |
" item_lookup - A simple pandas dataframe of the unique product ID/product descriptions available\n", | |
" \n", | |
" \n", | |
" returns:\n", | |
" \n", | |
" - The full matrix of predictions for all users and items\n", | |
" '''\n", | |
" #item_factors_matrix = (item_vecs * np.shape(user_vecs)[0]).T \n", | |
" item_factors_matrix = item_vecs.T \n", | |
" rec_matrix = user_vecs.dot(item_factors_matrix) # Get dot product of user vector and all item vectors\n", | |
" # Scale this recommendation vector between 0 and 1\n", | |
" min_max = MinMaxScaler()\n", | |
" min_max.fit(rec_matrix)\n", | |
" rec_matrix_scaled = min_max.transform(rec_matrix)\n", | |
" return rec_matrix_scaled" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"output_predictions = prediction_matrix(product_train, user_vecs, item_vecs, customers_arr, products_arr, item_lookup)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## XGBOOST Ensemble ##" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" CustomerID StockCode Country Quantity UnitPrice\n", | |
"59631 13672.0 21655 United Kingdom 1.0 8.45\n", | |
"59632 13672.0 79323B United Kingdom 4.0 13.50\n", | |
"59634 13672.0 85232B United Kingdom 1.0 24.75\n" | |
] | |
} | |
], | |
"source": [ | |
"print(xgboost_data.query('CustomerID==13672'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 54, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'pandas.core.frame.DataFrame'>\n", | |
"Int64Index: 266733 entries, 0 to 267624\n", | |
"Data columns (total 5 columns):\n", | |
"CustomerID 266733 non-null float64\n", | |
"StockCode 266733 non-null object\n", | |
"Country 266733 non-null object\n", | |
"Quantity 266733 non-null float64\n", | |
"UnitPrice 266733 non-null float64\n", | |
"dtypes: float64(3), object(2)\n", | |
"memory usage: 12.2+ MB\n" | |
] | |
} | |
], | |
"source": [ | |
"xgboost_data.info()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#there is something wrong with this user\n", | |
"#xgboost_data = xgboost_data.query('CustomerID!=15802')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#TODO: Turn country into a numerical variable, remove other columns from retail_data except unit price, add a column for ALS value, add column for target, set all to 1. Introduce ALS value from result matrix. Add random negatives by replicating every row, changing to a random user ID and getting their ALS value for that item (ideally we would change item, but we'd need to know the unit price, which is not immediate)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 55, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>12346.0</td>\n", | |
" <td>23166</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.08</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>16008</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.25</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>17021</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>36.0</td>\n", | |
" <td>0.30</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20665</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.95</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20719</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>40.0</td>\n", | |
" <td>3.40</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice\n", | |
"0 12346.0 23166 United Kingdom 1.0 2.08\n", | |
"1 12347.0 16008 Iceland 24.0 0.25\n", | |
"2 12347.0 17021 Iceland 36.0 0.30\n", | |
"3 12347.0 20665 Iceland 6.0 2.95\n", | |
"4 12347.0 20719 Iceland 40.0 3.40" | |
] | |
}, | |
"execution_count": 55, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 56, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#Let's add a column for the value of ALS, and one for the positive negative label\n", | |
"xgboost_data = xgboost_data.assign(ALS=0., label=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 57, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>12346.0</td>\n", | |
" <td>23166</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.08</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>16008</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.25</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>17021</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>36.0</td>\n", | |
" <td>0.30</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20665</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.95</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20719</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>40.0</td>\n", | |
" <td>3.40</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS label\n", | |
"0 12346.0 23166 United Kingdom 1.0 2.08 0.0 1\n", | |
"1 12347.0 16008 Iceland 24.0 0.25 0.0 1\n", | |
"2 12347.0 17021 Iceland 36.0 0.30 0.0 1\n", | |
"3 12347.0 20665 Iceland 6.0 2.95 0.0 1\n", | |
"4 12347.0 20719 Iceland 40.0 3.40 0.0 1" | |
] | |
}, | |
"execution_count": 57, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 58, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(266733, 7)" | |
] | |
}, | |
"execution_count": 58, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.shape(xgboost_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#now let's prepare a matrix of random negatives\n", | |
"xgboost_negatives = xgboost_data.copy()\n", | |
"random_list = np.random.choice(products_arr,266733)\n", | |
"xgboost_negatives['StockCode'] = random_list\n", | |
"xgboost_negatives['label'] =0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>12346.0</td>\n", | |
" <td>23405</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.08</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>21043</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.25</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>21873</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>36.0</td>\n", | |
" <td>0.30</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>22535</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.95</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>22674</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>40.0</td>\n", | |
" <td>3.40</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS label\n", | |
"0 12346.0 23405 United Kingdom 1.0 2.08 0.0 0\n", | |
"1 12347.0 21043 Iceland 24.0 0.25 0.0 0\n", | |
"2 12347.0 21873 Iceland 36.0 0.30 0.0 0\n", | |
"3 12347.0 22535 Iceland 6.0 2.95 0.0 0\n", | |
"4 12347.0 22674 Iceland 40.0 3.40 0.0 0" | |
] | |
}, | |
"execution_count": 62, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_negatives.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 63, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>12346.0</td>\n", | |
" <td>23166</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.08</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>16008</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.25</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>17021</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>36.0</td>\n", | |
" <td>0.30</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20665</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.95</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20719</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>40.0</td>\n", | |
" <td>3.40</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS label\n", | |
"0 12346.0 23166 United Kingdom 1.0 2.08 0.0 1\n", | |
"1 12347.0 16008 Iceland 24.0 0.25 0.0 1\n", | |
"2 12347.0 17021 Iceland 36.0 0.30 0.0 1\n", | |
"3 12347.0 20665 Iceland 6.0 2.95 0.0 1\n", | |
"4 12347.0 20719 Iceland 40.0 3.40 0.0 1" | |
] | |
}, | |
"execution_count": 63, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 64, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"frames = [xgboost_data,xgboost_negatives]\n", | |
"xgboost_final_data = pd.concat(frames)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>12346.0</td>\n", | |
" <td>23166</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>1.0</td>\n", | |
" <td>2.08</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>16008</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.25</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>17021</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>36.0</td>\n", | |
" <td>0.30</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20665</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.95</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>12347.0</td>\n", | |
" <td>20719</td>\n", | |
" <td>Iceland</td>\n", | |
" <td>40.0</td>\n", | |
" <td>3.40</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS label\n", | |
"0 12346.0 23166 United Kingdom 1.0 2.08 0.0 1\n", | |
"1 12347.0 16008 Iceland 24.0 0.25 0.0 1\n", | |
"2 12347.0 17021 Iceland 36.0 0.30 0.0 1\n", | |
"3 12347.0 20665 Iceland 6.0 2.95 0.0 1\n", | |
"4 12347.0 20719 Iceland 40.0 3.40 0.0 1" | |
] | |
}, | |
"execution_count": 65, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_final_data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>267620</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>22053</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>24.0</td>\n", | |
" <td>2.10</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267621</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>84678</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.55</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267622</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>22161</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>96.0</td>\n", | |
" <td>2.90</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267623</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>23356</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>120.0</td>\n", | |
" <td>4.55</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267624</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>35957</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>48.0</td>\n", | |
" <td>3.30</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS label\n", | |
"267620 18287.0 22053 United Kingdom 24.0 2.10 0.0 0\n", | |
"267621 18287.0 84678 United Kingdom 6.0 2.55 0.0 0\n", | |
"267622 18287.0 22161 United Kingdom 96.0 2.90 0.0 0\n", | |
"267623 18287.0 23356 United Kingdom 120.0 4.55 0.0 0\n", | |
"267624 18287.0 35957 United Kingdom 48.0 3.30 0.0 0" | |
] | |
}, | |
"execution_count": 66, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_final_data.tail()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 67, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"users_index = pd.DataFrame({'User':customers_arr,'Pos':np.arange(customers_arr.size)}, index=customers_arr)\n", | |
"items_index = pd.DataFrame({'Item':products_arr,'Pos':np.arange(products_arr.size)}, index=products_arr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def set_ALS_values_fast(data,customers,products,predictions):\n", | |
" x=np.array(customers.Pos[data.CustomerID]).astype(int)\n", | |
" y=np.array(products.Pos[data.StockCode.astype(str)]).astype(int)\n", | |
" data.ALS = predictions[x,y] " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"customers_inxgboostdata = np.array(users_index.Pos[xgboost_data.CustomerID]).astype(int)\n", | |
"items_inxgboostdata = np.array(items_index.Pos[xgboost_data.StockCode]).astype(int)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 71, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Item 10002.0\n", | |
"Pos 0.0\n", | |
"dtype: float64" | |
] | |
}, | |
"execution_count": 71, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"items_index.min()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"set_ALS_values_fast(xgboost_final_data, users_index, items_index, output_predictions)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>267620</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>22053</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>24.0</td>\n", | |
" <td>2.10</td>\n", | |
" <td>0.345061</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267621</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>84678</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.55</td>\n", | |
" <td>0.607426</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267622</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>22161</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>96.0</td>\n", | |
" <td>2.90</td>\n", | |
" <td>0.398117</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267623</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>23356</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>120.0</td>\n", | |
" <td>4.55</td>\n", | |
" <td>0.677318</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267624</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>35957</td>\n", | |
" <td>United Kingdom</td>\n", | |
" <td>48.0</td>\n", | |
" <td>3.30</td>\n", | |
" <td>0.543847</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS \\\n", | |
"267620 18287.0 22053 United Kingdom 24.0 2.10 0.345061 \n", | |
"267621 18287.0 84678 United Kingdom 6.0 2.55 0.607426 \n", | |
"267622 18287.0 22161 United Kingdom 96.0 2.90 0.398117 \n", | |
"267623 18287.0 23356 United Kingdom 120.0 4.55 0.677318 \n", | |
"267624 18287.0 35957 United Kingdom 48.0 3.30 0.543847 \n", | |
"\n", | |
" label \n", | |
"267620 0 \n", | |
"267621 0 \n", | |
"267622 0 \n", | |
"267623 0 \n", | |
"267624 0 " | |
] | |
}, | |
"execution_count": 73, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_final_data.tail()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(533466, 7)" | |
] | |
}, | |
"execution_count": 74, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.shape(xgboost_final_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 75, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"xgboost_final_data.to_csv('Xgboostdata_with_neagatives.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'pandas.core.frame.DataFrame'>\n", | |
"Int64Index: 533466 entries, 0 to 267624\n", | |
"Data columns (total 7 columns):\n", | |
"CustomerID 533466 non-null float64\n", | |
"StockCode 533466 non-null object\n", | |
"Country 533466 non-null object\n", | |
"Quantity 533466 non-null float64\n", | |
"UnitPrice 533466 non-null float64\n", | |
"ALS 533466 non-null float64\n", | |
"label 533466 non-null int64\n", | |
"dtypes: float64(4), int64(1), object(2)\n", | |
"memory usage: 52.6+ MB\n" | |
] | |
} | |
], | |
"source": [ | |
"xgboost_final_data.info()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We need to convert all features to numerical. We will use categorical encoding" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 77, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"for feature in xgboost_final_data.columns: # Loop through all columns in the dataframe\n", | |
" if xgboost_final_data[feature].dtype == 'object': # Only apply for columns with categorical strings\n", | |
" xgboost_final_data[feature] = pd.Categorical(xgboost_final_data[feature]).codes # Replace strings with an integer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 78, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'pandas.core.frame.DataFrame'>\n", | |
"Int64Index: 533466 entries, 0 to 267624\n", | |
"Data columns (total 7 columns):\n", | |
"CustomerID 533466 non-null float64\n", | |
"StockCode 533466 non-null int16\n", | |
"Country 533466 non-null int8\n", | |
"Quantity 533466 non-null float64\n", | |
"UnitPrice 533466 non-null float64\n", | |
"ALS 533466 non-null float64\n", | |
"label 533466 non-null int64\n", | |
"dtypes: float64(4), int16(1), int64(1), int8(1)\n", | |
"memory usage: 45.9 MB\n" | |
] | |
} | |
], | |
"source": [ | |
"xgboost_final_data.info()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>267620</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>3740</td>\n", | |
" <td>35</td>\n", | |
" <td>24.0</td>\n", | |
" <td>2.10</td>\n", | |
" <td>0.345061</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267621</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>5734</td>\n", | |
" <td>35</td>\n", | |
" <td>6.0</td>\n", | |
" <td>2.55</td>\n", | |
" <td>0.607426</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267622</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>3839</td>\n", | |
" <td>35</td>\n", | |
" <td>96.0</td>\n", | |
" <td>2.90</td>\n", | |
" <td>0.398117</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267623</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>4959</td>\n", | |
" <td>35</td>\n", | |
" <td>120.0</td>\n", | |
" <td>4.55</td>\n", | |
" <td>0.677318</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>267624</th>\n", | |
" <td>18287.0</td>\n", | |
" <td>5260</td>\n", | |
" <td>35</td>\n", | |
" <td>48.0</td>\n", | |
" <td>3.30</td>\n", | |
" <td>0.543847</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS label\n", | |
"267620 18287.0 3740 35 24.0 2.10 0.345061 0\n", | |
"267621 18287.0 5734 35 6.0 2.55 0.607426 0\n", | |
"267622 18287.0 3839 35 96.0 2.90 0.398117 0\n", | |
"267623 18287.0 4959 35 120.0 4.55 0.677318 0\n", | |
"267624 18287.0 5260 35 48.0 3.30 0.543847 0" | |
] | |
}, | |
"execution_count": 79, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_final_data.tail()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We will now split into train and test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#first we need to shuffle rows since we have now all positive and negatives together\n", | |
"xgboost_final_data = xgboost_final_data.sample(frac=1).reset_index(drop=True)\n", | |
"xgboost_final_data = xgboost_final_data.sample(frac=1).reset_index(drop=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 81, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>533461</th>\n", | |
" <td>14466.0</td>\n", | |
" <td>3935</td>\n", | |
" <td>35</td>\n", | |
" <td>14.0</td>\n", | |
" <td>9.90</td>\n", | |
" <td>0.632180</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>533462</th>\n", | |
" <td>17511.0</td>\n", | |
" <td>3472</td>\n", | |
" <td>35</td>\n", | |
" <td>216.0</td>\n", | |
" <td>1.08</td>\n", | |
" <td>0.826440</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>533463</th>\n", | |
" <td>16389.0</td>\n", | |
" <td>2620</td>\n", | |
" <td>35</td>\n", | |
" <td>12.0</td>\n", | |
" <td>1.25</td>\n", | |
" <td>0.497573</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>533464</th>\n", | |
" <td>14456.0</td>\n", | |
" <td>5488</td>\n", | |
" <td>35</td>\n", | |
" <td>1.0</td>\n", | |
" <td>0.55</td>\n", | |
" <td>0.649902</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>533465</th>\n", | |
" <td>14581.0</td>\n", | |
" <td>4003</td>\n", | |
" <td>35</td>\n", | |
" <td>24.0</td>\n", | |
" <td>0.65</td>\n", | |
" <td>0.322136</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS label\n", | |
"533461 14466.0 3935 35 14.0 9.90 0.632180 0\n", | |
"533462 17511.0 3472 35 216.0 1.08 0.826440 0\n", | |
"533463 16389.0 2620 35 12.0 1.25 0.497573 1\n", | |
"533464 14456.0 5488 35 1.0 0.55 0.649902 0\n", | |
"533465 14581.0 4003 35 24.0 0.65 0.322136 0" | |
] | |
}, | |
"execution_count": 81, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_final_data.tail()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>13381.0</td>\n", | |
" <td>1259</td>\n", | |
" <td>35</td>\n", | |
" <td>30.0</td>\n", | |
" <td>10.20</td>\n", | |
" <td>0.667535</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>14737.0</td>\n", | |
" <td>5811</td>\n", | |
" <td>35</td>\n", | |
" <td>20.0</td>\n", | |
" <td>1.25</td>\n", | |
" <td>0.351160</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>17126.0</td>\n", | |
" <td>6443</td>\n", | |
" <td>35</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.25</td>\n", | |
" <td>0.818149</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>15318.0</td>\n", | |
" <td>4566</td>\n", | |
" <td>35</td>\n", | |
" <td>12.0</td>\n", | |
" <td>0.79</td>\n", | |
" <td>0.491165</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>14481.0</td>\n", | |
" <td>1127</td>\n", | |
" <td>35</td>\n", | |
" <td>3.0</td>\n", | |
" <td>5.95</td>\n", | |
" <td>0.730605</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS label\n", | |
"0 13381.0 1259 35 30.0 10.20 0.667535 1\n", | |
"1 14737.0 5811 35 20.0 1.25 0.351160 0\n", | |
"2 17126.0 6443 35 1.0 1.25 0.818149 0\n", | |
"3 15318.0 4566 35 12.0 0.79 0.491165 0\n", | |
"4 14481.0 1127 35 3.0 5.95 0.730605 1" | |
] | |
}, | |
"execution_count": 82, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_final_data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"426772\n" | |
] | |
} | |
], | |
"source": [ | |
"#now we can select 20% of the data for testing\n", | |
"train_samples = int(len(xgboost_final_data)*0.8)\n", | |
"print (train_samples)\n", | |
"xgboost_train = xgboost_final_data[:train_samples].copy() \n", | |
"xgboost_test = xgboost_final_data[train_samples:].copy()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#We also need to separate target variable label from the datasets\n", | |
"label_train = xgboost_train.pop('label')\n", | |
"label_test = xgboost_test.pop('label')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'pandas.core.frame.DataFrame'>\n", | |
"RangeIndex: 426772 entries, 0 to 426771\n", | |
"Data columns (total 6 columns):\n", | |
"CustomerID 426772 non-null float64\n", | |
"StockCode 426772 non-null int16\n", | |
"Country 426772 non-null int8\n", | |
"Quantity 426772 non-null float64\n", | |
"UnitPrice 426772 non-null float64\n", | |
"ALS 426772 non-null float64\n", | |
"dtypes: float64(4), int16(1), int8(1)\n", | |
"memory usage: 14.2 MB\n" | |
] | |
} | |
], | |
"source": [ | |
"xgboost_train.info()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'pandas.core.frame.DataFrame'>\n", | |
"RangeIndex: 106694 entries, 426772 to 533465\n", | |
"Data columns (total 6 columns):\n", | |
"CustomerID 106694 non-null float64\n", | |
"StockCode 106694 non-null int16\n", | |
"Country 106694 non-null int8\n", | |
"Quantity 106694 non-null float64\n", | |
"UnitPrice 106694 non-null float64\n", | |
"ALS 106694 non-null float64\n", | |
"dtypes: float64(4), int16(1), int8(1)\n", | |
"memory usage: 3.6 MB\n" | |
] | |
} | |
], | |
"source": [ | |
"xgboost_test.info()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 87, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style>\n", | |
" .dataframe thead tr:only-child th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: left;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>CustomerID</th>\n", | |
" <th>StockCode</th>\n", | |
" <th>Country</th>\n", | |
" <th>Quantity</th>\n", | |
" <th>UnitPrice</th>\n", | |
" <th>ALS</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>426772</th>\n", | |
" <td>17965.0</td>\n", | |
" <td>4476</td>\n", | |
" <td>35</td>\n", | |
" <td>10.0</td>\n", | |
" <td>2.50</td>\n", | |
" <td>0.849737</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>426773</th>\n", | |
" <td>16020.0</td>\n", | |
" <td>4051</td>\n", | |
" <td>35</td>\n", | |
" <td>6.0</td>\n", | |
" <td>3.25</td>\n", | |
" <td>0.162656</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>426774</th>\n", | |
" <td>17049.0</td>\n", | |
" <td>1977</td>\n", | |
" <td>35</td>\n", | |
" <td>12.0</td>\n", | |
" <td>1.25</td>\n", | |
" <td>0.740949</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>426775</th>\n", | |
" <td>15301.0</td>\n", | |
" <td>1695</td>\n", | |
" <td>35</td>\n", | |
" <td>12.0</td>\n", | |
" <td>0.85</td>\n", | |
" <td>0.756989</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>426776</th>\n", | |
" <td>15800.0</td>\n", | |
" <td>2595</td>\n", | |
" <td>35</td>\n", | |
" <td>12.0</td>\n", | |
" <td>1.25</td>\n", | |
" <td>0.772001</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" CustomerID StockCode Country Quantity UnitPrice ALS\n", | |
"426772 17965.0 4476 35 10.0 2.50 0.849737\n", | |
"426773 16020.0 4051 35 6.0 3.25 0.162656\n", | |
"426774 17049.0 1977 35 12.0 1.25 0.740949\n", | |
"426775 15301.0 1695 35 12.0 0.85 0.756989\n", | |
"426776 15800.0 2595 35 12.0 1.25 0.772001" | |
] | |
}, | |
"execution_count": 87, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgboost_test.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 88, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import xgboost as xgb\n", | |
"from sklearn.model_selection import GridSearchCV" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"xgboost_train.to_csv('xgboost_train.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"label_train.to_csv('label_train.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 131, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#Look for more info on the XGBoost parameters here: https://github.com/dmlc/xgboost/blob/master/doc/parameter.md\n", | |
"cv_params = {'max_depth': [3,5,7], 'min_child_weight': [1,3,5]}\n", | |
"\n", | |
"ind_params = {'learning_rate': 0.1, 'n_estimators': 1000, 'seed':0, 'subsample': 0.8, 'colsample_bytree': 0.8, \n", | |
" 'objective': 'binary:logistic'}\n", | |
"optimized_XGB = GridSearchCV(xgb.XGBClassifier(**ind_params), \n", | |
" cv_params, \n", | |
" scoring = 'roc_auc', cv = 5, n_jobs = 1, verbose=2) \n", | |
"#We optimize for AUC to compare with previous results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 132, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Fitting 5 folds for each of 9 candidates, totalling 45 fits\n", | |
"[CV] max_depth=3, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=1, total= 1.3min\n", | |
"[CV] max_depth=3, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 1.4min remaining: 0.0s\n", | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=1, total= 1.2min\n", | |
"[CV] max_depth=3, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=1, total= 1.1min\n", | |
"[CV] max_depth=3, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=1, total= 1.2min\n", | |
"[CV] max_depth=3, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=1, total= 1.2min\n", | |
"[CV] max_depth=3, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=3, total= 1.1min\n", | |
"[CV] max_depth=3, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=3, total= 1.2min\n", | |
"[CV] max_depth=3, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=3, total= 1.2min\n", | |
"[CV] max_depth=3, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=3, total= 1.2min\n", | |
"[CV] max_depth=3, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=3, total= 1.1min\n", | |
"[CV] max_depth=3, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=5, total= 1.1min\n", | |
"[CV] max_depth=3, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=5, total= 1.2min\n", | |
"[CV] max_depth=3, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=5, total= 1.1min\n", | |
"[CV] max_depth=3, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=5, total= 1.1min\n", | |
"[CV] max_depth=3, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=3, min_child_weight=5, total= 1.1min\n", | |
"[CV] max_depth=5, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=1, total= 1.8min\n", | |
"[CV] max_depth=5, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=1, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=1, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=1, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=1, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=3, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=3, total= 1.8min\n", | |
"[CV] max_depth=5, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=3, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=3, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=3, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=5, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=5, total= 1.8min\n", | |
"[CV] max_depth=5, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=5, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=5, total= 1.7min\n", | |
"[CV] max_depth=5, min_child_weight=5 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=5, min_child_weight=5, total= 1.7min\n", | |
"[CV] max_depth=7, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=7, min_child_weight=1, total= 2.4min\n", | |
"[CV] max_depth=7, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=7, min_child_weight=1, total= 2.3min\n", | |
"[CV] max_depth=7, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=7, min_child_weight=1, total= 2.3min\n", | |
"[CV] max_depth=7, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=7, min_child_weight=1, total= 2.3min\n", | |
"[CV] max_depth=7, min_child_weight=1 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=7, min_child_weight=1, total= 2.3min\n", | |
"[CV] max_depth=7, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[CV] .................. max_depth=7, min_child_weight=3, total= 2.3min\n", | |
"[CV] max_depth=7, min_child_weight=3 .................................\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/marcin.cylke/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py:203: DeprecationWarning: The seed parameter is deprecated as of version .6.Please use random_state instead.seed is deprecated.\n", | |
" 'seed is deprecated.', DeprecationWarning)\n" | |
] | |
}, | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-132-4cbe90b538a0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#train XGB\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0moptimized_XGB\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mselected_xgboost_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0meval_metric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'auc'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m 636\u001b[0m error_score=self.error_score)\n\u001b[1;32m 637\u001b[0m for parameters, (train, test) in product(candidate_params,\n\u001b[0;32m--> 638\u001b[0;31m cv.split(X, y, groups)))\n\u001b[0m\u001b[1;32m 639\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 640\u001b[0m \u001b[0;31m# if one choose to see train score, \"out\" will contain train score info\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 777\u001b[0m \u001b[0;31m# was dispatched. In particular this covers the edge\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[0;31m# case of Parallel used with an exhausted iterator.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 779\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 780\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterating\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 781\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py\u001b[0m in \u001b[0;36mdispatch_one_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 624\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 625\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtasks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 626\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py\u001b[0m in \u001b[0;36m_dispatch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 586\u001b[0m \u001b[0mdispatch_timestamp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 587\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBatchCompletionCallBack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdispatch_timestamp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 588\u001b[0;31m \u001b[0mjob\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 589\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjob\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 590\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mapply_async\u001b[0;34m(self, func, callback)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mapply_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;34m\"\"\"Schedule a func to be run\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImmediateResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;31m# Don't delay the application, to avoid keeping the input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;31m# arguments in memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 334\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__len__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__len__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/model_selection/_validation.py\u001b[0m in \u001b[0;36m_fit_and_score\u001b[0;34m(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, error_score)\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 436\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 437\u001b[0;31m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 438\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 439\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/sklearn.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight, eval_set, eval_metric, early_stopping_rounds, verbose)\u001b[0m\n\u001b[1;32m 498\u001b[0m \u001b[0mearly_stopping_rounds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mearly_stopping_rounds\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 499\u001b[0m \u001b[0mevals_result\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mevals_result\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfeval\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 500\u001b[0;31m verbose_eval=verbose)\n\u001b[0m\u001b[1;32m 501\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 502\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobjective\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxgb_options\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"objective\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/training.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(params, dtrain, num_boost_round, evals, obj, feval, maximize, early_stopping_rounds, evals_result, verbose_eval, xgb_model, callbacks, learning_rates)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0mevals\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mevals\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfeval\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 204\u001b[0;31m xgb_model=xgb_model, callbacks=callbacks)\n\u001b[0m\u001b[1;32m 205\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/training.py\u001b[0m in \u001b[0;36m_train_internal\u001b[0;34m(params, dtrain, num_boost_round, evals, obj, feval, xgb_model, callbacks)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;31m# Skip the first update if it is a recovery step.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mversion\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0mbst\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0mbst\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_rabit_checkpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0mversion\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/xgboost-0.6-py3.5.egg/xgboost/core.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, dtrain, iteration, fobj)\u001b[0m\n\u001b[1;32m 884\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfobj\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 885\u001b[0m _check_call(_LIB.XGBoosterUpdateOneIter(self.handle, ctypes.c_int(iteration),\n\u001b[0;32m--> 886\u001b[0;31m dtrain.handle))\n\u001b[0m\u001b[1;32m 887\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtrain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"#train XGB\n", | |
"optimized_XGB.fit(selected_xgboost_train, label_train,eval_metric='auc')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"ename": "NotFittedError", | |
"evalue": "This GridSearchCV instance is not fitted yet. Call 'fit' with appropriate arguments before using this method.", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mNotFittedError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-105-a0f4753666d2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#train error\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0moptimized_XGB\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrid_scores_\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mgrid_scores_\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 738\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrid_scores_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 739\u001b[0;31m \u001b[0mcheck_is_fitted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'cv_results_'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 740\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultimetric_\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 741\u001b[0m raise AttributeError(\"grid_scores_ attribute is not available for\"\n", | |
"\u001b[0;32m~/Virtualenvs/xavier_intro_tutorial/lib/python3.5/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_is_fitted\u001b[0;34m(estimator, attributes, msg, all_or_any)\u001b[0m\n\u001b[1;32m 735\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 736\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall_or_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mattr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mattributes\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 737\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mNotFittedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'name'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 738\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 739\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mNotFittedError\u001b[0m: This GridSearchCV instance is not fitted yet. Call 'fit' with appropriate arguments before using this method." | |
] | |
} | |
], | |
"source": [ | |
"#train error\n", | |
"optimized_XGB.grid_scores_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 176, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"columns = [\"Quantity\", \"Country\", \"UnitPrice\", \"ALS\"]\n", | |
"columns = [\"CustomerID\", \"StockCode\"]\n", | |
"\n", | |
"selected_xgboost_train = xgboost_train[columns]\n", | |
"selected_xgboost_test = xgboost_test[columns]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 177, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#we take our best model given cross validation results\n", | |
"best_params = {'eta': 0.1, 'seed':0, 'subsample': 0.8, 'colsample_bytree': 0.8, \n", | |
" 'objective': 'binary:logistic', 'max_depth':5, 'min_child_weight':1} \n", | |
"\n", | |
"train_xgb_matrix = xgb.DMatrix(selected_xgboost_train, label_train) # Create our DMatrix to make XGBoost more efficient\n", | |
"best_model = xgb.train(best_params, train_xgb_matrix, num_boost_round = 432)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## on raw xgb model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 178, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 3.31898540e-04, 2.34217325e-04, 9.99991059e-01, ...,\n", | |
" 9.99991059e-01, 6.64029494e-02, 3.33350268e-04], dtype=float32)" | |
] | |
}, | |
"execution_count": 178, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#now let's check error on test set\n", | |
"test_xgb_matrix = xgb.DMatrix(selected_xgboost_test)\n", | |
"from sklearn.metrics import roc_auc_score\n", | |
"label_pred = best_model.predict(test_xgb_matrix)\n", | |
"label_pred" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 179, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#need to convert values to binary label\n", | |
"label_pred[label_pred > 0.5] = 1\n", | |
"label_pred[label_pred <= 0.5] = 0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 180, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 0., 0., 1., ..., 1., 0., 0.], dtype=float32)" | |
] | |
}, | |
"execution_count": 180, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"label_pred" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 181, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.96973339195568597" | |
] | |
}, | |
"execution_count": 181, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"roc_auc_score(label_pred.astype(int), label_test.astype(int))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Bonus Follow ups\n", | |
"\n", | |
"* Improve overall auc by adding more features to xgboost ensemble\n", | |
"* Fine-tune xgboost ensemble by optimizing other hyperparameters (e.g. learning rate)\n", | |
"* Analyze how hyperparameters behave to metrics different from AUC" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment