Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save alonsosilvaallende/3e0e78f9c7938b98af64f49bc4136963 to your computer and use it in GitHub Desktop.
Save alonsosilvaallende/3e0e78f9c7938b98af64f49bc4136963 to your computer and use it in GitHub Desktop.
pharmacoSmoking_ci_ibs_hyperparameters.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "pharmacoSmoking_ci_ibs_hyperparameters.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNnDUK1yOnYmtbHD/ROpniq",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/3e0e78f9c7938b98af64f49bc4136963/pharmacosmoking_ci_ibs_hyperparameters.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "5eBjyOzOioew"
},
"outputs": [],
"source": [
"!pip install -q scikit-survival"
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import statsmodels.api as sm\n",
"pharmacoSmoking = sm.datasets.get_rdataset(\"pharmacoSmoking\", \"asaur\")\n",
"data = pharmacoSmoking.data\n",
"data.head(3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 198
},
"id": "EPtGZM0GiqKx",
"outputId": "361151e8-9d54-46c7-8ca3-453f978303b8"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n",
" import pandas.util.testing as tm\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
" <div id=\"df-f57040f6-549e-4066-a0ab-b5558ac96e05\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>ttr</th>\n",
" <th>relapse</th>\n",
" <th>grp</th>\n",
" <th>age</th>\n",
" <th>gender</th>\n",
" <th>race</th>\n",
" <th>employment</th>\n",
" <th>yearsSmoking</th>\n",
" <th>levelSmoking</th>\n",
" <th>ageGroup2</th>\n",
" <th>ageGroup4</th>\n",
" <th>priorAttempts</th>\n",
" <th>longestNoSmoke</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>21</td>\n",
" <td>182</td>\n",
" <td>0</td>\n",
" <td>patchOnly</td>\n",
" <td>36</td>\n",
" <td>Male</td>\n",
" <td>white</td>\n",
" <td>ft</td>\n",
" <td>26</td>\n",
" <td>heavy</td>\n",
" <td>21-49</td>\n",
" <td>35-49</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>113</td>\n",
" <td>14</td>\n",
" <td>1</td>\n",
" <td>patchOnly</td>\n",
" <td>41</td>\n",
" <td>Male</td>\n",
" <td>white</td>\n",
" <td>other</td>\n",
" <td>27</td>\n",
" <td>heavy</td>\n",
" <td>21-49</td>\n",
" <td>35-49</td>\n",
" <td>3</td>\n",
" <td>90</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>39</td>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>combination</td>\n",
" <td>25</td>\n",
" <td>Female</td>\n",
" <td>white</td>\n",
" <td>other</td>\n",
" <td>12</td>\n",
" <td>heavy</td>\n",
" <td>21-49</td>\n",
" <td>21-34</td>\n",
" <td>3</td>\n",
" <td>21</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-f57040f6-549e-4066-a0ab-b5558ac96e05')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-f57040f6-549e-4066-a0ab-b5558ac96e05 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-f57040f6-549e-4066-a0ab-b5558ac96e05');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
" id ttr relapse ... ageGroup4 priorAttempts longestNoSmoke\n",
"0 21 182 0 ... 35-49 0 0\n",
"1 113 14 1 ... 35-49 3 90\n",
"2 39 5 1 ... 21-34 3 21\n",
"\n",
"[3 rows x 14 columns]"
]
},
"metadata": {},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"source": [
"# Drop redundant information and ids\n",
"data = data.drop(columns=[\"id\",\"ageGroup2\",\"ageGroup4\"])\n",
"data.head(3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 143
},
"id": "NgvnH8-Di2Wd",
"outputId": "b741cace-932a-4041-d4d2-f0f7282c5383"
},
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
" <div id=\"df-98443a87-79bc-458b-bc76-d971f0e3a830\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ttr</th>\n",
" <th>relapse</th>\n",
" <th>grp</th>\n",
" <th>age</th>\n",
" <th>gender</th>\n",
" <th>race</th>\n",
" <th>employment</th>\n",
" <th>yearsSmoking</th>\n",
" <th>levelSmoking</th>\n",
" <th>priorAttempts</th>\n",
" <th>longestNoSmoke</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>182</td>\n",
" <td>0</td>\n",
" <td>patchOnly</td>\n",
" <td>36</td>\n",
" <td>Male</td>\n",
" <td>white</td>\n",
" <td>ft</td>\n",
" <td>26</td>\n",
" <td>heavy</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>14</td>\n",
" <td>1</td>\n",
" <td>patchOnly</td>\n",
" <td>41</td>\n",
" <td>Male</td>\n",
" <td>white</td>\n",
" <td>other</td>\n",
" <td>27</td>\n",
" <td>heavy</td>\n",
" <td>3</td>\n",
" <td>90</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5</td>\n",
" <td>1</td>\n",
" <td>combination</td>\n",
" <td>25</td>\n",
" <td>Female</td>\n",
" <td>white</td>\n",
" <td>other</td>\n",
" <td>12</td>\n",
" <td>heavy</td>\n",
" <td>3</td>\n",
" <td>21</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-98443a87-79bc-458b-bc76-d971f0e3a830')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-98443a87-79bc-458b-bc76-d971f0e3a830 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-98443a87-79bc-458b-bc76-d971f0e3a830');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
" ttr relapse grp ... levelSmoking priorAttempts longestNoSmoke\n",
"0 182 0 patchOnly ... heavy 0 0\n",
"1 14 1 patchOnly ... heavy 3 90\n",
"2 5 1 combination ... heavy 3 21\n",
"\n",
"[3 rows x 11 columns]"
]
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"source": [
"from sksurv.datasets import get_x_y\n",
"\n",
"X, y = get_x_y(data, attr_labels=[\"relapse\", \"ttr\"], pos_label=True)"
],
"metadata": {
"id": "6MRPeiGIi6V0"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for c in X.columns:\n",
" if X[c].dtype.kind not in ['i', 'f']:\n",
" X[c] = X[c].astype(\"category\")"
],
"metadata": {
"id": "Nlu5jYKvjDU8"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sksurv.preprocessing import OneHotEncoder\n",
"\n",
"X_trn, X_val, y_trn, y_val = train_test_split(X, y, random_state=0)\n",
"# Preprocessing\n",
"enc = OneHotEncoder()\n",
"scaler = StandardScaler()\n",
"X_trn = enc.fit_transform(X_trn)\n",
"X_trn = pd.DataFrame(scaler.fit_transform(X_trn), columns=X_trn.columns)\n",
"X_val = enc.transform(X_val)\n",
"X_val = pd.DataFrame(scaler.transform(X_val), columns=X_val.columns)"
],
"metadata": {
"id": "kEgdfhNPjF6H"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sksurv.ensemble import RandomSurvivalForest\n",
"from sksurv.metrics import concordance_index_censored\n",
"\n",
"model = RandomSurvivalForest(random_state=42)\n",
"model.fit(X_trn, y_trn)\n",
"ci_trn = concordance_index_censored(y_trn['relapse'], y_trn['ttr'], model.predict(X_trn))\n",
"ci_val = concordance_index_censored(y_val['relapse'], y_val['ttr'], model.predict(X_val))\n",
"ci_trn[0], ci_val[0]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "a3xSYwx5jScZ",
"outputId": "31b406e0-6986-42a5-bfa7-e0f751387ac5"
},
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.8282642894598846, 0.6187363834422658)"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"param_grid_rsf = {\n",
" 'max_features': (\"sqrt\", 0.5, 1),\n",
" 'min_samples_leaf': (1,3,5),\n",
" 'n_estimators': (100, 200),\n",
" 'max_depth': (5,7,9),\n",
"}"
],
"metadata": {
"id": "1AHpkgeykXx2"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"gcv_ci = GridSearchCV(\n",
" RandomSurvivalForest(),\n",
" param_grid=param_grid_rsf,\n",
" cv=4,\n",
" n_jobs=-1,\n",
")"
],
"metadata": {
"id": "8eJOSMc5jfkU"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"source": [
"gcv_ci.fit(X_trn, y_trn)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "sszi6Hqrkzg-",
"outputId": "1aa0be12-618f-4cbb-f074-bc826693f81e"
},
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GridSearchCV(cv=4, estimator=RandomSurvivalForest(), n_jobs=-1,\n",
" param_grid={'max_depth': (5, 7, 9),\n",
" 'max_features': ('sqrt', 0.5, 1),\n",
" 'min_samples_leaf': (1, 3, 5),\n",
" 'n_estimators': (100, 200)})"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"ci_trn = concordance_index_censored(y_trn['relapse'], y_trn['ttr'], gcv_ci.predict(X_trn))\n",
"ci_val = concordance_index_censored(y_val['relapse'], y_val['ttr'], gcv_ci.predict(X_val))\n",
"ci_trn[0], ci_val[0]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Us6zPWeok7xp",
"outputId": "33e48935-b340-419d-be0e-0e1851445029"
},
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.8626114315679078, 0.5838779956427015)"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"source": [
"from sksurv.metrics import integrated_brier_score\n",
"\n",
"survs_trn = model.predict_survival_function(X_trn)\n",
"survs_val = model.predict_survival_function(X_val)\n",
"times = np.arange(y_val['ttr'].min(), y_trn['ttr'][y_trn['ttr']!=y_trn['ttr'].max()].max())\n",
"preds_trn = np.asarray([[fn(t) for t in times] for fn in survs_trn])\n",
"preds_val = np.asarray([[fn(t) for t in times] for fn in survs_val])\n",
"ibs_trn = integrated_brier_score(y_trn, y_trn, preds_trn, times)\n",
"ibs_val = integrated_brier_score(y_trn, y_val, preds_val, times)\n",
"ibs_trn, ibs_val"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gW7UwEE_k307",
"outputId": "adbc86be-6edd-452c-ebf4-6f1c3e418207"
},
"execution_count": 12,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.1413713684488042, 0.21497683508486282)"
]
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.model_selection import KFold\n",
"\n",
"cv = KFold(n_splits=4, shuffle=True, random_state=0)"
],
"metadata": {
"id": "zNp96QBalY1Y"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"source": [
"param_grid_rsf = {\n",
" 'estimator__max_features': (\"sqrt\", 0.5, 1),\n",
" 'estimator__min_samples_leaf': (1,3,5),\n",
" 'estimator__n_estimators': (100,200),\n",
" 'estimator__max_depth': (3,7,10),\n",
"}"
],
"metadata": {
"id": "Esd1F7YNl9Gd"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sksurv.ensemble import RandomSurvivalForest\n",
"from sksurv.metrics import as_integrated_brier_score_scorer\n",
"\n",
"gcv_ibs = GridSearchCV(\n",
" as_integrated_brier_score_scorer(RandomSurvivalForest(), times=times),\n",
" param_grid=param_grid_rsf,\n",
" cv=cv,\n",
" n_jobs=-1,\n",
")"
],
"metadata": {
"id": "plEtG0eVlhMO"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"gcv_ibs.fit(X_trn, y_trn)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xEntw1jclx5w",
"outputId": "a969b1c7-1c5e-407a-fa79-45cb8fbe85eb"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/sklearn/model_selection/_search.py:972: UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n",
" nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]\n",
" category=UserWarning,\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GridSearchCV(cv=KFold(n_splits=4, random_state=0, shuffle=True),\n",
" estimator=as_integrated_brier_score_scorer(estimator=RandomSurvivalForest(),\n",
" times=array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.,\n",
" 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21.,\n",
" 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,\n",
" 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,\n",
" 44., 45., 46., 47., 48., 49., 50., 51., 5...\n",
" 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142.,\n",
" 143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153.,\n",
" 154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164.,\n",
" 165., 166., 167., 168., 169.])),\n",
" n_jobs=-1,\n",
" param_grid={'estimator__max_depth': (3, 7, 10),\n",
" 'estimator__max_features': ('sqrt', 0.5, 1),\n",
" 'estimator__min_samples_leaf': (1, 3, 5),\n",
" 'estimator__n_estimators': (100, 200)})"
]
},
"metadata": {},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"source": [
"rsf = RandomSurvivalForest(max_depth=gcv_ibs.best_params_['estimator__max_depth'],\n",
" max_features=gcv_ibs.best_params_['estimator__max_features'], \n",
" min_samples_leaf=gcv_ibs.best_params_['estimator__min_samples_leaf'],\n",
" n_estimators=gcv_ibs.best_params_['estimator__n_estimators']\n",
" )\n",
"rsf"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KwS1i98ul0hF",
"outputId": "02c93ec6-bdba-4ddb-860d-772e89730eaf"
},
"execution_count": 17,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomSurvivalForest(max_depth=3, max_features='sqrt', min_samples_leaf=1)"
]
},
"metadata": {},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"source": [
"rsf.fit(X_trn, y_trn)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ypj1A9UfmRou",
"outputId": "823aa33a-1648-4f82-9696-2d378c8e13bf"
},
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomSurvivalForest(max_depth=3, max_features='sqrt', min_samples_leaf=1)"
]
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"source": [
"survs_trn = rsf.predict_survival_function(X_trn)\n",
"survs_val = rsf.predict_survival_function(X_val)\n",
"preds_trn = np.asarray([[fn(t) for t in times] for fn in survs_trn])\n",
"preds_val = np.asarray([[fn(t) for t in times] for fn in survs_val])\n",
"ibs_trn = integrated_brier_score(y_trn, y_trn, preds_trn, times)\n",
"ibs_val = integrated_brier_score(y_trn, y_val, preds_val, times)\n",
"ibs_trn, ibs_val"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "T_D3raDJmV4Q",
"outputId": "1db3a419-be13-48fe-b2b0-343e796744c2"
},
"execution_count": 19,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.18609360226313928, 0.21563923724817194)"
]
},
"metadata": {},
"execution_count": 19
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "Dt7tUcgimi7-"
},
"execution_count": 19,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment