Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save alonsosilvaallende/98a7662a8de782b2ba9816a37527fc7f to your computer and use it in GitHub Desktop.
Save alonsosilvaallende/98a7662a8de782b2ba9816a37527fc7f to your computer and use it in GitHub Desktop.
Copy of Cox_PH_and_RSF-colab.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/98a7662a8de782b2ba9816a37527fc7f/copy-of-cox_ph_and_rsf-colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UN3PoUTSb2nT"
},
"source": [
"The objective of this notebook is to compare different models to estimate the survival probability given a set of features/covariables.\n",
"\n",
">[\"Experimental Comparison of Semi-parametric, Parametric, and Machine Learning Models for Time-to-Event Analysis Through the Concordance Index,\"](https://arxiv.org/abs/2003.08820)\n",
"Camila Fernandez, Chung Shue Chen, Pierre Gaillard, Alonso Silva\n",
"\n",
"To perform this analysis we will use [scikit-learn](https://scikit-learn.org/) and [scikit-survival](https://pypi.org/project/scikit-survival/). Finally, we will use [eli5](https://eli5.readthedocs.io/en/latest/index.html) to study feature importances (computed with permutation importance)."
]
},
{
"cell_type": "code",
"source": [
"%pip install --quiet scikit-learn\n",
"%pip install --quiet scikit-survival"
],
"metadata": {
"id": "xMhD09DUq_E4"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VoQVEI5p_rga"
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "SnT6e_JPb2ns"
},
"source": [
"We first download a dataset from scikit-survival."
]
},
{
"cell_type": "code",
"metadata": {
"id": "D0xxNWzI-N3j"
},
"source": [
"from sksurv.datasets import load_gbsg2\n",
"\n",
"X, y = load_gbsg2()"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "IBTo4q_Hb2n0"
},
"source": [
"## An example: German Breast Cancer Study Group 2 (gbcsg)\n",
"\n",
"This dataset contains the following 8 features/covariables:\n",
"\n",
"- age: age (in years),\n",
"- estrec: estrogen receptor (in fmol),\n",
"- horTh: hormonal therapy (yes or no),\n",
"- menostat: menopausal status (premenopausal or postmenopausal),\n",
"- pnodes: number of positive nodes,\n",
"- progrec: progesterone receptor (in fmol),\n",
"- tgrade: tumor grade (I < II < III),\n",
"- tsize: tumor size (in mm).\n",
"\n",
"and the two outputs:\n",
"\n",
"- recurrence free time (in days),\n",
"- censoring indicator (0 - censored, 1 - event).\n",
"\n",
"The dataset has 686 samples and 8 features/covariables.\n",
"\n",
"\n",
"**References**\n",
"\n",
"M. Schumacher, G. Basert, H. Bojar, K. Huebner, M. Olschewski, W. Sauerbrei, C. Schmoor, C. Beyerle, R.L.A. Neumann and H.F. Rauschecker for the German Breast Cancer Study Group (1994), [Randomized 2 x 2 trial evaluating hormonal treatment and the duration of chemotherapy in node-positive breast cancer patients](https://www.ncbi.nlm.nih.gov/pubmed/7931478). Journal of Clinical Oncology, 12, 2086–2093."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kAsZ72YYb2n3"
},
"source": [
"Let's take a look at the features/covariates."
]
},
{
"cell_type": "code",
"metadata": {
"id": "_OlmlI6g-X43",
"outputId": "874a8141-c402-4be1-a2db-ada1b9a2975e",
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
}
},
"source": [
"cols = [\"age\", \"estrec\", \"pnodes\", \"progrec\", \"tsize\"]\n",
"formatdict = {}\n",
"for col in cols: formatdict[col] = \"{:,.0f}\"\n",
"X.head(10).style.hide(axis=\"index\").format(formatdict)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7a30d7a54af0>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_f9879\" class=\"dataframe\">\n",
" <thead>\n",
" <tr>\n",
" <th id=\"T_f9879_level0_col0\" class=\"col_heading level0 col0\" >age</th>\n",
" <th id=\"T_f9879_level0_col1\" class=\"col_heading level0 col1\" >estrec</th>\n",
" <th id=\"T_f9879_level0_col2\" class=\"col_heading level0 col2\" >horTh</th>\n",
" <th id=\"T_f9879_level0_col3\" class=\"col_heading level0 col3\" >menostat</th>\n",
" <th id=\"T_f9879_level0_col4\" class=\"col_heading level0 col4\" >pnodes</th>\n",
" <th id=\"T_f9879_level0_col5\" class=\"col_heading level0 col5\" >progrec</th>\n",
" <th id=\"T_f9879_level0_col6\" class=\"col_heading level0 col6\" >tgrade</th>\n",
" <th id=\"T_f9879_level0_col7\" class=\"col_heading level0 col7\" >tsize</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_f9879_row0_col0\" class=\"data row0 col0\" >70</td>\n",
" <td id=\"T_f9879_row0_col1\" class=\"data row0 col1\" >66</td>\n",
" <td id=\"T_f9879_row0_col2\" class=\"data row0 col2\" >no</td>\n",
" <td id=\"T_f9879_row0_col3\" class=\"data row0 col3\" >Post</td>\n",
" <td id=\"T_f9879_row0_col4\" class=\"data row0 col4\" >3</td>\n",
" <td id=\"T_f9879_row0_col5\" class=\"data row0 col5\" >48</td>\n",
" <td id=\"T_f9879_row0_col6\" class=\"data row0 col6\" >II</td>\n",
" <td id=\"T_f9879_row0_col7\" class=\"data row0 col7\" >21</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_f9879_row1_col0\" class=\"data row1 col0\" >56</td>\n",
" <td id=\"T_f9879_row1_col1\" class=\"data row1 col1\" >77</td>\n",
" <td id=\"T_f9879_row1_col2\" class=\"data row1 col2\" >yes</td>\n",
" <td id=\"T_f9879_row1_col3\" class=\"data row1 col3\" >Post</td>\n",
" <td id=\"T_f9879_row1_col4\" class=\"data row1 col4\" >7</td>\n",
" <td id=\"T_f9879_row1_col5\" class=\"data row1 col5\" >61</td>\n",
" <td id=\"T_f9879_row1_col6\" class=\"data row1 col6\" >II</td>\n",
" <td id=\"T_f9879_row1_col7\" class=\"data row1 col7\" >12</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_f9879_row2_col0\" class=\"data row2 col0\" >58</td>\n",
" <td id=\"T_f9879_row2_col1\" class=\"data row2 col1\" >271</td>\n",
" <td id=\"T_f9879_row2_col2\" class=\"data row2 col2\" >yes</td>\n",
" <td id=\"T_f9879_row2_col3\" class=\"data row2 col3\" >Post</td>\n",
" <td id=\"T_f9879_row2_col4\" class=\"data row2 col4\" >9</td>\n",
" <td id=\"T_f9879_row2_col5\" class=\"data row2 col5\" >52</td>\n",
" <td id=\"T_f9879_row2_col6\" class=\"data row2 col6\" >II</td>\n",
" <td id=\"T_f9879_row2_col7\" class=\"data row2 col7\" >35</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_f9879_row3_col0\" class=\"data row3 col0\" >59</td>\n",
" <td id=\"T_f9879_row3_col1\" class=\"data row3 col1\" >29</td>\n",
" <td id=\"T_f9879_row3_col2\" class=\"data row3 col2\" >yes</td>\n",
" <td id=\"T_f9879_row3_col3\" class=\"data row3 col3\" >Post</td>\n",
" <td id=\"T_f9879_row3_col4\" class=\"data row3 col4\" >4</td>\n",
" <td id=\"T_f9879_row3_col5\" class=\"data row3 col5\" >60</td>\n",
" <td id=\"T_f9879_row3_col6\" class=\"data row3 col6\" >II</td>\n",
" <td id=\"T_f9879_row3_col7\" class=\"data row3 col7\" >17</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_f9879_row4_col0\" class=\"data row4 col0\" >73</td>\n",
" <td id=\"T_f9879_row4_col1\" class=\"data row4 col1\" >65</td>\n",
" <td id=\"T_f9879_row4_col2\" class=\"data row4 col2\" >no</td>\n",
" <td id=\"T_f9879_row4_col3\" class=\"data row4 col3\" >Post</td>\n",
" <td id=\"T_f9879_row4_col4\" class=\"data row4 col4\" >1</td>\n",
" <td id=\"T_f9879_row4_col5\" class=\"data row4 col5\" >26</td>\n",
" <td id=\"T_f9879_row4_col6\" class=\"data row4 col6\" >II</td>\n",
" <td id=\"T_f9879_row4_col7\" class=\"data row4 col7\" >35</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_f9879_row5_col0\" class=\"data row5 col0\" >32</td>\n",
" <td id=\"T_f9879_row5_col1\" class=\"data row5 col1\" >13</td>\n",
" <td id=\"T_f9879_row5_col2\" class=\"data row5 col2\" >no</td>\n",
" <td id=\"T_f9879_row5_col3\" class=\"data row5 col3\" >Pre</td>\n",
" <td id=\"T_f9879_row5_col4\" class=\"data row5 col4\" >24</td>\n",
" <td id=\"T_f9879_row5_col5\" class=\"data row5 col5\" >0</td>\n",
" <td id=\"T_f9879_row5_col6\" class=\"data row5 col6\" >III</td>\n",
" <td id=\"T_f9879_row5_col7\" class=\"data row5 col7\" >57</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_f9879_row6_col0\" class=\"data row6 col0\" >59</td>\n",
" <td id=\"T_f9879_row6_col1\" class=\"data row6 col1\" >0</td>\n",
" <td id=\"T_f9879_row6_col2\" class=\"data row6 col2\" >yes</td>\n",
" <td id=\"T_f9879_row6_col3\" class=\"data row6 col3\" >Post</td>\n",
" <td id=\"T_f9879_row6_col4\" class=\"data row6 col4\" >2</td>\n",
" <td id=\"T_f9879_row6_col5\" class=\"data row6 col5\" >181</td>\n",
" <td id=\"T_f9879_row6_col6\" class=\"data row6 col6\" >II</td>\n",
" <td id=\"T_f9879_row6_col7\" class=\"data row6 col7\" >8</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_f9879_row7_col0\" class=\"data row7 col0\" >65</td>\n",
" <td id=\"T_f9879_row7_col1\" class=\"data row7 col1\" >25</td>\n",
" <td id=\"T_f9879_row7_col2\" class=\"data row7 col2\" >no</td>\n",
" <td id=\"T_f9879_row7_col3\" class=\"data row7 col3\" >Post</td>\n",
" <td id=\"T_f9879_row7_col4\" class=\"data row7 col4\" >1</td>\n",
" <td id=\"T_f9879_row7_col5\" class=\"data row7 col5\" >192</td>\n",
" <td id=\"T_f9879_row7_col6\" class=\"data row7 col6\" >II</td>\n",
" <td id=\"T_f9879_row7_col7\" class=\"data row7 col7\" >16</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_f9879_row8_col0\" class=\"data row8 col0\" >80</td>\n",
" <td id=\"T_f9879_row8_col1\" class=\"data row8 col1\" >59</td>\n",
" <td id=\"T_f9879_row8_col2\" class=\"data row8 col2\" >no</td>\n",
" <td id=\"T_f9879_row8_col3\" class=\"data row8 col3\" >Post</td>\n",
" <td id=\"T_f9879_row8_col4\" class=\"data row8 col4\" >30</td>\n",
" <td id=\"T_f9879_row8_col5\" class=\"data row8 col5\" >0</td>\n",
" <td id=\"T_f9879_row8_col6\" class=\"data row8 col6\" >II</td>\n",
" <td id=\"T_f9879_row8_col7\" class=\"data row8 col7\" >39</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_f9879_row9_col0\" class=\"data row9 col0\" >66</td>\n",
" <td id=\"T_f9879_row9_col1\" class=\"data row9 col1\" >3</td>\n",
" <td id=\"T_f9879_row9_col2\" class=\"data row9 col2\" >no</td>\n",
" <td id=\"T_f9879_row9_col3\" class=\"data row9 col3\" >Post</td>\n",
" <td id=\"T_f9879_row9_col4\" class=\"data row9 col4\" >7</td>\n",
" <td id=\"T_f9879_row9_col5\" class=\"data row9 col5\" >0</td>\n",
" <td id=\"T_f9879_row9_col6\" class=\"data row9 col6\" >II</td>\n",
" <td id=\"T_f9879_row9_col7\" class=\"data row9 col7\" >18</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zcMjHeAEb2n_"
},
"source": [
"Let's take a look at the output."
]
},
{
"cell_type": "code",
"metadata": {
"id": "h8ltRTa4_WOn",
"outputId": "e8ddf347-88a8-4383-8b32-dcfa34fcc1fa",
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"y[:10]"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([( True, 1814.), ( True, 2018.), ( True, 712.), ( True, 1807.),\n",
" ( True, 772.), ( True, 448.), (False, 2172.), (False, 2161.),\n",
" ( True, 471.), (False, 2014.)],\n",
" dtype=[('cens', '?'), ('time', '<f8')])"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A5NXlNYob2oH"
},
"source": [
"For the output, scikit-survival uses a numpy nd array, so to show it we do a dataframe."
]
},
{
"cell_type": "code",
"source": [
"df_y = pd.DataFrame(data={'time': y['time'].astype(int), 'event': y['cens']})\n",
"df_y[:10].style.hide(axis=\"index\").highlight_min('event', color='lightgreen')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
},
"id": "AfPvZcjJ-GwQ",
"outputId": "7ee1358d-69bb-4d15-876a-92e708147a24"
},
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7a30d61bea10>"
],
"text/html": [
"<style type=\"text/css\">\n",
"#T_118da_row6_col1, #T_118da_row7_col1, #T_118da_row9_col1 {\n",
" background-color: lightgreen;\n",
"}\n",
"</style>\n",
"<table id=\"T_118da\" class=\"dataframe\">\n",
" <thead>\n",
" <tr>\n",
" <th id=\"T_118da_level0_col0\" class=\"col_heading level0 col0\" >time</th>\n",
" <th id=\"T_118da_level0_col1\" class=\"col_heading level0 col1\" >event</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_118da_row0_col0\" class=\"data row0 col0\" >1814</td>\n",
" <td id=\"T_118da_row0_col1\" class=\"data row0 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_118da_row1_col0\" class=\"data row1 col0\" >2018</td>\n",
" <td id=\"T_118da_row1_col1\" class=\"data row1 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_118da_row2_col0\" class=\"data row2 col0\" >712</td>\n",
" <td id=\"T_118da_row2_col1\" class=\"data row2 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_118da_row3_col0\" class=\"data row3 col0\" >1807</td>\n",
" <td id=\"T_118da_row3_col1\" class=\"data row3 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_118da_row4_col0\" class=\"data row4 col0\" >772</td>\n",
" <td id=\"T_118da_row4_col1\" class=\"data row4 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_118da_row5_col0\" class=\"data row5 col0\" >448</td>\n",
" <td id=\"T_118da_row5_col1\" class=\"data row5 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_118da_row6_col0\" class=\"data row6 col0\" >2172</td>\n",
" <td id=\"T_118da_row6_col1\" class=\"data row6 col1\" >False</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_118da_row7_col0\" class=\"data row7 col0\" >2161</td>\n",
" <td id=\"T_118da_row7_col1\" class=\"data row7 col1\" >False</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_118da_row8_col0\" class=\"data row8 col0\" >471</td>\n",
" <td id=\"T_118da_row8_col1\" class=\"data row8 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_118da_row9_col0\" class=\"data row9 col0\" >2014</td>\n",
" <td id=\"T_118da_row9_col1\" class=\"data row9 col1\" >False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xzW5x7ljb2oP"
},
"source": [
"One of the main challenges of survival analysis is **right censoring**, i.e., by the end of the study, the event of interest (for example, in medicine 'death of a patient' or in this dataset 'recurrence of cancer') has only occurred for a subset of the observations.\n",
"\n",
"The **right censoring** in this dataset is given by the column named 'event' and it's a variable which can take value 'True' if the patient had a recurrence of cancer or 'False' if the patient is recurrence free at the indicated time (right-censored samples)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VmfAR7igb2oW"
},
"source": [
"Let's see how many right-censored samples do we have."
]
},
{
"cell_type": "code",
"metadata": {
"id": "rzS8h1GG_o_A",
"outputId": "d7d49914-42cb-4407-cc43-98aea98545f3",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"print(f'Number of samples: {len(df_y)}')\n",
"print(f'Number of right censored samples: {len(df_y.query(\"event == False\"))}')\n",
"print(f'Percentage of right censored samples: {100*len(df_y.query(\"event == False\"))/len(df_y):.1f}%')"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of samples: 686\n",
"Number of right censored samples: 387\n",
"Percentage of right censored samples: 56.4%\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VtsENFsnQhZx"
},
"source": [
"There are 387 patients (56.4%) who were right censored (recurrence free) at the end of the study.\n",
"\n",
"Let's divide our dataset in training and test sets."
]
},
{
"cell_type": "code",
"source": [
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.preprocessing import OrdinalEncoder"
],
"metadata": {
"id": "PV9SQ8LZ20BL"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"X[\"horTh\"] = [1 if X[\"horTh\"].iloc[i] == 'yes' else 0 for i in range(X.shape[0])]"
],
"metadata": {
"id": "MxHGiw0E4-hP"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"source": [
"X[\"menostat\"] = [1 if X[\"menostat\"].iloc[i] == 'Post' else 0 for i in range(X.shape[0])]"
],
"metadata": {
"id": "P_R8Fr4a5JUt"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"X[\"tgrade\"] = OrdinalEncoder(categories=[['I', 'II', 'III']]).fit_transform(X[[\"tgrade\"]])"
],
"metadata": {
"id": "b6ABtXNK3Dd9"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "duYhddUr_1nH",
"outputId": "3789d704-0a2e-42a5-ee2f-f5ad03a7810c",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_trn, X_test, y_trn, y_test = train_test_split(X, y, random_state=20)\n",
"\n",
"print(f'Number of training samples: {len(y_trn)}')\n",
"print(f'Number of test samples: {len(y_test)}')"
],
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of training samples: 514\n",
"Number of test samples: 172\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3VEOV-vWb2ow"
},
"source": [
"We divide the features/covariates into continuous and categorical."
]
},
{
"cell_type": "code",
"source": [
"X.dtypes"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 335
},
"id": "4OqMuC40mXDb",
"outputId": "717dbdb6-b48e-4a6b-9671-03cc1b05b9f2"
},
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"age float64\n",
"estrec float64\n",
"horTh int64\n",
"menostat int64\n",
"pnodes float64\n",
"progrec float64\n",
"tgrade float64\n",
"tsize float64\n",
"dtype: object"
],
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>age</th>\n",
" <td>float64</td>\n",
" </tr>\n",
" <tr>\n",
" <th>estrec</th>\n",
" <td>float64</td>\n",
" </tr>\n",
" <tr>\n",
" <th>horTh</th>\n",
" <td>int64</td>\n",
" </tr>\n",
" <tr>\n",
" <th>menostat</th>\n",
" <td>int64</td>\n",
" </tr>\n",
" <tr>\n",
" <th>pnodes</th>\n",
" <td>float64</td>\n",
" </tr>\n",
" <tr>\n",
" <th>progrec</th>\n",
" <td>float64</td>\n",
" </tr>\n",
" <tr>\n",
" <th>tgrade</th>\n",
" <td>float64</td>\n",
" </tr>\n",
" <tr>\n",
" <th>tsize</th>\n",
" <td>float64</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div><br><label><b>dtype:</b> object</label>"
]
},
"metadata": {},
"execution_count": 16
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDjgk9PWb2o3"
},
"source": [
"We use ordinal encoding for categorical features/covariates and standard scaling for continuous features/covariates."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N_nodxlEb2o-"
},
"source": [
"# Baseline: Cox Proportional Hazards model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KqdChwYJb2o_"
},
"source": [
"Cox Proportional Hazards model assumes that the log-hazard of a subject is a linear function of their $m$ static covariates/features $h_i, i\\in\\{1,\\ldots,m\\}$, and a population-level baseline hazard function $h_0(t)$ that changes over time:\n",
"\\begin{equation}\n",
"h(t|x)=h_0(t)\\exp\\left(\\sum_{i=1}^mh_i(x_i-\\bar{x_i})\\right).\n",
"\\end{equation}\n",
"\n",
"The term *proportional hazards* refers to the assumption of a constant relationship between the dependent variable and the regression coefficients."
]
},
{
"cell_type": "code",
"metadata": {
"id": "77YbwMKvAFHQ",
"outputId": "a60ddcf3-cfca-4d47-b20c-031d48b6e03c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 80
}
},
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sksurv.linear_model import CoxPHSurvivalAnalysis\n",
"from sksurv.metrics import concordance_index_censored\n",
"\n",
"cox = CoxPHSurvivalAnalysis()\n",
"cox.fit(X_trn, y_trn)"
],
"execution_count": 17,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"CoxPHSurvivalAnalysis()"
],
"text/html": [
"<style>#sk-container-id-1 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-1 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-1 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-1 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-1 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-1 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-1 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-1 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>CoxPHSurvivalAnalysis()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;CoxPHSurvivalAnalysis<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>CoxPHSurvivalAnalysis()</pre></div> </div></div></div></div>"
]
},
"metadata": {},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"source": [
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], cox.predict(X_test))\n",
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oVdrGM0NtWdR",
"outputId": "1387a9ea-0d1d-461c-d7d3-5b726c49e09f"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The c-index of Cox is given by 0.665\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Let's now attempt to quantify how a survival curve estimated on a training set performs on a test set.\n",
"\n",
"## Survival model evaluation using the Integrated Brier Score (IBS) and the Concordance Index (C-index)"
],
"metadata": {
"id": "I2wAzNXar551"
}
},
{
"cell_type": "markdown",
"source": [
"The Brier score and the C-index are measures that **assess the quality of a predicted survival curve** on a finite data sample.\n",
"\n",
"- **The Brier score is a proper scoring rule**, meaning that an estimate of the survival curve has minimal Brier score if and only if it matches the true survival probabilities induced by the underlying data generating process. In that respect the **Brier score** assesses both the **calibration** and the **ranking power** of a survival probability estimator.\n",
"\n",
"- On the other hand, the **C-index** only assesses the **ranking power**: it is invariant to a monotonic transform of the survival probabilities. It only focus on the ability of a predictive survival model to identify which individual is likely to fail first out of any pair of two individuals.\n",
"\n",
"\n",
"\n",
"It is comprised between 0 and 1 (lower is better).\n",
"It answers the question \"how close to the real probabilities are our estimates?\"."
],
"metadata": {
"id": "Gap1YWH5sAA1"
}
},
{
"cell_type": "markdown",
"source": [
"<summary>Mathematical formulation</summary>\n",
" \n",
"$$\\mathrm{BS}^c(t) = \\frac{1}{n} \\sum_{i=1}^n I(d_i \\leq t \\land \\delta_i = 1)\n",
" \\frac{(0 - \\hat{S}(t | \\mathbf{x}_i))^2}{\\hat{G}(d_i)} + I(d_i > t)\n",
" \\frac{(1 - \\hat{S}(t | \\mathbf{x}_i))^2}{\\hat{G}(t)}$$\n",
" \n",
"In the survival analysis context, the Brier Score can be seen as the Mean Squared Error (MSE) between our probability $\\hat{S}(t)$ and our target label $\\delta_i \\in {0, 1}$, weighted by the inverse probability of censoring $\\frac{1}{\\hat{G}(t)}$. In practice we estimate $\\hat{G}(t)$ using a variant of the Kaplan-Estimator with swapped event indicator.\n",
"\n",
"- When no event or censoring has happened at $t$ yet, i.e. $I(d_i > t)$, we penalize a low probability of survival with $(1 - \\hat{S}(t|\\mathbf{x}_i))^2$.\n",
"- Conversely, when an individual has experienced an event before $t$, i.e. $I(d_i \\leq t \\land \\delta_i = 1)$, we penalize a high probability of survival with $(0 - \\hat{S}(t|\\mathbf{x}_i))^2$."
],
"metadata": {
"id": "3bmwqNQisHup"
}
},
{
"cell_type": "markdown",
"source": [
"![BrierScore.svg]()"
],
"metadata": {
"id": "CmP1ahaXsy_c"
}
},
{
"cell_type": "code",
"source": [
"times = np.arange(365, 1826)"
],
"metadata": {
"id": "zoUFRn2yqLPF"
},
"execution_count": 45,
"outputs": []
},
{
"cell_type": "code",
"source": [
"survs = cox.predict_survival_function(X_test)"
],
"metadata": {
"id": "06i_zOUBqMUt"
},
"execution_count": 46,
"outputs": []
},
{
"cell_type": "code",
"source": [
"preds = np.asarray([[fn(t) for t in times] for fn in survs])"
],
"metadata": {
"id": "prIyWOtMqe7s"
},
"execution_count": 47,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sksurv.metrics import integrated_brier_score\n",
"\n",
"integrated_brier_score(y_trn, y_test, preds, times)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Qyvn1TexqkoO",
"outputId": "8300b953-9390-465e-9397-cc232bd65c39"
},
"execution_count": 48,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.17509710701666106"
]
},
"metadata": {},
"execution_count": 48
}
]
},
{
"cell_type": "code",
"source": [
"!pip install -q eli5"
],
"metadata": {
"id": "yU-3yjgDr3N7",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2e0d458d-504f-4a1a-ba12-9fd44c68bd9e"
},
"execution_count": 34,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/216.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m122.9/216.2 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m216.2/216.2 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for eli5 (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from eli5.sklearn import PermutationImportance"
],
"metadata": {
"id": "k6mzeqdNr7ZB",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 478
},
"outputId": "16fa075b-d501-48c2-f4e5-c8fafeeba9ad"
},
"execution_count": 35,
"outputs": [
{
"output_type": "error",
"ename": "ImportError",
"evalue": "cannot import name 'if_delegate_has_method' from 'sklearn.utils.metaestimators' (/usr/local/lib/python3.10/dist-packages/sklearn/utils/metaestimators.py)",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-35-bd1ee3945164>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0meli5\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msklearn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPermutationImportance\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/eli5/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m )\n\u001b[1;32m 12\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mexplain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mexplain_weights\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexplain_prediction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0msklearn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mexplain_weights_sklearn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexplain_prediction_sklearn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtransform_feature_names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/eli5/sklearn/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# -*- coding: utf-8 -*-\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0m__future__\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mabsolute_import\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m from .explain_weights import (\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mexplain_weights_sklearn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mexplain_linear_classifier_weights\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/eli5/sklearn/explain_weights.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0mget_feature_importance_explanation\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m )\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mpermutation_importance\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPermutationImportance\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/eli5/sklearn/permutation_importance.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_selection\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcheck_cv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetaestimators\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mif_delegate_has_method\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcheck_random_state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m from sklearn.base import (\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name 'if_delegate_has_method' from 'sklearn.utils.metaestimators' (/usr/local/lib/python3.10/dist-packages/sklearn/utils/metaestimators.py)",
"",
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
],
"errorDetails": {
"actions": [
{
"action": "open_url",
"actionText": "Open Examples",
"url": "/notebooks/snippets/importing_libraries.ipynb"
}
]
}
}
]
},
{
"cell_type": "code",
"source": [
"perm = PermutationImportance(\n",
" cox, n_iter=100, random_state=42).fit(X_trn,y_trn)"
],
"metadata": {
"id": "WYW6dDDvsBNe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"data = perm.results_\n",
"data = pd.DataFrame(data, columns=X_trn.columns)\n",
"meds = data.median()\n",
"meds = meds.sort_values(ascending=False)\n",
"data = data[meds.index]"
],
"metadata": {
"id": "n5aKc--PsMHm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"fig, ax = plt.subplots(figsize=(10,7))\n",
"data.boxplot(ax=ax)\n",
"ax.set_title('Feature Importances')\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 622
},
"id": "oBbnA1ZMsQeY",
"outputId": "40006cd2-965b-4c1b-8f84-1f1f9b88ce90"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x700 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"!pip install -q shap\n",
"import shap\n",
"explainer = shap.Explainer(cox.predict, X_trn)\n",
"shap_values = explainer(X_test[:100])\n",
"shap.plots.waterfall(shap_values[0])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 540
},
"id": "J03T1z80N8fg",
"outputId": "965e2a52-46dd-471c-84c7-33a7ade8f894"
},
"execution_count": 36,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 800x550 with 3 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"X.describe().transpose().round(2).drop(columns=\"count\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 300
},
"id": "-s0v-Yx5N_3o",
"outputId": "c8a734c0-6b6a-472e-a7d1-c38e84cbe06c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" mean std min 25% 50% 75% max\n",
"age 53.05 10.12 21.0 46.0 53.0 61.00 80.0\n",
"estrec 96.25 153.08 0.0 8.0 36.0 114.00 1144.0\n",
"horTh 0.36 0.48 0.0 0.0 0.0 1.00 1.0\n",
"menostat 0.58 0.49 0.0 0.0 1.0 1.00 1.0\n",
"pnodes 5.01 5.48 1.0 1.0 3.0 7.00 51.0\n",
"progrec 110.00 202.33 0.0 7.0 32.5 131.75 2380.0\n",
"tgrade 1.12 0.58 0.0 1.0 1.0 1.00 2.0\n",
"tsize 29.33 14.30 3.0 20.0 25.0 35.00 120.0"
],
"text/html": [
"\n",
" <div id=\"df-7378cec5-644d-4560-a5ae-03cbfae13f92\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean</th>\n",
" <th>std</th>\n",
" <th>min</th>\n",
" <th>25%</th>\n",
" <th>50%</th>\n",
" <th>75%</th>\n",
" <th>max</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>age</th>\n",
" <td>53.05</td>\n",
" <td>10.12</td>\n",
" <td>21.0</td>\n",
" <td>46.0</td>\n",
" <td>53.0</td>\n",
" <td>61.00</td>\n",
" <td>80.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>estrec</th>\n",
" <td>96.25</td>\n",
" <td>153.08</td>\n",
" <td>0.0</td>\n",
" <td>8.0</td>\n",
" <td>36.0</td>\n",
" <td>114.00</td>\n",
" <td>1144.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>horTh</th>\n",
" <td>0.36</td>\n",
" <td>0.48</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.00</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>menostat</th>\n",
" <td>0.58</td>\n",
" <td>0.49</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.00</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>pnodes</th>\n",
" <td>5.01</td>\n",
" <td>5.48</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>3.0</td>\n",
" <td>7.00</td>\n",
" <td>51.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>progrec</th>\n",
" <td>110.00</td>\n",
" <td>202.33</td>\n",
" <td>0.0</td>\n",
" <td>7.0</td>\n",
" <td>32.5</td>\n",
" <td>131.75</td>\n",
" <td>2380.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>tgrade</th>\n",
" <td>1.12</td>\n",
" <td>0.58</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.00</td>\n",
" <td>2.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>tsize</th>\n",
" <td>29.33</td>\n",
" <td>14.30</td>\n",
" <td>3.0</td>\n",
" <td>20.0</td>\n",
" <td>25.0</td>\n",
" <td>35.00</td>\n",
" <td>120.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
"\n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-7378cec5-644d-4560-a5ae-03cbfae13f92')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-7378cec5-644d-4560-a5ae-03cbfae13f92 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-7378cec5-644d-4560-a5ae-03cbfae13f92');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
"\n",
"\n",
"<div id=\"df-9bc6fdae-b63e-4b3e-8a71-b5762d0a65e4\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-9bc6fdae-b63e-4b3e-8a71-b5762d0a65e4')\"\n",
" title=\"Suggest charts.\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" --bg-color: #E8F0FE;\n",
" --fill-color: #1967D2;\n",
" --hover-bg-color: #E2EBFA;\n",
" --hover-fill-color: #174EA6;\n",
" --disabled-fill-color: #AAA;\n",
" --disabled-bg-color: #DDD;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" --bg-color: #3B4455;\n",
" --fill-color: #D2E3FC;\n",
" --hover-bg-color: #434B5C;\n",
" --hover-fill-color: #FFFFFF;\n",
" --disabled-bg-color: #3B4455;\n",
" --disabled-fill-color: #666;\n",
" }\n",
"\n",
" .colab-df-quickchart {\n",
" background-color: var(--bg-color);\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: var(--fill-color);\n",
" height: 32px;\n",
" padding: 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: var(--hover-bg-color);\n",
" box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: var(--button-hover-fill-color);\n",
" }\n",
"\n",
" .colab-df-quickchart-complete:disabled,\n",
" .colab-df-quickchart-complete:disabled:hover {\n",
" background-color: var(--disabled-bg-color);\n",
" fill: var(--disabled-fill-color);\n",
" box-shadow: none;\n",
" }\n",
"\n",
" .colab-df-spinner {\n",
" border: 2px solid var(--fill-color);\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" animation:\n",
" spin 1s steps(1) infinite;\n",
" }\n",
"\n",
" @keyframes spin {\n",
" 0% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" border-left-color: var(--fill-color);\n",
" }\n",
" 20% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 30% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 40% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 60% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 80% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" 90% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const quickchartButtonEl =\n",
" document.querySelector('#' + key + ' button');\n",
" quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
" quickchartButtonEl.classList.add('colab-df-spinner');\n",
" try {\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" } catch (error) {\n",
" console.error('Error during call to suggestCharts:', error);\n",
" }\n",
" quickchartButtonEl.classList.remove('colab-df-spinner');\n",
" quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
" }\n",
" (() => {\n",
" let quickchartButtonEl =\n",
" document.querySelector('#df-9bc6fdae-b63e-4b3e-8a71-b5762d0a65e4 button');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
" })();\n",
" </script>\n",
"</div>\n",
" </div>\n",
" </div>\n"
]
},
"metadata": {},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"source": [
"shap.plots.beeswarm(shap_values)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 453
},
"id": "0kbYd1wAODT9",
"outputId": "212ab8e2-adc6-45bf-e75e-c78d3c2eec7a"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 800x470 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZNHQ-bkWALNy",
"outputId": "16d9b6fb-f5dd-4c20-8896-6171641412b0",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"from scipy.stats import reciprocal\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"\n",
"param_distributions = {\n",
" 'alpha': reciprocal(0.1, 100),\n",
"}\n",
"\n",
"model_random_search = RandomizedSearchCV(\n",
" cox, param_distributions=param_distributions, n_iter=50, n_jobs=-1, cv=3, random_state=42)\n",
"model_random_search.fit(X_trn, y_trn)\n",
"\n",
"print(\n",
" f\"The c-index of Cox using a {model_random_search.__class__.__name__} is \"\n",
" f\"{model_random_search.score(X_test, y_test):.3f}\")\n",
"print(\n",
" f\"The best set of parameters is: {model_random_search.best_params_}\"\n",
")"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The c-index of Cox using a RandomizedSearchCV is 0.660\n",
"The best set of parameters is: {'alpha': 39.67605077052987}\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "MrhAveCQAdxF",
"outputId": "7d7474e5-2f54-425a-84ec-5b7f2b701f0e",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"alpha = model_random_search.best_params_['alpha']\n",
"cox_best = make_pipeline(CoxPHSurvivalAnalysis(alpha=alpha))\n",
"cox_best.fit(X_trn, y_trn)\n",
"\n",
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], cox_best.predict(X_test))\n",
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The c-index of Cox is given by 0.660\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sksurv.ensemble import RandomSurvivalForest"
],
"metadata": {
"id": "s34_kkDKKzw5"
},
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"source": [
"rsf = RandomSurvivalForest(\n",
" n_estimators=100, min_samples_leaf=15, n_jobs=-1, random_state=20\n",
")\n",
"rsf.fit(X_trn, y_trn)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 80
},
"id": "vU06jwE7Kuec",
"outputId": "0c61ad2e-539b-4b21-ad9d-881ce7042b9d"
},
"execution_count": 38,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomSurvivalForest(min_samples_leaf=15, n_jobs=-1, random_state=20)"
],
"text/html": [
"<style>#sk-container-id-2 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-2 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-2 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-2 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-2 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-2 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-2 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-2 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-2 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-2 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-2 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-2 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-2 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-2 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-2 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomSurvivalForest(min_samples_leaf=15, n_jobs=-1, random_state=20)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;RandomSurvivalForest<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>RandomSurvivalForest(min_samples_leaf=15, n_jobs=-1, random_state=20)</pre></div> </div></div></div></div>"
]
},
"metadata": {},
"execution_count": 38
}
]
},
{
"cell_type": "code",
"source": [
"rsf.score(X_test, y_test)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LJ0Tp63pKvMJ",
"outputId": "64980f97-2284-4bfd-bb36-444e2aa9c971"
},
"execution_count": 39,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.6785901467505241"
]
},
"metadata": {},
"execution_count": 39
}
]
},
{
"cell_type": "code",
"source": [
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], rsf.predict(X_test))\n",
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Y1hNGO_QO2pW",
"outputId": "dea05184-dfaa-43d0-e2a9-31667eaeda33"
},
"execution_count": 40,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The c-index of Cox is given by 0.679\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"times = np.arange(365, 1826)"
],
"metadata": {
"id": "XHsIhuW7PHik"
},
"execution_count": 41,
"outputs": []
},
{
"cell_type": "code",
"source": [
"survs = rsf.predict_survival_function(X_test)"
],
"metadata": {
"id": "iSZ0pa17PSJS"
},
"execution_count": 42,
"outputs": []
},
{
"cell_type": "code",
"source": [
"preds = np.asarray([[fn(t) for t in times] for fn in survs])"
],
"metadata": {
"id": "np-god8mPW5W"
},
"execution_count": 43,
"outputs": []
},
{
"cell_type": "code",
"source": [
"integrated_brier_score(y_trn, y_test, preds, times)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "U1avp5nfPaOS",
"outputId": "396c114b-abfc-460e-d9b5-91ab544e31d3"
},
"execution_count": 44,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.17664251295573888"
]
},
"metadata": {},
"execution_count": 44
}
]
},
{
"cell_type": "code",
"source": [
"param_distributions = {\n",
" 'min_samples_leaf': [3, 7, 15],\n",
" 'max_depth': [3, 7, None]\n",
"}\n",
"\n",
"model_random_search = RandomizedSearchCV(\n",
" rsf, param_distributions=param_distributions, n_iter=50, n_jobs=-1, cv=3, random_state=42)\n",
"model_random_search.fit(X_trn, y_trn)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 134
},
"id": "QjiQ8TfBPfZg",
"outputId": "b9697241-63cd-448a-f779-024bea222c72"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"The total space of parameters 9 is smaller than n_iter=50. Running 9 iterations. For exhaustive searches, use GridSearchCV.\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomizedSearchCV(cv=3,\n",
" estimator=RandomSurvivalForest(min_samples_leaf=15,\n",
" n_jobs=-1, random_state=20),\n",
" n_iter=50, n_jobs=-1,\n",
" param_distributions={'max_depth': [3, 7, None],\n",
" 'min_samples_leaf': [3, 7, 15]},\n",
" random_state=42)"
],
"text/html": [
"<style>#sk-container-id-3 {color: black;background-color: white;}#sk-container-id-3 pre{padding: 0;}#sk-container-id-3 div.sk-toggleable {background-color: white;}#sk-container-id-3 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-3 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-3 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-3 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-3 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-3 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-3 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-3 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-3 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-3 div.sk-item {position: relative;z-index: 1;}#sk-container-id-3 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-3 div.sk-item::before, #sk-container-id-3 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-3 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-3 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-3 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-3 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-3 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-3 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-3 div.sk-label-container {text-align: center;}#sk-container-id-3 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-3 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomizedSearchCV(cv=3,\n",
" estimator=RandomSurvivalForest(min_samples_leaf=15,\n",
" n_jobs=-1, random_state=20),\n",
" n_iter=50, n_jobs=-1,\n",
" param_distributions={&#x27;max_depth&#x27;: [3, 7, None],\n",
" &#x27;min_samples_leaf&#x27;: [3, 7, 15]},\n",
" random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomizedSearchCV</label><div class=\"sk-toggleable__content\"><pre>RandomizedSearchCV(cv=3,\n",
" estimator=RandomSurvivalForest(min_samples_leaf=15,\n",
" n_jobs=-1, random_state=20),\n",
" n_iter=50, n_jobs=-1,\n",
" param_distributions={&#x27;max_depth&#x27;: [3, 7, None],\n",
" &#x27;min_samples_leaf&#x27;: [3, 7, 15]},\n",
" random_state=42)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: RandomSurvivalForest</label><div class=\"sk-toggleable__content\"><pre>RandomSurvivalForest(min_samples_leaf=15, n_jobs=-1, random_state=20)</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomSurvivalForest</label><div class=\"sk-toggleable__content\"><pre>RandomSurvivalForest(min_samples_leaf=15, n_jobs=-1, random_state=20)</pre></div></div></div></div></div></div></div></div></div></div>"
]
},
"metadata": {},
"execution_count": 38
}
]
},
{
"cell_type": "code",
"source": [
"model_random_search.score(X_test, y_test)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_X2klnTmQVXX",
"outputId": "48b05237-4a2f-4757-bdc2-b0c975417845"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.6805555555555556"
]
},
"metadata": {},
"execution_count": 39
}
]
},
{
"cell_type": "code",
"source": [
"print(\n",
" f\"The best set of parameters is: {model_random_search.best_params_}\"\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6Mam6JN2Qqh3",
"outputId": "30e8221a-6cef-46cd-8b0e-ef3faf08e9c1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The best set of parameters is: {'min_samples_leaf': 15, 'max_depth': 3}\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"rsf_best = RandomSurvivalForest(\n",
" n_estimators=100, min_samples_leaf=15, max_depth=3, n_jobs=-1, random_state=20\n",
")\n",
"rsf_best.fit(X_trn, y_trn)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 91
},
"id": "PbSNPjgQQruO",
"outputId": "41a76e72-7928-4a0a-8186-3966912f9582"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomSurvivalForest(max_depth=3, min_samples_leaf=15, n_jobs=-1,\n",
" random_state=20)"
],
"text/html": [
"<style>#sk-container-id-4 {color: black;background-color: white;}#sk-container-id-4 pre{padding: 0;}#sk-container-id-4 div.sk-toggleable {background-color: white;}#sk-container-id-4 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-4 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-4 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-4 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-4 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-4 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-4 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-4 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-4 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-4 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-4 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-4 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-4 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-4 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-4 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-4 div.sk-item {position: relative;z-index: 1;}#sk-container-id-4 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-4 div.sk-item::before, #sk-container-id-4 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-4 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-4 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-4 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-4 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-4 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-4 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-4 div.sk-label-container {text-align: center;}#sk-container-id-4 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-4 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-4\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomSurvivalForest(max_depth=3, min_samples_leaf=15, n_jobs=-1,\n",
" random_state=20)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" checked><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomSurvivalForest</label><div class=\"sk-toggleable__content\"><pre>RandomSurvivalForest(max_depth=3, min_samples_leaf=15, n_jobs=-1,\n",
" random_state=20)</pre></div></div></div></div></div>"
]
},
"metadata": {},
"execution_count": 41
}
]
},
{
"cell_type": "code",
"source": [
"times = np.arange(365, 1826)\n",
"survs = rsf_best.predict_survival_function(X_test)"
],
"metadata": {
"id": "iO5hU-exQZup"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"preds = np.asarray([[fn(t) for t in times] for fn in survs])"
],
"metadata": {
"id": "nVQsQOocQhOI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"integrated_brier_score(y_trn, y_test, preds, times)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rYRRmruYRCgq",
"outputId": "789050d7-318f-4157-ba2c-6e911751659d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.17850725604478493"
]
},
"metadata": {},
"execution_count": 44
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment