Skip to content

Instantly share code, notes, and snippets.

@Sandy4321
Forked from mrocklin/criteo.ipynb
Created January 26, 2020 19:55
Show Gist options
  • Save Sandy4321/d86815323b739b817356e534aceb33fd to your computer and use it in GitHub Desktop.
Save Sandy4321/d86815323b739b817356e534aceb33fd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Analyze Criteo with Dask-ML\n",
"\n",
"This notebook uses the Criteo click logs dataset to stress a few parts of Dask-ML pipelines. It is a continued experiment that branches from earlier work from Tom Augspurger.\n",
"\n",
"- Github issue: https://github.com/dask/dask-ml/issues/295\n",
"- Tom's earlier gist: https://gist.github.com/TomAugspurger/4a058f00b32fc049ab5f2860d03fd579\n",
"\n",
"We read data from parquet or CSV (I happen to have my own parquet files, but CSV should be fine too) then construct a pipeline using column transformers, and fit with SGD.\n",
"\n",
"Results are not good, but things run better than they did before :)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Criteo dataset\n",
"\n",
"We'll work with the Criteo dataset. This has a mixture of numeric and categorical features. It's also a large dataset, which presents some challenges for many pre-processing methods.\n",
"\n",
"The full dataset is from http://labs.criteo.com/2013/12/download-terabyte-click-logs/. We'll work with a sample."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start Dask Client"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table style=\"border: 2px solid white;\">\n",
"<tr>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Client</h3>\n",
"<ul>\n",
" <li><b>Scheduler: </b>tcp://127.0.0.1:43539\n",
" <li><b>Dashboard: </b><a href='http://127.0.0.1:8787/status' target='_blank'>http://127.0.0.1:8787/status</a>\n",
"</ul>\n",
"</td>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Cluster</h3>\n",
"<ul>\n",
" <li><b>Workers: </b>4</li>\n",
" <li><b>Cores: </b>4</li>\n",
" <li><b>Memory: </b>16.68 GB</li>\n",
"</ul>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"<Client: scheduler='tcp://127.0.0.1:43539' processes=4 cores=4>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from dask.distributed import Client\n",
"\n",
"client = Client()\n",
"client"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"ordinal_columns = [\n",
" 'category_0', 'category_1', 'category_2', 'category_3',\n",
" 'category_4', 'category_6', 'category_7', 'category_9',\n",
" 'category_10', 'category_11', 'category_13', 'category_14',\n",
" 'category_17', 'category_19', 'category_20', 'category_21',\n",
" 'category_22', 'category_23',\n",
"]\n",
"\n",
"onehot_columns = [\n",
" 'category_5', 'category_8', 'category_12',\n",
" 'category_15', 'category_16', 'category_18',\n",
" 'category_24', 'category_25',\n",
"]\n",
"\n",
"numeric_columns = [f'numeric_{i}' for i in range(13)]\n",
"columns = ['click'] + numeric_columns + onehot_columns + ordinal_columns"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import dask.dataframe as dd\n",
"sample = dd.read_parquet(\"../day-0.parquet/\").partitions[:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our goal is to predict 'click' using the other columns."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>click</th>\n",
" <th>numeric_0</th>\n",
" <th>numeric_1</th>\n",
" <th>numeric_2</th>\n",
" <th>numeric_3</th>\n",
" <th>numeric_4</th>\n",
" <th>numeric_5</th>\n",
" <th>numeric_6</th>\n",
" <th>numeric_7</th>\n",
" <th>numeric_8</th>\n",
" <th>...</th>\n",
" <th>category_16</th>\n",
" <th>category_17</th>\n",
" <th>category_18</th>\n",
" <th>category_19</th>\n",
" <th>category_20</th>\n",
" <th>category_21</th>\n",
" <th>category_22</th>\n",
" <th>category_23</th>\n",
" <th>category_24</th>\n",
" <th>category_25</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>5.0</td>\n",
" <td>110.0</td>\n",
" <td>NaN</td>\n",
" <td>16.0</td>\n",
" <td>NaN</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>14</td>\n",
" <td>7</td>\n",
" <td>...</td>\n",
" <td>b'd20856aa'</td>\n",
" <td>b'b8170bba'</td>\n",
" <td>b'9512c20b'</td>\n",
" <td>b'c38e2f28'</td>\n",
" <td>b'14f65a5d'</td>\n",
" <td>b'25b1b089'</td>\n",
" <td>b'd7c1fc0b'</td>\n",
" <td>b'7caf609c'</td>\n",
" <td>b'30436bfc'</td>\n",
" <td>b'ed10571d'</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>32.0</td>\n",
" <td>3.0</td>\n",
" <td>5.0</td>\n",
" <td>NaN</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>61</td>\n",
" <td>5</td>\n",
" <td>...</td>\n",
" <td>b'd20856aa'</td>\n",
" <td>b'a1eb1511'</td>\n",
" <td>b'9512c20b'</td>\n",
" <td>b'febfd863'</td>\n",
" <td>b'a3323ca1'</td>\n",
" <td>b'c8e1ee56'</td>\n",
" <td>b'1752e9e8'</td>\n",
" <td>b'75350c8a'</td>\n",
" <td>b'991321ea'</td>\n",
" <td>b'b757e957'</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>233.0</td>\n",
" <td>1.0</td>\n",
" <td>146.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>99</td>\n",
" <td>7</td>\n",
" <td>...</td>\n",
" <td>b'd20856aa'</td>\n",
" <td>b'628f1b8d'</td>\n",
" <td>b'9512c20b'</td>\n",
" <td>b'c38e2f28'</td>\n",
" <td>b'14f65a5d'</td>\n",
" <td>b'25b1b089'</td>\n",
" <td>b'd7c1fc0b'</td>\n",
" <td>b'34a9b905'</td>\n",
" <td>b'ff654802'</td>\n",
" <td>b'ed10571d'</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>NaN</td>\n",
" <td>24.0</td>\n",
" <td>NaN</td>\n",
" <td>11.0</td>\n",
" <td>24.0</td>\n",
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>56</td>\n",
" <td>3</td>\n",
" <td>...</td>\n",
" <td>b'1f7fc70b'</td>\n",
" <td>b'a1eb1511'</td>\n",
" <td>b'9512c20b'</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>b'dc209cd3'</td>\n",
" <td>b'b8a81fb0'</td>\n",
" <td>b'30436bfc'</td>\n",
" <td>b'b757e957'</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>60.0</td>\n",
" <td>223.0</td>\n",
" <td>6.0</td>\n",
" <td>15.0</td>\n",
" <td>5.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" <td>8</td>\n",
" <td>...</td>\n",
" <td>b'd20856aa'</td>\n",
" <td>b'd9f758ff'</td>\n",
" <td>b'9512c20b'</td>\n",
" <td>b'c709ec07'</td>\n",
" <td>b'2b07677e'</td>\n",
" <td>b'a89a92a5'</td>\n",
" <td>b'aa137169'</td>\n",
" <td>b'e619743b'</td>\n",
" <td>b'cdc3217e'</td>\n",
" <td>b'ed10571d'</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 40 columns</p>\n",
"</div>"
],
"text/plain": [
" click numeric_0 numeric_1 numeric_2 numeric_3 numeric_4 numeric_5 \\\n",
"0 1 5.0 110.0 NaN 16.0 NaN 1.0 \n",
"1 0 32.0 3.0 5.0 NaN 1.0 0.0 \n",
"2 0 NaN 233.0 1.0 146.0 1.0 0.0 \n",
"3 0 NaN 24.0 NaN 11.0 24.0 NaN \n",
"4 0 60.0 223.0 6.0 15.0 5.0 0.0 \n",
"\n",
" numeric_6 numeric_7 numeric_8 ... category_16 category_17 \\\n",
"0 0.0 14 7 ... b'd20856aa' b'b8170bba' \n",
"1 0.0 61 5 ... b'd20856aa' b'a1eb1511' \n",
"2 0.0 99 7 ... b'd20856aa' b'628f1b8d' \n",
"3 0.0 56 3 ... b'1f7fc70b' b'a1eb1511' \n",
"4 0.0 1 8 ... b'd20856aa' b'd9f758ff' \n",
"\n",
" category_18 category_19 category_20 category_21 category_22 \\\n",
"0 b'9512c20b' b'c38e2f28' b'14f65a5d' b'25b1b089' b'd7c1fc0b' \n",
"1 b'9512c20b' b'febfd863' b'a3323ca1' b'c8e1ee56' b'1752e9e8' \n",
"2 b'9512c20b' b'c38e2f28' b'14f65a5d' b'25b1b089' b'd7c1fc0b' \n",
"3 b'9512c20b' None None None b'dc209cd3' \n",
"4 b'9512c20b' b'c709ec07' b'2b07677e' b'a89a92a5' b'aa137169' \n",
"\n",
" category_23 category_24 category_25 \n",
"0 b'7caf609c' b'30436bfc' b'ed10571d' \n",
"1 b'75350c8a' b'991321ea' b'b757e957' \n",
"2 b'34a9b905' b'ff654802' b'ed10571d' \n",
"3 b'b8a81fb0' b'30436bfc' b'b757e957' \n",
"4 b'e619743b' b'cdc3217e' b'ed10571d' \n",
"\n",
"[5 rows x 40 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sample.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"sample[onehot_columns + ordinal_columns] = sample[onehot_columns + ordinal_columns].fillna(value='')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"y = sample['click']\n",
"X = sample.drop(\"click\", axis='columns')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's lay out our pre-processing pipeline. We have three types of columns\n",
"\n",
"1. Numeric columns\n",
"2. Low-cardinality categorical columns\n",
"3. High-cardinality categorical columns\n",
"\n",
"Each of those will be processed differently.\n",
"\n",
"1. Numeric columns will have missing values filled with the column average and standard scaled\n",
"2. Low-cardinality categorical columns will be one-hot encoded\n",
"3. High-cardinality categorical columns will be deterministically hashed and standard scaled\n",
"\n",
"You'll probably want to quibble with some of these choices, but right now, I'm\n",
"just interested in the ability to do these kinds of transformations at all.\n",
"\n",
"We need to define a couple custom estimators, one for hashing the values of a dask dataframe, and one for converting a dask dataframe to a dask array."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import dask\n",
"from dask_ml.compose import make_column_transformer\n",
"from dask_ml.feature_extraction.text import HashingVectorizer\n",
"from dask_ml.preprocessing import StandardScaler\n",
"from dask_ml.wrappers import Incremental\n",
"from dask_ml.impute import SimpleImputer\n",
"\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.linear_model import SGDClassifier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now for the pipeline."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"X, y = dask.persist(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"weights = {0: 1, \n",
" 1: (y.count() / y.sum()).compute()}"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(memory=None,\n",
" steps=[('columntransformer-1', ColumnTransformer(n_jobs=1, preserve_dataframe=True, remainder='passthrough',\n",
" sparse_threshold=0.3, transformer_weights=None,\n",
" transformers=[('simpleimputer', SimpleImputer(copy=True, fill_value=None, missing_values=nan, strategy='mean',\n",
" verbose...,\n",
" verbose=0, warm_start=False),\n",
" random_state=None, scoring=None, shuffle_blocks=True))])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hashing_vectorizer = HashingVectorizer()\n",
"nan_imputer = SimpleImputer()\n",
"\n",
"to_numeric = make_column_transformer(\n",
" (numeric_columns, nan_imputer),\n",
" remainder='passthrough',\n",
")\n",
"\n",
"scaler = make_column_transformer(\n",
" (list(numeric_columns), StandardScaler()),\n",
" *[(c, hashing_vectorizer) for c in (onehot_columns + ordinal_columns)],\n",
" remainder='passthrough'\n",
")\n",
"\n",
"clf = Incremental(\n",
" SGDClassifier(loss='log',\n",
" random_state=0,\n",
" max_iter=1000, \n",
" class_weight=weights)\n",
")\n",
"\n",
"pipe = make_pipeline(to_numeric, scaler, clf)\n",
"pipe"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Overall it reads pretty similarly to how we described it in prose.\n",
"We specify\n",
"\n",
"1. Onehot the low-cardinality categoricals, hash the others\n",
"2. Fill missing values in the numeric columns\n",
"3. Standard scale the numeric and hashed columns\n",
"4. Fit the incremental SGD\n",
"\n",
"And again, these ColumnTransformers are just estimators so we stick them in a regular scikit-learn `Pipeline` before calling `.fit`:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\", \"Concatenating\", UserWarning)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 15.3 s, sys: 1.31 s, total: 16.6 s\n",
"Wall time: 1min 32s\n"
]
}
],
"source": [
"%%time\n",
"result = pipe.fit(X, y, incremental__classes=[0, 1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate Results\n",
"\n",
"They're poor. We predict way more clicks than was appropriate."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"748388"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result.predict(X).sum().compute()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"36965"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.sum().compute()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Future Work\n",
"\n",
"\n",
"- We currently only do a single pass over our data with `Incremental.fit`. We probably need to replace this with a system that will do multiple passes until we converge. \n",
"- We probably want to choose the `class_weights` value more carefully, and there are likely some other parameters to this pipeline that could use tuning\n",
"- Currently dask-ml has mechanisms for hyper-parameter optimization and incremental training like what we're doing here, but they don't work on full pipelines.\n",
"- It's tricky working with both dask arrays and dask dataframes in column transformer pipelines. We have to use dask arrays in order to support `scipy.sparse` matrices from the HashingVectorizer, but we prefer dealing with column names provided by dask dataframes.\n",
"- Should `HashingVectorizer` support `None` values? https://github.com/scikit-learn/scikit-learn/issues/12347\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment