Created
February 14, 2022 14:22
-
-
Save alonsosilvaallende/3e0e78f9c7938b98af64f49bc4136963 to your computer and use it in GitHub Desktop.
pharmacoSmoking_ci_ibs_hyperparameters.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "pharmacoSmoking_ci_ibs_hyperparameters.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyNnDUK1yOnYmtbHD/ROpniq", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/3e0e78f9c7938b98af64f49bc4136963/pharmacosmoking_ci_ibs_hyperparameters.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "5eBjyOzOioew" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install -q scikit-survival" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import statsmodels.api as sm\n", | |
"pharmacoSmoking = sm.datasets.get_rdataset(\"pharmacoSmoking\", \"asaur\")\n", | |
"data = pharmacoSmoking.data\n", | |
"data.head(3)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 198 | |
}, | |
"id": "EPtGZM0GiqKx", | |
"outputId": "361151e8-9d54-46c7-8ca3-453f978303b8" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", | |
" import pandas.util.testing as tm\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div id=\"df-f57040f6-549e-4066-a0ab-b5558ac96e05\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>ttr</th>\n", | |
" <th>relapse</th>\n", | |
" <th>grp</th>\n", | |
" <th>age</th>\n", | |
" <th>gender</th>\n", | |
" <th>race</th>\n", | |
" <th>employment</th>\n", | |
" <th>yearsSmoking</th>\n", | |
" <th>levelSmoking</th>\n", | |
" <th>ageGroup2</th>\n", | |
" <th>ageGroup4</th>\n", | |
" <th>priorAttempts</th>\n", | |
" <th>longestNoSmoke</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>21</td>\n", | |
" <td>182</td>\n", | |
" <td>0</td>\n", | |
" <td>patchOnly</td>\n", | |
" <td>36</td>\n", | |
" <td>Male</td>\n", | |
" <td>white</td>\n", | |
" <td>ft</td>\n", | |
" <td>26</td>\n", | |
" <td>heavy</td>\n", | |
" <td>21-49</td>\n", | |
" <td>35-49</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>113</td>\n", | |
" <td>14</td>\n", | |
" <td>1</td>\n", | |
" <td>patchOnly</td>\n", | |
" <td>41</td>\n", | |
" <td>Male</td>\n", | |
" <td>white</td>\n", | |
" <td>other</td>\n", | |
" <td>27</td>\n", | |
" <td>heavy</td>\n", | |
" <td>21-49</td>\n", | |
" <td>35-49</td>\n", | |
" <td>3</td>\n", | |
" <td>90</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>39</td>\n", | |
" <td>5</td>\n", | |
" <td>1</td>\n", | |
" <td>combination</td>\n", | |
" <td>25</td>\n", | |
" <td>Female</td>\n", | |
" <td>white</td>\n", | |
" <td>other</td>\n", | |
" <td>12</td>\n", | |
" <td>heavy</td>\n", | |
" <td>21-49</td>\n", | |
" <td>21-34</td>\n", | |
" <td>3</td>\n", | |
" <td>21</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-f57040f6-549e-4066-a0ab-b5558ac96e05')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-f57040f6-549e-4066-a0ab-b5558ac96e05 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-f57040f6-549e-4066-a0ab-b5558ac96e05');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
" id ttr relapse ... ageGroup4 priorAttempts longestNoSmoke\n", | |
"0 21 182 0 ... 35-49 0 0\n", | |
"1 113 14 1 ... 35-49 3 90\n", | |
"2 39 5 1 ... 21-34 3 21\n", | |
"\n", | |
"[3 rows x 14 columns]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 2 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Drop redundant information and ids\n", | |
"data = data.drop(columns=[\"id\",\"ageGroup2\",\"ageGroup4\"])\n", | |
"data.head(3)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 143 | |
}, | |
"id": "NgvnH8-Di2Wd", | |
"outputId": "b741cace-932a-4041-d4d2-f0f7282c5383" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div id=\"df-98443a87-79bc-458b-bc76-d971f0e3a830\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>ttr</th>\n", | |
" <th>relapse</th>\n", | |
" <th>grp</th>\n", | |
" <th>age</th>\n", | |
" <th>gender</th>\n", | |
" <th>race</th>\n", | |
" <th>employment</th>\n", | |
" <th>yearsSmoking</th>\n", | |
" <th>levelSmoking</th>\n", | |
" <th>priorAttempts</th>\n", | |
" <th>longestNoSmoke</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>182</td>\n", | |
" <td>0</td>\n", | |
" <td>patchOnly</td>\n", | |
" <td>36</td>\n", | |
" <td>Male</td>\n", | |
" <td>white</td>\n", | |
" <td>ft</td>\n", | |
" <td>26</td>\n", | |
" <td>heavy</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>14</td>\n", | |
" <td>1</td>\n", | |
" <td>patchOnly</td>\n", | |
" <td>41</td>\n", | |
" <td>Male</td>\n", | |
" <td>white</td>\n", | |
" <td>other</td>\n", | |
" <td>27</td>\n", | |
" <td>heavy</td>\n", | |
" <td>3</td>\n", | |
" <td>90</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>5</td>\n", | |
" <td>1</td>\n", | |
" <td>combination</td>\n", | |
" <td>25</td>\n", | |
" <td>Female</td>\n", | |
" <td>white</td>\n", | |
" <td>other</td>\n", | |
" <td>12</td>\n", | |
" <td>heavy</td>\n", | |
" <td>3</td>\n", | |
" <td>21</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-98443a87-79bc-458b-bc76-d971f0e3a830')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-98443a87-79bc-458b-bc76-d971f0e3a830 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-98443a87-79bc-458b-bc76-d971f0e3a830');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
" ttr relapse grp ... levelSmoking priorAttempts longestNoSmoke\n", | |
"0 182 0 patchOnly ... heavy 0 0\n", | |
"1 14 1 patchOnly ... heavy 3 90\n", | |
"2 5 1 combination ... heavy 3 21\n", | |
"\n", | |
"[3 rows x 11 columns]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sksurv.datasets import get_x_y\n", | |
"\n", | |
"X, y = get_x_y(data, attr_labels=[\"relapse\", \"ttr\"], pos_label=True)" | |
], | |
"metadata": { | |
"id": "6MRPeiGIi6V0" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for c in X.columns:\n", | |
" if X[c].dtype.kind not in ['i', 'f']:\n", | |
" X[c] = X[c].astype(\"category\")" | |
], | |
"metadata": { | |
"id": "Nlu5jYKvjDU8" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"from sklearn.preprocessing import StandardScaler\n", | |
"from sksurv.preprocessing import OneHotEncoder\n", | |
"\n", | |
"X_trn, X_val, y_trn, y_val = train_test_split(X, y, random_state=0)\n", | |
"# Preprocessing\n", | |
"enc = OneHotEncoder()\n", | |
"scaler = StandardScaler()\n", | |
"X_trn = enc.fit_transform(X_trn)\n", | |
"X_trn = pd.DataFrame(scaler.fit_transform(X_trn), columns=X_trn.columns)\n", | |
"X_val = enc.transform(X_val)\n", | |
"X_val = pd.DataFrame(scaler.transform(X_val), columns=X_val.columns)" | |
], | |
"metadata": { | |
"id": "kEgdfhNPjF6H" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sksurv.ensemble import RandomSurvivalForest\n", | |
"from sksurv.metrics import concordance_index_censored\n", | |
"\n", | |
"model = RandomSurvivalForest(random_state=42)\n", | |
"model.fit(X_trn, y_trn)\n", | |
"ci_trn = concordance_index_censored(y_trn['relapse'], y_trn['ttr'], model.predict(X_trn))\n", | |
"ci_val = concordance_index_censored(y_val['relapse'], y_val['ttr'], model.predict(X_val))\n", | |
"ci_trn[0], ci_val[0]" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "a3xSYwx5jScZ", | |
"outputId": "31b406e0-6986-42a5-bfa7-e0f751387ac5" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(0.8282642894598846, 0.6187363834422658)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"param_grid_rsf = {\n", | |
" 'max_features': (\"sqrt\", 0.5, 1),\n", | |
" 'min_samples_leaf': (1,3,5),\n", | |
" 'n_estimators': (100, 200),\n", | |
" 'max_depth': (5,7,9),\n", | |
"}" | |
], | |
"metadata": { | |
"id": "1AHpkgeykXx2" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sklearn.model_selection import GridSearchCV\n", | |
"\n", | |
"gcv_ci = GridSearchCV(\n", | |
" RandomSurvivalForest(),\n", | |
" param_grid=param_grid_rsf,\n", | |
" cv=4,\n", | |
" n_jobs=-1,\n", | |
")" | |
], | |
"metadata": { | |
"id": "8eJOSMc5jfkU" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"gcv_ci.fit(X_trn, y_trn)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "sszi6Hqrkzg-", | |
"outputId": "1aa0be12-618f-4cbb-f074-bc826693f81e" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"GridSearchCV(cv=4, estimator=RandomSurvivalForest(), n_jobs=-1,\n", | |
" param_grid={'max_depth': (5, 7, 9),\n", | |
" 'max_features': ('sqrt', 0.5, 1),\n", | |
" 'min_samples_leaf': (1, 3, 5),\n", | |
" 'n_estimators': (100, 200)})" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"ci_trn = concordance_index_censored(y_trn['relapse'], y_trn['ttr'], gcv_ci.predict(X_trn))\n", | |
"ci_val = concordance_index_censored(y_val['relapse'], y_val['ttr'], gcv_ci.predict(X_val))\n", | |
"ci_trn[0], ci_val[0]" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Us6zPWeok7xp", | |
"outputId": "33e48935-b340-419d-be0e-0e1851445029" | |
}, | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(0.8626114315679078, 0.5838779956427015)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sksurv.metrics import integrated_brier_score\n", | |
"\n", | |
"survs_trn = model.predict_survival_function(X_trn)\n", | |
"survs_val = model.predict_survival_function(X_val)\n", | |
"times = np.arange(y_val['ttr'].min(), y_trn['ttr'][y_trn['ttr']!=y_trn['ttr'].max()].max())\n", | |
"preds_trn = np.asarray([[fn(t) for t in times] for fn in survs_trn])\n", | |
"preds_val = np.asarray([[fn(t) for t in times] for fn in survs_val])\n", | |
"ibs_trn = integrated_brier_score(y_trn, y_trn, preds_trn, times)\n", | |
"ibs_val = integrated_brier_score(y_trn, y_val, preds_val, times)\n", | |
"ibs_trn, ibs_val" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "gW7UwEE_k307", | |
"outputId": "adbc86be-6edd-452c-ebf4-6f1c3e418207" | |
}, | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(0.1413713684488042, 0.21497683508486282)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 12 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sklearn.model_selection import KFold\n", | |
"\n", | |
"cv = KFold(n_splits=4, shuffle=True, random_state=0)" | |
], | |
"metadata": { | |
"id": "zNp96QBalY1Y" | |
}, | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"param_grid_rsf = {\n", | |
" 'estimator__max_features': (\"sqrt\", 0.5, 1),\n", | |
" 'estimator__min_samples_leaf': (1,3,5),\n", | |
" 'estimator__n_estimators': (100,200),\n", | |
" 'estimator__max_depth': (3,7,10),\n", | |
"}" | |
], | |
"metadata": { | |
"id": "Esd1F7YNl9Gd" | |
}, | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sksurv.ensemble import RandomSurvivalForest\n", | |
"from sksurv.metrics import as_integrated_brier_score_scorer\n", | |
"\n", | |
"gcv_ibs = GridSearchCV(\n", | |
" as_integrated_brier_score_scorer(RandomSurvivalForest(), times=times),\n", | |
" param_grid=param_grid_rsf,\n", | |
" cv=cv,\n", | |
" n_jobs=-1,\n", | |
")" | |
], | |
"metadata": { | |
"id": "plEtG0eVlhMO" | |
}, | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"gcv_ibs.fit(X_trn, y_trn)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "xEntw1jclx5w", | |
"outputId": "a969b1c7-1c5e-407a-fa79-45cb8fbe85eb" | |
}, | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/sklearn/model_selection/_search.py:972: UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n", | |
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n", | |
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]\n", | |
" category=UserWarning,\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"GridSearchCV(cv=KFold(n_splits=4, random_state=0, shuffle=True),\n", | |
" estimator=as_integrated_brier_score_scorer(estimator=RandomSurvivalForest(),\n", | |
" times=array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.,\n", | |
" 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21.,\n", | |
" 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,\n", | |
" 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,\n", | |
" 44., 45., 46., 47., 48., 49., 50., 51., 5...\n", | |
" 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142.,\n", | |
" 143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153.,\n", | |
" 154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164.,\n", | |
" 165., 166., 167., 168., 169.])),\n", | |
" n_jobs=-1,\n", | |
" param_grid={'estimator__max_depth': (3, 7, 10),\n", | |
" 'estimator__max_features': ('sqrt', 0.5, 1),\n", | |
" 'estimator__min_samples_leaf': (1, 3, 5),\n", | |
" 'estimator__n_estimators': (100, 200)})" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 16 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"rsf = RandomSurvivalForest(max_depth=gcv_ibs.best_params_['estimator__max_depth'],\n", | |
" max_features=gcv_ibs.best_params_['estimator__max_features'], \n", | |
" min_samples_leaf=gcv_ibs.best_params_['estimator__min_samples_leaf'],\n", | |
" n_estimators=gcv_ibs.best_params_['estimator__n_estimators']\n", | |
" )\n", | |
"rsf" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "KwS1i98ul0hF", | |
"outputId": "02c93ec6-bdba-4ddb-860d-772e89730eaf" | |
}, | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"RandomSurvivalForest(max_depth=3, max_features='sqrt', min_samples_leaf=1)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"rsf.fit(X_trn, y_trn)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ypj1A9UfmRou", | |
"outputId": "823aa33a-1648-4f82-9696-2d378c8e13bf" | |
}, | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"RandomSurvivalForest(max_depth=3, max_features='sqrt', min_samples_leaf=1)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"survs_trn = rsf.predict_survival_function(X_trn)\n", | |
"survs_val = rsf.predict_survival_function(X_val)\n", | |
"preds_trn = np.asarray([[fn(t) for t in times] for fn in survs_trn])\n", | |
"preds_val = np.asarray([[fn(t) for t in times] for fn in survs_val])\n", | |
"ibs_trn = integrated_brier_score(y_trn, y_trn, preds_trn, times)\n", | |
"ibs_val = integrated_brier_score(y_trn, y_val, preds_val, times)\n", | |
"ibs_trn, ibs_val" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "T_D3raDJmV4Q", | |
"outputId": "1db3a419-be13-48fe-b2b0-343e796744c2" | |
}, | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(0.18609360226313928, 0.21563923724817194)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 19 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "Dt7tUcgimi7-" | |
}, | |
"execution_count": 19, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment