Created
June 7, 2020 14:50
-
-
Save MaxHalford/e23c4fe26c035b818bc40cbdde9c3a8f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Predicting taxi trip durations with creme and chantilly" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"In this example we'll build a model to predict the duration of taxi trips in the city of New-York (dataset [here](https://www.kaggle.com/c/nyc-taxi-trip-duration))." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's first install the necessary dependencies." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!pip install creme chantilly dill" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's now take a look at the data." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Taxis dataset\n", | |
"\n", | |
" Task Regression \n", | |
" Number of samples 1,458,644 \n", | |
"Number of features 8 \n", | |
" Sparse False \n", | |
" Path /Users/mhalford/creme_data/Taxis/train.csv \n", | |
" URL https://maxhalford.github.io/files/datasets/nyc_taxis.zip\n", | |
" Size 186.23 MB \n", | |
" Downloaded True " | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from creme import datasets\n", | |
"\n", | |
"trips = datasets.Taxis()\n", | |
"trips" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'vendor_id': '2',\n", | |
" 'pickup_datetime': datetime.datetime(2016, 1, 1, 0, 0, 17),\n", | |
" 'passenger_count': 5,\n", | |
" 'pickup_longitude': -73.98174285888672,\n", | |
" 'pickup_latitude': 40.71915817260742,\n", | |
" 'dropoff_longitude': -73.93882751464845,\n", | |
" 'dropoff_latitude': 40.82918167114258,\n", | |
" 'store_and_fwd_flag': 'N'}" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x, y = next(iter(trips))\n", | |
"x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"849" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"It seems reasonable to use the distance in order to predict the duration.\n", | |
"\n", | |
"With `creme`, we're working with dictionaries. Therefore, a simple way to go about extracting features is to write a function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"\n", | |
"def distances(trip):\n", | |
" lat_dist = trip['dropoff_latitude'] - trip['pickup_latitude']\n", | |
" lon_dist = trip['dropoff_longitude'] - trip['pickup_longitude']\n", | |
" return {\n", | |
" 'manhattan_distance': abs(lat_dist) + abs(lon_dist),\n", | |
" 'euclidean_distance': math.sqrt(lat_dist ** 2 + lon_dist ** 2)\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can verify that this function works on the first sample." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'manhattan_distance': 0.1529388427734233,\n", | |
" 'euclidean_distance': 0.11809698133739274}" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"distances(trip=x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Additionally, it should worthwhile to extract temporal information." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'hour': 0, 'day': 'Friday'}" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import calendar\n", | |
"\n", | |
"def datetime_info(trip):\n", | |
" day_no = trip['pickup_datetime'].weekday()\n", | |
" return {\n", | |
" 'hour': trip['pickup_datetime'].hour,\n", | |
" 'day': calendar.day_name[day_no]\n", | |
" }\n", | |
"\n", | |
"datetime_info(trip=x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can now assemble these steps into a `TransformerUnion`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"TransformerUnion (\n", | |
" FuncTransformer (\n", | |
" func=\"distances\"\n", | |
" ),\n", | |
" FuncTransformer (\n", | |
" func=\"datetime_info\"\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from creme import compose\n", | |
"\n", | |
"extract_features = compose.TransformerUnion(distances, datetime_info)\n", | |
"extract_features" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`TransformerUnion` is a `Transformer`, which means that it has a `transform_one` method." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'hour': 0,\n", | |
" 'day': 'Friday',\n", | |
" 'manhattan_distance': 0.1529388427734233,\n", | |
" 'euclidean_distance': 0.11809698133739274}" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"features = extract_features.transform_one(x)\n", | |
"features" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can also call `fit_one`, but in this case it is unnecessary because our feature extractors are stateless.\n", | |
"\n", | |
"We would now like to train a linear regression. The problem is that the `day` feature is categorical, whilst a linear regression only accepts numeric data. A simple way circumvent this issue is to use one-hot encoding, which involves replacing the `day` feature with a binary feature per day of the week." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'hour': 0.0,\n", | |
" 'manhattan_distance': 0.1529388427734233,\n", | |
" 'euclidean_distance': 0.11809698133739274,\n", | |
" 'day_Friday': 1}" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import numbers \n", | |
"from creme import preprocessing\n", | |
"\n", | |
"cat = compose.SelectType(str) | preprocessing.OneHotEncoder()\n", | |
"num = compose.SelectType(numbers.Number) | preprocessing.StandardScaler()\n", | |
"\n", | |
"preprocess = compose.TransformerUnion(cat, num)\n", | |
"preprocess.transform_one(features)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can now assemble these steps into a pipeline." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'hour': 0.0,\n", | |
" 'manhattan_distance': 0.1529388427734233,\n", | |
" 'euclidean_distance': 0.11809698133739274,\n", | |
" 'day_Friday': 1}" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipeline = compose.Pipeline(\n", | |
" extract_features,\n", | |
" preprocess\n", | |
")\n", | |
"\n", | |
"pipeline.transform_one(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We're now ready to append a linear regression to our pipeline." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from creme import linear_model\n", | |
"\n", | |
"pipeline |= linear_model.LinearRegression()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's take a look at what our pipeline looks like." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 2.42.3 (20191010.1750)\n", | |
" -->\n", | |
"<!-- Title: %3 Pages: 1 -->\n", | |
"<svg width=\"288pt\" height=\"404pt\"\n", | |
" viewBox=\"0.00 0.00 288.21 404.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 400)\">\n", | |
"<title>%3</title>\n", | |
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-400 284.21,-400 284.21,4 -4,4\"/>\n", | |
"<!-- x -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>x</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"141.57\" cy=\"-378\" rx=\"27\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"141.57\" y=\"-373.8\" font-family=\"Times,serif\" font-size=\"14.00\">x</text>\n", | |
"</g>\n", | |
"<!-- distances -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>distances</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"81.57\" cy=\"-306\" rx=\"42.55\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"81.57\" y=\"-301.8\" font-family=\"Times,serif\" font-size=\"14.00\">distances</text>\n", | |
"</g>\n", | |
"<!-- x->distances -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>x->distances</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M128.55,-361.81C120.83,-352.8 110.86,-341.18 102.1,-330.95\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"104.56,-328.45 95.4,-323.13 99.25,-333 104.56,-328.45\"/>\n", | |
"</g>\n", | |
"<!-- datetime_info -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>datetime_info</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"201.57\" cy=\"-306\" rx=\"59.44\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"201.57\" y=\"-301.8\" font-family=\"Times,serif\" font-size=\"14.00\">datetime_info</text>\n", | |
"</g>\n", | |
"<!-- x->datetime_info -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>x->datetime_info</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M154.59,-361.81C162.2,-352.93 171.99,-341.5 180.66,-331.39\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"183.46,-333.5 187.32,-323.63 178.15,-328.94 183.46,-333.5\"/>\n", | |
"</g>\n", | |
"<!-- Select(str) -->\n", | |
"<g id=\"node4\" class=\"node\">\n", | |
"<title>Select(str)</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"76.57\" cy=\"-234\" rx=\"46.4\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"76.57\" y=\"-229.8\" font-family=\"Times,serif\" font-size=\"14.00\">Select(str)</text>\n", | |
"</g>\n", | |
"<!-- distances->Select(str) -->\n", | |
"<g id=\"edge7\" class=\"edge\">\n", | |
"<title>distances->Select(str)</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M80.33,-287.7C79.78,-279.98 79.12,-270.71 78.5,-262.11\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"81.99,-261.83 77.79,-252.1 75.01,-262.33 81.99,-261.83\"/>\n", | |
"</g>\n", | |
"<!-- Select(Number) -->\n", | |
"<g id=\"node6\" class=\"node\">\n", | |
"<title>Select(Number)</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"207.57\" cy=\"-234\" rx=\"66.67\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"207.57\" y=\"-229.8\" font-family=\"Times,serif\" font-size=\"14.00\">Select(Number)</text>\n", | |
"</g>\n", | |
"<!-- distances->Select(Number) -->\n", | |
"<g id=\"edge9\" class=\"edge\">\n", | |
"<title>distances->Select(Number)</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M106.46,-291.17C124.83,-280.97 150.12,-266.92 170.81,-255.42\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"172.61,-258.42 179.66,-250.51 169.21,-252.3 172.61,-258.42\"/>\n", | |
"</g>\n", | |
"<!-- datetime_info->Select(str) -->\n", | |
"<g id=\"edge10\" class=\"edge\">\n", | |
"<title>datetime_info->Select(str)</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M174.44,-289.81C155.76,-279.35 130.8,-265.37 110.76,-254.15\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"112.37,-251.04 101.93,-249.21 108.95,-257.14 112.37,-251.04\"/>\n", | |
"</g>\n", | |
"<!-- datetime_info->Select(Number) -->\n", | |
"<g id=\"edge11\" class=\"edge\">\n", | |
"<title>datetime_info->Select(Number)</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M203.05,-287.7C203.71,-279.98 204.51,-270.71 205.24,-262.11\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"208.73,-262.37 206.1,-252.1 201.76,-261.77 208.73,-262.37\"/>\n", | |
"</g>\n", | |
"<!-- OneHotEncoder -->\n", | |
"<g id=\"node5\" class=\"node\">\n", | |
"<title>OneHotEncoder</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"67.57\" cy=\"-162\" rx=\"67.64\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"67.57\" y=\"-157.8\" font-family=\"Times,serif\" font-size=\"14.00\">OneHotEncoder</text>\n", | |
"</g>\n", | |
"<!-- Select(str)->OneHotEncoder -->\n", | |
"<g id=\"edge6\" class=\"edge\">\n", | |
"<title>Select(str)->OneHotEncoder</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M74.34,-215.7C73.35,-207.98 72.16,-198.71 71.05,-190.11\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"74.51,-189.58 69.77,-180.1 67.57,-190.47 74.51,-189.58\"/>\n", | |
"</g>\n", | |
"<!-- LinearRegression -->\n", | |
"<g id=\"node8\" class=\"node\">\n", | |
"<title>LinearRegression</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"141.57\" cy=\"-90\" rx=\"72.46\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"141.57\" y=\"-85.8\" font-family=\"Times,serif\" font-size=\"14.00\">LinearRegression</text>\n", | |
"</g>\n", | |
"<!-- OneHotEncoder->LinearRegression -->\n", | |
"<g id=\"edge4\" class=\"edge\">\n", | |
"<title>OneHotEncoder->LinearRegression</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M85.1,-144.41C94.44,-135.58 106.08,-124.57 116.36,-114.84\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"119.04,-117.13 123.9,-107.71 114.23,-112.04 119.04,-117.13\"/>\n", | |
"</g>\n", | |
"<!-- StandardScaler -->\n", | |
"<g id=\"node7\" class=\"node\">\n", | |
"<title>StandardScaler</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"216.57\" cy=\"-162\" rx=\"63.78\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"216.57\" y=\"-157.8\" font-family=\"Times,serif\" font-size=\"14.00\">StandardScaler</text>\n", | |
"</g>\n", | |
"<!-- Select(Number)->StandardScaler -->\n", | |
"<g id=\"edge8\" class=\"edge\">\n", | |
"<title>Select(Number)->StandardScaler</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M209.79,-215.7C210.78,-207.98 211.98,-198.71 213.08,-190.11\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"216.56,-190.47 214.37,-180.1 209.62,-189.58 216.56,-190.47\"/>\n", | |
"</g>\n", | |
"<!-- StandardScaler->LinearRegression -->\n", | |
"<g id=\"edge5\" class=\"edge\">\n", | |
"<title>StandardScaler->LinearRegression</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M198.79,-144.41C189.17,-135.43 177.13,-124.19 166.58,-114.34\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"168.91,-111.73 159.21,-107.47 164.14,-116.85 168.91,-111.73\"/>\n", | |
"</g>\n", | |
"<!-- y -->\n", | |
"<g id=\"node9\" class=\"node\">\n", | |
"<title>y</title>\n", | |
"<ellipse fill=\"none\" stroke=\"black\" cx=\"141.57\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", | |
"<text text-anchor=\"middle\" x=\"141.57\" y=\"-13.8\" font-family=\"Times,serif\" font-size=\"14.00\">y</text>\n", | |
"</g>\n", | |
"<!-- LinearRegression->y -->\n", | |
"<g id=\"edge3\" class=\"edge\">\n", | |
"<title>LinearRegression->y</title>\n", | |
"<path fill=\"none\" stroke=\"black\" d=\"M141.57,-71.7C141.57,-63.98 141.57,-54.71 141.57,-46.11\"/>\n", | |
"<polygon fill=\"black\" stroke=\"black\" points=\"145.07,-46.1 141.57,-36.1 138.07,-46.1 145.07,-46.1\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.dot.Digraph at 0x1a30d1b910>" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pipeline.draw()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let us now use progressive validation to evaluate the performance of our model. This will loop through the data and make a prediction for each sample before learning from it. This is the canonical way of evaluating online machine learning models." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[10,000] MAE: 534.984054\n", | |
"[20,000] MAE: 537.327384\n", | |
"[30,000] MAE: 865.921832\n", | |
"[40,000] MAE: 759.319743\n", | |
"[50,000] MAE: 903.466296\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"MAE: 903.466296" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from creme import metrics\n", | |
"from creme import model_selection\n", | |
"\n", | |
"model_selection.progressive_val_score(\n", | |
" X_y=trips.take(50_000),\n", | |
" model=pipeline,\n", | |
" metric=metrics.MAE(),\n", | |
" print_every=10_000\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we might want to look at tuning some hyperparameters. This is quite to batch learning, because in our case we want the best parameters on-the-fly. To start off, we can enumerate a list of hyperparameters combinations we want to try out. The `expand_param_grid` function is really practical for doing so." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"10" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from creme import optim\n", | |
"\n", | |
"param_grid = model_selection.expand_param_grid({\n", | |
" 'LinearRegression': {\n", | |
" 'optimizer': [\n", | |
" (optim.SGD, {'lr': [.1, .01, .005]}),\n", | |
" (optim.Adam, {'beta_1': [.01, .001], 'lr': [.1, .01, .001]}),\n", | |
" (optim.Adam, {'beta_1': [.1], 'lr': [.001]}),\n", | |
" ]\n", | |
" }\n", | |
"})\n", | |
"\n", | |
"models = [\n", | |
" pipeline._set_params(params)\n", | |
" for params in param_grid\n", | |
"]\n", | |
"\n", | |
"len(models)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"At of writing this document, the only available model selection tool is [successive halving](https://arxiv.org/pdf/1502.07943.pdf). In our case we're doing regression, so we'll use `SuccessiveHalvingRegressor`. You can treat it like any other model, as it implements `fit_one` and `predict_one`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[10,000] MAE: 550.542469\n", | |
"[20,000] MAE: 556.9258\n", | |
"[30,000] MAE: 901.923797\n", | |
"[40,000] MAE: 799.411683\n", | |
"[50,000] MAE: 873.054589\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"MAE: 873.054589" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sh = model_selection.SuccessiveHalvingRegressor(\n", | |
" models=models,\n", | |
" metric=metrics.MAE(),\n", | |
" budget=10000\n", | |
")\n", | |
"\n", | |
"model_selection.progressive_val_score(\n", | |
" X_y=trips.take(50_000),\n", | |
" model=sh,\n", | |
" metric=metrics.MAE(),\n", | |
" print_every=10_000\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Pipeline (\n", | |
" TransformerUnion (\n", | |
" FuncTransformer (\n", | |
" func=\"distances\"\n", | |
" ),\n", | |
" FuncTransformer (\n", | |
" func=\"datetime_info\"\n", | |
" )\n", | |
" ),\n", | |
" TransformerUnion (\n", | |
" Pipeline (\n", | |
" Select (\n", | |
" <class 'str'>\n", | |
" ),\n", | |
" OneHotEncoder (\n", | |
" sparse=False\n", | |
" )\n", | |
" ),\n", | |
" Pipeline (\n", | |
" Select (\n", | |
" <class 'numbers.Number'>\n", | |
" ),\n", | |
" StandardScaler (\n", | |
" with_mean=True\n", | |
" with_std=True\n", | |
" )\n", | |
" )\n", | |
" ),\n", | |
" LinearRegression (\n", | |
" optimizer=SGD (\n", | |
" lr=Constant (\n", | |
" learning_rate=0.005\n", | |
" )\n", | |
" )\n", | |
" loss=Squared ()\n", | |
" l2=0.\n", | |
" intercept=750.856353\n", | |
" intercept_lr=Constant (\n", | |
" learning_rate=0.01\n", | |
" )\n", | |
" clip_gradient=1e+12\n", | |
" initializer=Zeros ()\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"sh.best_model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now how about deploying our model? Well the `creme` team has developped a little tool called [`chantilly`](https://github.com/creme-ml/chantilly) to simplify the process. It is essentially a [Flask](https://flask.palletsprojects.com/en/1.1.x/) app, and so is very simple to install. The source code is also very easy to delve into." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"/bin/sh: chantilly: command not found\n" | |
] | |
} | |
], | |
"source": [ | |
"!chantilly run # run this in a terminal session" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We first have to tell Chantilly what \"flavor\" we want to use. In this case we're doing regression so we'll use the \"regression\" flavor." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<Response [201]>" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import requests\n", | |
"\n", | |
"host = 'http://localhost:5000'\n", | |
"\n", | |
"requests.post(host + '/api/init', json={'flavor': 'regression'})" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let us now upload the model. We need to make a couple of changes first:\n", | |
"\n", | |
"- At the moment, using external needs in user-defined functions need to be done within each function. This might change in a future release.\n", | |
"- The `creme` dataset already takes care of parsing the datetimes. However, `chantilly` will assume that JSON data is provided, which thus has to be accounted for." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def distances(trip):\n", | |
" import math\n", | |
" \n", | |
" lat_dist = trip['dropoff_latitude'] - trip['pickup_latitude']\n", | |
" lon_dist = trip['dropoff_longitude'] - trip['pickup_longitude']\n", | |
" \n", | |
" return {\n", | |
" 'manhattan_distance': abs(lat_dist) + abs(lon_dist),\n", | |
" 'euclidean_distance': math.sqrt(lat_dist ** 2 + lon_dist ** 2)\n", | |
" }\n", | |
"\n", | |
"def datetime_info(trip):\n", | |
" import calendar\n", | |
" import datetime as dt\n", | |
" \n", | |
" day = dt.datetime.fromisoformat(trip['pickup_datetime'])\n", | |
" \n", | |
" return {\n", | |
" 'hour': day.hour,\n", | |
" 'day': calendar.day_name[day.weekday()]\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The model can now be uploaded." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<Response [201]>" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import dill\n", | |
"\n", | |
"extract_features = compose.TransformerUnion(distances, datetime_info)\n", | |
"\n", | |
"pipeline = compose.Pipeline(\n", | |
" extract_features,\n", | |
" preprocess,\n", | |
" linear_model.LinearRegression()\n", | |
")\n", | |
"\n", | |
"requests.post(host + '/api/model', data=dill.dumps(pipeline))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"To make things realistic, we'll run a simulation where the taxis leave and arrive in the order as given in the dataset. Indeed, we can reproduce a live workload from a historical dataset, therefore producing an environment which is very close to what happens in a production setting." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"#0000000 departs at 2016-01-01 00:00:17\n", | |
"#0000001 departs at 2016-01-01 00:00:53\n", | |
"#0000002 departs at 2016-01-01 00:01:01\n", | |
"#0000003 departs at 2016-01-01 00:01:14\n", | |
"#0000004 departs at 2016-01-01 00:01:20\n", | |
"#0000005 departs at 2016-01-01 00:01:33\n", | |
"#0000006 departs at 2016-01-01 00:01:37\n", | |
"#0000007 departs at 2016-01-01 00:01:47\n", | |
"#0000008 departs at 2016-01-01 00:02:06\n", | |
"#0000009 departs at 2016-01-01 00:02:45\n", | |
"#0000010 departs at 2016-01-01 00:03:02\n", | |
"#0000006 arrives at 2016-01-01 00:03:31 - average error: 0:01:54\n", | |
"#0000011 departs at 2016-01-01 00:03:31\n", | |
"#0000012 departs at 2016-01-01 00:03:35\n", | |
"#0000013 departs at 2016-01-01 00:04:42\n", | |
"#0000014 departs at 2016-01-01 00:04:57\n", | |
"#0000015 departs at 2016-01-01 00:05:07\n", | |
"#0000016 departs at 2016-01-01 00:05:08\n", | |
"#0000017 departs at 2016-01-01 00:05:18\n", | |
"#0000018 departs at 2016-01-01 00:05:35\n", | |
"#0000019 departs at 2016-01-01 00:05:39\n", | |
"#0000003 arrives at 2016-01-01 00:05:54 - average error: 0:03:17\n", | |
"#0000020 departs at 2016-01-01 00:06:04\n", | |
"#0000021 departs at 2016-01-01 00:06:12\n", | |
"#0000022 departs at 2016-01-01 00:06:22\n", | |
"#0000023 departs at 2016-01-01 00:06:24\n", | |
"#0000024 departs at 2016-01-01 00:06:47\n", | |
"#0000025 departs at 2016-01-01 00:06:56\n", | |
"#0000026 departs at 2016-01-01 00:06:59\n", | |
"#0000027 departs at 2016-01-01 00:07:04\n", | |
"#0000028 departs at 2016-01-01 00:07:06\n", | |
"#0000029 departs at 2016-01-01 00:07:07\n", | |
"#0000021 arrives at 2016-01-01 00:07:13 - average error: 0:02:16.660295\n", | |
"#0000030 departs at 2016-01-01 00:07:22\n", | |
"#0000010 arrives at 2016-01-01 00:07:25 - average error: 0:02:48.245221\n", | |
"#0000031 departs at 2016-01-01 00:07:27\n", | |
"#0000032 departs at 2016-01-01 00:07:29\n", | |
"#0000033 departs at 2016-01-01 00:07:34\n", | |
"#0000034 departs at 2016-01-01 00:07:46\n", | |
"#0000035 departs at 2016-01-01 00:07:47\n", | |
"#0000002 arrives at 2016-01-01 00:07:49 - average error: 0:03:36.196177\n", | |
"#0000036 departs at 2016-01-01 00:07:52\n", | |
"#0000037 departs at 2016-01-01 00:08:07\n", | |
"#0000038 departs at 2016-01-01 00:08:09\n", | |
"#0000039 departs at 2016-01-01 00:08:11\n", | |
"#0000040 departs at 2016-01-01 00:08:15\n", | |
"#0000041 departs at 2016-01-01 00:08:29\n", | |
"#0000014 arrives at 2016-01-01 00:08:37 - average error: 0:03:34.342094\n", | |
"#0000042 departs at 2016-01-01 00:08:37\n", | |
"#0000043 departs at 2016-01-01 00:08:38\n", | |
"#0000044 departs at 2016-01-01 00:08:40\n", | |
"#0000045 departs at 2016-01-01 00:08:46\n", | |
"#0000046 departs at 2016-01-01 00:08:47\n", | |
"#0000047 departs at 2016-01-01 00:08:49\n", | |
"#0000048 departs at 2016-01-01 00:08:52\n", | |
"#0000049 departs at 2016-01-01 00:08:53\n" | |
] | |
}, | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-28-15f9fc6bbe36>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;31m# Wait\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mnap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrip\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'pickup_datetime'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mnow\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mnow\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrip\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'pickup_datetime'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-28-15f9fc6bbe36>\u001b[0m in \u001b[0;36mnap\u001b[0;34m(td)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mnap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtd\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimedelta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msleep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mseconds\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"import datetime as dt\n", | |
"import time\n", | |
"from creme import datasets\n", | |
"from creme import stream\n", | |
"import requests\n", | |
"\n", | |
"\n", | |
"# Use the first trip's departure time as a reference time\n", | |
"taxis = datasets.Taxis()\n", | |
"now = next(iter(taxis))[0]['pickup_datetime']\n", | |
"mae = metrics.MAE() \n", | |
"predictions = {}\n", | |
"\n", | |
"\n", | |
"def nap(td: dt.timedelta):\n", | |
" time.sleep(td.seconds / 10)\n", | |
"\n", | |
"\n", | |
"for trip_no, trip, duration in stream.simulate_qa(\n", | |
" taxis,\n", | |
" moment='pickup_datetime',\n", | |
" delay=lambda _, duration: dt.timedelta(seconds=duration)\n", | |
"):\n", | |
"\n", | |
" trip_no = str(trip_no).zfill(len(str(taxis.n_samples)))\n", | |
"\n", | |
" # Taxi trip starts\n", | |
"\n", | |
" if duration is None:\n", | |
"\n", | |
" # Wait\n", | |
" nap(trip['pickup_datetime'] - now)\n", | |
" now = trip['pickup_datetime']\n", | |
"\n", | |
" # Ask chantilly to make a prediction\n", | |
" r = requests.post(host + '/api/predict', json={\n", | |
" 'id': trip_no,\n", | |
" 'features': {**trip, 'pickup_datetime': trip['pickup_datetime'].isoformat()}\n", | |
" })\n", | |
"\n", | |
" # Store the prediction\n", | |
" predictions[trip_no] = r.json()['prediction']\n", | |
"\n", | |
" print(f'#{trip_no} departs at {now}')\n", | |
" continue\n", | |
"\n", | |
" # Taxi trip ends\n", | |
"\n", | |
" # Wait\n", | |
" arrival_time = trip['pickup_datetime'] + dt.timedelta(seconds=duration)\n", | |
" nap(arrival_time - now)\n", | |
" now = arrival_time\n", | |
"\n", | |
" # Ask chantilly to update the model\n", | |
" requests.post(host + '/api/learn', json={'id': trip_no, 'ground_truth': duration})\n", | |
"\n", | |
" # Update the metric\n", | |
" mae.update(y_true=duration, y_pred=predictions.pop(trip_no))\n", | |
"\n", | |
" msg = f'#{trip_no} arrives at {now} - average error: {dt.timedelta(seconds=mae.get())}'\n", | |
" print(msg)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'MAE': 214.34209431927968,\n", | |
" 'RMSE': 248.1057514493265,\n", | |
" 'SMAPE': 167.4549197133317}" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"requests.get(host + '/api/metrics').json()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'learn': {'ewm_duration': 4426537,\n", | |
" 'ewm_duration_human': '4ms426μs537ns',\n", | |
" 'mean_duration': 4422086,\n", | |
" 'mean_duration_human': '4ms422μs86ns',\n", | |
" 'n_calls': 6},\n", | |
" 'predict': {'ewm_duration': 3974240,\n", | |
" 'ewm_duration_human': '3ms974μs240ns',\n", | |
" 'mean_duration': 4081125,\n", | |
" 'mean_duration_human': '4ms81μs125ns',\n", | |
" 'n_calls': 62}}" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"requests.get(host + '/api/stats').json()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment