Skip to content

Instantly share code, notes, and snippets.

@firmai
Last active October 14, 2022 09:31
Show Gist options
  • Save firmai/a4c9cc8667311204f41d5a4a702a0005 to your computer and use it in GitHub Desktop.
Save firmai/a4c9cc8667311204f41d5a4a702a0005 to your computer and use it in GitHub Desktop.
Decesion Tree Regression (HW2).ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/firmai/a4c9cc8667311204f41d5a4a702a0005/01_decision_trees.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "su65C3O5ZDur"
},
"source": [
"# How to use decision trees to predict equity returns and price moves"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mNQBuG5ZZDuw"
},
"source": [
"In this notebook, we illustrate how to use tree-based models to gain insight and make predictions. \n",
"\n",
"To demonstrate regression trees we predict returns, and for the classification case, we return to the example of positive and negative asset price moves."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Gsz8G7AZDuy"
},
"source": [
"## Imports & Settings"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:25.352734Z",
"start_time": "2021-04-16T00:32:25.350697Z"
},
"id": "WKdqCz6ZZDu0"
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings('ignore')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:26.702500Z",
"start_time": "2021-04-16T00:32:25.552972Z"
},
"id": "ml-Oh0CMZDu3"
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import os, sys\n",
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
"from scipy.stats import spearmanr\n",
"import pandas as pd\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.ticker import FuncFormatter\n",
"from matplotlib import cm\n",
"import seaborn as sns\n",
"\n",
"from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, export_graphviz, _tree\n",
"from sklearn.linear_model import LinearRegression, LogisticRegression\n",
"from sklearn.model_selection import train_test_split, GridSearchCV, learning_curve\n",
"from sklearn.metrics import roc_auc_score, roc_curve, mean_squared_error, make_scorer\n",
"import graphviz\n",
"\n",
"import statsmodels.api as sm"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:26.708683Z",
"start_time": "2021-04-16T00:32:26.703952Z"
},
"id": "zzFVDCnXZDu7"
},
"outputs": [],
"source": [
"sys.path.insert(1, os.path.join(sys.path[0], '..'))\n",
"\n",
"import numpy as np\n",
"np.random.seed(42)\n",
"\n",
"def format_time(t):\n",
" \"\"\"Return a formatted time string 'HH:MM:SS\n",
" based on a numeric time() value\"\"\"\n",
" m, s = divmod(t, 60)\n",
" h, m = divmod(m, 60)\n",
" return f'{h:0>2.0f}:{m:0>2.0f}:{s:0>2.0f}'\n",
"\n",
"# Don't pay too much attention to this piece of code yet\n",
"# .. we will look more into cross validation methods in the future.\n",
"class MultipleTimeSeriesCV:\n",
" \"\"\"Generates tuples of train_idx, test_idx pairs\n",
" Assumes the MultiIndex contains levels 'symbol' and 'date'\n",
" purges overlapping outcomes\"\"\"\n",
"\n",
" def __init__(self,\n",
" n_splits=3,\n",
" train_period_length=126,\n",
" test_period_length=21,\n",
" lookahead=None,\n",
" date_idx='date',\n",
" shuffle=False):\n",
" self.n_splits = n_splits\n",
" self.lookahead = lookahead\n",
" self.test_length = test_period_length\n",
" self.train_length = train_period_length\n",
" self.shuffle = shuffle\n",
" self.date_idx = date_idx\n",
"\n",
" def split(self, X, y=None, groups=None):\n",
" unique_dates = X.index.get_level_values(self.date_idx).unique()\n",
" days = sorted(unique_dates, reverse=True)\n",
" split_idx = []\n",
" for i in range(self.n_splits):\n",
" test_end_idx = i * self.test_length\n",
" test_start_idx = test_end_idx + self.test_length\n",
" train_end_idx = test_start_idx + self.lookahead - 1\n",
" train_start_idx = train_end_idx + self.train_length + self.lookahead - 1\n",
" split_idx.append([train_start_idx, train_end_idx,\n",
" test_start_idx, test_end_idx])\n",
"\n",
" dates = X.reset_index()[[self.date_idx]]\n",
" for train_start, train_end, test_start, test_end in split_idx:\n",
"\n",
" train_idx = dates[(dates[self.date_idx] > days[train_start])\n",
" & (dates[self.date_idx] <= days[train_end])].index\n",
" test_idx = dates[(dates[self.date_idx] > days[test_start])\n",
" & (dates[self.date_idx] <= days[test_end])].index\n",
" if self.shuffle:\n",
" np.random.shuffle(list(train_idx))\n",
" yield train_idx.to_numpy(), test_idx.to_numpy()\n",
"\n",
" def get_n_splits(self, X, y, groups=None):\n",
" return self.n_splits"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:26.721121Z",
"start_time": "2021-04-16T00:32:26.709927Z"
},
"id": "b8dIEObMZDu-"
},
"outputs": [],
"source": [
"sns.set_style('white')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:26.730830Z",
"start_time": "2021-04-16T00:32:26.722176Z"
},
"id": "SZ53r0TbZDu_"
},
"outputs": [],
"source": [
"results_path = Path('results', 'decision_trees')\n",
"if not results_path.exists():\n",
" results_path.mkdir(parents=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QmnvurmdZDvB"
},
"source": [
"## Load Model Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MhF6xQYfZDvD"
},
"source": [
"We use a simplified version of the data set constructed in Chapter 4, Alpha factor research. It consists of daily stock prices provided by Quandl for the 2010-2017 period and various engineered features. The details can be found in the notebook [data_prep](00_data_prep.ipynb) in the GitHub repo for this chapter."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "afnYWl7UZDvF"
},
"source": [
"The decision tree models in this chapter are not equipped to handle missing or categorical variables, so we will apply dummy encoding to the latter after dropping any of the former."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:30.995569Z",
"start_time": "2021-04-16T00:32:30.805133Z"
},
"scrolled": true,
"id": "U_g_q7HoZDvH"
},
"outputs": [],
"source": [
"data = pd.read_csv(\"https://open-data.s3.filebase.com/equities_monthly.csv\")"
]
},
{
"cell_type": "code",
"source": [
"data = data.set_index([\"ticker\",\"date\"])"
],
"metadata": {
"id": "FCmbwHLinZ1j"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"data.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 418
},
"id": "4gFErwwMnVlu",
"outputId": "d9773d44-37a5-455e-c0e1-9910a9022c3e"
},
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" atr bb_down bb_high bb_low bb_mid bb_up \\\n",
"ticker date \n",
"A 2006-12-31 -0.442397 0.072112 3.223450 3.137767 3.180609 0.013571 \n",
" 2007-01-31 -0.520562 0.016425 3.207382 3.111725 3.159553 0.079232 \n",
" 2007-02-28 -0.464587 0.008895 3.177726 3.111756 3.144741 0.057075 \n",
" 2007-03-31 -0.497099 0.099084 3.191225 3.078332 3.134779 0.013810 \n",
" 2007-04-30 -0.552364 0.007060 3.234956 3.189514 3.212235 0.038382 \n",
"\n",
" macd natr rsi sector ... \\\n",
"ticker date ... \n",
"A 2006-12-31 0.188915 2.037431 60.772801 Capital Goods ... \n",
" 2007-01-31 -0.366521 1.886539 35.388712 Capital Goods ... \n",
" 2007-02-28 -0.116034 2.141267 39.517738 Capital Goods ... \n",
" 2007-03-31 0.196264 1.886661 65.552439 Capital Goods ... \n",
" 2007-04-30 0.234341 1.630554 48.928443 Capital Goods ... \n",
"\n",
" RMW CMA momentum_3 momentum_6 momentum_3_6 \\\n",
"ticker date \n",
"A 2006-12-31 -1.374298 1.620554 -0.072978 -0.077871 -0.004893 \n",
" 2007-01-31 -1.418516 1.654243 0.046866 0.101630 0.054764 \n",
" 2007-02-28 -1.618912 1.469760 0.006869 0.005676 -0.001193 \n",
" 2007-03-31 -2.336146 0.682374 -0.072323 -0.056068 0.016255 \n",
" 2007-04-30 -2.304566 0.665849 0.003918 -0.026027 -0.029945 \n",
"\n",
" momentum_12 momentum_3_12 year month target \n",
"ticker date \n",
"A 2006-12-31 -0.090712 -0.017733 2006 12 -0.081779 \n",
" 2007-01-31 0.076960 0.030093 2007 1 -0.007812 \n",
" 2007-02-28 -0.002602 -0.009471 2007 2 0.061102 \n",
" 2007-03-31 -0.070101 0.002222 2007 3 0.020184 \n",
" 2007-04-30 -0.029424 -0.033342 2007 4 0.110562 \n",
"\n",
"[5 rows x 27 columns]"
],
"text/html": [
"\n",
" <div id=\"df-e45b18b2-fff9-48db-bbf9-f4c5053538c0\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>atr</th>\n",
" <th>bb_down</th>\n",
" <th>bb_high</th>\n",
" <th>bb_low</th>\n",
" <th>bb_mid</th>\n",
" <th>bb_up</th>\n",
" <th>macd</th>\n",
" <th>natr</th>\n",
" <th>rsi</th>\n",
" <th>sector</th>\n",
" <th>...</th>\n",
" <th>RMW</th>\n",
" <th>CMA</th>\n",
" <th>momentum_3</th>\n",
" <th>momentum_6</th>\n",
" <th>momentum_3_6</th>\n",
" <th>momentum_12</th>\n",
" <th>momentum_3_12</th>\n",
" <th>year</th>\n",
" <th>month</th>\n",
" <th>target</th>\n",
" </tr>\n",
" <tr>\n",
" <th>ticker</th>\n",
" <th>date</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">A</th>\n",
" <th>2006-12-31</th>\n",
" <td>-0.442397</td>\n",
" <td>0.072112</td>\n",
" <td>3.223450</td>\n",
" <td>3.137767</td>\n",
" <td>3.180609</td>\n",
" <td>0.013571</td>\n",
" <td>0.188915</td>\n",
" <td>2.037431</td>\n",
" <td>60.772801</td>\n",
" <td>Capital Goods</td>\n",
" <td>...</td>\n",
" <td>-1.374298</td>\n",
" <td>1.620554</td>\n",
" <td>-0.072978</td>\n",
" <td>-0.077871</td>\n",
" <td>-0.004893</td>\n",
" <td>-0.090712</td>\n",
" <td>-0.017733</td>\n",
" <td>2006</td>\n",
" <td>12</td>\n",
" <td>-0.081779</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2007-01-31</th>\n",
" <td>-0.520562</td>\n",
" <td>0.016425</td>\n",
" <td>3.207382</td>\n",
" <td>3.111725</td>\n",
" <td>3.159553</td>\n",
" <td>0.079232</td>\n",
" <td>-0.366521</td>\n",
" <td>1.886539</td>\n",
" <td>35.388712</td>\n",
" <td>Capital Goods</td>\n",
" <td>...</td>\n",
" <td>-1.418516</td>\n",
" <td>1.654243</td>\n",
" <td>0.046866</td>\n",
" <td>0.101630</td>\n",
" <td>0.054764</td>\n",
" <td>0.076960</td>\n",
" <td>0.030093</td>\n",
" <td>2007</td>\n",
" <td>1</td>\n",
" <td>-0.007812</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2007-02-28</th>\n",
" <td>-0.464587</td>\n",
" <td>0.008895</td>\n",
" <td>3.177726</td>\n",
" <td>3.111756</td>\n",
" <td>3.144741</td>\n",
" <td>0.057075</td>\n",
" <td>-0.116034</td>\n",
" <td>2.141267</td>\n",
" <td>39.517738</td>\n",
" <td>Capital Goods</td>\n",
" <td>...</td>\n",
" <td>-1.618912</td>\n",
" <td>1.469760</td>\n",
" <td>0.006869</td>\n",
" <td>0.005676</td>\n",
" <td>-0.001193</td>\n",
" <td>-0.002602</td>\n",
" <td>-0.009471</td>\n",
" <td>2007</td>\n",
" <td>2</td>\n",
" <td>0.061102</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2007-03-31</th>\n",
" <td>-0.497099</td>\n",
" <td>0.099084</td>\n",
" <td>3.191225</td>\n",
" <td>3.078332</td>\n",
" <td>3.134779</td>\n",
" <td>0.013810</td>\n",
" <td>0.196264</td>\n",
" <td>1.886661</td>\n",
" <td>65.552439</td>\n",
" <td>Capital Goods</td>\n",
" <td>...</td>\n",
" <td>-2.336146</td>\n",
" <td>0.682374</td>\n",
" <td>-0.072323</td>\n",
" <td>-0.056068</td>\n",
" <td>0.016255</td>\n",
" <td>-0.070101</td>\n",
" <td>0.002222</td>\n",
" <td>2007</td>\n",
" <td>3</td>\n",
" <td>0.020184</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2007-04-30</th>\n",
" <td>-0.552364</td>\n",
" <td>0.007060</td>\n",
" <td>3.234956</td>\n",
" <td>3.189514</td>\n",
" <td>3.212235</td>\n",
" <td>0.038382</td>\n",
" <td>0.234341</td>\n",
" <td>1.630554</td>\n",
" <td>48.928443</td>\n",
" <td>Capital Goods</td>\n",
" <td>...</td>\n",
" <td>-2.304566</td>\n",
" <td>0.665849</td>\n",
" <td>0.003918</td>\n",
" <td>-0.026027</td>\n",
" <td>-0.029945</td>\n",
" <td>-0.029424</td>\n",
" <td>-0.033342</td>\n",
" <td>2007</td>\n",
" <td>4</td>\n",
" <td>0.110562</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 27 columns</p>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e45b18b2-fff9-48db-bbf9-f4c5053538c0')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-e45b18b2-fff9-48db-bbf9-f4c5053538c0 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-e45b18b2-fff9-48db-bbf9-f4c5053538c0');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMqnrU6YZDvJ"
},
"source": [
"## Simple Regression Tree with Time Series Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lSuL8bbKZDvK"
},
"source": [
"Regression trees make predictions based on the mean outcome value for the training samples assigned to a given node and typically rely on the mean-squared error to select optimal rules during recursive binary splitting.\n",
"\n",
"Given a training set, the algorithm iterates over the predictors, $X_1, X_2, ..., X_p$, and possible cutpoints, $s_1, s_2, ..., s_N$, to find an optimal combination. The optimal rule splits the feature space into two regions, $\\{X\\mid X_i < s_j\\}$ and $\\{X\\mid X_i > s_j\\}$, with values for the $X_i$ feature either below or above the $s_j$ threshold so that predictions based on the training subsets maximize the reduction of the squared residuals relative to the current node."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hMTc4QxQZDvK"
},
"source": [
"### Generate two lags of monthly returns"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:38.965107Z",
"start_time": "2021-04-16T00:32:38.942472Z"
},
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-mnPEltXZDvL",
"outputId": "21412ce4-296b-4e33-d0ec-fbc263d0d779"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"MultiIndex: 77176 entries, ('A', '2007-01-31') to ('ZION', '2017-11-30')\n",
"Data columns (total 3 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 y 77176 non-null float64\n",
" 1 t-1 77176 non-null float64\n",
" 2 t-2 77176 non-null float64\n",
"dtypes: float64(3)\n",
"memory usage: 2.1+ MB\n"
]
}
],
"source": [
"X2 = data.loc[:, ['target', 'return_1m']]\n",
"X2.columns = ['y', 't-1']\n",
"X2['t-2'] = data.groupby(level='ticker').return_1m.shift()\n",
"X2 = X2.dropna()\n",
"X2.info()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:38.969598Z",
"start_time": "2021-04-16T00:32:38.966574Z"
},
"id": "oaK6FwziZDvM"
},
"outputs": [],
"source": [
"y2 = X2.y\n",
"X2 = X2.drop('y', axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "egeb6VtsZDvN"
},
"source": [
"### Explore Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WhiVPIs7ZDvO"
},
"source": [
"Note the small spike where we clipped the data."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:39.517083Z",
"start_time": "2021-04-16T00:32:38.970753Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
"height": 279
},
"id": "kzYKbveqZDvP",
"outputId": "5e2da0d1-07f8-4e0c-9af2-6be1d7b269fd"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"sns.distplot(y2)\n",
"sns.despine();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ClGLEcZ3ZDvR"
},
"source": [
"### Configure Tree"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P7cRADS_ZDvU"
},
"source": [
"Let's start with a simplified example to facilitate visualization and only use two months of lagged returns to predict the following month, in the vein of an AR(2) model from the last chapter:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:39.521903Z",
"start_time": "2021-04-16T00:32:39.518821Z"
},
"id": "Cinm1A8TZDvV"
},
"outputs": [],
"source": [
"reg_tree_t2 = DecisionTreeRegressor(criterion='mse',\n",
" splitter='best',\n",
" max_depth=6,\n",
" min_samples_split=2,\n",
" min_samples_leaf=50,\n",
" min_weight_fraction_leaf=0.0,\n",
" max_features=None,\n",
" random_state=42,\n",
" max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c4o82YPnZDvY"
},
"source": [
"### Train Decision Tree"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:46.840780Z",
"start_time": "2021-04-16T00:32:39.523207Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vcOh4asxZDva",
"outputId": "40964f44-f899-4324-f6ac-5faebddd9288"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"150 ms ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"reg_tree_t2.fit(X=X2, y=y2)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:46.944042Z",
"start_time": "2021-04-16T00:32:46.846638Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1uiyb_oDZDvb",
"outputId": "2e7a7193-18df-4158-e9da-808753bb712a"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DecisionTreeRegressor(criterion='mse', max_depth=6, min_samples_leaf=50,\n",
" random_state=42)"
]
},
"metadata": {},
"execution_count": 14
}
],
"source": [
"reg_tree_t2.fit(X=X2, y=y2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lFsAzGNlZDvd"
},
"source": [
"### Visualize Tree"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k2TqwCi-ZDve"
},
"source": [
"You can visualize the tree using the graphviz library (see GitHub for installation instructions) because sklearn can output a description of the tree using the .dot language used by that library. \n",
"\n",
"You can configure the output to include feature and class labels and limit the number of levels to keep the chart readable, as follows:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:46.993321Z",
"start_time": "2021-04-16T00:32:46.945531Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
"height": 495
},
"id": "9DB_MIwFZDve",
"outputId": "4b3ada4f-22a6-4050-f3f2-223209c3e7a4"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<graphviz.files.Source at 0x7f2b18aa6150>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: Tree Pages: 1 -->\n<svg width=\"608pt\" height=\"356pt\"\n viewBox=\"0.00 0.00 608.00 356.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 352)\">\n<title>Tree</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-352 604,-352 604,4 -4,4\"/>\n<!-- 0 -->\n<g id=\"node1\" class=\"node\">\n<title>0</title>\n<path fill=\"#f1b991\" stroke=\"#000000\" d=\"M346,-348C346,-348 242,-348 242,-348 236,-348 230,-342 230,-336 230,-336 230,-292 230,-292 230,-286 236,-280 242,-280 242,-280 346,-280 346,-280 352,-280 358,-286 358,-292 358,-292 358,-336 358,-336 358,-342 352,-348 346,-348\"/>\n<text text-anchor=\"start\" x=\"257.5\" y=\"-332.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;2 ≤ &#45;0.267</text>\n<text text-anchor=\"start\" x=\"253.5\" y=\"-317.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.008</text>\n<text text-anchor=\"start\" x=\"238\" y=\"-302.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 77176</text>\n<text text-anchor=\"start\" x=\"254\" y=\"-287.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.01</text>\n</g>\n<!-- 1 -->\n<g id=\"node2\" class=\"node\">\n<title>1</title>\n<path fill=\"#eca470\" stroke=\"#000000\" d=\"M269,-244C269,-244 181,-244 181,-244 175,-244 169,-238 169,-232 169,-232 169,-188 169,-188 169,-182 175,-176 181,-176 181,-176 269,-176 269,-176 275,-176 281,-182 281,-188 281,-188 281,-232 281,-232 281,-238 275,-244 269,-244\"/>\n<text text-anchor=\"start\" x=\"188.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;1 ≤ &#45;0.137</text>\n<text text-anchor=\"start\" x=\"184.5\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.035</text>\n<text text-anchor=\"start\" x=\"177\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 722</text>\n<text text-anchor=\"start\" x=\"181\" y=\"-183.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.037</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g id=\"edge1\" class=\"edge\">\n<title>0&#45;&gt;1</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M271.4068,-279.9465C265.6286,-271.2373 259.3425,-261.7626 253.3166,-252.6801\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"256.089,-250.5278 247.6439,-244.13 250.256,-254.3978 256.089,-250.5278\"/>\n<text text-anchor=\"middle\" x=\"242.6866\" y=\"-264.9336\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">True</text>\n</g>\n<!-- 18 -->\n<g id=\"node9\" class=\"node\">\n<title>18</title>\n<path fill=\"#f1b992\" stroke=\"#000000\" d=\"M415,-244C415,-244 311,-244 311,-244 305,-244 299,-238 299,-232 299,-232 299,-188 299,-188 299,-182 305,-176 311,-176 311,-176 415,-176 415,-176 421,-176 427,-182 427,-188 427,-188 427,-232 427,-232 427,-238 421,-244 415,-244\"/>\n<text text-anchor=\"start\" x=\"326.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;1 ≤ &#45;0.147</text>\n<text text-anchor=\"start\" x=\"322.5\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.008</text>\n<text text-anchor=\"start\" x=\"307\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 76454</text>\n<text text-anchor=\"start\" x=\"323\" y=\"-183.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.01</text>\n</g>\n<!-- 0&#45;&gt;18 -->\n<g id=\"edge8\" class=\"edge\">\n<title>0&#45;&gt;18</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M316.5932,-279.9465C322.3714,-271.2373 328.6575,-261.7626 334.6834,-252.6801\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"337.744,-254.3978 340.3561,-244.13 331.911,-250.5278 337.744,-254.3978\"/>\n<text text-anchor=\"middle\" x=\"345.3134\" y=\"-264.9336\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">False</text>\n</g>\n<!-- 2 -->\n<g id=\"node3\" class=\"node\">\n<title>2</title>\n<path fill=\"#e99659\" stroke=\"#000000\" d=\"M140,-140C140,-140 52,-140 52,-140 46,-140 40,-134 40,-128 40,-128 40,-84 40,-84 40,-78 46,-72 52,-72 52,-72 140,-72 140,-72 146,-72 152,-78 152,-84 152,-84 152,-128 152,-128 152,-134 146,-140 140,-140\"/>\n<text text-anchor=\"start\" x=\"59.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;1 ≤ &#45;0.228</text>\n<text text-anchor=\"start\" x=\"55.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.041</text>\n<text text-anchor=\"start\" x=\"48\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 217</text>\n<text text-anchor=\"start\" x=\"52\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.057</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g id=\"edge2\" class=\"edge\">\n<title>1&#45;&gt;2</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M182.7606,-175.9465C171.0669,-166.519 158.2606,-156.1946 146.163,-146.4415\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"148.3162,-143.6816 138.3343,-140.13 143.9227,-149.1312 148.3162,-143.6816\"/>\n</g>\n<!-- 5 -->\n<g id=\"node6\" class=\"node\">\n<title>5</title>\n<path fill=\"#eeab7a\" stroke=\"#000000\" d=\"M270,-140C270,-140 182,-140 182,-140 176,-140 170,-134 170,-128 170,-128 170,-84 170,-84 170,-78 176,-72 182,-72 182,-72 270,-72 270,-72 276,-72 282,-78 282,-84 282,-84 282,-128 282,-128 282,-134 276,-140 270,-140\"/>\n<text text-anchor=\"start\" x=\"191.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;1 ≤ 0.321</text>\n<text text-anchor=\"start\" x=\"185.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.032</text>\n<text text-anchor=\"start\" x=\"178\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 505</text>\n<text text-anchor=\"start\" x=\"182\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.029</text>\n</g>\n<!-- 1&#45;&gt;5 -->\n<g id=\"edge5\" class=\"edge\">\n<title>1&#45;&gt;5</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M225.3274,-175.9465C225.406,-167.776 225.491,-158.9318 225.5734,-150.3697\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"229.0754,-150.1632 225.6718,-140.13 222.0757,-150.0958 229.0754,-150.1632\"/>\n</g>\n<!-- 3 -->\n<g id=\"node4\" class=\"node\">\n<title>3</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M42,-36C42,-36 12,-36 12,-36 6,-36 0,-30 0,-24 0,-24 0,-12 0,-12 0,-6 6,0 12,0 12,0 42,0 42,0 48,0 54,-6 54,-12 54,-12 54,-24 54,-24 54,-30 48,-36 42,-36\"/>\n<text text-anchor=\"middle\" x=\"27\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g id=\"edge3\" class=\"edge\">\n<title>2&#45;&gt;3</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M69.3228,-71.9769C62.0807,-62.7406 54.3531,-52.8851 47.5097,-44.1573\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"50.1888,-41.9017 41.2641,-36.192 44.6802,-46.221 50.1888,-41.9017\"/>\n</g>\n<!-- 4 -->\n<g id=\"node5\" class=\"node\">\n<title>4</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M114,-36C114,-36 84,-36 84,-36 78,-36 72,-30 72,-24 72,-24 72,-12 72,-12 72,-6 78,0 84,0 84,0 114,0 114,0 120,0 126,-6 126,-12 126,-12 126,-24 126,-24 126,-30 120,-36 114,-36\"/>\n<text text-anchor=\"middle\" x=\"99\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 2&#45;&gt;4 -->\n<g id=\"edge4\" class=\"edge\">\n<title>2&#45;&gt;4</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M97.1599,-71.9769C97.4488,-63.5023 97.7555,-54.5065 98.0339,-46.3388\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"101.537,-46.3054 98.3798,-36.192 94.541,-46.0669 101.537,-46.3054\"/>\n</g>\n<!-- 6 -->\n<g id=\"node7\" class=\"node\">\n<title>6</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M210,-36C210,-36 180,-36 180,-36 174,-36 168,-30 168,-24 168,-24 168,-12 168,-12 168,-6 174,0 180,0 180,0 210,0 210,0 216,0 222,-6 222,-12 222,-12 222,-24 222,-24 222,-30 216,-36 210,-36\"/>\n<text text-anchor=\"middle\" x=\"195\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 5&#45;&gt;6 -->\n<g id=\"edge6\" class=\"edge\">\n<title>5&#45;&gt;6</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M214.0146,-71.9769C210.9621,-63.3119 207.7177,-54.102 204.7895,-45.7894\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"208.0323,-44.4609 201.4085,-36.192 201.43,-46.7868 208.0323,-44.4609\"/>\n</g>\n<!-- 17 -->\n<g id=\"node8\" class=\"node\">\n<title>17</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M282,-36C282,-36 252,-36 252,-36 246,-36 240,-30 240,-24 240,-24 240,-12 240,-12 240,-6 246,0 252,0 252,0 282,0 282,0 288,0 294,-6 294,-12 294,-12 294,-24 294,-24 294,-30 288,-36 282,-36\"/>\n<text text-anchor=\"middle\" x=\"267\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 5&#45;&gt;17 -->\n<g id=\"edge7\" class=\"edge\">\n<title>5&#45;&gt;17</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M241.8517,-71.9769C245.9331,-63.2167 250.2741,-53.8995 254.1802,-45.5157\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"257.4735,-46.7346 258.5242,-36.192 251.1284,-43.7783 257.4735,-46.7346\"/>\n</g>\n<!-- 19 -->\n<g id=\"node10\" class=\"node\">\n<title>19</title>\n<path fill=\"#f3c5a4\" stroke=\"#000000\" d=\"M410,-140C410,-140 314,-140 314,-140 308,-140 302,-134 302,-128 302,-128 302,-84 302,-84 302,-78 308,-72 314,-72 314,-72 410,-72 410,-72 416,-72 422,-78 422,-84 422,-84 422,-128 422,-128 422,-134 416,-140 410,-140\"/>\n<text text-anchor=\"start\" x=\"329.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;2 ≤ &#45;0.17</text>\n<text text-anchor=\"start\" x=\"321.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.023</text>\n<text text-anchor=\"start\" x=\"310\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 3243</text>\n<text text-anchor=\"start\" x=\"315.5\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = &#45;0.005</text>\n</g>\n<!-- 18&#45;&gt;19 -->\n<g id=\"edge9\" class=\"edge\">\n<title>18&#45;&gt;19</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M362.6726,-175.9465C362.594,-167.776 362.509,-158.9318 362.4266,-150.3697\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"365.9243,-150.0958 362.3282,-140.13 358.9246,-150.1632 365.9243,-150.0958\"/>\n</g>\n<!-- 42 -->\n<g id=\"node13\" class=\"node\">\n<title>42</title>\n<path fill=\"#f1b991\" stroke=\"#000000\" d=\"M556,-140C556,-140 452,-140 452,-140 446,-140 440,-134 440,-128 440,-128 440,-84 440,-84 440,-78 446,-72 452,-72 452,-72 556,-72 556,-72 562,-72 568,-78 568,-84 568,-84 568,-128 568,-128 568,-134 562,-140 556,-140\"/>\n<text text-anchor=\"start\" x=\"469.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;1 ≤ 0.118</text>\n<text text-anchor=\"start\" x=\"463.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.007</text>\n<text text-anchor=\"start\" x=\"448\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 73211</text>\n<text text-anchor=\"start\" x=\"464\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.01</text>\n</g>\n<!-- 18&#45;&gt;42 -->\n<g id=\"edge12\" class=\"edge\">\n<title>18&#45;&gt;42</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M409.1687,-175.9465C422.0719,-166.4293 436.2145,-155.9978 449.5483,-146.163\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"451.7575,-148.8826 457.7276,-140.13 447.6023,-143.2492 451.7575,-148.8826\"/>\n</g>\n<!-- 20 -->\n<g id=\"node11\" class=\"node\">\n<title>20</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M366,-36C366,-36 336,-36 336,-36 330,-36 324,-30 324,-24 324,-24 324,-12 324,-12 324,-6 330,0 336,0 336,0 366,0 366,0 372,0 378,-6 378,-12 378,-12 378,-24 378,-24 378,-30 372,-36 366,-36\"/>\n<text text-anchor=\"middle\" x=\"351\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 19&#45;&gt;20 -->\n<g id=\"edge10\" class=\"edge\">\n<title>19&#45;&gt;20</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M357.7471,-71.9769C356.6878,-63.5023 355.5633,-54.5065 354.5423,-46.3388\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"357.9874,-45.6806 353.274,-36.192 351.0415,-46.5489 357.9874,-45.6806\"/>\n</g>\n<!-- 27 -->\n<g id=\"node12\" class=\"node\">\n<title>27</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M438,-36C438,-36 408,-36 408,-36 402,-36 396,-30 396,-24 396,-24 396,-12 396,-12 396,-6 402,0 408,0 408,0 438,0 438,0 444,0 450,-6 450,-12 450,-12 450,-24 450,-24 450,-30 444,-36 438,-36\"/>\n<text text-anchor=\"middle\" x=\"423\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 19&#45;&gt;27 -->\n<g id=\"edge11\" class=\"edge\">\n<title>19&#45;&gt;27</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M385.5842,-71.9769C391.9206,-62.8358 398.6775,-53.0883 404.6809,-44.4276\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"407.5692,-46.4045 410.3897,-36.192 401.8162,-42.4166 407.5692,-46.4045\"/>\n</g>\n<!-- 43 -->\n<g id=\"node14\" class=\"node\">\n<title>43</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M516,-36C516,-36 486,-36 486,-36 480,-36 474,-30 474,-24 474,-24 474,-12 474,-12 474,-6 480,0 486,0 486,0 516,0 516,0 522,0 528,-6 528,-12 528,-12 528,-24 528,-24 528,-30 522,-36 516,-36\"/>\n<text text-anchor=\"middle\" x=\"501\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 42&#45;&gt;43 -->\n<g id=\"edge13\" class=\"edge\">\n<title>42&#45;&gt;43</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M502.8401,-71.9769C502.5512,-63.5023 502.2445,-54.5065 501.9661,-46.3388\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"505.459,-46.0669 501.6202,-36.192 498.463,-46.3054 505.459,-46.0669\"/>\n</g>\n<!-- 54 -->\n<g id=\"node15\" class=\"node\">\n<title>54</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M588,-36C588,-36 558,-36 558,-36 552,-36 546,-30 546,-24 546,-24 546,-12 546,-12 546,-6 552,0 558,0 558,0 588,0 588,0 594,0 600,-6 600,-12 600,-12 600,-24 600,-24 600,-30 594,-36 588,-36\"/>\n<text text-anchor=\"middle\" x=\"573\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 42&#45;&gt;54 -->\n<g id=\"edge14\" class=\"edge\">\n<title>42&#45;&gt;54</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M530.6772,-71.9769C537.9193,-62.7406 545.6469,-52.8851 552.4903,-44.1573\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"555.3198,-46.221 558.7359,-36.192 549.8112,-41.9017 555.3198,-46.221\"/>\n</g>\n</g>\n</svg>\n"
},
"metadata": {},
"execution_count": 15
}
],
"source": [
"out_file = results_path / 'reg_tree_t2.dot'\n",
"dot_data = export_graphviz(reg_tree_t2,\n",
" out_file=out_file.as_posix(),\n",
" feature_names=X2.columns,\n",
" max_depth=2,\n",
" filled=True,\n",
" rounded=True,\n",
" special_characters=True)\n",
"if out_file is not None:\n",
" dot_data = Path(out_file).read_text()\n",
"\n",
"graphviz.Source(dot_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FuFWhyOjZDvf"
},
"source": [
"### Compare with Linear Regression"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ocI5jngbZDvg"
},
"source": [
"The OLS summary below and a visualization of the first two levels of the decision tree above reveal the striking differences between the models. The OLS model provides three parameters for the intercepts and the two features in line with the linear assumption.\n",
"\n",
"In contrast, the regression tree chart above displays for each node of the first two levels the feature and threshold used to split the data (note that features can be used repeatedly), as well as the current value of the mean-squared error (MSE), the number of samples, and predicted value based on these training samples."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arHLQ4L6ZDvg"
},
"source": [
"The tree chart also highlights the uneven distribution of samples across the nodes as the numbers vary between 31,000 and 65,000 samples after only two splits."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nbZtCbWPZDvg"
},
"source": [
"#### statsmodels OLS"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:47.020618Z",
"start_time": "2021-04-16T00:32:46.994715Z"
},
"scrolled": true,
"id": "T17M8HKhZDvh"
},
"outputs": [],
"source": [
"ols_model = sm.OLS(endog=y2, exog=sm.add_constant(X2))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:50.247927Z",
"start_time": "2021-04-16T00:32:47.021772Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zQu9lEWqZDvh",
"outputId": "1d177d7f-ba36-49b2-e254-3ebaec5e624c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"117 µs ± 4.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"ols_model.fit()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:50.306911Z",
"start_time": "2021-04-16T00:32:50.249060Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c7DSZDEQZDvj",
"outputId": "72770407-1dd1-43a4-95fe-62dfa4039cd1"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" OLS Regression Results \n",
"==============================================================================\n",
"Dep. Variable: y R-squared: 0.001\n",
"Model: OLS Adj. R-squared: 0.001\n",
"Method: Least Squares F-statistic: 31.83\n",
"Date: Fri, 14 Oct 2022 Prob (F-statistic): 1.53e-14\n",
"Time: 09:09:19 Log-Likelihood: 75823.\n",
"No. Observations: 77176 AIC: -1.516e+05\n",
"Df Residuals: 77173 BIC: -1.516e+05\n",
"Df Model: 2 \n",
"Covariance Type: nonrobust \n",
"==============================================================================\n",
" coef std err t P>|t| [0.025 0.975]\n",
"------------------------------------------------------------------------------\n",
"const 0.0100 0.000 30.232 0.000 0.009 0.011\n",
"t-1 0.0227 0.004 6.322 0.000 0.016 0.030\n",
"t-2 -0.0179 0.004 -5.003 0.000 -0.025 -0.011\n",
"==============================================================================\n",
"Omnibus: 3351.598 Durbin-Watson: 1.995\n",
"Prob(Omnibus): 0.000 Jarque-Bera (JB): 10867.041\n",
"Skew: 0.102 Prob(JB): 0.00\n",
"Kurtosis: 4.827 Cond. No. 11.1\n",
"==============================================================================\n",
"\n",
"Notes:\n",
"[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n"
]
}
],
"source": [
"result = ols_model.fit()\n",
"print(result.summary())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "965CzM_XZDvl"
},
"source": [
"#### sklearn Linear Regression"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:50.311548Z",
"start_time": "2021-04-16T00:32:50.309603Z"
},
"id": "WFu4cNK0ZDvm"
},
"outputs": [],
"source": [
"lin_reg = LinearRegression()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:52.560013Z",
"start_time": "2021-04-16T00:32:50.312750Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZZ2e92JqZDvm",
"outputId": "a617eef2-f23a-438b-e264-b3b4de9841dc"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"5.1 ms ± 182 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"lin_reg.fit(X=X2,y=y2)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:52.647179Z",
"start_time": "2021-04-16T00:32:52.560943Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AAEjWxLvZDvp",
"outputId": "6bdf344c-b85b-4ad2-96c7-646149b01caa"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LinearRegression()"
]
},
"metadata": {},
"execution_count": 21
}
],
"source": [
"lin_reg.fit(X=X2,y=y2)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:52.651998Z",
"start_time": "2021-04-16T00:32:52.648519Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NXqXzDhRZDvq",
"outputId": "fce47608-3bf4-481e-9212-fbf0bbbe78f0"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.009971254720772652"
]
},
"metadata": {},
"execution_count": 22
}
],
"source": [
"lin_reg.intercept_"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:52.681349Z",
"start_time": "2021-04-16T00:32:52.653554Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "XUgMfvtqZDvr",
"outputId": "46380bc6-3109-4fc6-f475-10b71f12d8ca"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([ 0.02269243, -0.01794206])"
]
},
"metadata": {},
"execution_count": 23
}
],
"source": [
"lin_reg.coef_"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KsEOXjFsZDvr"
},
"source": [
"### Linear Regression vs Regression Tree Decision Surfaces"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D8fs8ioyZDvr"
},
"source": [
"To further illustrate the different assumptions about the functional form of the relationships between the input variables and the output, we can visualize current return predictions as a function of the feature space, that is, as a function of the range of values for the lagged returns. The following figure shows the current period return as a function of returns one and two periods ago for linear regression and the regression tree:\n",
"\n",
"The linear-regression model result on the right side underlines the linearity of the relationship between lagged and current returns, whereas the regression tree chart on the left illustrates the non-linear relationship encoded in the recursive partitioning of the feature space."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:52.741131Z",
"start_time": "2021-04-16T00:32:52.716930Z"
},
"id": "XlI91MHSZDvs"
},
"outputs": [],
"source": [
"t1, t2 = np.meshgrid(np.linspace(X2['t-1'].quantile(.01), X2['t-1'].quantile(.99), 100),\n",
" np.linspace(X2['t-2'].quantile(.01), X2['t-2'].quantile(.99), 100))\n",
"X_data = np.c_[t1.ravel(), t2.ravel()]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:53.406640Z",
"start_time": "2021-04-16T00:32:52.742829Z"
},
"scrolled": false,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 373
},
"id": "x4zBIj_jZDvt",
"outputId": "b7a4e77b-9f87-4c9b-90ce-07ad1a111662"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x360 with 4 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"fig, axes = plt.subplots(ncols=2, figsize=(12,5))\n",
"\n",
"# Linear Regression\n",
"ret1 = lin_reg.predict(X_data).reshape(t1.shape)\n",
"surface1 = axes[0].contourf(t1, t2, ret1, cmap='Blues')\n",
"plt.colorbar(mappable=surface1, ax=axes[0])\n",
"\n",
"# Regression Tree\n",
"ret2 = reg_tree_t2.predict(X_data).reshape(t1.shape)\n",
"surface2 = axes[1].contourf(t1, t2, ret2, cmap='Blues')\n",
"plt.colorbar(mappable=surface2, ax=axes[1])\n",
"\n",
"# Format plots\n",
"titles = ['Linear Regression', 'Regression Tree']\n",
"for i, ax in enumerate(axes):\n",
" ax.set_xlabel('t-1')\n",
" ax.set_ylabel('t-2')\n",
" ax.set_title(titles[i])\n",
"\n",
"fig.suptitle('Decision Surfaces', fontsize=14)\n",
"sns.despine()\n",
"fig.tight_layout()\n",
"fig.subplots_adjust(top=.9);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vs0fZNmiZDvu"
},
"source": [
"## Simple Classification Tree with Time Series Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zS0HJw_dZDvu"
},
"source": [
"A classification tree works just like the regression version, except that categorical nature of the outcome requires a different approach to making predictions and measuring the loss. While a regression tree predicts the response for an observation assigned to a leaf node using the mean outcome of the associated training samples, a classification tree instead uses the mode, that is, the most common class among the training samples in the relevant region. A classification tree can also generate probabilistic predictions based on relative class frequencies."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yM6HKy0eZDvu"
},
"source": [
"### Loss Functions"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O9elgoq4ZDvv"
},
"source": [
"When growing a classification tree, we also use recursive binary splitting but, instead of evaluating the quality of a decision rule using the reduction of the mean-squared error, we can use the classification error rate, which is simply the fraction of the training samples in a given (leave) node that do not belong to the most common class."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Al5e92sZDvv"
},
"source": [
"However, the alternative measures, Gini Index or Cross-Entropy, are preferred because they are more sensitive to node purity than the classification error rate. Node purity refers to the extent of the preponderance of a single class in a node. A node that only contains samples with outcomes belonging to a single class is pure and imply successful classification for this particular region of the feature space. "
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:53.411641Z",
"start_time": "2021-04-16T00:32:53.409446Z"
},
"id": "cmxjmgodZDvw"
},
"outputs": [],
"source": [
"def entropy(f):\n",
" return (-f*np.log2(f) - (1-f)*np.log2(1-f))/2"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:53.426940Z",
"start_time": "2021-04-16T00:32:53.414041Z"
},
"id": "s-NY0IORZDvw"
},
"outputs": [],
"source": [
"def gini(f):\n",
" return 2*f*(1-f)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:53.447019Z",
"start_time": "2021-04-16T00:32:53.443818Z"
},
"id": "CeNbTm4AZDvy"
},
"outputs": [],
"source": [
"def misclassification_rate(f):\n",
" return np.where(f<=.5, f, 1-f)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qRgl3LVPZDvz"
},
"source": [
"Both the Gini Impurity and the Cross-Entropy measure take on smaller values when the class proportions approach zero or one, that is, when the child nodes become pure as a result of the split and are highest when the class proportions are even or 0.5 in the binary case. \n",
"\n",
"The chart below visualizes the values assumed by these two measures and the misclassification error rates across the [0, 1] interval of proportions."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:32:53.721436Z",
"start_time": "2021-04-16T00:32:53.452115Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
"height": 297
},
"id": "R2XpaOEsZDv0",
"outputId": "d2d7fa55-1107-48ef-d78d-9124a352cffe"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"x = np.linspace(0, 1, 10000)\n",
"(pd.DataFrame({'Gini': gini(x), \n",
" 'Entropy': entropy(x),\n",
" 'Misclassification Rate': misclassification_rate(x)}, index=x)\n",
" .plot(title='Classification Loss Functions', lw=2, style=['-', '--', ':']))\n",
"sns.despine()\n",
"plt.tight_layout();"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a2FDpy9iZDv1"
},
"source": [
"#### Compare computation time"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SGT65YdOZDv3"
},
"source": [
"Gini is often preferred over entropy because it computes faster:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:08.588281Z",
"start_time": "2021-04-16T00:32:53.723873Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wMOMsrN6ZDv4",
"outputId": "d925c5ed-8928-426c-c138-5888dae2638f"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"23.8 µs ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"misclassification_rate(x)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:18.035051Z",
"start_time": "2021-04-16T00:33:08.589304Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "na3e9J9QZDv5",
"outputId": "a45a3594-eb36-4cab-e825-d8ad918b43c6"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"12.8 µs ± 312 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"gini(x)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:19.423307Z",
"start_time": "2021-04-16T00:33:18.036362Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "h2WKvDL9ZDv5",
"outputId": "3fd7b9e8-0310-42a5-94ff-4354102a88ef"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"366 µs ± 7.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"entropy(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xJL7p_gAZDv6"
},
"source": [
"### Configure Tree"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:19.427483Z",
"start_time": "2021-04-16T00:33:19.424763Z"
},
"id": "YQzjopLuZDv6"
},
"outputs": [],
"source": [
"clf_tree_t2 = DecisionTreeClassifier(criterion='gini',\n",
" splitter='best',\n",
" max_depth=5,\n",
" min_samples_split=1000,\n",
" min_samples_leaf=1,\n",
" min_weight_fraction_leaf=0.0,\n",
" max_features=None,\n",
" random_state=42,\n",
" max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0,\n",
" class_weight=None)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gZ7rEHS9ZDv7"
},
"source": [
"### Train Tree"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:19.460489Z",
"start_time": "2021-04-16T00:33:19.428704Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6cSy3mQUZDv7",
"outputId": "ca139a97-2f59-47a6-9b90-8e6b282ff38c"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1 43399\n",
"0 33777\n",
"Name: y, dtype: int64"
]
},
"metadata": {},
"execution_count": 34
}
],
"source": [
"y_binary = (y2>0).astype(int)\n",
"y_binary.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:26.269743Z",
"start_time": "2021-04-16T00:33:19.461996Z"
},
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yqUk1DurZDv7",
"outputId": "94db9f51-f4db-4f49-9689-d00dc9ec5720"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"183 ms ± 48.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"clf_tree_t2.fit(X=X2, y=y_binary)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:26.363009Z",
"start_time": "2021-04-16T00:33:26.270887Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zqyVpMOGZDv7",
"outputId": "f3258149-a562-4b23-b604-85c112789888"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DecisionTreeClassifier(max_depth=5, min_samples_split=1000, random_state=42)"
]
},
"metadata": {},
"execution_count": 36
}
],
"source": [
"clf_tree_t2.fit(X=X2, y=y_binary)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "52pbXiC7ZDv7"
},
"source": [
"### Visualize Tree"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:26.404932Z",
"start_time": "2021-04-16T00:33:26.363983Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
"height": 555
},
"id": "RiC-lAbfZDv7",
"outputId": "887faee7-ef26-4b4a-da26-c72dc7617ed8"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<graphviz.files.Source at 0x7f2b161c3850>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: Tree Pages: 1 -->\n<svg width=\"658pt\" height=\"401pt\"\n viewBox=\"0.00 0.00 657.50 401.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 397)\">\n<title>Tree</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-397 653.5,-397 653.5,4 -4,4\"/>\n<!-- 0 -->\n<g id=\"node1\" class=\"node\">\n<title>0</title>\n<path fill=\"#d3e9f9\" stroke=\"#000000\" d=\"M382.5,-393C382.5,-393 241.5,-393 241.5,-393 235.5,-393 229.5,-387 229.5,-381 229.5,-381 229.5,-322 229.5,-322 229.5,-316 235.5,-310 241.5,-310 241.5,-310 382.5,-310 382.5,-310 388.5,-310 394.5,-316 394.5,-322 394.5,-322 394.5,-381 394.5,-381 394.5,-387 388.5,-393 382.5,-393\"/>\n<text text-anchor=\"start\" x=\"275.5\" y=\"-377.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;1 ≤ &#45;0.146</text>\n<text text-anchor=\"start\" x=\"274\" y=\"-362.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.492</text>\n<text text-anchor=\"start\" x=\"256\" y=\"-347.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 77176</text>\n<text text-anchor=\"start\" x=\"237.5\" y=\"-332.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [33777, 43399]</text>\n<text text-anchor=\"start\" x=\"277.5\" y=\"-317.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 1 -->\n<g id=\"node2\" class=\"node\">\n<title>1</title>\n<path fill=\"#fef9f5\" stroke=\"#000000\" d=\"M287.5,-274C287.5,-274 162.5,-274 162.5,-274 156.5,-274 150.5,-268 150.5,-262 150.5,-262 150.5,-203 150.5,-203 150.5,-197 156.5,-191 162.5,-191 162.5,-191 287.5,-191 287.5,-191 293.5,-191 299.5,-197 299.5,-203 299.5,-203 299.5,-262 299.5,-262 299.5,-268 293.5,-274 287.5,-274\"/>\n<text text-anchor=\"start\" x=\"188.5\" y=\"-258.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;2 ≤ &#45;0.179</text>\n<text text-anchor=\"start\" x=\"195.5\" y=\"-243.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.5</text>\n<text text-anchor=\"start\" x=\"173\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 3459</text>\n<text text-anchor=\"start\" x=\"158.5\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [1775, 1684]</text>\n<text text-anchor=\"start\" x=\"181\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Down</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g id=\"edge1\" class=\"edge\">\n<title>0&#45;&gt;1</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M281.5716,-309.8796C275.1189,-301.0534 268.243,-291.6485 261.5887,-282.5466\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"264.2856,-280.3051 255.5583,-274.2981 258.6347,-284.4364 264.2856,-280.3051\"/>\n<text text-anchor=\"middle\" x=\"251.7208\" y=\"-295.3018\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">True</text>\n</g>\n<!-- 10 -->\n<g id=\"node7\" class=\"node\">\n<title>10</title>\n<path fill=\"#d1e8f9\" stroke=\"#000000\" d=\"M470.5,-274C470.5,-274 329.5,-274 329.5,-274 323.5,-274 317.5,-268 317.5,-262 317.5,-262 317.5,-203 317.5,-203 317.5,-197 323.5,-191 329.5,-191 329.5,-191 470.5,-191 470.5,-191 476.5,-191 482.5,-197 482.5,-203 482.5,-203 482.5,-262 482.5,-262 482.5,-268 476.5,-274 470.5,-274\"/>\n<text text-anchor=\"start\" x=\"365.5\" y=\"-258.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;2 ≤ 0.181</text>\n<text text-anchor=\"start\" x=\"362\" y=\"-243.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.491</text>\n<text text-anchor=\"start\" x=\"344\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 73717</text>\n<text text-anchor=\"start\" x=\"325.5\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [32002, 41715]</text>\n<text text-anchor=\"start\" x=\"365.5\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 0&#45;&gt;10 -->\n<g id=\"edge6\" class=\"edge\">\n<title>0&#45;&gt;10</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M342.7781,-309.8796C349.3051,-301.0534 356.2599,-291.6485 362.9908,-282.5466\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"365.9588,-284.4195 369.0905,-274.2981 360.3305,-280.2574 365.9588,-284.4195\"/>\n<text text-anchor=\"middle\" x=\"372.7932\" y=\"-295.3224\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">False</text>\n</g>\n<!-- 2 -->\n<g id=\"node3\" class=\"node\">\n<title>2</title>\n<path fill=\"#d6ebfa\" stroke=\"#000000\" d=\"M120,-147.5C120,-147.5 12,-147.5 12,-147.5 6,-147.5 0,-141.5 0,-135.5 0,-135.5 0,-91.5 0,-91.5 0,-85.5 6,-79.5 12,-79.5 12,-79.5 120,-79.5 120,-79.5 126,-79.5 132,-85.5 132,-91.5 132,-91.5 132,-135.5 132,-135.5 132,-141.5 126,-147.5 120,-147.5\"/>\n<text text-anchor=\"start\" x=\"28\" y=\"-132.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.493</text>\n<text text-anchor=\"start\" x=\"18\" y=\"-117.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 471</text>\n<text text-anchor=\"start\" x=\"8\" y=\"-102.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [208, 263]</text>\n<text text-anchor=\"start\" x=\"31.5\" y=\"-87.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g id=\"edge2\" class=\"edge\">\n<title>1&#45;&gt;2</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M169.3895,-190.8796C153.2987,-178.8368 135.7685,-165.7167 119.7909,-153.7586\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"121.6962,-150.8129 111.593,-147.623 117.5018,-156.4171 121.6962,-150.8129\"/>\n</g>\n<!-- 3 -->\n<g id=\"node4\" class=\"node\">\n<title>3</title>\n<path fill=\"#fdf3ed\" stroke=\"#000000\" d=\"M287.5,-155C287.5,-155 162.5,-155 162.5,-155 156.5,-155 150.5,-149 150.5,-143 150.5,-143 150.5,-84 150.5,-84 150.5,-78 156.5,-72 162.5,-72 162.5,-72 287.5,-72 287.5,-72 293.5,-72 299.5,-78 299.5,-84 299.5,-84 299.5,-143 299.5,-143 299.5,-149 293.5,-155 287.5,-155\"/>\n<text text-anchor=\"start\" x=\"188.5\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;1 ≤ &#45;0.262</text>\n<text text-anchor=\"start\" x=\"187\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.499</text>\n<text text-anchor=\"start\" x=\"173\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 2988</text>\n<text text-anchor=\"start\" x=\"158.5\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [1567, 1421]</text>\n<text text-anchor=\"start\" x=\"181\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Down</text>\n</g>\n<!-- 1&#45;&gt;3 -->\n<g id=\"edge3\" class=\"edge\">\n<title>1&#45;&gt;3</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M225,-190.8796C225,-182.6838 225,-173.9891 225,-165.5013\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"228.5001,-165.298 225,-155.2981 221.5001,-165.2981 228.5001,-165.298\"/>\n</g>\n<!-- 4 -->\n<g id=\"node5\" class=\"node\">\n<title>4</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M212,-36C212,-36 182,-36 182,-36 176,-36 170,-30 170,-24 170,-24 170,-12 170,-12 170,-6 176,0 182,0 182,0 212,0 212,0 218,0 224,-6 224,-12 224,-12 224,-24 224,-24 224,-30 218,-36 212,-36\"/>\n<text text-anchor=\"middle\" x=\"197\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g id=\"edge4\" class=\"edge\">\n<title>3&#45;&gt;4</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M212.8002,-71.8901C210.2305,-63.1253 207.5784,-54.0798 205.1981,-45.9615\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"208.5274,-44.8764 202.3552,-36.2651 201.8102,-46.8459 208.5274,-44.8764\"/>\n</g>\n<!-- 5 -->\n<g id=\"node6\" class=\"node\">\n<title>5</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M284,-36C284,-36 254,-36 254,-36 248,-36 242,-30 242,-24 242,-24 242,-12 242,-12 242,-6 248,0 254,0 254,0 284,0 284,0 290,0 296,-6 296,-12 296,-12 296,-24 296,-24 296,-30 290,-36 284,-36\"/>\n<text text-anchor=\"middle\" x=\"269\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 3&#45;&gt;5 -->\n<g id=\"edge5\" class=\"edge\">\n<title>3&#45;&gt;5</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M244.171,-71.8901C248.299,-62.9305 252.5621,-53.6777 256.3659,-45.4217\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"259.5789,-46.8121 260.5847,-36.2651 253.2212,-43.8829 259.5789,-46.8121\"/>\n</g>\n<!-- 11 -->\n<g id=\"node8\" class=\"node\">\n<title>11</title>\n<path fill=\"#d0e8f9\" stroke=\"#000000\" d=\"M470.5,-155C470.5,-155 329.5,-155 329.5,-155 323.5,-155 317.5,-149 317.5,-143 317.5,-143 317.5,-84 317.5,-84 317.5,-78 323.5,-72 329.5,-72 329.5,-72 470.5,-72 470.5,-72 476.5,-72 482.5,-78 482.5,-84 482.5,-84 482.5,-143 482.5,-143 482.5,-149 476.5,-155 470.5,-155\"/>\n<text text-anchor=\"start\" x=\"363.5\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;2 ≤ &#45;0.148</text>\n<text text-anchor=\"start\" x=\"362\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.491</text>\n<text text-anchor=\"start\" x=\"344\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 71340</text>\n<text text-anchor=\"start\" x=\"325.5\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [30833, 40507]</text>\n<text text-anchor=\"start\" x=\"365.5\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 10&#45;&gt;11 -->\n<g id=\"edge7\" class=\"edge\">\n<title>10&#45;&gt;11</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M400,-190.8796C400,-182.6838 400,-173.9891 400,-165.5013\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"403.5001,-165.298 400,-155.2981 396.5001,-165.2981 403.5001,-165.298\"/>\n</g>\n<!-- 24 -->\n<g id=\"node11\" class=\"node\">\n<title>24</title>\n<path fill=\"#f9fcfe\" stroke=\"#000000\" d=\"M637.5,-155C637.5,-155 512.5,-155 512.5,-155 506.5,-155 500.5,-149 500.5,-143 500.5,-143 500.5,-84 500.5,-84 500.5,-78 506.5,-72 512.5,-72 512.5,-72 637.5,-72 637.5,-72 643.5,-72 649.5,-78 649.5,-84 649.5,-84 649.5,-143 649.5,-143 649.5,-149 643.5,-155 637.5,-155\"/>\n<text text-anchor=\"start\" x=\"540.5\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">t&#45;1 ≤ 0.144</text>\n<text text-anchor=\"start\" x=\"545.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.5</text>\n<text text-anchor=\"start\" x=\"523\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 2377</text>\n<text text-anchor=\"start\" x=\"508.5\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [1169, 1208]</text>\n<text text-anchor=\"start\" x=\"540.5\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 10&#45;&gt;24 -->\n<g id=\"edge10\" class=\"edge\">\n<title>10&#45;&gt;24</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M461.2065,-190.8796C475.4606,-181.1868 490.741,-170.7961 505.3307,-160.8752\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"507.5269,-163.6143 513.8281,-155.0969 503.5907,-157.8258 507.5269,-163.6143\"/>\n</g>\n<!-- 12 -->\n<g id=\"node9\" class=\"node\">\n<title>12</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M387,-36C387,-36 357,-36 357,-36 351,-36 345,-30 345,-24 345,-24 345,-12 345,-12 345,-6 351,0 357,0 357,0 387,0 387,0 393,0 399,-6 399,-12 399,-12 399,-24 399,-24 399,-30 393,-36 387,-36\"/>\n<text text-anchor=\"middle\" x=\"372\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 11&#45;&gt;12 -->\n<g id=\"edge8\" class=\"edge\">\n<title>11&#45;&gt;12</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M387.8002,-71.8901C385.2305,-63.1253 382.5784,-54.0798 380.1981,-45.9615\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"383.5274,-44.8764 377.3552,-36.2651 376.8102,-46.8459 383.5274,-44.8764\"/>\n</g>\n<!-- 17 -->\n<g id=\"node10\" class=\"node\">\n<title>17</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M459,-36C459,-36 429,-36 429,-36 423,-36 417,-30 417,-24 417,-24 417,-12 417,-12 417,-6 423,0 429,0 429,0 459,0 459,0 465,0 471,-6 471,-12 471,-12 471,-24 471,-24 471,-30 465,-36 459,-36\"/>\n<text text-anchor=\"middle\" x=\"444\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 11&#45;&gt;17 -->\n<g id=\"edge9\" class=\"edge\">\n<title>11&#45;&gt;17</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M419.171,-71.8901C423.299,-62.9305 427.5621,-53.6777 431.3659,-45.4217\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"434.5789,-46.8121 435.5847,-36.2651 428.2212,-43.8829 434.5789,-46.8121\"/>\n</g>\n<!-- 25 -->\n<g id=\"node12\" class=\"node\">\n<title>25</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M546,-36C546,-36 516,-36 516,-36 510,-36 504,-30 504,-24 504,-24 504,-12 504,-12 504,-6 510,0 516,0 516,0 546,0 546,0 552,0 558,-6 558,-12 558,-12 558,-24 558,-24 558,-30 552,-36 546,-36\"/>\n<text text-anchor=\"middle\" x=\"531\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 24&#45;&gt;25 -->\n<g id=\"edge11\" class=\"edge\">\n<title>24&#45;&gt;25</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M555.829,-71.8901C551.701,-62.9305 547.4379,-53.6777 543.6341,-45.4217\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"546.7788,-43.8829 539.4153,-36.2651 540.4211,-46.8121 546.7788,-43.8829\"/>\n</g>\n<!-- 30 -->\n<g id=\"node13\" class=\"node\">\n<title>30</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M618,-36C618,-36 588,-36 588,-36 582,-36 576,-30 576,-24 576,-24 576,-12 576,-12 576,-6 582,0 588,0 588,0 618,0 618,0 624,0 630,-6 630,-12 630,-12 630,-24 630,-24 630,-30 624,-36 618,-36\"/>\n<text text-anchor=\"middle\" x=\"603\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 24&#45;&gt;30 -->\n<g id=\"edge12\" class=\"edge\">\n<title>24&#45;&gt;30</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M587.1998,-71.8901C589.7695,-63.1253 592.4216,-54.0798 594.8019,-45.9615\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"598.1898,-46.8459 597.6448,-36.2651 591.4726,-44.8764 598.1898,-46.8459\"/>\n</g>\n</g>\n</svg>\n"
},
"metadata": {},
"execution_count": 37
}
],
"source": [
"out_file = results_path / 'clf_tree_t2.dot'\n",
"dot_data = export_graphviz(clf_tree_t2,\n",
" out_file=out_file.as_posix(),\n",
" feature_names=X2.columns,\n",
" class_names=['Down', 'Up'],\n",
" max_depth=2,\n",
" filled=True,\n",
" rounded=True,\n",
" special_characters=True)\n",
"if out_file is not None:\n",
" dot_data = Path(out_file).read_text()\n",
"\n",
"graphviz.Source(dot_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M5M-76A7ZDv9"
},
"source": [
"### Compare with Logistic Regression"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uy1vUnnlZDv9"
},
"source": [
"#### Statsmodels"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:26.431166Z",
"start_time": "2021-04-16T00:33:26.406282Z"
},
"id": "DBOt-zXeZDv9"
},
"outputs": [],
"source": [
"log_reg_sm = sm.Logit(endog=y_binary, exog=sm.add_constant(X2))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:28.117793Z",
"start_time": "2021-04-16T00:33:26.432422Z"
},
"scrolled": false,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "apv6nmAtZDv-",
"outputId": "e0cb226f-3a6c-47c7-c01f-285db9c004bf"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"104 ms ± 39.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"log_reg_sm.fit(disp=False)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:28.199185Z",
"start_time": "2021-04-16T00:33:28.119038Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-egRlB08ZDv_",
"outputId": "dd0e0444-94b3-4ab4-85a6-8dc6090a22d3"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Optimization terminated successfully.\n",
" Current function value: 0.685278\n",
" Iterations 4\n"
]
}
],
"source": [
"log_result = log_reg_sm.fit()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:28.452764Z",
"start_time": "2021-04-16T00:33:28.200353Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nVEXhbs8ZDv_",
"outputId": "f6df0219-c4d8-4924-95ee-d512e923d4a0"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Logit Regression Results \n",
"==============================================================================\n",
"Dep. Variable: y No. Observations: 77176\n",
"Model: Logit Df Residuals: 77173\n",
"Method: MLE Df Model: 2\n",
"Date: Fri, 14 Oct 2022 Pseudo R-squ.: 0.0001127\n",
"Time: 09:10:41 Log-Likelihood: -52887.\n",
"converged: True LL-Null: -52893.\n",
"Covariance Type: nonrobust LLR p-value: 0.002584\n",
"==============================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"------------------------------------------------------------------------------\n",
"const 0.2485 0.007 33.873 0.000 0.234 0.263\n",
"t-1 0.2712 0.080 3.394 0.001 0.115 0.428\n",
"t-2 -0.0560 0.080 -0.701 0.483 -0.212 0.100\n",
"==============================================================================\n"
]
}
],
"source": [
"print(log_result.summary())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dcv0InzeZDwA"
},
"source": [
"#### sklearn"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:28.458948Z",
"start_time": "2021-04-16T00:33:28.455385Z"
},
"id": "2axW9sP4ZDwA"
},
"outputs": [],
"source": [
"log_reg_sk = LogisticRegression()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:33.909359Z",
"start_time": "2021-04-16T00:33:28.461195Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DUUbh3P8ZDwA",
"outputId": "f33cc68d-3b48-49f3-84e2-32d11aa6376d"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"190 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"log_reg_sk.fit(X=X2, y=y_binary)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:34.057197Z",
"start_time": "2021-04-16T00:33:33.910465Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xoeUzJORZDwA",
"outputId": "9a19ba3a-2dc1-4eee-933a-97e9627fea49"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"LogisticRegression()"
]
},
"metadata": {},
"execution_count": 44
}
],
"source": [
"log_reg_sk.fit(X=X2, y=y_binary)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:34.066180Z",
"start_time": "2021-04-16T00:33:34.058359Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NgwRYAPHZDwB",
"outputId": "220f0ef1-38ca-4d88-e571-3d620d4029a8"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 0.26951533, -0.05558622]])"
]
},
"metadata": {},
"execution_count": 45
}
],
"source": [
"log_reg_sk.coef_"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TtKbAUDcZDwB"
},
"source": [
"### Decision Surfaces: Classifier Tree vs. Logistic Regression "
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:34.454616Z",
"start_time": "2021-04-16T00:33:34.072296Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
"height": 373
},
"id": "AxGe_fpPZDwB",
"outputId": "a2e423e7-e53f-48ca-b81c-6266f3932e2a"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x360 with 4 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"fig, axes = plt.subplots(ncols=2, figsize=(12,5))\n",
"\n",
"# Linear Regression\n",
"ret1 = log_reg_sk.predict_proba(X_data)[:, 1].reshape(t1.shape)\n",
"surface1 = axes[0].contourf(t1, t2, ret1, cmap='Blues')\n",
"plt.colorbar(mappable=surface1, ax=axes[0])\n",
"\n",
"# Regression Tree\n",
"ret2 = clf_tree_t2.predict_proba(X_data)[:, 1].reshape(t1.shape)\n",
"surface2 = axes[1].contourf(t1, t2, ret2, cmap='Blues')\n",
"plt.colorbar(mappable=surface2, ax=axes[1])\n",
"\n",
"# Format plots\n",
"titles = ['Logistic Regression', 'Classification Tree']\n",
"for i, ax in enumerate(axes):\n",
" ax.set_xlabel('t-1')\n",
" ax.set_ylabel('t-2')\n",
" ax.set_title(titles[i])\n",
"\n",
"fig.suptitle('Decision Surfaces', fontsize=20)\n",
"sns.despine()\n",
"fig.tight_layout()\n",
"fig.subplots_adjust(top=.9);"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P6U7fhuJZDwC"
},
"source": [
"## Regression Tree with all Features"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rPZEVW_CZDwC"
},
"source": [
"We now train, visualize, and evaluate a regression tree with up to 5 consecutive splits using 80% of the samples for training to predict the remaining 20%.\n",
"\n",
"We are taking a shortcut here to simplify the illustration and use the built-in train_test_split, which does not protect against lookahead bias, as our custom iterator. The tree configuration implies up to $2^5=32$ leaf nodes that, on average in the balanced case, would contain over 4,300 of the training samples."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GBFo4fFbZDwC"
},
"source": [
"### Train-Test Split"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:34.491936Z",
"start_time": "2021-04-16T00:33:34.455838Z"
},
"id": "MId6OaARZDwD"
},
"outputs": [],
"source": [
"X = pd.get_dummies(data.drop('target', axis=1))\n",
"y = data.target"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:34.529344Z",
"start_time": "2021-04-16T00:33:34.493349Z"
},
"id": "aRfJc-R7ZDwD"
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zDYRzxvlZDwD"
},
"source": [
"### Configure Tree"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5Lk2Lkz_ZDwE"
},
"source": [
"The output after training the model displays all the DecisionTreeClassifier parameters that we will address in more detail in the next section when we discuss parameter-tuning. "
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:34.532666Z",
"start_time": "2021-04-16T00:33:34.530505Z"
},
"id": "Yj8yMMqSZDwE"
},
"outputs": [],
"source": [
"regression_tree = DecisionTreeRegressor(criterion='mse',\n",
" splitter='best',\n",
" max_depth=5,\n",
" min_samples_split=2,\n",
" min_samples_leaf=1,\n",
" min_weight_fraction_leaf=0.0,\n",
" max_features=None,\n",
" random_state=42,\n",
" max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lcWykdcXZDwF"
},
"source": [
"### Train Model"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:35.266789Z",
"start_time": "2021-04-16T00:33:34.534056Z"
},
"scrolled": false,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VXBR5dpcZDwF",
"outputId": "5f16a936-4cee-4a33-ce9a-2b0b7e437088"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DecisionTreeRegressor(criterion='mse', max_depth=5, random_state=42)"
]
},
"metadata": {},
"execution_count": 50
}
],
"source": [
"regression_tree.fit(X=X_train, y=y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8FM090j3ZDwG"
},
"source": [
"### Visualize Tree"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JCKjr5ueZDwG"
},
"source": [
"The result shows that the model uses a variety of different features and indicates the split rules for both continuous and categorical (dummy) variables. "
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:35.332234Z",
"start_time": "2021-04-16T00:33:35.267941Z"
},
"colab": {
"base_uri": "https://localhost:8080/",
"height": 634
},
"id": "XIlriA5rZDwG",
"outputId": "c2fe4ae5-0675-43df-ca2b-5b5a7e6ce9c7"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<graphviz.files.Source at 0x7f2b2541b990>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: Tree Pages: 1 -->\n<svg width=\"1171pt\" height=\"460pt\"\n viewBox=\"0.00 0.00 1171.00 460.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 456)\">\n<title>Tree</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-456 1167,-456 1167,4 -4,4\"/>\n<!-- 0 -->\n<g id=\"node1\" class=\"node\">\n<title>0</title>\n<path fill=\"#f6d3b9\" stroke=\"#000000\" d=\"M608,-452C608,-452 504,-452 504,-452 498,-452 492,-446 492,-440 492,-440 492,-396 492,-396 492,-390 498,-384 504,-384 504,-384 608,-384 608,-384 614,-384 620,-390 620,-396 620,-396 620,-440 620,-440 620,-446 614,-452 608,-452\"/>\n<text text-anchor=\"start\" x=\"518\" y=\"-436.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 3.5</text>\n<text text-anchor=\"start\" x=\"515.5\" y=\"-421.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.008</text>\n<text text-anchor=\"start\" x=\"500\" y=\"-406.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 62230</text>\n<text text-anchor=\"start\" x=\"516\" y=\"-391.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.01</text>\n</g>\n<!-- 1 -->\n<g id=\"node2\" class=\"node\">\n<title>1</title>\n<path fill=\"#f5d0b4\" stroke=\"#000000\" d=\"M494,-348C494,-348 356,-348 356,-348 350,-348 344,-342 344,-336 344,-336 344,-292 344,-292 344,-286 350,-280 356,-280 356,-280 494,-280 494,-280 500,-280 506,-286 506,-292 506,-292 506,-336 506,-336 506,-342 500,-348 494,-348\"/>\n<text text-anchor=\"start\" x=\"352\" y=\"-332.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">momentum_6 ≤ &#45;0.225</text>\n<text text-anchor=\"start\" x=\"384.5\" y=\"-317.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.008</text>\n<text text-anchor=\"start\" x=\"369\" y=\"-302.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 15600</text>\n<text text-anchor=\"start\" x=\"381\" y=\"-287.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.021</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g id=\"edge1\" class=\"edge\">\n<title>0&#45;&gt;1</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M513.1057,-383.9465C501.2307,-374.519 488.2259,-364.1946 475.9407,-354.4415\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"477.9989,-351.6066 467.9907,-348.13 473.6464,-357.089 477.9989,-351.6066\"/>\n<text text-anchor=\"middle\" x=\"470.8442\" y=\"-369.2666\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">True</text>\n</g>\n<!-- 32 -->\n<g id=\"node17\" class=\"node\">\n<title>32</title>\n<path fill=\"#f6d4bb\" stroke=\"#000000\" d=\"M759,-348C759,-348 655,-348 655,-348 649,-348 643,-342 643,-336 643,-336 643,-292 643,-292 643,-286 649,-280 655,-280 655,-280 759,-280 759,-280 765,-280 771,-286 771,-292 771,-292 771,-336 771,-336 771,-342 765,-348 759,-348\"/>\n<text text-anchor=\"start\" x=\"672.5\" y=\"-332.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">atr ≤ 0.029</text>\n<text text-anchor=\"start\" x=\"666.5\" y=\"-317.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.008</text>\n<text text-anchor=\"start\" x=\"651\" y=\"-302.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 46630</text>\n<text text-anchor=\"start\" x=\"663\" y=\"-287.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.006</text>\n</g>\n<!-- 0&#45;&gt;32 -->\n<g id=\"edge16\" class=\"edge\">\n<title>0&#45;&gt;32</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M605.443,-383.9465C619.3917,-374.3395 634.6929,-363.8009 649.0903,-353.8848\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"651.1955,-356.6848 657.4459,-348.13 647.2249,-350.9198 651.1955,-356.6848\"/>\n<text text-anchor=\"middle\" x=\"652.9144\" y=\"-369.0159\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">False</text>\n</g>\n<!-- 2 -->\n<g id=\"node3\" class=\"node\">\n<title>2</title>\n<path fill=\"#f0b68c\" stroke=\"#000000\" d=\"M272,-244C272,-244 184,-244 184,-244 178,-244 172,-238 172,-232 172,-232 172,-188 172,-188 172,-182 178,-176 184,-176 184,-176 272,-176 272,-176 278,-176 284,-182 284,-188 284,-188 284,-232 284,-232 284,-238 278,-244 272,-244\"/>\n<text text-anchor=\"start\" x=\"190\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 2.5</text>\n<text text-anchor=\"start\" x=\"187.5\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.028</text>\n<text text-anchor=\"start\" x=\"180\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 259</text>\n<text text-anchor=\"start\" x=\"184\" y=\"-183.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.112</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g id=\"edge2\" class=\"edge\">\n<title>1&#45;&gt;2</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M360.4948,-279.9465C338.8937,-268.5428 314.8063,-255.8267 293.1558,-244.3969\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"294.5179,-241.1583 284.0405,-239.5848 291.2498,-247.3486 294.5179,-241.1583\"/>\n</g>\n<!-- 17 -->\n<g id=\"node10\" class=\"node\">\n<title>17</title>\n<path fill=\"#f5d0b5\" stroke=\"#000000\" d=\"M486.5,-244C486.5,-244 363.5,-244 363.5,-244 357.5,-244 351.5,-238 351.5,-232 351.5,-232 351.5,-188 351.5,-188 351.5,-182 357.5,-176 363.5,-176 363.5,-176 486.5,-176 486.5,-176 492.5,-176 498.5,-182 498.5,-188 498.5,-188 498.5,-232 498.5,-232 498.5,-238 492.5,-244 486.5,-244\"/>\n<text text-anchor=\"start\" x=\"359.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">return_12m ≤ &#45;0.045</text>\n<text text-anchor=\"start\" x=\"384.5\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.007</text>\n<text text-anchor=\"start\" x=\"369\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 15341</text>\n<text text-anchor=\"start\" x=\"385\" y=\"-183.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.02</text>\n</g>\n<!-- 1&#45;&gt;17 -->\n<g id=\"edge9\" class=\"edge\">\n<title>1&#45;&gt;17</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M425,-279.9465C425,-271.776 425,-262.9318 425,-254.3697\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"428.5001,-254.13 425,-244.13 421.5001,-254.13 428.5001,-254.13\"/>\n</g>\n<!-- 3 -->\n<g id=\"node4\" class=\"node\">\n<title>3</title>\n<path fill=\"#f5ccb0\" stroke=\"#000000\" d=\"M142,-140C142,-140 54,-140 54,-140 48,-140 42,-134 42,-128 42,-128 42,-84 42,-84 42,-78 48,-72 54,-72 54,-72 142,-72 142,-72 148,-72 154,-78 154,-84 154,-84 154,-128 154,-128 154,-134 148,-140 142,-140\"/>\n<text text-anchor=\"start\" x=\"63.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">atr ≤ 0.666</text>\n<text text-anchor=\"start\" x=\"57.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.022</text>\n<text text-anchor=\"start\" x=\"50\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 130</text>\n<text text-anchor=\"start\" x=\"54\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.032</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g id=\"edge3\" class=\"edge\">\n<title>2&#45;&gt;3</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M185.4331,-175.9465C173.6488,-166.519 160.7432,-156.1946 148.5519,-146.4415\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"150.6577,-143.6439 140.6625,-140.13 146.2848,-149.11 150.6577,-143.6439\"/>\n</g>\n<!-- 10 -->\n<g id=\"node7\" class=\"node\">\n<title>10</title>\n<path fill=\"#eb9f68\" stroke=\"#000000\" d=\"M272,-140C272,-140 184,-140 184,-140 178,-140 172,-134 172,-128 172,-128 172,-84 172,-84 172,-78 178,-72 184,-72 184,-72 272,-72 272,-72 278,-72 284,-78 284,-84 284,-84 284,-128 284,-128 284,-134 278,-140 272,-140\"/>\n<text text-anchor=\"start\" x=\"183.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2002.5</text>\n<text text-anchor=\"start\" x=\"191.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.02</text>\n<text text-anchor=\"start\" x=\"180\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 129</text>\n<text text-anchor=\"start\" x=\"184\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.193</text>\n</g>\n<!-- 2&#45;&gt;10 -->\n<g id=\"edge6\" class=\"edge\">\n<title>2&#45;&gt;10</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M228,-175.9465C228,-167.776 228,-158.9318 228,-150.3697\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"231.5001,-150.13 228,-140.13 224.5001,-150.13 231.5001,-150.13\"/>\n</g>\n<!-- 4 -->\n<g id=\"node5\" class=\"node\">\n<title>4</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M42,-36C42,-36 12,-36 12,-36 6,-36 0,-30 0,-24 0,-24 0,-12 0,-12 0,-6 6,0 12,0 12,0 42,0 42,0 48,0 54,-6 54,-12 54,-12 54,-24 54,-24 54,-30 48,-36 42,-36\"/>\n<text text-anchor=\"middle\" x=\"27\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g id=\"edge4\" class=\"edge\">\n<title>3&#45;&gt;4</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M70.5495,-71.9769C63.0975,-62.7406 55.1459,-52.8851 48.1042,-44.1573\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"50.6809,-41.777 41.6776,-36.192 45.2329,-46.1725 50.6809,-41.777\"/>\n</g>\n<!-- 7 -->\n<g id=\"node6\" class=\"node\">\n<title>7</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M114,-36C114,-36 84,-36 84,-36 78,-36 72,-30 72,-24 72,-24 72,-12 72,-12 72,-6 78,0 84,0 84,0 114,0 114,0 120,0 126,-6 126,-12 126,-12 126,-24 126,-24 126,-30 120,-36 114,-36\"/>\n<text text-anchor=\"middle\" x=\"99\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 3&#45;&gt;7 -->\n<g id=\"edge5\" class=\"edge\">\n<title>3&#45;&gt;7</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M98.3866,-71.9769C98.4829,-63.5023 98.5852,-54.5065 98.678,-46.3388\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"102.1793,-46.2311 98.7933,-36.192 95.1798,-46.1515 102.1793,-46.2311\"/>\n</g>\n<!-- 11 -->\n<g id=\"node8\" class=\"node\">\n<title>11</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M201,-36C201,-36 171,-36 171,-36 165,-36 159,-30 159,-24 159,-24 159,-12 159,-12 159,-6 165,0 171,0 171,0 201,0 201,0 207,0 213,-6 213,-12 213,-12 213,-24 213,-24 213,-30 207,-36 201,-36\"/>\n<text text-anchor=\"middle\" x=\"186\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 10&#45;&gt;11 -->\n<g id=\"edge7\" class=\"edge\">\n<title>10&#45;&gt;11</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M211.7617,-71.9769C207.5352,-63.1215 203.0371,-53.6969 199.0022,-45.2427\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"202.1486,-43.7092 194.6825,-36.192 195.8312,-46.7243 202.1486,-43.7092\"/>\n</g>\n<!-- 14 -->\n<g id=\"node9\" class=\"node\">\n<title>14</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M273,-36C273,-36 243,-36 243,-36 237,-36 231,-30 231,-24 231,-24 231,-12 231,-12 231,-6 237,0 243,0 243,0 273,0 273,0 279,0 285,-6 285,-12 285,-12 285,-24 285,-24 285,-30 279,-36 273,-36\"/>\n<text text-anchor=\"middle\" x=\"258\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 10&#45;&gt;14 -->\n<g id=\"edge8\" class=\"edge\">\n<title>10&#45;&gt;14</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M239.5988,-71.9769C242.5528,-63.3119 245.6925,-54.102 248.5263,-45.7894\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"251.8842,-46.7864 251.7982,-36.192 245.2586,-44.5277 251.8842,-46.7864\"/>\n</g>\n<!-- 18 -->\n<g id=\"node11\" class=\"node\">\n<title>18</title>\n<path fill=\"#f3c7a7\" stroke=\"#000000\" d=\"M402,-140C402,-140 314,-140 314,-140 308,-140 302,-134 302,-128 302,-128 302,-84 302,-84 302,-78 308,-72 314,-72 314,-72 402,-72 402,-72 408,-72 414,-78 414,-84 414,-84 414,-128 414,-128 414,-134 408,-140 402,-140\"/>\n<text text-anchor=\"start\" x=\"320\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 1.5</text>\n<text text-anchor=\"start\" x=\"321.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.03</text>\n<text text-anchor=\"start\" x=\"310\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 863</text>\n<text text-anchor=\"start\" x=\"314\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.052</text>\n</g>\n<!-- 17&#45;&gt;18 -->\n<g id=\"edge10\" class=\"edge\">\n<title>17&#45;&gt;18</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M403.0617,-175.9465C397.451,-167.2373 391.3471,-157.7626 385.4958,-148.6801\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"388.3457,-146.641 379.9876,-140.13 382.4611,-150.4321 388.3457,-146.641\"/>\n</g>\n<!-- 25 -->\n<g id=\"node14\" class=\"node\">\n<title>25</title>\n<path fill=\"#f5d0b6\" stroke=\"#000000\" d=\"M548,-140C548,-140 444,-140 444,-140 438,-140 432,-134 432,-128 432,-128 432,-84 432,-84 432,-78 438,-72 444,-72 444,-72 548,-72 548,-72 554,-72 560,-78 560,-84 560,-84 560,-128 560,-128 560,-134 554,-140 548,-140\"/>\n<text text-anchor=\"start\" x=\"451.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2009.5</text>\n<text text-anchor=\"start\" x=\"455.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.006</text>\n<text text-anchor=\"start\" x=\"440\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 14478</text>\n<text text-anchor=\"start\" x=\"452\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.018</text>\n</g>\n<!-- 17&#45;&gt;25 -->\n<g id=\"edge13\" class=\"edge\">\n<title>17&#45;&gt;25</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M448.2481,-175.9465C454.255,-167.1475 460.7955,-157.5672 467.0543,-148.3993\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"469.952,-150.3623 472.6997,-140.13 464.1707,-146.4155 469.952,-150.3623\"/>\n</g>\n<!-- 19 -->\n<g id=\"node12\" class=\"node\">\n<title>19</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M352,-36C352,-36 322,-36 322,-36 316,-36 310,-30 310,-24 310,-24 310,-12 310,-12 310,-6 316,0 322,0 322,0 352,0 352,0 358,0 364,-6 364,-12 364,-12 364,-24 364,-24 364,-30 358,-36 352,-36\"/>\n<text text-anchor=\"middle\" x=\"337\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 18&#45;&gt;19 -->\n<g id=\"edge11\" class=\"edge\">\n<title>18&#45;&gt;19</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M349.8809,-71.9769C347.8358,-63.4071 345.6635,-54.3043 343.697,-46.0638\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"347.0669,-45.1064 341.3413,-36.192 340.2581,-46.7313 347.0669,-45.1064\"/>\n</g>\n<!-- 22 -->\n<g id=\"node13\" class=\"node\">\n<title>22</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M424,-36C424,-36 394,-36 394,-36 388,-36 382,-30 382,-24 382,-24 382,-12 382,-12 382,-6 388,0 394,0 394,0 424,0 424,0 430,0 436,-6 436,-12 436,-12 436,-24 436,-24 436,-30 430,-36 424,-36\"/>\n<text text-anchor=\"middle\" x=\"409\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 18&#45;&gt;22 -->\n<g id=\"edge12\" class=\"edge\">\n<title>18&#45;&gt;22</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M377.7179,-71.9769C382.9052,-63.0262 388.4295,-53.4941 393.3695,-44.9703\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"396.4709,-46.599 398.4569,-36.192 390.4144,-43.089 396.4709,-46.599\"/>\n</g>\n<!-- 26 -->\n<g id=\"node15\" class=\"node\">\n<title>26</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M500,-36C500,-36 470,-36 470,-36 464,-36 458,-30 458,-24 458,-24 458,-12 458,-12 458,-6 464,0 470,0 470,0 500,0 500,0 506,0 512,-6 512,-12 512,-12 512,-24 512,-24 512,-30 506,-36 500,-36\"/>\n<text text-anchor=\"middle\" x=\"485\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 25&#45;&gt;26 -->\n<g id=\"edge14\" class=\"edge\">\n<title>25&#45;&gt;26</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M491.7471,-71.9769C490.6878,-63.5023 489.5633,-54.5065 488.5423,-46.3388\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"491.9874,-45.6806 487.274,-36.192 485.0415,-46.5489 491.9874,-45.6806\"/>\n</g>\n<!-- 29 -->\n<g id=\"node16\" class=\"node\">\n<title>29</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M572,-36C572,-36 542,-36 542,-36 536,-36 530,-30 530,-24 530,-24 530,-12 530,-12 530,-6 536,0 542,0 542,0 572,0 572,0 578,0 584,-6 584,-12 584,-12 584,-24 584,-24 584,-30 578,-36 572,-36\"/>\n<text text-anchor=\"middle\" x=\"557\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 25&#45;&gt;29 -->\n<g id=\"edge15\" class=\"edge\">\n<title>25&#45;&gt;29</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M519.5842,-71.9769C525.9206,-62.8358 532.6775,-53.0883 538.6809,-44.4276\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"541.5692,-46.4045 544.3897,-36.192 535.8162,-42.4166 541.5692,-46.4045\"/>\n</g>\n<!-- 33 -->\n<g id=\"node18\" class=\"node\">\n<title>33</title>\n<path fill=\"#f6d2b8\" stroke=\"#000000\" d=\"M759,-244C759,-244 655,-244 655,-244 649,-244 643,-238 643,-232 643,-232 643,-188 643,-188 643,-182 649,-176 655,-176 655,-176 759,-176 759,-176 765,-176 771,-182 771,-188 771,-188 771,-232 771,-232 771,-238 765,-244 759,-244\"/>\n<text text-anchor=\"start\" x=\"662.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2002.5</text>\n<text text-anchor=\"start\" x=\"666.5\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.009</text>\n<text text-anchor=\"start\" x=\"651\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 24132</text>\n<text text-anchor=\"start\" x=\"663\" y=\"-183.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.013</text>\n</g>\n<!-- 32&#45;&gt;33 -->\n<g id=\"edge17\" class=\"edge\">\n<title>32&#45;&gt;33</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M707,-279.9465C707,-271.776 707,-262.9318 707,-254.3697\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"710.5001,-254.13 707,-244.13 703.5001,-254.13 710.5001,-254.13\"/>\n</g>\n<!-- 48 -->\n<g id=\"node25\" class=\"node\">\n<title>48</title>\n<path fill=\"#f7d6be\" stroke=\"#000000\" d=\"M974,-244C974,-244 870,-244 870,-244 864,-244 858,-238 858,-232 858,-232 858,-188 858,-188 858,-182 864,-176 870,-176 870,-176 974,-176 974,-176 980,-176 986,-182 986,-188 986,-188 986,-232 986,-232 986,-238 980,-244 974,-244\"/>\n<text text-anchor=\"start\" x=\"877.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2008.5</text>\n<text text-anchor=\"start\" x=\"881.5\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.008</text>\n<text text-anchor=\"start\" x=\"866\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 22498</text>\n<text text-anchor=\"start\" x=\"875.5\" y=\"-183.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = &#45;0.002</text>\n</g>\n<!-- 32&#45;&gt;48 -->\n<g id=\"edge24\" class=\"edge\">\n<title>32&#45;&gt;48</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M771.2159,-282.9374C795.5707,-271.1565 823.484,-257.6543 848.6204,-245.4953\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"850.2657,-248.5874 857.7438,-241.0821 847.2175,-242.2859 850.2657,-248.5874\"/>\n</g>\n<!-- 34 -->\n<g id=\"node19\" class=\"node\">\n<title>34</title>\n<path fill=\"#f7dbc6\" stroke=\"#000000\" d=\"M686,-140C686,-140 590,-140 590,-140 584,-140 578,-134 578,-128 578,-128 578,-84 578,-84 578,-78 584,-72 590,-72 590,-72 686,-72 686,-72 692,-72 698,-78 698,-84 698,-84 698,-128 698,-128 698,-134 692,-140 686,-140\"/>\n<text text-anchor=\"start\" x=\"600\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 8.5</text>\n<text text-anchor=\"start\" x=\"597.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.014</text>\n<text text-anchor=\"start\" x=\"586\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 1587</text>\n<text text-anchor=\"start\" x=\"591.5\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = &#45;0.018</text>\n</g>\n<!-- 33&#45;&gt;34 -->\n<g id=\"edge18\" class=\"edge\">\n<title>33&#45;&gt;34</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M684.4068,-175.9465C678.6286,-167.2373 672.3425,-157.7626 666.3166,-148.6801\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"669.089,-146.5278 660.6439,-140.13 663.256,-150.3978 669.089,-146.5278\"/>\n</g>\n<!-- 41 -->\n<g id=\"node22\" class=\"node\">\n<title>41</title>\n<path fill=\"#f6d1b7\" stroke=\"#000000\" d=\"M832,-140C832,-140 728,-140 728,-140 722,-140 716,-134 716,-128 716,-128 716,-84 716,-84 716,-78 722,-72 728,-72 728,-72 832,-72 832,-72 838,-72 844,-78 844,-84 844,-84 844,-128 844,-128 844,-134 838,-140 832,-140\"/>\n<text text-anchor=\"start\" x=\"735.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2003.5</text>\n<text text-anchor=\"start\" x=\"739.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.008</text>\n<text text-anchor=\"start\" x=\"724\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 22545</text>\n<text text-anchor=\"start\" x=\"736\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.015</text>\n</g>\n<!-- 33&#45;&gt;41 -->\n<g id=\"edge21\" class=\"edge\">\n<title>33&#45;&gt;41</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M730.9029,-175.9465C737.0791,-167.1475 743.8038,-157.5672 750.239,-148.3993\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"753.1629,-150.3258 756.0434,-140.13 747.4334,-146.3041 753.1629,-150.3258\"/>\n</g>\n<!-- 35 -->\n<g id=\"node20\" class=\"node\">\n<title>35</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M646,-36C646,-36 616,-36 616,-36 610,-36 604,-30 604,-24 604,-24 604,-12 604,-12 604,-6 610,0 616,0 616,0 646,0 646,0 652,0 658,-6 658,-12 658,-12 658,-24 658,-24 658,-30 652,-36 646,-36\"/>\n<text text-anchor=\"middle\" x=\"631\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 34&#45;&gt;35 -->\n<g id=\"edge19\" class=\"edge\">\n<title>34&#45;&gt;35</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M635.2936,-71.9769C634.6195,-63.5023 633.9039,-54.5065 633.2542,-46.3388\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"636.7291,-45.8829 632.4471,-36.192 629.7511,-46.438 636.7291,-45.8829\"/>\n</g>\n<!-- 38 -->\n<g id=\"node21\" class=\"node\">\n<title>38</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M718,-36C718,-36 688,-36 688,-36 682,-36 676,-30 676,-24 676,-24 676,-12 676,-12 676,-6 682,0 688,0 688,0 718,0 718,0 724,0 730,-6 730,-12 730,-12 730,-24 730,-24 730,-30 724,-36 718,-36\"/>\n<text text-anchor=\"middle\" x=\"703\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 34&#45;&gt;38 -->\n<g id=\"edge20\" class=\"edge\">\n<title>34&#45;&gt;38</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M663.1307,-71.9769C669.8826,-62.8358 677.0825,-53.0883 683.4796,-44.4276\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"686.4367,-46.3151 689.5628,-36.192 680.8061,-42.1561 686.4367,-46.3151\"/>\n</g>\n<!-- 42 -->\n<g id=\"node23\" class=\"node\">\n<title>42</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M791,-36C791,-36 761,-36 761,-36 755,-36 749,-30 749,-24 749,-24 749,-12 749,-12 749,-6 755,0 761,0 761,0 791,0 791,0 797,0 803,-6 803,-12 803,-12 803,-24 803,-24 803,-30 797,-36 791,-36\"/>\n<text text-anchor=\"middle\" x=\"776\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 41&#45;&gt;42 -->\n<g id=\"edge22\" class=\"edge\">\n<title>41&#45;&gt;42</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M778.4535,-71.9769C778.0683,-63.5023 777.6594,-54.5065 777.2881,-46.3388\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"780.7775,-46.0227 776.8269,-36.192 773.7847,-46.3406 780.7775,-46.0227\"/>\n</g>\n<!-- 45 -->\n<g id=\"node24\" class=\"node\">\n<title>45</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M863,-36C863,-36 833,-36 833,-36 827,-36 821,-30 821,-24 821,-24 821,-12 821,-12 821,-6 827,0 833,0 833,0 863,0 863,0 869,0 875,-6 875,-12 875,-12 875,-24 875,-24 875,-30 869,-36 863,-36\"/>\n<text text-anchor=\"middle\" x=\"848\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 41&#45;&gt;45 -->\n<g id=\"edge23\" class=\"edge\">\n<title>41&#45;&gt;45</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M806.2906,-71.9769C813.4277,-62.7406 821.0433,-52.8851 827.7875,-44.1573\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"830.5976,-46.2449 833.9426,-36.192 825.0586,-41.9647 830.5976,-46.2449\"/>\n</g>\n<!-- 49 -->\n<g id=\"node26\" class=\"node\">\n<title>49</title>\n<path fill=\"#f8dcc8\" stroke=\"#000000\" d=\"M970,-140C970,-140 874,-140 874,-140 868,-140 862,-134 862,-128 862,-128 862,-84 862,-84 862,-78 868,-72 874,-72 874,-72 970,-72 970,-72 976,-72 982,-78 982,-84 982,-84 982,-128 982,-128 982,-134 976,-140 970,-140\"/>\n<text text-anchor=\"start\" x=\"877.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2007.5</text>\n<text text-anchor=\"start\" x=\"881.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.013</text>\n<text text-anchor=\"start\" x=\"870\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 6081</text>\n<text text-anchor=\"start\" x=\"875.5\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = &#45;0.023</text>\n</g>\n<!-- 48&#45;&gt;49 -->\n<g id=\"edge25\" class=\"edge\">\n<title>48&#45;&gt;49</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M922,-175.9465C922,-167.776 922,-158.9318 922,-150.3697\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"925.5001,-150.13 922,-140.13 918.5001,-150.13 925.5001,-150.13\"/>\n</g>\n<!-- 56 -->\n<g id=\"node29\" class=\"node\">\n<title>56</title>\n<path fill=\"#f6d4bb\" stroke=\"#000000\" d=\"M1116,-140C1116,-140 1012,-140 1012,-140 1006,-140 1000,-134 1000,-128 1000,-128 1000,-84 1000,-84 1000,-78 1006,-72 1012,-72 1012,-72 1116,-72 1116,-72 1122,-72 1128,-78 1128,-84 1128,-84 1128,-128 1128,-128 1128,-134 1122,-140 1116,-140\"/>\n<text text-anchor=\"start\" x=\"1008\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">bb_down ≤ 0.038</text>\n<text text-anchor=\"start\" x=\"1023.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">mse = 0.006</text>\n<text text-anchor=\"start\" x=\"1008\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 16417</text>\n<text text-anchor=\"start\" x=\"1020\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = 0.006</text>\n</g>\n<!-- 48&#45;&gt;56 -->\n<g id=\"edge28\" class=\"edge\">\n<title>48&#45;&gt;56</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M968.4961,-175.9465C981.4908,-166.4293 995.7337,-155.9978 1009.1621,-146.163\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1011.3998,-148.8624 1017.3994,-140.13 1007.2637,-143.2151 1011.3998,-148.8624\"/>\n</g>\n<!-- 50 -->\n<g id=\"node27\" class=\"node\">\n<title>50</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M935,-36C935,-36 905,-36 905,-36 899,-36 893,-30 893,-24 893,-24 893,-12 893,-12 893,-6 899,0 905,0 905,0 935,0 935,0 941,0 947,-6 947,-12 947,-12 947,-24 947,-24 947,-30 941,-36 935,-36\"/>\n<text text-anchor=\"middle\" x=\"920\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 49&#45;&gt;50 -->\n<g id=\"edge26\" class=\"edge\">\n<title>49&#45;&gt;50</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M921.2267,-71.9769C921.0341,-63.5023 920.8297,-54.5065 920.6441,-46.3388\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"924.1399,-46.1098 920.4135,-36.192 917.1417,-46.2689 924.1399,-46.1098\"/>\n</g>\n<!-- 53 -->\n<g id=\"node28\" class=\"node\">\n<title>53</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M1007,-36C1007,-36 977,-36 977,-36 971,-36 965,-30 965,-24 965,-24 965,-12 965,-12 965,-6 971,0 977,0 977,0 1007,0 1007,0 1013,0 1019,-6 1019,-12 1019,-12 1019,-24 1019,-24 1019,-30 1013,-36 1007,-36\"/>\n<text text-anchor=\"middle\" x=\"992\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 49&#45;&gt;53 -->\n<g id=\"edge27\" class=\"edge\">\n<title>49&#45;&gt;53</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M949.0638,-71.9769C956.4109,-62.7406 964.2505,-52.8851 971.1931,-44.1573\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"974.043,-46.1968 977.5291,-36.192 968.5647,-41.8391 974.043,-46.1968\"/>\n</g>\n<!-- 57 -->\n<g id=\"node30\" class=\"node\">\n<title>57</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M1079,-36C1079,-36 1049,-36 1049,-36 1043,-36 1037,-30 1037,-24 1037,-24 1037,-12 1037,-12 1037,-6 1043,0 1049,0 1049,0 1079,0 1079,0 1085,0 1091,-6 1091,-12 1091,-12 1091,-24 1091,-24 1091,-30 1085,-36 1079,-36\"/>\n<text text-anchor=\"middle\" x=\"1064\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 56&#45;&gt;57 -->\n<g id=\"edge29\" class=\"edge\">\n<title>56&#45;&gt;57</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1064,-71.9769C1064,-63.5023 1064,-54.5065 1064,-46.3388\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1067.5001,-46.1919 1064,-36.192 1060.5001,-46.192 1067.5001,-46.1919\"/>\n</g>\n<!-- 60 -->\n<g id=\"node31\" class=\"node\">\n<title>60</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M1151,-36C1151,-36 1121,-36 1121,-36 1115,-36 1109,-30 1109,-24 1109,-24 1109,-12 1109,-12 1109,-6 1115,0 1121,0 1121,0 1151,0 1151,0 1157,0 1163,-6 1163,-12 1163,-12 1163,-24 1163,-24 1163,-30 1157,-36 1151,-36\"/>\n<text text-anchor=\"middle\" x=\"1136\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 56&#45;&gt;60 -->\n<g id=\"edge30\" class=\"edge\">\n<title>56&#45;&gt;60</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1091.8371,-71.9769C1099.3941,-62.7406 1107.4576,-52.8851 1114.5986,-44.1573\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1117.4921,-46.1479 1121.1157,-36.192 1112.0744,-41.7152 1117.4921,-46.1479\"/>\n</g>\n</g>\n</svg>\n"
},
"metadata": {},
"execution_count": 51
}
],
"source": [
"out_file = results_path / 'reg_tree.dot'\n",
"dot_data = export_graphviz(regression_tree,\n",
" out_file=out_file.as_posix(),\n",
" feature_names=X_train.columns,\n",
" max_depth=3,\n",
" filled=True,\n",
" rounded=True,\n",
" special_characters=True)\n",
"if out_file is not None:\n",
" dot_data = Path(out_file).read_text()\n",
"\n",
"graphviz.Source(dot_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kimQ7QYiZDwH"
},
"source": [
"### Evaluate Test Set"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:35.355228Z",
"start_time": "2021-04-16T00:33:35.333642Z"
},
"id": "nDCYUa3vZDwI"
},
"outputs": [],
"source": [
"y_pred = regression_tree.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:35.363324Z",
"start_time": "2021-04-16T00:33:35.358601Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Qdfmha5mZDwJ",
"outputId": "a472a9d4-fa7a-4ab5-f74d-aa83edd21a04"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.08659823080682821"
]
},
"metadata": {},
"execution_count": 53
}
],
"source": [
"np.sqrt(mean_squared_error(y_pred=y_pred, y_true=y_test))"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:35.379470Z",
"start_time": "2021-04-16T00:33:35.365615Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "educlMgkZDwL",
"outputId": "b2043fdb-13bf-43f4-be73-09fc8b098227"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"18.95 (p-value=0.00%)\n"
]
}
],
"source": [
"r, p = spearmanr(y_pred, y_test)\n",
"print(f'{r*100:.2f} (p-value={p:.2%})')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2HPpWFYjZDwM"
},
"source": [
"## Classification Tree with all Features"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OW2XC4rBZDwN"
},
"source": [
"We will now train, visualize, and evaluate a classification tree with up to 5 consecutive splits using 80% of the samples for training to predict the remaining 20%. We are taking a shortcut here to simplify the illustration and use the built-in train_test_split, which does not protect against lookahead bias, as our custom iterator. The tree configuration implies up to $2^5=32$ leaf nodes that, on average in the balanced case, would contain over 4,300 of the training samples."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4s2bXgPiZDwO"
},
"source": [
"### Train-Test Split"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:35.391198Z",
"start_time": "2021-04-16T00:33:35.380409Z"
},
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jyuyf81OZDwO",
"outputId": "2761b409-3e20-4545-e7dc-3e6698a4733f"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1 43733\n",
"0 34055\n",
"Name: target, dtype: int64"
]
},
"metadata": {},
"execution_count": 55
}
],
"source": [
"y_binary = (y>0).astype(int)\n",
"y_binary.value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:35.421858Z",
"start_time": "2021-04-16T00:33:35.392557Z"
},
"id": "q3riNoiJZDwP"
},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y_binary, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:35.425897Z",
"start_time": "2021-04-16T00:33:35.423388Z"
},
"id": "QdoF0Pl1ZDwP"
},
"outputs": [],
"source": [
"clf = DecisionTreeClassifier(criterion='gini',\n",
" max_depth=5,\n",
" random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.216311Z",
"start_time": "2021-04-16T00:33:35.427617Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3bRtiGpdZDwQ",
"outputId": "a91c0a43-d4a7-4211-cb47-24d287f3d707"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DecisionTreeClassifier(max_depth=5, random_state=42)"
]
},
"metadata": {},
"execution_count": 58
}
],
"source": [
"clf.fit(X=X_train, y=y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NJVtIeWbZDwR"
},
"source": [
"### Plot Tree"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.260644Z",
"start_time": "2021-04-16T00:33:36.219967Z"
},
"scrolled": false,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 734
},
"id": "z_6xrpF1ZDwR",
"outputId": "f726b54e-d9c5-4bec-a4d5-16e97628caf3"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<graphviz.files.Source at 0x7f2b159a5ed0>"
],
"image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: Tree Pages: 1 -->\n<svg width=\"1293pt\" height=\"520pt\"\n viewBox=\"0.00 0.00 1293.00 520.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 516)\">\n<title>Tree</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-516 1289,-516 1289,4 -4,4\"/>\n<!-- 0 -->\n<g id=\"node1\" class=\"node\">\n<title>0</title>\n<path fill=\"#d4eaf9\" stroke=\"#000000\" d=\"M701,-512C701,-512 560,-512 560,-512 554,-512 548,-506 548,-500 548,-500 548,-441 548,-441 548,-435 554,-429 560,-429 560,-429 701,-429 701,-429 707,-429 713,-435 713,-441 713,-441 713,-500 713,-500 713,-506 707,-512 701,-512\"/>\n<text text-anchor=\"start\" x=\"592.5\" y=\"-496.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 3.5</text>\n<text text-anchor=\"start\" x=\"592.5\" y=\"-481.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.492</text>\n<text text-anchor=\"start\" x=\"574.5\" y=\"-466.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 62230</text>\n<text text-anchor=\"start\" x=\"556\" y=\"-451.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [27286, 34944]</text>\n<text text-anchor=\"start\" x=\"596\" y=\"-436.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 1 -->\n<g id=\"node2\" class=\"node\">\n<title>1</title>\n<path fill=\"#b7dbf5\" stroke=\"#000000\" d=\"M534,-393C534,-393 409,-393 409,-393 403,-393 397,-387 397,-381 397,-381 397,-322 397,-322 397,-316 403,-310 409,-310 409,-310 534,-310 534,-310 540,-310 546,-316 546,-322 546,-322 546,-381 546,-381 546,-387 540,-393 534,-393\"/>\n<text text-anchor=\"start\" x=\"427\" y=\"-377.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2008.5</text>\n<text text-anchor=\"start\" x=\"433.5\" y=\"-362.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.475</text>\n<text text-anchor=\"start\" x=\"415.5\" y=\"-347.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 15600</text>\n<text text-anchor=\"start\" x=\"405\" y=\"-332.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [6053, 9547]</text>\n<text text-anchor=\"start\" x=\"437\" y=\"-317.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g id=\"edge1\" class=\"edge\">\n<title>0&#45;&gt;1</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M574.8895,-428.8796C562.0597,-419.2774 548.3148,-408.9903 535.1714,-399.1534\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"537.1823,-396.2868 527.0791,-393.0969 532.9879,-401.891 537.1823,-396.2868\"/>\n<text text-anchor=\"middle\" x=\"530.6395\" y=\"-414.1421\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">True</text>\n</g>\n<!-- 32 -->\n<g id=\"node17\" class=\"node\">\n<title>32</title>\n<path fill=\"#dfeffb\" stroke=\"#000000\" d=\"M856,-393C856,-393 715,-393 715,-393 709,-393 703,-387 703,-381 703,-381 703,-322 703,-322 703,-316 709,-310 715,-310 715,-310 856,-310 856,-310 862,-310 868,-316 868,-322 868,-322 868,-381 868,-381 868,-387 862,-393 856,-393\"/>\n<text text-anchor=\"start\" x=\"741\" y=\"-377.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2002.5</text>\n<text text-anchor=\"start\" x=\"747.5\" y=\"-362.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.496</text>\n<text text-anchor=\"start\" x=\"729.5\" y=\"-347.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 46630</text>\n<text text-anchor=\"start\" x=\"711\" y=\"-332.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [21233, 25397]</text>\n<text text-anchor=\"start\" x=\"751\" y=\"-317.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 0&#45;&gt;32 -->\n<g id=\"edge16\" class=\"edge\">\n<title>0&#45;&gt;32</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M684.7115,-428.8796C697.1006,-419.368 710.3649,-409.1843 723.0676,-399.432\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"725.5186,-401.9628 731.3191,-393.0969 721.2558,-396.4105 725.5186,-401.9628\"/>\n<text text-anchor=\"middle\" x=\"728.0624\" y=\"-414.1839\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">False</text>\n</g>\n<!-- 2 -->\n<g id=\"node3\" class=\"node\">\n<title>2</title>\n<path fill=\"#d8ecfa\" stroke=\"#000000\" d=\"M296,-274C296,-274 171,-274 171,-274 165,-274 159,-268 159,-262 159,-262 159,-203 159,-203 159,-197 165,-191 171,-191 171,-191 296,-191 296,-191 302,-191 308,-197 308,-203 308,-203 308,-262 308,-262 308,-268 302,-274 296,-274\"/>\n<text text-anchor=\"start\" x=\"170\" y=\"-258.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">return_12m ≤ 0.032</text>\n<text text-anchor=\"start\" x=\"195.5\" y=\"-243.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.494</text>\n<text text-anchor=\"start\" x=\"181.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 6335</text>\n<text text-anchor=\"start\" x=\"167\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [2826, 3509]</text>\n<text text-anchor=\"start\" x=\"199\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 1&#45;&gt;2 -->\n<g id=\"edge2\" class=\"edge\">\n<title>1&#45;&gt;2</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M396.8483,-314.1742C371.6056,-301.5528 343.2561,-287.378 317.3485,-274.4243\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"318.8052,-271.2395 308.2957,-269.8978 315.6747,-277.5005 318.8052,-271.2395\"/>\n</g>\n<!-- 17 -->\n<g id=\"node10\" class=\"node\">\n<title>17</title>\n<path fill=\"#a3d1f3\" stroke=\"#000000\" d=\"M534,-274C534,-274 409,-274 409,-274 403,-274 397,-268 397,-262 397,-262 397,-203 397,-203 397,-197 403,-191 409,-191 409,-191 534,-191 534,-191 540,-191 546,-197 546,-203 546,-203 546,-262 546,-262 546,-268 540,-274 534,-274\"/>\n<text text-anchor=\"start\" x=\"433.5\" y=\"-258.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 2.5</text>\n<text text-anchor=\"start\" x=\"433.5\" y=\"-243.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.454</text>\n<text text-anchor=\"start\" x=\"419.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 9265</text>\n<text text-anchor=\"start\" x=\"405\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [3227, 6038]</text>\n<text text-anchor=\"start\" x=\"437\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 1&#45;&gt;17 -->\n<g id=\"edge9\" class=\"edge\">\n<title>1&#45;&gt;17</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M471.5,-309.8796C471.5,-301.6838 471.5,-292.9891 471.5,-284.5013\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"475.0001,-284.298 471.5,-274.2981 468.0001,-284.2981 475.0001,-284.298\"/>\n</g>\n<!-- 3 -->\n<g id=\"node4\" class=\"node\">\n<title>3</title>\n<path fill=\"#cee7f9\" stroke=\"#000000\" d=\"M137,-155C137,-155 12,-155 12,-155 6,-155 0,-149 0,-143 0,-143 0,-84 0,-84 0,-78 6,-72 12,-72 12,-72 137,-72 137,-72 143,-72 149,-78 149,-84 149,-84 149,-143 149,-143 149,-149 143,-155 137,-155\"/>\n<text text-anchor=\"start\" x=\"36.5\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 1.5</text>\n<text text-anchor=\"start\" x=\"41\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.49</text>\n<text text-anchor=\"start\" x=\"22.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 5396</text>\n<text text-anchor=\"start\" x=\"8\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [2320, 3076]</text>\n<text text-anchor=\"start\" x=\"40\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g id=\"edge3\" class=\"edge\">\n<title>2&#45;&gt;3</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M177.8895,-190.8796C165.0597,-181.2774 151.3148,-170.9903 138.1714,-161.1534\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"140.1823,-158.2868 130.0791,-155.0969 135.9879,-163.891 140.1823,-158.2868\"/>\n</g>\n<!-- 10 -->\n<g id=\"node7\" class=\"node\">\n<title>10</title>\n<path fill=\"#fbede2\" stroke=\"#000000\" d=\"M287.5,-155C287.5,-155 179.5,-155 179.5,-155 173.5,-155 167.5,-149 167.5,-143 167.5,-143 167.5,-84 167.5,-84 167.5,-78 173.5,-72 179.5,-72 179.5,-72 287.5,-72 287.5,-72 293.5,-72 299.5,-78 299.5,-84 299.5,-84 299.5,-143 299.5,-143 299.5,-149 293.5,-155 287.5,-155\"/>\n<text text-anchor=\"start\" x=\"195.5\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 2.5</text>\n<text text-anchor=\"start\" x=\"195.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.497</text>\n<text text-anchor=\"start\" x=\"185.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 939</text>\n<text text-anchor=\"start\" x=\"175.5\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [506, 433]</text>\n<text text-anchor=\"start\" x=\"189.5\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Down</text>\n</g>\n<!-- 2&#45;&gt;10 -->\n<g id=\"edge6\" class=\"edge\">\n<title>2&#45;&gt;10</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M233.5,-190.8796C233.5,-182.6838 233.5,-173.9891 233.5,-165.5013\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"237.0001,-165.298 233.5,-155.2981 230.0001,-165.2981 237.0001,-165.298\"/>\n</g>\n<!-- 4 -->\n<g id=\"node5\" class=\"node\">\n<title>4</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M53.5,-36C53.5,-36 23.5,-36 23.5,-36 17.5,-36 11.5,-30 11.5,-24 11.5,-24 11.5,-12 11.5,-12 11.5,-6 17.5,0 23.5,0 23.5,0 53.5,0 53.5,0 59.5,0 65.5,-6 65.5,-12 65.5,-12 65.5,-24 65.5,-24 65.5,-30 59.5,-36 53.5,-36\"/>\n<text text-anchor=\"middle\" x=\"38.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 3&#45;&gt;4 -->\n<g id=\"edge4\" class=\"edge\">\n<title>3&#45;&gt;4</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M58.8146,-71.8901C55.4739,-63.0279 52.025,-53.8788 48.9386,-45.6913\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"52.1877,-44.3878 45.3853,-36.2651 45.6376,-46.8569 52.1877,-44.3878\"/>\n</g>\n<!-- 7 -->\n<g id=\"node6\" class=\"node\">\n<title>7</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M125.5,-36C125.5,-36 95.5,-36 95.5,-36 89.5,-36 83.5,-30 83.5,-24 83.5,-24 83.5,-12 83.5,-12 83.5,-6 89.5,0 95.5,0 95.5,0 125.5,0 125.5,0 131.5,0 137.5,-6 137.5,-12 137.5,-12 137.5,-24 137.5,-24 137.5,-30 131.5,-36 125.5,-36\"/>\n<text text-anchor=\"middle\" x=\"110.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 3&#45;&gt;7 -->\n<g id=\"edge5\" class=\"edge\">\n<title>3&#45;&gt;7</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M90.1854,-71.8901C93.5261,-63.0279 96.975,-53.8788 100.0614,-45.6913\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"103.3624,-46.8569 103.6147,-36.2651 96.8123,-44.3878 103.3624,-46.8569\"/>\n</g>\n<!-- 11 -->\n<g id=\"node8\" class=\"node\">\n<title>11</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M207.5,-36C207.5,-36 177.5,-36 177.5,-36 171.5,-36 165.5,-30 165.5,-24 165.5,-24 165.5,-12 165.5,-12 165.5,-6 171.5,0 177.5,0 177.5,0 207.5,0 207.5,0 213.5,0 219.5,-6 219.5,-12 219.5,-12 219.5,-24 219.5,-24 219.5,-30 213.5,-36 207.5,-36\"/>\n<text text-anchor=\"middle\" x=\"192.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 10&#45;&gt;11 -->\n<g id=\"edge7\" class=\"edge\">\n<title>10&#45;&gt;11</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M215.6361,-71.8901C211.8314,-63.0279 207.9035,-53.8788 204.3884,-45.6913\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"207.5028,-44.0733 200.3416,-36.2651 201.0705,-46.8348 207.5028,-44.0733\"/>\n</g>\n<!-- 14 -->\n<g id=\"node9\" class=\"node\">\n<title>14</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M279.5,-36C279.5,-36 249.5,-36 249.5,-36 243.5,-36 237.5,-30 237.5,-24 237.5,-24 237.5,-12 237.5,-12 237.5,-6 243.5,0 249.5,0 249.5,0 279.5,0 279.5,0 285.5,0 291.5,-6 291.5,-12 291.5,-12 291.5,-24 291.5,-24 291.5,-30 285.5,-36 279.5,-36\"/>\n<text text-anchor=\"middle\" x=\"264.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 10&#45;&gt;14 -->\n<g id=\"edge8\" class=\"edge\">\n<title>10&#45;&gt;14</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M247.0069,-71.8901C249.852,-63.1253 252.7882,-54.0798 255.4235,-45.9615\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"258.8125,-46.8572 258.571,-36.2651 252.1545,-44.6959 258.8125,-46.8572\"/>\n</g>\n<!-- 18 -->\n<g id=\"node11\" class=\"node\">\n<title>18</title>\n<path fill=\"#96cbf1\" stroke=\"#000000\" d=\"M455,-155C455,-155 330,-155 330,-155 324,-155 318,-149 318,-143 318,-143 318,-84 318,-84 318,-78 324,-72 330,-72 330,-72 455,-72 455,-72 461,-72 467,-78 467,-84 467,-84 467,-143 467,-143 467,-149 461,-155 455,-155\"/>\n<text text-anchor=\"start\" x=\"348\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2009.5</text>\n<text text-anchor=\"start\" x=\"354.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.435</text>\n<text text-anchor=\"start\" x=\"340.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 6210</text>\n<text text-anchor=\"start\" x=\"326\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [1983, 4227]</text>\n<text text-anchor=\"start\" x=\"358\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 17&#45;&gt;18 -->\n<g id=\"edge10\" class=\"edge\">\n<title>17&#45;&gt;18</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M443.8696,-190.8796C438.07,-182.1434 431.8941,-172.8404 425.9092,-163.8253\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"428.6951,-161.6935 420.2483,-155.2981 422.8632,-165.5652 428.6951,-161.6935\"/>\n</g>\n<!-- 25 -->\n<g id=\"node14\" class=\"node\">\n<title>25</title>\n<path fill=\"#c1e0f7\" stroke=\"#000000\" d=\"M622,-155C622,-155 497,-155 497,-155 491,-155 485,-149 485,-143 485,-143 485,-84 485,-84 485,-78 491,-72 497,-72 497,-72 622,-72 622,-72 628,-72 634,-78 634,-84 634,-84 634,-143 634,-143 634,-149 628,-155 622,-155\"/>\n<text text-anchor=\"start\" x=\"515\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2011.5</text>\n<text text-anchor=\"start\" x=\"521.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.483</text>\n<text text-anchor=\"start\" x=\"507.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 3055</text>\n<text text-anchor=\"start\" x=\"493\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [1244, 1811]</text>\n<text text-anchor=\"start\" x=\"525\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 17&#45;&gt;25 -->\n<g id=\"edge13\" class=\"edge\">\n<title>17&#45;&gt;25</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M502.2781,-190.8796C508.8051,-182.0534 515.7599,-172.6485 522.4908,-163.5466\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"525.4588,-165.4195 528.5905,-155.2981 519.8305,-161.2574 525.4588,-165.4195\"/>\n</g>\n<!-- 19 -->\n<g id=\"node12\" class=\"node\">\n<title>19</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M370.5,-36C370.5,-36 340.5,-36 340.5,-36 334.5,-36 328.5,-30 328.5,-24 328.5,-24 328.5,-12 328.5,-12 328.5,-6 334.5,0 340.5,0 340.5,0 370.5,0 370.5,0 376.5,0 382.5,-6 382.5,-12 382.5,-12 382.5,-24 382.5,-24 382.5,-30 376.5,-36 370.5,-36\"/>\n<text text-anchor=\"middle\" x=\"355.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 18&#45;&gt;19 -->\n<g id=\"edge11\" class=\"edge\">\n<title>18&#45;&gt;19</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M376.3789,-71.8901C372.9454,-63.0279 369.4007,-53.8788 366.2285,-45.6913\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"369.4529,-44.3253 362.5765,-36.2651 362.9257,-46.8542 369.4529,-44.3253\"/>\n</g>\n<!-- 22 -->\n<g id=\"node13\" class=\"node\">\n<title>22</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M442.5,-36C442.5,-36 412.5,-36 412.5,-36 406.5,-36 400.5,-30 400.5,-24 400.5,-24 400.5,-12 400.5,-12 400.5,-6 406.5,0 412.5,0 412.5,0 442.5,0 442.5,0 448.5,0 454.5,-6 454.5,-12 454.5,-12 454.5,-24 454.5,-24 454.5,-30 448.5,-36 442.5,-36\"/>\n<text text-anchor=\"middle\" x=\"427.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 18&#45;&gt;22 -->\n<g id=\"edge12\" class=\"edge\">\n<title>18&#45;&gt;22</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M407.7497,-71.8901C410.9976,-63.0279 414.3507,-53.8788 417.3514,-45.6913\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"420.6511,-46.8588 420.806,-36.2651 414.0786,-44.45 420.6511,-46.8588\"/>\n</g>\n<!-- 26 -->\n<g id=\"node15\" class=\"node\">\n<title>26</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M523.5,-36C523.5,-36 493.5,-36 493.5,-36 487.5,-36 481.5,-30 481.5,-24 481.5,-24 481.5,-12 481.5,-12 481.5,-6 487.5,0 493.5,0 493.5,0 523.5,0 523.5,0 529.5,0 535.5,-6 535.5,-12 535.5,-12 535.5,-24 535.5,-24 535.5,-30 529.5,-36 523.5,-36\"/>\n<text text-anchor=\"middle\" x=\"508.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 25&#45;&gt;26 -->\n<g id=\"edge14\" class=\"edge\">\n<title>25&#45;&gt;26</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M537.279,-71.8901C532.4423,-62.8331 527.4456,-53.4765 523.0005,-45.1528\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"526.0522,-43.4373 518.2541,-36.2651 519.8775,-46.7348 526.0522,-43.4373\"/>\n</g>\n<!-- 29 -->\n<g id=\"node16\" class=\"node\">\n<title>29</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M595.5,-36C595.5,-36 565.5,-36 565.5,-36 559.5,-36 553.5,-30 553.5,-24 553.5,-24 553.5,-12 553.5,-12 553.5,-6 559.5,0 565.5,0 565.5,0 595.5,0 595.5,0 601.5,0 607.5,-6 607.5,-12 607.5,-12 607.5,-24 607.5,-24 607.5,-30 601.5,-36 595.5,-36\"/>\n<text text-anchor=\"middle\" x=\"580.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 25&#45;&gt;29 -->\n<g id=\"edge15\" class=\"edge\">\n<title>25&#45;&gt;29</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M568.6498,-71.8901C570.5557,-63.2227 572.522,-54.2808 574.2918,-46.2325\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"577.7542,-46.7835 576.4836,-36.2651 570.9175,-45.2801 577.7542,-46.7835\"/>\n</g>\n<!-- 33 -->\n<g id=\"node18\" class=\"node\">\n<title>33</title>\n<path fill=\"#f6d5bd\" stroke=\"#000000\" d=\"M848,-274C848,-274 723,-274 723,-274 717,-274 711,-268 711,-262 711,-262 711,-203 711,-203 711,-197 717,-191 723,-191 723,-191 848,-191 848,-191 854,-191 860,-197 860,-203 860,-203 860,-262 860,-262 860,-268 854,-274 848,-274\"/>\n<text text-anchor=\"start\" x=\"747.5\" y=\"-258.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 8.5</text>\n<text text-anchor=\"start\" x=\"752\" y=\"-243.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.48</text>\n<text text-anchor=\"start\" x=\"733.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 2516</text>\n<text text-anchor=\"start\" x=\"719\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [1508, 1008]</text>\n<text text-anchor=\"start\" x=\"741.5\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Down</text>\n</g>\n<!-- 32&#45;&gt;33 -->\n<g id=\"edge17\" class=\"edge\">\n<title>32&#45;&gt;33</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M785.5,-309.8796C785.5,-301.6838 785.5,-292.9891 785.5,-284.5013\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"789.0001,-284.298 785.5,-274.2981 782.0001,-284.2981 789.0001,-284.298\"/>\n</g>\n<!-- 48 -->\n<g id=\"node25\" class=\"node\">\n<title>48</title>\n<path fill=\"#d9ecfa\" stroke=\"#000000\" d=\"M1098,-274C1098,-274 957,-274 957,-274 951,-274 945,-268 945,-262 945,-262 945,-203 945,-203 945,-197 951,-191 957,-191 957,-191 1098,-191 1098,-191 1104,-191 1110,-197 1110,-203 1110,-203 1110,-262 1110,-262 1110,-268 1104,-274 1098,-274\"/>\n<text text-anchor=\"start\" x=\"983\" y=\"-258.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2006.5</text>\n<text text-anchor=\"start\" x=\"989.5\" y=\"-243.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.494</text>\n<text text-anchor=\"start\" x=\"971.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 44114</text>\n<text text-anchor=\"start\" x=\"953\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [19725, 24389]</text>\n<text text-anchor=\"start\" x=\"993\" y=\"-198.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 32&#45;&gt;48 -->\n<g id=\"edge24\" class=\"edge\">\n<title>32&#45;&gt;48</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M868.1058,-310.8798C889.8881,-300.1687 913.5223,-288.5469 935.7548,-277.6144\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"937.4266,-280.6926 944.8559,-273.1391 934.3377,-274.411 937.4266,-280.6926\"/>\n</g>\n<!-- 34 -->\n<g id=\"node19\" class=\"node\">\n<title>34</title>\n<path fill=\"#f1b991\" stroke=\"#000000\" d=\"M772.5,-155C772.5,-155 664.5,-155 664.5,-155 658.5,-155 652.5,-149 652.5,-143 652.5,-143 652.5,-84 652.5,-84 652.5,-78 658.5,-72 664.5,-72 664.5,-72 772.5,-72 772.5,-72 778.5,-72 784.5,-78 784.5,-84 784.5,-84 784.5,-143 784.5,-143 784.5,-149 778.5,-155 772.5,-155\"/>\n<text text-anchor=\"start\" x=\"680.5\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 4.5</text>\n<text text-anchor=\"start\" x=\"680.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.426</text>\n<text text-anchor=\"start\" x=\"666.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 1398</text>\n<text text-anchor=\"start\" x=\"660.5\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [968, 430]</text>\n<text text-anchor=\"start\" x=\"674.5\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Down</text>\n</g>\n<!-- 33&#45;&gt;34 -->\n<g id=\"edge18\" class=\"edge\">\n<title>33&#45;&gt;34</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M762.0667,-190.8796C757.1987,-182.2335 752.0181,-173.0322 746.9914,-164.1042\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"749.9893,-162.2947 742.0334,-155.2981 743.8897,-165.729 749.9893,-162.2947\"/>\n</g>\n<!-- 41 -->\n<g id=\"node22\" class=\"node\">\n<title>41</title>\n<path fill=\"#f2f9fd\" stroke=\"#000000\" d=\"M922.5,-155C922.5,-155 814.5,-155 814.5,-155 808.5,-155 802.5,-149 802.5,-143 802.5,-143 802.5,-84 802.5,-84 802.5,-78 808.5,-72 814.5,-72 814.5,-72 922.5,-72 922.5,-72 928.5,-72 934.5,-78 934.5,-84 934.5,-84 934.5,-143 934.5,-143 934.5,-149 928.5,-155 922.5,-155\"/>\n<text text-anchor=\"start\" x=\"826\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 10.5</text>\n<text text-anchor=\"start\" x=\"830.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.499</text>\n<text text-anchor=\"start\" x=\"816.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 1118</text>\n<text text-anchor=\"start\" x=\"810.5\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [540, 578]</text>\n<text text-anchor=\"start\" x=\"834\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 33&#45;&gt;41 -->\n<g id=\"edge21\" class=\"edge\">\n<title>33&#45;&gt;41</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M814.5294,-190.8796C820.6855,-182.0534 827.2451,-172.6485 833.5936,-163.5466\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"836.4967,-165.5024 839.3467,-155.2981 830.7552,-161.4978 836.4967,-165.5024\"/>\n</g>\n<!-- 35 -->\n<g id=\"node20\" class=\"node\">\n<title>35</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M672.5,-36C672.5,-36 642.5,-36 642.5,-36 636.5,-36 630.5,-30 630.5,-24 630.5,-24 630.5,-12 630.5,-12 630.5,-6 636.5,0 642.5,0 642.5,0 672.5,0 672.5,0 678.5,0 684.5,-6 684.5,-12 684.5,-12 684.5,-24 684.5,-24 684.5,-30 678.5,-36 672.5,-36\"/>\n<text text-anchor=\"middle\" x=\"657.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 34&#45;&gt;35 -->\n<g id=\"edge19\" class=\"edge\">\n<title>34&#45;&gt;35</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M691.922,-71.8901C686.0747,-62.7357 680.0318,-53.2752 674.6724,-44.8847\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"677.4994,-42.8086 669.1667,-36.2651 671.6001,-46.5767 677.4994,-42.8086\"/>\n</g>\n<!-- 38 -->\n<g id=\"node21\" class=\"node\">\n<title>38</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M744.5,-36C744.5,-36 714.5,-36 714.5,-36 708.5,-36 702.5,-30 702.5,-24 702.5,-24 702.5,-12 702.5,-12 702.5,-6 708.5,0 714.5,0 714.5,0 744.5,0 744.5,0 750.5,0 756.5,-6 756.5,-12 756.5,-12 756.5,-24 756.5,-24 756.5,-30 750.5,-36 744.5,-36\"/>\n<text text-anchor=\"middle\" x=\"729.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 34&#45;&gt;38 -->\n<g id=\"edge20\" class=\"edge\">\n<title>34&#45;&gt;38</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M723.2928,-71.8901C724.2911,-63.2227 725.3211,-54.2808 726.2481,-46.2325\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"729.7288,-46.5999 727.3962,-36.2651 722.7748,-45.7989 729.7288,-46.5999\"/>\n</g>\n<!-- 42 -->\n<g id=\"node23\" class=\"node\">\n<title>42</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M818.5,-36C818.5,-36 788.5,-36 788.5,-36 782.5,-36 776.5,-30 776.5,-24 776.5,-24 776.5,-12 776.5,-12 776.5,-6 782.5,0 788.5,0 788.5,0 818.5,0 818.5,0 824.5,0 830.5,-6 830.5,-12 830.5,-12 830.5,-24 830.5,-24 830.5,-30 824.5,-36 818.5,-36\"/>\n<text text-anchor=\"middle\" x=\"803.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 41&#45;&gt;42 -->\n<g id=\"edge22\" class=\"edge\">\n<title>41&#45;&gt;42</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M840.1791,-71.8901C833.8821,-62.6384 827.3723,-53.0739 821.6165,-44.6173\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"824.4518,-42.5626 815.9317,-36.2651 818.665,-46.5013 824.4518,-42.5626\"/>\n</g>\n<!-- 45 -->\n<g id=\"node24\" class=\"node\">\n<title>45</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M890.5,-36C890.5,-36 860.5,-36 860.5,-36 854.5,-36 848.5,-30 848.5,-24 848.5,-24 848.5,-12 848.5,-12 848.5,-6 854.5,0 860.5,0 860.5,0 890.5,0 890.5,0 896.5,0 902.5,-6 902.5,-12 902.5,-12 902.5,-24 902.5,-24 902.5,-30 896.5,-36 890.5,-36\"/>\n<text text-anchor=\"middle\" x=\"875.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 41&#45;&gt;45 -->\n<g id=\"edge23\" class=\"edge\">\n<title>41&#45;&gt;45</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M871.5499,-71.8901C872.1781,-63.3201 872.8259,-54.4817 873.4107,-46.5041\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"876.9207,-46.4942 874.1612,-36.2651 869.9394,-45.9825 876.9207,-46.4942\"/>\n</g>\n<!-- 49 -->\n<g id=\"node26\" class=\"node\">\n<title>49</title>\n<path fill=\"#b3d9f5\" stroke=\"#000000\" d=\"M1090,-155C1090,-155 965,-155 965,-155 959,-155 953,-149 953,-143 953,-143 953,-84 953,-84 953,-78 959,-72 965,-72 965,-72 1090,-72 1090,-72 1096,-72 1102,-78 1102,-84 1102,-84 1102,-143 1102,-143 1102,-149 1096,-155 1090,-155\"/>\n<text text-anchor=\"start\" x=\"989.5\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">month ≤ 8.5</text>\n<text text-anchor=\"start\" x=\"989.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.472</text>\n<text text-anchor=\"start\" x=\"971.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 10831</text>\n<text text-anchor=\"start\" x=\"961\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [4129, 6702]</text>\n<text text-anchor=\"start\" x=\"993\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 48&#45;&gt;49 -->\n<g id=\"edge25\" class=\"edge\">\n<title>48&#45;&gt;49</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1027.5,-190.8796C1027.5,-182.6838 1027.5,-173.9891 1027.5,-165.5013\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1031.0001,-165.298 1027.5,-155.2981 1024.0001,-165.2981 1031.0001,-165.298\"/>\n</g>\n<!-- 56 -->\n<g id=\"node29\" class=\"node\">\n<title>56</title>\n<path fill=\"#e8f3fc\" stroke=\"#000000\" d=\"M1273,-155C1273,-155 1132,-155 1132,-155 1126,-155 1120,-149 1120,-143 1120,-143 1120,-84 1120,-84 1120,-78 1126,-72 1132,-72 1132,-72 1273,-72 1273,-72 1279,-72 1285,-78 1285,-84 1285,-84 1285,-143 1285,-143 1285,-149 1279,-155 1273,-155\"/>\n<text text-anchor=\"start\" x=\"1158\" y=\"-139.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">year ≤ 2008.5</text>\n<text text-anchor=\"start\" x=\"1164.5\" y=\"-124.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.498</text>\n<text text-anchor=\"start\" x=\"1146.5\" y=\"-109.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 33283</text>\n<text text-anchor=\"start\" x=\"1128\" y=\"-94.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [15596, 17687]</text>\n<text text-anchor=\"start\" x=\"1168\" y=\"-79.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Up</text>\n</g>\n<!-- 48&#45;&gt;56 -->\n<g id=\"edge28\" class=\"edge\">\n<title>48&#45;&gt;56</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1088.7065,-190.8796C1102.9606,-181.1868 1118.241,-170.7961 1132.8307,-160.8752\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1135.0269,-163.6143 1141.3281,-155.0969 1131.0907,-157.8258 1135.0269,-163.6143\"/>\n</g>\n<!-- 50 -->\n<g id=\"node27\" class=\"node\">\n<title>50</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M1003.5,-36C1003.5,-36 973.5,-36 973.5,-36 967.5,-36 961.5,-30 961.5,-24 961.5,-24 961.5,-12 961.5,-12 961.5,-6 967.5,0 973.5,0 973.5,0 1003.5,0 1003.5,0 1009.5,0 1015.5,-6 1015.5,-12 1015.5,-12 1015.5,-24 1015.5,-24 1015.5,-30 1009.5,-36 1003.5,-36\"/>\n<text text-anchor=\"middle\" x=\"988.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 49&#45;&gt;50 -->\n<g id=\"edge26\" class=\"edge\">\n<title>49&#45;&gt;50</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1010.5075,-71.8901C1006.8884,-63.0279 1003.1521,-53.8788 999.8085,-45.6913\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1002.98,-44.1996 995.959,-36.2651 996.4995,-46.8461 1002.98,-44.1996\"/>\n</g>\n<!-- 53 -->\n<g id=\"node28\" class=\"node\">\n<title>53</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M1075.5,-36C1075.5,-36 1045.5,-36 1045.5,-36 1039.5,-36 1033.5,-30 1033.5,-24 1033.5,-24 1033.5,-12 1033.5,-12 1033.5,-6 1039.5,0 1045.5,0 1045.5,0 1075.5,0 1075.5,0 1081.5,0 1087.5,-6 1087.5,-12 1087.5,-12 1087.5,-24 1087.5,-24 1087.5,-30 1081.5,-36 1075.5,-36\"/>\n<text text-anchor=\"middle\" x=\"1060.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 49&#45;&gt;53 -->\n<g id=\"edge27\" class=\"edge\">\n<title>49&#45;&gt;53</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1041.8783,-71.8901C1044.907,-63.1253 1048.0326,-54.0798 1050.8379,-45.9615\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1054.2305,-46.8599 1054.1885,-36.2651 1047.6144,-44.5736 1054.2305,-46.8599\"/>\n</g>\n<!-- 57 -->\n<g id=\"node30\" class=\"node\">\n<title>57</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M1181.5,-36C1181.5,-36 1151.5,-36 1151.5,-36 1145.5,-36 1139.5,-30 1139.5,-24 1139.5,-24 1139.5,-12 1139.5,-12 1139.5,-6 1145.5,0 1151.5,0 1151.5,0 1181.5,0 1181.5,0 1187.5,0 1193.5,-6 1193.5,-12 1193.5,-12 1193.5,-24 1193.5,-24 1193.5,-30 1187.5,-36 1181.5,-36\"/>\n<text text-anchor=\"middle\" x=\"1166.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 56&#45;&gt;57 -->\n<g id=\"edge29\" class=\"edge\">\n<title>56&#45;&gt;57</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1186.8146,-71.8901C1183.4739,-63.0279 1180.025,-53.8788 1176.9386,-45.6913\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1180.1877,-44.3878 1173.3853,-36.2651 1173.6376,-46.8569 1180.1877,-44.3878\"/>\n</g>\n<!-- 60 -->\n<g id=\"node31\" class=\"node\">\n<title>60</title>\n<path fill=\"#c0c0c0\" stroke=\"#000000\" d=\"M1253.5,-36C1253.5,-36 1223.5,-36 1223.5,-36 1217.5,-36 1211.5,-30 1211.5,-24 1211.5,-24 1211.5,-12 1211.5,-12 1211.5,-6 1217.5,0 1223.5,0 1223.5,0 1253.5,0 1253.5,0 1259.5,0 1265.5,-6 1265.5,-12 1265.5,-12 1265.5,-24 1265.5,-24 1265.5,-30 1259.5,-36 1253.5,-36\"/>\n<text text-anchor=\"middle\" x=\"1238.5\" y=\"-14.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">(...)</text>\n</g>\n<!-- 56&#45;&gt;60 -->\n<g id=\"edge30\" class=\"edge\">\n<title>56&#45;&gt;60</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M1218.1854,-71.8901C1221.5261,-63.0279 1224.975,-53.8788 1228.0614,-45.6913\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"1231.3624,-46.8569 1231.6147,-36.2651 1224.8123,-44.3878 1231.3624,-46.8569\"/>\n</g>\n</g>\n</svg>\n"
},
"metadata": {},
"execution_count": 59
}
],
"source": [
"out_file = results_path / 'clf_tree.dot'\n",
"dot_data = export_graphviz(clf,\n",
" out_file=out_file.as_posix(),\n",
" feature_names=X.columns,\n",
" class_names=['Down', 'Up'],\n",
" max_depth=3,\n",
" filled=True,\n",
" rounded=True,\n",
" special_characters=True)\n",
"if out_file is not None:\n",
" dot_data = Path(out_file).read_text()\n",
"\n",
"graphviz.Source(dot_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-0hmTMQoZDwS"
},
"source": [
"### Evaluate Test Set"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4jlpD99RZDwS"
},
"source": [
"To evaluate the predictive accuracy of our first classification tree, we will use our test set to generate predicted class probabilities. \n",
"\n",
"The `.predict_proba()` method produces one probability for each class. In the binary class, these probabilities are complementary and sum to 1, so we only need the value for the positive class. "
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.270863Z",
"start_time": "2021-04-16T00:33:36.262990Z"
},
"id": "eV1ZsGsKZDwS"
},
"outputs": [],
"source": [
"y_score = clf.predict_proba(X=X_test)[:, 1]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bpr6yEe1ZDwT"
},
"source": [
"To evaluate the generalization error, we will use the area under the curve based on the receiver-operating characteristic that we introduced in Chapter 6, The Machine Learning Process. The result indicates a significant improvement above and beyond the baseline value of 0.5 for a random prediction:"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.283613Z",
"start_time": "2021-04-16T00:33:36.272140Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oPQElmo4ZDwT",
"outputId": "c390cc55-f25a-4548-94f4-8c93183a7dd3"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.6162304070004103"
]
},
"metadata": {},
"execution_count": 61
}
],
"source": [
"roc_auc_score(y_score=y_score, y_true=y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_--mqdElZDwU"
},
"source": [
"### Print Decision Path"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X8fEv6WEZDwU"
},
"source": [
"Inspired by https://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.294116Z",
"start_time": "2021-04-16T00:33:36.284877Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "84x4UWlJZDwU",
"outputId": "c94b2c21-c47d-4a73-d671-b6d20d1f3dcc"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Help on class Tree in module sklearn.tree._tree:\n",
"\n",
"class Tree(builtins.object)\n",
" | Array-based representation of a binary decision tree.\n",
" | \n",
" | The binary tree is represented as a number of parallel arrays. The i-th\n",
" | element of each array holds information about the node `i`. Node 0 is the\n",
" | tree's root. You can find a detailed description of all arrays in\n",
" | `_tree.pxd`. NOTE: Some of the arrays only apply to either leaves or split\n",
" | nodes, resp. In this case the values of nodes of the other type are\n",
" | arbitrary!\n",
" | \n",
" | Attributes\n",
" | ----------\n",
" | node_count : int\n",
" | The number of nodes (internal nodes + leaves) in the tree.\n",
" | \n",
" | capacity : int\n",
" | The current capacity (i.e., size) of the arrays, which is at least as\n",
" | great as `node_count`.\n",
" | \n",
" | max_depth : int\n",
" | The depth of the tree, i.e. the maximum depth of its leaves.\n",
" | \n",
" | children_left : array of int, shape [node_count]\n",
" | children_left[i] holds the node id of the left child of node i.\n",
" | For leaves, children_left[i] == TREE_LEAF. Otherwise,\n",
" | children_left[i] > i. This child handles the case where\n",
" | X[:, feature[i]] <= threshold[i].\n",
" | \n",
" | children_right : array of int, shape [node_count]\n",
" | children_right[i] holds the node id of the right child of node i.\n",
" | For leaves, children_right[i] == TREE_LEAF. Otherwise,\n",
" | children_right[i] > i. This child handles the case where\n",
" | X[:, feature[i]] > threshold[i].\n",
" | \n",
" | feature : array of int, shape [node_count]\n",
" | feature[i] holds the feature to split on, for the internal node i.\n",
" | \n",
" | threshold : array of double, shape [node_count]\n",
" | threshold[i] holds the threshold for the internal node i.\n",
" | \n",
" | value : array of double, shape [node_count, n_outputs, max_n_classes]\n",
" | Contains the constant prediction value of each node.\n",
" | \n",
" | impurity : array of double, shape [node_count]\n",
" | impurity[i] holds the impurity (i.e., the value of the splitting\n",
" | criterion) at node i.\n",
" | \n",
" | n_node_samples : array of int, shape [node_count]\n",
" | n_node_samples[i] holds the number of training samples reaching node i.\n",
" | \n",
" | weighted_n_node_samples : array of int, shape [node_count]\n",
" | weighted_n_node_samples[i] holds the weighted number of training samples\n",
" | reaching node i.\n",
" | \n",
" | Methods defined here:\n",
" | \n",
" | __getstate__(...)\n",
" | Getstate re-implementation, for pickling.\n",
" | \n",
" | __reduce__(...)\n",
" | Reduce re-implementation, for pickling.\n",
" | \n",
" | __setstate__(...)\n",
" | Setstate re-implementation, for unpickling.\n",
" | \n",
" | apply(...)\n",
" | Finds the terminal region (=leaf node) for each sample in X.\n",
" | \n",
" | compute_feature_importances(...)\n",
" | Computes the importance of each feature (aka variable).\n",
" | \n",
" | compute_partial_dependence(...)\n",
" | Partial dependence of the response on the ``target_feature`` set.\n",
" | \n",
" | For each sample in ``X`` a tree traversal is performed.\n",
" | Each traversal starts from the root with weight 1.0.\n",
" | \n",
" | At each non-leaf node that splits on a target feature, either\n",
" | the left child or the right child is visited based on the feature\n",
" | value of the current sample, and the weight is not modified.\n",
" | At each non-leaf node that splits on a complementary feature,\n",
" | both children are visited and the weight is multiplied by the fraction\n",
" | of training samples which went to each child.\n",
" | \n",
" | At each leaf, the value of the node is multiplied by the current\n",
" | weight (weights sum to 1 for all visited terminal nodes).\n",
" | \n",
" | Parameters\n",
" | ----------\n",
" | X : view on 2d ndarray, shape (n_samples, n_target_features)\n",
" | The grid points on which the partial dependence should be\n",
" | evaluated.\n",
" | target_features : view on 1d ndarray, shape (n_target_features)\n",
" | The set of target features for which the partial dependence\n",
" | should be evaluated.\n",
" | out : view on 1d ndarray, shape (n_samples)\n",
" | The value of the partial dependence function on each grid\n",
" | point.\n",
" | \n",
" | decision_path(...)\n",
" | Finds the decision path (=node) for each sample in X.\n",
" | \n",
" | predict(...)\n",
" | Predict target for X.\n",
" | \n",
" | ----------------------------------------------------------------------\n",
" | Static methods defined here:\n",
" | \n",
" | __new__(*args, **kwargs) from builtins.type\n",
" | Create and return a new object. See help(type) for accurate signature.\n",
" | \n",
" | ----------------------------------------------------------------------\n",
" | Data descriptors defined here:\n",
" | \n",
" | capacity\n",
" | \n",
" | children_left\n",
" | \n",
" | children_right\n",
" | \n",
" | feature\n",
" | \n",
" | impurity\n",
" | \n",
" | max_depth\n",
" | \n",
" | max_n_classes\n",
" | \n",
" | n_classes\n",
" | \n",
" | n_features\n",
" | \n",
" | n_leaves\n",
" | \n",
" | n_node_samples\n",
" | \n",
" | n_outputs\n",
" | \n",
" | node_count\n",
" | \n",
" | threshold\n",
" | \n",
" | value\n",
" | \n",
" | weighted_n_node_samples\n",
" | \n",
" | ----------------------------------------------------------------------\n",
" | Data and other attributes defined here:\n",
" | \n",
" | __pyx_vtable__ = <capsule object NULL>\n",
"\n"
]
}
],
"source": [
"from sklearn.tree._tree import Tree\n",
"help(Tree)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.302800Z",
"start_time": "2021-04-16T00:33:36.295041Z"
},
"id": "c5vTXcTwZDwV"
},
"outputs": [],
"source": [
"def tree_to_code(tree, feature_names):\n",
" if isinstance(tree, DecisionTreeClassifier):\n",
" model = 'clf'\n",
" elif isinstance(tree, DecisionTreeRegressor):\n",
" model = 'reg'\n",
" else:\n",
" raise ValueError('Need Regression or Classification Tree')\n",
" \n",
" tree_ = tree.tree_\n",
" feature_name = [\n",
" feature_names[i] if i != _tree.TREE_UNDEFINED else \"undefined!\"\n",
" for i in tree_.feature\n",
" ]\n",
" print(\"def tree({}):\".format(\", \".join(feature_names)))\n",
"\n",
" def recurse(node, depth):\n",
" indent = \" \" * depth\n",
" if tree_.feature[node] != _tree.TREE_UNDEFINED:\n",
" name = feature_name[node]\n",
" threshold = tree_.threshold[node]\n",
" print(indent, f'if {name} <= {threshold:.2%}')\n",
" recurse(tree_.children_left[node], depth + 1)\n",
" print(indent, f'else: # if {name} > {threshold:.2%}')\n",
" recurse(tree_.children_right[node], depth + 1)\n",
" else:\n",
" pred = tree_.value[node][0]\n",
" val = pred[1]/sum(pred) if model == 'clf' else pred[0]\n",
" print(indent, f'return {val:.2%}')\n",
" recurse(0, 1)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.322605Z",
"start_time": "2021-04-16T00:33:36.304005Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "s-KuRzTVZDwW",
"outputId": "6cfc838a-3658-4dc9-de72-e1355d93097c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"def tree(t-1, t-2):\n",
" if t-1 <= -14.64%\n",
" if t-2 <= -17.93%\n",
" return 55.84%\n",
" else: # if t-2 > -17.93%\n",
" if t-1 <= -26.16%\n",
" return 41.08%\n",
" else: # if t-1 > -26.16%\n",
" if t-1 <= -26.14%\n",
" return 100.00%\n",
" else: # if t-1 > -26.14%\n",
" if t-1 <= -14.66%\n",
" return 49.09%\n",
" else: # if t-1 > -14.66%\n",
" return 0.00%\n",
" else: # if t-1 > -14.64%\n",
" if t-2 <= 18.14%\n",
" if t-2 <= -14.83%\n",
" if t-2 <= -14.85%\n",
" if t-1 <= -14.31%\n",
" return 77.78%\n",
" else: # if t-1 > -14.31%\n",
" return 52.24%\n",
" else: # if t-2 > -14.85%\n",
" return 0.00%\n",
" else: # if t-2 > -14.83%\n",
" if t-1 <= -6.59%\n",
" if t-2 <= -6.04%\n",
" return 58.58%\n",
" else: # if t-2 > -6.04%\n",
" return 53.78%\n",
" else: # if t-1 > -6.59%\n",
" if t-1 <= 2.68%\n",
" return 58.02%\n",
" else: # if t-1 > 2.68%\n",
" return 56.44%\n",
" else: # if t-2 > 18.14%\n",
" if t-1 <= 14.36%\n",
" if t-1 <= -4.91%\n",
" return 57.74%\n",
" else: # if t-1 > -4.91%\n",
" if t-1 <= -1.17%\n",
" return 45.69%\n",
" else: # if t-1 > -1.17%\n",
" return 52.47%\n",
" else: # if t-1 > 14.36%\n",
" return 41.73%\n"
]
}
],
"source": [
"tree_to_code(clf_tree_t2, X2.columns)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cVIK-BGkZDwX"
},
"source": [
"## Overfitting, Regularization & Parameter Tuning"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BssQtPn0ZDwX"
},
"source": [
"Decision trees have a strong tendency to overfit, especially when a dataset has a large number of features relative to the number of samples. As discussed in previous chapters, overfitting increases the prediction error because the model does not only learn the signal contained in the training data, but also the noise.\n",
"There are several ways to address the risk of overfitting."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X-fsFNPaZDwY"
},
"source": [
"Decision trees provide several regularization hyperparameters to limit the growth of a tree and the associated complexity. While every split increases the number of nodes, it also reduces the number of samples available per node to support a prediction. For each additional level, twice the number of samples is needed to populate the new nodes with the same sample density. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mv1DsvxJZDwY"
},
"source": [
"### Decision Tree Parameters"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WNN7fetVZDwZ"
},
"source": [
"The following table lists key parameters available for this purpose in the sklearn decision tree implementation. After introducing the most important parameters, we will illustrate how to use cross-validation to optimize the hyperparameter settings with respect to the bias-variance tradeoff and lower prediction errors:"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ONSoSPnAZDwZ"
},
"source": [
"| Parameter | Default | Options | Description |\n",
"|--------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n",
"| criterion | gini | Regression: MSE, MAE Classification: Gini impurity, Cross Entropy | Metric to evaluate split quality. |\n",
"| splitter | best | best, random | How to choose the split at each node. Supported strategies are “best” to choose the best split and “random” to choose the best random split. |\n",
"| max_depth | None | int | Max # of levels in tree. Split nodes until max_depth is reached or all leaves are pure or all leaves contain less than min_samples_split samples. |\n",
"| max_features | None | None: max_features=n_features; int; float (fraction): int(max_features * n_features) auto, sqrt: max_features=sqrt(n_features). log2: max_features=log2(n_features). | # of features to consider when evaluating split |\n",
"| max_leaf_nodes | None | None: unlimited # of leaf nodes int | Continue to split nodes that reduce relative impurity the most until reaching max_leaf_nodes. |\n",
"| min_impurity_decrease | 0 | float | Split node if impurity decreases by at least this value. |\n",
"| min_samples_leaf | 1 | int; float (as percentage of N) | Minimum # of samples to be at a leaf node. A split will only be considered if there are at least min_samples_leaf training samples in each of the left and right branches. May smoothen the model, esp. for regression. |\n",
"| min_samples_split | 2 | int; float (as percentage of N) | The minimum number of samples required to split an internal node: |\n",
"| min_weight_fraction_leaf | 0 | NA | The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided (in fit method). |\n",
"| class_weight | None | balanced: inversely proportional to class frequencies dict: {class_label: weight} list of dicts (for multi-output) | Weights associated with classes |"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wv07vR34ZDwZ"
},
"source": [
"The `max_depth` parameter imposes a hard limit on the number of consecutive splits and represents the most straightforward way to cap the growth of a tree.\n",
"\n",
"The m`in_samples_split` and `min_samples_leaf` parameters are alternative, data-driven ways to limit the growth of a tree. Rather than imposing a hard limit on the number of consecutive splits, these parameters control the minimum number of samples required to further split the data. The latter guarantees a certain number of samples per leaf, while the former can create very small leaves if a split results in a very uneven distribution. Small parameter values facilitate overfitting, while a high number may prevent the tree from learning the signal in the data. \n",
"\n",
"The default values are often quite low, and you should use cross-validation to explore a range of potential values. You can also use a float to indicate a percentage as opposed to an absolute number. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RUymy9VaZDwa"
},
"source": [
"### Cross-validation parameters"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.328359Z",
"start_time": "2021-04-16T00:33:36.323742Z"
},
"id": "2zX80RQeZDwb"
},
"outputs": [],
"source": [
"n_splits = 10\n",
"train_period_length = 60\n",
"test_period_length = 6\n",
"lookahead = 1\n",
"\n",
"cv = MultipleTimeSeriesCV(n_splits=n_splits,\n",
" train_period_length=train_period_length,\n",
" test_period_length=test_period_length,\n",
" lookahead=lookahead)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.334074Z",
"start_time": "2021-04-16T00:33:36.329663Z"
},
"id": "ExDruCerZDwb"
},
"outputs": [],
"source": [
"max_depths = range(1, 16)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RKocRO8QZDwb"
},
"source": [
"### Finding the best trees using GridSearchCV"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3W0vg_3YZDwb"
},
"source": [
"scikit-learn provides a method to define ranges of values for multiple hyperparameters. It automates the process of cross-validating the various combinations of these parameter values to identify the optimal configuration. Let's walk through the process of automatically tuning your model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WzwvPwWPZDwc"
},
"source": [
"#### Define parameter grid"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kx8XKP6IZDwc"
},
"source": [
"The first step is to define a dictionary where the keywords name the hyperparameters, and the values list the parameter settings to be tested:"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.346973Z",
"start_time": "2021-04-16T00:33:36.335032Z"
},
"id": "TQyjkhhzZDwc"
},
"outputs": [],
"source": [
"param_grid = {'max_depth': [2, 3, 4, 5, 6, 7, 8, 10, 12, 15],\n",
" 'min_samples_leaf': [5, 25, 50, 100],\n",
" 'max_features': ['sqrt', 'auto']}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vF1UJqRVZDwc"
},
"source": [
"#### Classification Tree"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IKLHcTjbZDwe"
},
"source": [
"Then, we instantiate a model object:"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.354743Z",
"start_time": "2021-04-16T00:33:36.348117Z"
},
"id": "p97Cb6ZiZDwe"
},
"outputs": [],
"source": [
"clf = DecisionTreeClassifier(random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yiZUlKNeZDwf"
},
"source": [
"Now we instantiate the GridSearchCV object, providing the estimator object and parameter grid, as well as a scoring method and cross-validation choice to the initialization method. We'll use an object of our custom OneStepTimeSeriesSplit class, initialized to use ten folds for the cv parameter, and set the scoring to the roc_auc metric. We can parallelize the search using the n_jobs parameter and automatically obtain a trained model that uses the optimal hyperparameters by setting `refit=True`."
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:36.363000Z",
"start_time": "2021-04-16T00:33:36.355848Z"
},
"id": "FzPz0zxJZDwf"
},
"outputs": [],
"source": [
"gridsearch_clf = GridSearchCV(estimator=clf,\n",
" param_grid=param_grid,\n",
" scoring='roc_auc',\n",
" n_jobs=-1,\n",
" cv=cv,\n",
" refit=True,\n",
" return_train_score=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MK7s4jo5ZDwg"
},
"source": [
"With all settings in place, we can fit GridSearchCV just like any other model:"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:53.853282Z",
"start_time": "2021-04-16T00:33:36.364001Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cnBsuzLKZDwg",
"outputId": "29cfa149-9eae-4c5e-f95f-0818109f36c1"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GridSearchCV(cv=<__main__.MultipleTimeSeriesCV object at 0x7f2b15753450>,\n",
" estimator=DecisionTreeClassifier(random_state=42), n_jobs=-1,\n",
" param_grid={'max_depth': [2, 3, 4, 5, 6, 7, 8, 10, 12, 15],\n",
" 'max_features': ['sqrt', 'auto'],\n",
" 'min_samples_leaf': [5, 25, 50, 100]},\n",
" return_train_score=True, scoring='roc_auc')"
]
},
"metadata": {},
"execution_count": 70
}
],
"source": [
"gridsearch_clf.fit(X=X, y=y_binary)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d6B9swt8ZDwh"
},
"source": [
"The training process produces some new attributes for our GridSearchCV object, most importantly the information about the optimal settings and the best cross-validation score (now using the proper setup that avoids lookahead bias).\n",
"\n",
"Setting `max_depth` to 10, `min_samples_leaf` to 750, and randomly selecting only a number corresponding to the square root of the total number of features when deciding on a split, produces the best results, with an AUC of 0.532:"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:53.857480Z",
"start_time": "2021-04-16T00:33:53.854544Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Xaz7n88kZDwh",
"outputId": "75c1331a-aaa1-4700-dd25-35a592bf3ea2"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'max_depth': 15, 'max_features': 'sqrt', 'min_samples_leaf': 50}"
]
},
"metadata": {},
"execution_count": 71
}
],
"source": [
"gridsearch_clf.best_params_"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:53.871480Z",
"start_time": "2021-04-16T00:33:53.858438Z"
},
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pjgB02nGZDwi",
"outputId": "7ee78230-35cd-4a10-930c-96ea80572db9"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.5199693215653027"
]
},
"metadata": {},
"execution_count": 72
}
],
"source": [
"gridsearch_clf.best_score_"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s966SYVBZDwj"
},
"source": [
"#### Define Custom IC score"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:53.880255Z",
"start_time": "2021-04-16T00:33:53.872541Z"
},
"id": "SR4z_pu5ZDwj"
},
"outputs": [],
"source": [
"def rank_correl(y, y_pred):\n",
" return spearmanr(y, y_pred)[0]\n",
"ic = make_scorer(rank_correl)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M4r-sQpMZDwk"
},
"source": [
"#### Regression Tree"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:53.888614Z",
"start_time": "2021-04-16T00:33:53.881766Z"
},
"id": "27Umbp3WZDwk"
},
"outputs": [],
"source": [
"reg_tree = DecisionTreeRegressor(random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:33:53.896846Z",
"start_time": "2021-04-16T00:33:53.889580Z"
},
"id": "eC5p2I21ZDwk"
},
"outputs": [],
"source": [
"gridsearch_reg = GridSearchCV(estimator=reg_tree,\n",
" param_grid=param_grid,\n",
" scoring=ic,\n",
" n_jobs=-1,\n",
" cv=cv,\n",
" refit=True,\n",
" return_train_score=True)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:34:33.328283Z",
"start_time": "2021-04-16T00:33:53.897801Z"
},
"id": "Z7YZ73s7ZDwl",
"outputId": "08499f8c-0b82-4b24-9943-4ab9f09f344d",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GridSearchCV(cv=<__main__.MultipleTimeSeriesCV object at 0x7f2b15753450>,\n",
" estimator=DecisionTreeRegressor(random_state=42), n_jobs=-1,\n",
" param_grid={'max_depth': [2, 3, 4, 5, 6, 7, 8, 10, 12, 15],\n",
" 'max_features': ['sqrt', 'auto'],\n",
" 'min_samples_leaf': [5, 25, 50, 100]},\n",
" return_train_score=True, scoring=make_scorer(rank_correl))"
]
},
"metadata": {},
"execution_count": 76
}
],
"source": [
"gridsearch_reg.fit(X=X, y=y)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:34:33.340010Z",
"start_time": "2021-04-16T00:34:33.333245Z"
},
"id": "x12D5jeuZDwl",
"outputId": "9cd40dcc-87bb-4721-a544-b93cd83fb501",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'max_depth': 10, 'max_features': 'sqrt', 'min_samples_leaf': 5}"
]
},
"metadata": {},
"execution_count": 77
}
],
"source": [
"gridsearch_reg.best_params_"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:34:33.351382Z",
"start_time": "2021-04-16T00:34:33.341496Z"
},
"id": "CghDOUptZDwl",
"outputId": "a00ec3a9-48be-46a0-b7bb-badf7f6fe002",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.04896325082800322"
]
},
"metadata": {},
"execution_count": 78
}
],
"source": [
"gridsearch_reg.best_score_"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:34:33.363598Z",
"start_time": "2021-04-16T00:34:33.352847Z"
},
"id": "aDHLVtlOZDwl",
"outputId": "a96713ae-b08f-42b3-fb7a-cf964a9a1ea6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 143
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Regression Classification\n",
"max_depth 10 15\n",
"max_features sqrt sqrt\n",
"min_samples_leaf 5 50"
],
"text/html": [
"\n",
" <div id=\"df-be62d867-f8d3-4931-96d1-4bef06b7d765\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Regression</th>\n",
" <th>Classification</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>max_depth</th>\n",
" <td>10</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max_features</th>\n",
" <td>sqrt</td>\n",
" <td>sqrt</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min_samples_leaf</th>\n",
" <td>5</td>\n",
" <td>50</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-be62d867-f8d3-4931-96d1-4bef06b7d765')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-be62d867-f8d3-4931-96d1-4bef06b7d765 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-be62d867-f8d3-4931-96d1-4bef06b7d765');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 79
}
],
"source": [
"pd.DataFrame({'Regression': pd.Series(gridsearch_reg.best_params_),\n",
" 'Classification': pd.Series(gridsearch_clf.best_params_)})"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aMqsdUhKZDwm"
},
"source": [
"### Classifier Cross-Validation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W8afUvm2ZDwm"
},
"source": [
"Cross-validation is the most important tool to obtain an unbiased estimate of the generalization error, which in turn permits an informed choice among the various configuration options. sklearn offers several tools to facilitate the process of cross-validating numerous parameter settings, namely the GridSearchCV convenience class that we will illustrate in the next section. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K-muYhTPZDwn"
},
"source": [
"The following code illustrates how to run cross-validation more manually to obtain custom tree attributes, such as the total number of nodes or leaf nodes associated with certain hyperparameter settings. \n",
"\n",
"The following function accesses the internal `.tree_` attribute to retrieve information about the total node count, and how many of these nodes are leaf nodes:"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:34:33.369594Z",
"start_time": "2021-04-16T00:34:33.364452Z"
},
"id": "O-vkLpnpZDwn"
},
"outputs": [],
"source": [
"def get_leaves_count(tree):\n",
" t = tree.tree_\n",
" n = t.node_count\n",
" leaves = len([i for i in range(t.node_count) if t.children_left[i]== -1])\n",
" return leaves"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GP0fPPiTZDwn"
},
"source": [
"We can combine this information with the train and test scores to gain detailed knowledge about the model behavior throughout the cross-validation process, as follows:"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:34:46.896130Z",
"start_time": "2021-04-16T00:34:33.377385Z"
},
"scrolled": true,
"id": "jx_WuOJPZDwn",
"outputId": "4d791d7d-bfa6-437c-d751-60be3ea599ad",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 "
]
}
],
"source": [
"train_scores, val_scores, leaves = {}, {}, {}\n",
"for max_depth in max_depths:\n",
" print(max_depth, end=' ', flush=True)\n",
" clf = DecisionTreeClassifier(criterion='gini', \n",
" max_depth=max_depth,\n",
" min_samples_leaf=5,\n",
" max_features='sqrt',\n",
" random_state=42)\n",
" train_scores[max_depth], val_scores[max_depth], leaves[max_depth] = [], [], []\n",
" for train_idx, test_idx in cv.split(X):\n",
" X_train, y_train, = X.iloc[train_idx], y_binary.iloc[train_idx]\n",
" X_test, y_test = X.iloc[test_idx], y_binary.iloc[test_idx]\n",
" clf.fit(X=X_train, y=y_train)\n",
"\n",
" train_pred = clf.predict_proba(X=X_train)[:, 1]\n",
" train_score = roc_auc_score(y_score=train_pred, y_true=y_train)\n",
" train_scores[max_depth].append(train_score)\n",
"\n",
" test_pred = clf.predict_proba(X=X_test)[:, 1]\n",
" val_score = roc_auc_score(y_score=test_pred, y_true=y_test)\n",
" val_scores[max_depth].append(val_score) \n",
" leaves[max_depth].append(get_leaves_count(clf))\n",
" \n",
"clf_train_scores = pd.DataFrame(train_scores)\n",
"clf_valid_scores = pd.DataFrame(val_scores)\n",
"clf_leaves = pd.DataFrame(leaves)"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:34:46.906671Z",
"start_time": "2021-04-16T00:34:46.898503Z"
},
"id": "Sk_LpmFRZDwo"
},
"outputs": [],
"source": [
"clf_cv_data = pd.concat([pd.melt(clf_train_scores,\n",
" var_name='Max. Depth',\n",
" value_name='ROC AUC').assign(Data='Train'),\n",
" pd.melt(clf_valid_scores,\n",
" var_name='Max. Depth',\n",
" value_name='ROC AUC').assign(Data='Valid')])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vq4tWo38ZDwo"
},
"source": [
"### Regression tree cross-validation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iRhxlR1wZDwo"
},
"source": [
"#### Run cross-validation"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:34:58.712044Z",
"start_time": "2021-04-16T00:34:46.907862Z"
},
"id": "yzjFnePjZDwo",
"outputId": "4e6a4e89-48b3-4531-cb92-9c9317a61ecf",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 "
]
}
],
"source": [
"train_scores, val_scores, leaves = {}, {}, {}\n",
"for max_depth in max_depths:\n",
" print(max_depth, end=' ', flush=True)\n",
" reg_tree = DecisionTreeRegressor(max_depth=max_depth,\n",
" min_samples_leaf=50,\n",
" max_features= 'sqrt',\n",
" random_state=42)\n",
" train_scores[max_depth], val_scores[max_depth], leaves[max_depth] = [], [], []\n",
" for train_idx, test_idx in cv.split(X):\n",
" X_train, y_train, = X.iloc[train_idx], y.iloc[train_idx]\n",
" X_test, y_test = X.iloc[test_idx], y.iloc[test_idx]\n",
" reg_tree.fit(X=X_train, y=y_train)\n",
"\n",
" train_pred = reg_tree.predict(X=X_train)\n",
" train_score = spearmanr(train_pred, y_train)[0]\n",
" train_scores[max_depth].append(train_score)\n",
"\n",
" test_pred = reg_tree.predict(X=X_test)\n",
" val_score = spearmanr(test_pred, y_test)[0]\n",
" val_scores[max_depth].append(val_score)\n",
" leaves[max_depth].append(get_leaves_count(reg_tree))\n",
"\n",
"reg_train_scores = pd.DataFrame(train_scores)\n",
"reg_valid_scores = pd.DataFrame(val_scores)\n",
"reg_leaves = pd.DataFrame(leaves)"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:34:58.719502Z",
"start_time": "2021-04-16T00:34:58.712895Z"
},
"id": "W17AybwRZDwp"
},
"outputs": [],
"source": [
"reg_cv_data = (pd.melt(reg_train_scores, var_name='Max. Depth',\n",
" value_name='IC').assign(Data='Train').append(\n",
" pd.melt(reg_valid_scores,\n",
" var_name='Max. Depth',\n",
" value_name='IC').assign(Data='Valid')))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6o43CQAeZDwp"
},
"source": [
"### Compare CV Results"
]
},
{
"cell_type": "markdown",
"source": [
"There is a problem in this piece of code, scroll down to find the solution. "
],
"metadata": {
"id": "TrrcqN3I0r0J"
}
},
{
"cell_type": "code",
"source": [
"fig, axes = plt.subplots(ncols=2, figsize=(14, 5))\n",
"\n",
"sns.lineplot(data=reg_cv_data,\n",
" x='Max. Depth', y='IC',\n",
" hue='Data', ci=95,\n",
" ax=axes[0], lw=2)\n",
"\n",
"axes[0].set_title('Regression Tree')\n",
"axes[0].axvline(x=reg_valid_scores.mean().idxmax(), ls='--', c='k', lw=1)\n",
"axes[0].axhline(y=0, ls='--', c='k', lw=1)\n",
"\n",
"sns.lineplot(data=clf_cv_data,\n",
" x='Max. Depth', y='ROC AUC',\n",
" hue='Data', ci=95,\n",
" ax=axes[1], lw=2)\n",
"\n",
"axes[1].set_title('Classification Tree')\n",
"axes[1].axvline(x=clf_valid_scores.mean().idxmax(), ls='--', c='k', lw=1)\n",
"axes[1].axhline(y=.5, ls='--', c='k', lw=1)\n",
"for ax in axes:\n",
" ax.set_xlim(min(param_grid['max_depth']),\n",
" max(param_grid['max_depth']))\n",
"\n",
"fig.suptitle(f'Train-Validation Scores', fontsize=14)\n",
"sns.despine()\n",
"fig.tight_layout()\n",
"fig.subplots_adjust(top=.91)"
],
"metadata": {
"id": "F2dzOzK0ydLZ",
"outputId": "8bbe2e3f-0b7b-476b-f3f7-4cf244b588d1",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 697
}
},
"execution_count": 95,
"outputs": [
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-95-d48a31ad5f61>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Max. Depth'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'IC'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mhue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Data'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mci\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m95\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m ax=axes[0], lw=2)\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0maxes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_title\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Regression Tree'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/seaborn/_decorators.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 45\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 47\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/seaborn/relational.py\u001b[0m in \u001b[0;36mlineplot\u001b[0;34m(x, y, hue, size, style, data, palette, hue_order, hue_norm, sizes, size_order, size_norm, dashes, markers, style_order, units, estimator, ci, n_boot, seed, sort, err_style, err_kws, legend, ax, **kwargs)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_attach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0max\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 709\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 710\u001b[0;31m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0max\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 711\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 712\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/seaborn/relational.py\u001b[0m in \u001b[0;36mplot\u001b[0;34m(self, ax, kws)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[0;31m# Loop over the semantic subsets and add to the plot\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 470\u001b[0m \u001b[0mgrouping_vars\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"hue\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"size\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"style\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 471\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0msub_vars\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msub_data\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miter_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrouping_vars\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfrom_comp_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 472\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msort\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/seaborn/_core.py\u001b[0m in \u001b[0;36miter_data\u001b[0;34m(self, grouping_vars, reverse, from_comp_data)\u001b[0m\n\u001b[1;32m 981\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 982\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfrom_comp_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 983\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomp_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 984\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 985\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot_data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/seaborn/_core.py\u001b[0m in \u001b[0;36mcomp_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1055\u001b[0m \u001b[0morig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mvar\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropna\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1056\u001b[0m \u001b[0mcomp_col\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSeries\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvar\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1057\u001b[0;31m \u001b[0mcomp_col\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0morig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_numeric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_units\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1058\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1059\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_scale\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"log\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m__setitem__\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m 721\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 722\u001b[0m \u001b[0miloc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"iloc\"\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 723\u001b[0;31m \u001b[0miloc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setitem_with_indexer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 724\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_validate_key\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_setitem_with_indexer\u001b[0;34m(self, indexer, value, name)\u001b[0m\n\u001b[1;32m 1730\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setitem_with_indexer_split_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1731\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1732\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_setitem_single_block\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1733\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_setitem_with_indexer_split_path\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_setitem_single_block\u001b[0;34m(self, indexer, value, name)\u001b[0m\n\u001b[1;32m 1957\u001b[0m \u001b[0;31m# setting for extensionarrays that store dicts. Need to decide\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1958\u001b[0m \u001b[0;31m# if it's worth supporting that.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1959\u001b[0;31m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_align_series\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSeries\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1960\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1961\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mABCDataFrame\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m\"iloc\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_align_series\u001b[0;34m(self, indexer, ser, multiindex_indexer)\u001b[0m\n\u001b[1;32m 2094\u001b[0m \u001b[0;31m# series, so need to broadcast (see GH5206)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2095\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msum_aligners\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_sequence\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindexer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2096\u001b[0;31m \u001b[0mser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreindex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2097\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2098\u001b[0m \u001b[0;31m# single indexer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/series.py\u001b[0m in \u001b[0;36mreindex\u001b[0;34m(self, index, **kwargs)\u001b[0m\n\u001b[1;32m 4578\u001b[0m )\n\u001b[1;32m 4579\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mreindex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4580\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreindex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4581\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4582\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mdeprecate_nonkeyword_arguments\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mversion\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallowed_args\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"self\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"labels\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36mreindex\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 4817\u001b[0m \u001b[0;31m# perform the reindex on the axes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4818\u001b[0m return self._reindex_axes(\n\u001b[0;32m-> 4819\u001b[0;31m \u001b[0maxes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlimit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtolerance\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfill_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4820\u001b[0m ).__finalize__(self, method=\"reindex\")\n\u001b[1;32m 4821\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m_reindex_axes\u001b[0;34m(self, axes, level, limit, tolerance, method, fill_value, copy)\u001b[0m\n\u001b[1;32m 4841\u001b[0m \u001b[0mfill_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfill_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4842\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4843\u001b[0;31m \u001b[0mallow_dups\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4844\u001b[0m )\n\u001b[1;32m 4845\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m_reindex_with_indexers\u001b[0;34m(self, reindexers, fill_value, copy, allow_dups)\u001b[0m\n\u001b[1;32m 4887\u001b[0m \u001b[0mfill_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfill_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4888\u001b[0m \u001b[0mallow_dups\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mallow_dups\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4889\u001b[0;31m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4890\u001b[0m )\n\u001b[1;32m 4891\u001b[0m \u001b[0;31m# If we've made a copy once, no need to make another one\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/internals/managers.py\u001b[0m in \u001b[0;36mreindex_indexer\u001b[0;34m(self, new_axis, indexer, axis, fill_value, allow_dups, copy, consolidate, only_slice)\u001b[0m\n\u001b[1;32m 668\u001b[0m \u001b[0;31m# some axes don't allow reindexing with dups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mallow_dups\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 670\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_can_reindex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 671\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 672\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/pandas/core/indexes/base.py\u001b[0m in \u001b[0;36m_validate_can_reindex\u001b[0;34m(self, indexer)\u001b[0m\n\u001b[1;32m 3783\u001b[0m \u001b[0;31m# trying to reindex on an axis with duplicates\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3784\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_index_as_unique\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3785\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cannot reindex from a duplicate axis\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3786\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3787\u001b[0m def reindex(\n",
"\u001b[0;31mValueError\u001b[0m: cannot reindex from a duplicate axis"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1008x360 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Notice there was a problem in this code above. Let me show you how I fixed it. I saw the error message \n",
"\n",
"```\n",
"ValueError: cannot reindex from a duplicate axis\n",
"```\n",
"\n",
"Then I googled it in quotes \"ValueError: cannot reindex from a duplicate axis\" with the piece of code \"sns.lineplot\", and I found this post. \n",
"![image.png]()\n",
"https://github.com/mwaskom/seaborn/issues/2515\n",
"\n",
"After that, I realized I had to add `.reset_index()` behind the data.\n"
],
"metadata": {
"id": "dl4fZ5D2zv-Z"
}
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:35:00.042515Z",
"start_time": "2021-04-16T00:34:58.720596Z"
},
"scrolled": true,
"id": "zbK5aMThZDwq",
"outputId": "d248efc1-5ef2-49ca-f79a-16147661e97d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 373
}
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1008x360 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"fig, axes = plt.subplots(ncols=2, figsize=(14, 5))\n",
"\n",
"sns.lineplot(data=reg_cv_data.reset_index(),\n",
" x='Max. Depth', y='IC',\n",
" hue='Data', ci=95,\n",
" ax=axes[0], lw=2)\n",
"\n",
"axes[0].set_title('Regression Tree')\n",
"axes[0].axvline(x=reg_valid_scores.mean().idxmax(), ls='--', c='k', lw=1)\n",
"axes[0].axhline(y=0, ls='--', c='k', lw=1)\n",
"\n",
"sns.lineplot(data=clf_cv_data.reset_index(),\n",
" x='Max. Depth', y='ROC AUC',\n",
" hue='Data', ci=95,\n",
" ax=axes[1], lw=2)\n",
"\n",
"axes[1].set_title('Classification Tree')\n",
"axes[1].axvline(x=clf_valid_scores.mean().idxmax(), ls='--', c='k', lw=1)\n",
"axes[1].axhline(y=.5, ls='--', c='k', lw=1)\n",
"for ax in axes:\n",
" ax.set_xlim(min(param_grid['max_depth']),\n",
" max(param_grid['max_depth']))\n",
"\n",
"fig.suptitle(f'Train-Validation Scores', fontsize=14)\n",
"sns.despine()\n",
"fig.tight_layout()\n",
"fig.subplots_adjust(top=.91)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "umXQR0xZZDwq"
},
"source": [
"### Learning Curves for best models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oSl_QZv3ZDwq"
},
"source": [
"A learning curve is a useful tool that displays how the validation and training score evolve as the number of training samples evolves.\n",
"\n",
"The purpose of the learning curve is to find out whether and how much the model would benefit from using more data during training. It is also useful to diagnose whether the model's generalization error is more likely driven by bias or variance.\n",
"\n",
"If, for example, both the validation score and the training score converge to a similarly low value despite an increasing training set size, the error is more likely due to bias, and additional training data is unlikely to help."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lTbRwFxLZDwr"
},
"source": [
"#### Classifier"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:35:00.045677Z",
"start_time": "2021-04-16T00:35:00.043552Z"
},
"id": "86_rUBhXZDwr"
},
"outputs": [],
"source": [
"sizes = np.arange(.1, 1.01, .1)"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:35:01.724318Z",
"start_time": "2021-04-16T00:35:00.052073Z"
},
"id": "mfAybfqUZDwr"
},
"outputs": [],
"source": [
"train_sizes, train_scores, valid_scores = learning_curve(gridsearch_clf.best_estimator_,\n",
" X,\n",
" y_binary,\n",
" train_sizes=sizes,\n",
" cv=cv,\n",
" scoring='roc_auc',\n",
" n_jobs=-1,\n",
" shuffle=True,\n",
" random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:35:01.738102Z",
"start_time": "2021-04-16T00:35:01.725318Z"
},
"id": "L3tlE21wZDwr",
"outputId": "e1d619d2-3aee-424e-d7e8-3d94dfbfb5fd",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Int64Index: 200 entries, 0 to 99\n",
"Data columns (total 3 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Train Size 200 non-null int64 \n",
" 1 ROC AUC 200 non-null float64\n",
" 2 Data 200 non-null object \n",
"dtypes: float64(1), int64(1), object(1)\n",
"memory usage: 6.2+ KB\n"
]
}
],
"source": [
"clf_lc_data = pd.concat([\n",
" pd.melt(pd.DataFrame(train_scores.T, columns=train_sizes),\n",
" var_name='Train Size',\n",
" value_name='ROC AUC').assign(Data='Train'),\n",
" pd.melt(pd.DataFrame(valid_scores.T, columns=train_sizes),\n",
" var_name='Train Size',\n",
" value_name='ROC AUC').assign(Data='Valid')])\n",
"clf_lc_data.info()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Xsvn3n6OZDws"
},
"source": [
"#### Regression Tree"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:35:03.354249Z",
"start_time": "2021-04-16T00:35:01.739800Z"
},
"id": "YsCDfhZkZDws"
},
"outputs": [],
"source": [
"train_sizes, train_scores, valid_scores = learning_curve(gridsearch_reg.best_estimator_,\n",
" X, y,\n",
" train_sizes=sizes,\n",
" cv=cv,\n",
" scoring=ic,\n",
" n_jobs=-1,\n",
" shuffle=True,\n",
" random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:35:03.370414Z",
"start_time": "2021-04-16T00:35:03.355248Z"
},
"id": "ldF010knZDws",
"outputId": "59b3b6ea-a8d2-4309-8bdb-a0d561400195",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Int64Index: 200 entries, 0 to 99\n",
"Data columns (total 3 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 Train Size 200 non-null int64 \n",
" 1 IC 200 non-null float64\n",
" 2 Data 200 non-null object \n",
"dtypes: float64(1), int64(1), object(1)\n",
"memory usage: 6.2+ KB\n"
]
}
],
"source": [
"reg_lc_data = pd.concat([\n",
" pd.melt(pd.DataFrame(train_scores.T,\n",
" columns=train_sizes),\n",
" var_name='Train Size',\n",
" value_name='IC').assign(Data='Train'),\n",
" pd.melt(pd.DataFrame(valid_scores.T,\n",
" columns=train_sizes),\n",
" var_name='Train Size',\n",
" value_name='IC').assign(Data='Valid')])\n",
"reg_lc_data.info()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YDJto6zaZDws"
},
"source": [
"#### Compare Learning Curves"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:35:04.388657Z",
"start_time": "2021-04-16T00:35:03.371672Z"
},
"id": "qMe-G5GDZDwt",
"outputId": "b6d2f4ec-e16b-4e64-a142-32c3040a52c4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 373
}
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1008x360 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"fig, axes = plt.subplots(ncols=2, figsize=(14, 5))\n",
"xmin, xmax = reg_lc_data['Train Size'].min(), reg_lc_data['Train Size'].max()\n",
"\n",
"sns.lineplot(data=reg_lc_data.reset_index(),\n",
" x='Train Size', y='IC',\n",
" hue='Data', ci=95, ax=axes[0], lw=2)\n",
"axes[0].set_title('Best Regression Tree')\n",
"axes[0].set_ylabel('IC')\n",
"\n",
"axes[0].xaxis.set_major_formatter(\n",
" FuncFormatter(lambda x, _: '{:,.0f}'.format(x)))\n",
"\n",
"sns.lineplot(data=clf_lc_data.reset_index(),\n",
" x='Train Size',\n",
" y='ROC AUC',\n",
" hue='Data',\n",
" ci=95,\n",
" ax=axes[1],\n",
" lw=2)\n",
"axes[1].set_title('Best Classification Tree')\n",
"axes[1].set_ylabel('ROC AUC')\n",
"axes[1].xaxis.set_major_formatter(\n",
" FuncFormatter(lambda x, _: '{:,.0f}'.format(x)))\n",
"\n",
"for i in [0, 1]:\n",
" axes[i].tick_params(axis='both', which='major', labelsize=10)\n",
" axes[i].tick_params(axis='both', which='minor', labelsize=8)\n",
" axes[i].set_xlim(xmin, xmax)\n",
"\n",
"fig.suptitle('Learning Curves', fontsize=14)\n",
"sns.despine()\n",
"fig.tight_layout()\n",
"fig.subplots_adjust(top=.9)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eTkJemTUZDwt"
},
"source": [
"### Feature Importance"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d5Y2P_y-ZDwu"
},
"source": [
"Decision trees can not only be visualized to inspect the decision path for a given feature, but also provide a summary measure of the contribution of each feature to the model fit to the training data. \n",
"\n",
"The feature importance captures how much the splits produced by the feature helped to optimize the model's metric used to evaluate the split quality, which in our case is the Gini Impurity index. \n",
"\n",
"A feature's importance is computed as the (normalized) total reduction of this metric and takes into account the number of samples affected by a split. Hence, features used earlier in the tree where the nodes tend to contain more samples typically are considered of higher importance."
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:35:04.395752Z",
"start_time": "2021-04-16T00:35:04.389724Z"
},
"id": "SrEJOhk7ZDwu"
},
"outputs": [],
"source": [
"top_n = 15\n",
"labels = X.columns.str.replace('_', ' ').str.upper()\n",
"fi_clf = (pd.Series(gridsearch_clf.best_estimator_.feature_importances_, \n",
" index=labels).sort_values(ascending=False).iloc[:top_n])\n",
"fi_reg = (pd.Series(gridsearch_reg.best_estimator_.feature_importances_, \n",
" index=labels).sort_values(ascending=False).iloc[:top_n])"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {
"ExecuteTime": {
"end_time": "2021-04-16T00:35:04.714765Z",
"start_time": "2021-04-16T00:35:04.396825Z"
},
"id": "4mSomnjEZDwu",
"outputId": "b80683e3-97f5-497f-8163-af20365b0096",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 302
}
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x288 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
],
"source": [
"fig, axes= plt.subplots(ncols=2, figsize=(12,4), sharex=True)\n",
"color = cm.Blues(np.linspace(.4,.9, top_n))\n",
"fi_clf.sort_values().plot.barh(ax=axes[1], title='Classification Tree', color=color)\n",
"fi_reg.sort_values().plot.barh(ax=axes[0], title='Regression Tree', color=color)\n",
"axes[0].set_xlabel('Feature Importance')\n",
"axes[1].set_xlabel('Feature Importance')\n",
"fig.suptitle(f'Top {top_n} Features', fontsize=14)\n",
"sns.despine()\n",
"fig.tight_layout()\n",
"fig.subplots_adjust(top=.9);"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "343.837px"
},
"toc_section_display": true,
"toc_window_display": true
},
"colab": {
"name": "Decesion Tree Regression (HW2).ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment