Created
October 16, 2023 12:32
-
-
Save alonsosilvaallende/0371023c5a09eb07b2b8825b3ab99057 to your computer and use it in GitHub Desktop.
Copy of Cox_PH_and_RSF-colab.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.2" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/0371023c5a09eb07b2b8825b3ab99057/copy-of-cox_ph_and_rsf-colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UN3PoUTSb2nT" | |
}, | |
"source": [ | |
"The objective of this notebook is to compare different models to estimate the survival probability given a set of features/covariables.\n", | |
"\n", | |
">[\"Experimental Comparison of Semi-parametric, Parametric, and Machine Learning Models for Time-to-Event Analysis Through the Concordance Index,\"](https://arxiv.org/abs/2003.08820)\n", | |
"Camila Fernandez, Chung Shue Chen, Pierre Gaillard, Alonso Silva\n", | |
"\n", | |
"To perform this analysis we will use [scikit-learn](https://scikit-learn.org/) and [scikit-survival](https://pypi.org/project/scikit-survival/). Finally, we will use [eli5](https://eli5.readthedocs.io/en/latest/index.html) to study feature importances (computed with permutation importance)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pzqGFx_Bb_6A" | |
}, | |
"source": [ | |
"!pip install -q scikit-survival" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VoQVEI5p_rga" | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import matplotlib.pyplot as plt" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SnT6e_JPb2ns" | |
}, | |
"source": [ | |
"We first download a dataset from scikit-survival." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "D0xxNWzI-N3j" | |
}, | |
"source": [ | |
"from sksurv.datasets import load_gbsg2\n", | |
"\n", | |
"X, y = load_gbsg2()" | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "IBTo4q_Hb2n0" | |
}, | |
"source": [ | |
"## An example: German Breast Cancer Study Group 2 (gbcsg)\n", | |
"\n", | |
"This dataset contains the following 8 features/covariables:\n", | |
"\n", | |
"- age: age (in years),\n", | |
"- estrec: estrogen receptor (in fmol),\n", | |
"- horTh: hormonal therapy (yes or no),\n", | |
"- menostat: menopausal status (premenopausal or postmenopausal),\n", | |
"- pnodes: number of positive nodes,\n", | |
"- progrec: progesterone receptor (in fmol),\n", | |
"- tgrade: tumor grade (I < II < III),\n", | |
"- tsize: tumor size (in mm).\n", | |
"\n", | |
"and the two outputs:\n", | |
"\n", | |
"- recurrence free time (in days),\n", | |
"- censoring indicator (0 - censored, 1 - event).\n", | |
"\n", | |
"The dataset has 686 samples and 8 features/covariables.\n", | |
"\n", | |
"\n", | |
"**References**\n", | |
"\n", | |
"M. Schumacher, G. Basert, H. Bojar, K. Huebner, M. Olschewski, W. Sauerbrei, C. Schmoor, C. Beyerle, R.L.A. Neumann and H.F. Rauschecker for the German Breast Cancer Study Group (1994), [Randomized 2 x 2 trial evaluating hormonal treatment and the duration of chemotherapy in node-positive breast cancer patients](https://www.ncbi.nlm.nih.gov/pubmed/7931478). Journal of Clinical Oncology, 12, 2086–2093." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "kAsZ72YYb2n3" | |
}, | |
"source": [ | |
"Let's take a look at the features/covariates." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_OlmlI6g-X43", | |
"outputId": "1eca9943-4e79-4008-ee6e-1b4bbd70e306", | |
"scrolled": true, | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 362 | |
} | |
}, | |
"source": [ | |
"cols = [\"age\", \"estrec\", \"pnodes\", \"progrec\", \"tsize\"]\n", | |
"formatdict = {}\n", | |
"for col in cols: formatdict[col] = \"{:,.0f}\"\n", | |
"X.head(10).style.hide(axis=\"index\").format(formatdict)" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7884921ab910>" | |
], | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"</style>\n", | |
"<table id=\"T_3e856\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th id=\"T_3e856_level0_col0\" class=\"col_heading level0 col0\" >age</th>\n", | |
" <th id=\"T_3e856_level0_col1\" class=\"col_heading level0 col1\" >estrec</th>\n", | |
" <th id=\"T_3e856_level0_col2\" class=\"col_heading level0 col2\" >horTh</th>\n", | |
" <th id=\"T_3e856_level0_col3\" class=\"col_heading level0 col3\" >menostat</th>\n", | |
" <th id=\"T_3e856_level0_col4\" class=\"col_heading level0 col4\" >pnodes</th>\n", | |
" <th id=\"T_3e856_level0_col5\" class=\"col_heading level0 col5\" >progrec</th>\n", | |
" <th id=\"T_3e856_level0_col6\" class=\"col_heading level0 col6\" >tgrade</th>\n", | |
" <th id=\"T_3e856_level0_col7\" class=\"col_heading level0 col7\" >tsize</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row0_col0\" class=\"data row0 col0\" >70</td>\n", | |
" <td id=\"T_3e856_row0_col1\" class=\"data row0 col1\" >66</td>\n", | |
" <td id=\"T_3e856_row0_col2\" class=\"data row0 col2\" >no</td>\n", | |
" <td id=\"T_3e856_row0_col3\" class=\"data row0 col3\" >Post</td>\n", | |
" <td id=\"T_3e856_row0_col4\" class=\"data row0 col4\" >3</td>\n", | |
" <td id=\"T_3e856_row0_col5\" class=\"data row0 col5\" >48</td>\n", | |
" <td id=\"T_3e856_row0_col6\" class=\"data row0 col6\" >II</td>\n", | |
" <td id=\"T_3e856_row0_col7\" class=\"data row0 col7\" >21</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row1_col0\" class=\"data row1 col0\" >56</td>\n", | |
" <td id=\"T_3e856_row1_col1\" class=\"data row1 col1\" >77</td>\n", | |
" <td id=\"T_3e856_row1_col2\" class=\"data row1 col2\" >yes</td>\n", | |
" <td id=\"T_3e856_row1_col3\" class=\"data row1 col3\" >Post</td>\n", | |
" <td id=\"T_3e856_row1_col4\" class=\"data row1 col4\" >7</td>\n", | |
" <td id=\"T_3e856_row1_col5\" class=\"data row1 col5\" >61</td>\n", | |
" <td id=\"T_3e856_row1_col6\" class=\"data row1 col6\" >II</td>\n", | |
" <td id=\"T_3e856_row1_col7\" class=\"data row1 col7\" >12</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row2_col0\" class=\"data row2 col0\" >58</td>\n", | |
" <td id=\"T_3e856_row2_col1\" class=\"data row2 col1\" >271</td>\n", | |
" <td id=\"T_3e856_row2_col2\" class=\"data row2 col2\" >yes</td>\n", | |
" <td id=\"T_3e856_row2_col3\" class=\"data row2 col3\" >Post</td>\n", | |
" <td id=\"T_3e856_row2_col4\" class=\"data row2 col4\" >9</td>\n", | |
" <td id=\"T_3e856_row2_col5\" class=\"data row2 col5\" >52</td>\n", | |
" <td id=\"T_3e856_row2_col6\" class=\"data row2 col6\" >II</td>\n", | |
" <td id=\"T_3e856_row2_col7\" class=\"data row2 col7\" >35</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row3_col0\" class=\"data row3 col0\" >59</td>\n", | |
" <td id=\"T_3e856_row3_col1\" class=\"data row3 col1\" >29</td>\n", | |
" <td id=\"T_3e856_row3_col2\" class=\"data row3 col2\" >yes</td>\n", | |
" <td id=\"T_3e856_row3_col3\" class=\"data row3 col3\" >Post</td>\n", | |
" <td id=\"T_3e856_row3_col4\" class=\"data row3 col4\" >4</td>\n", | |
" <td id=\"T_3e856_row3_col5\" class=\"data row3 col5\" >60</td>\n", | |
" <td id=\"T_3e856_row3_col6\" class=\"data row3 col6\" >II</td>\n", | |
" <td id=\"T_3e856_row3_col7\" class=\"data row3 col7\" >17</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row4_col0\" class=\"data row4 col0\" >73</td>\n", | |
" <td id=\"T_3e856_row4_col1\" class=\"data row4 col1\" >65</td>\n", | |
" <td id=\"T_3e856_row4_col2\" class=\"data row4 col2\" >no</td>\n", | |
" <td id=\"T_3e856_row4_col3\" class=\"data row4 col3\" >Post</td>\n", | |
" <td id=\"T_3e856_row4_col4\" class=\"data row4 col4\" >1</td>\n", | |
" <td id=\"T_3e856_row4_col5\" class=\"data row4 col5\" >26</td>\n", | |
" <td id=\"T_3e856_row4_col6\" class=\"data row4 col6\" >II</td>\n", | |
" <td id=\"T_3e856_row4_col7\" class=\"data row4 col7\" >35</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row5_col0\" class=\"data row5 col0\" >32</td>\n", | |
" <td id=\"T_3e856_row5_col1\" class=\"data row5 col1\" >13</td>\n", | |
" <td id=\"T_3e856_row5_col2\" class=\"data row5 col2\" >no</td>\n", | |
" <td id=\"T_3e856_row5_col3\" class=\"data row5 col3\" >Pre</td>\n", | |
" <td id=\"T_3e856_row5_col4\" class=\"data row5 col4\" >24</td>\n", | |
" <td id=\"T_3e856_row5_col5\" class=\"data row5 col5\" >0</td>\n", | |
" <td id=\"T_3e856_row5_col6\" class=\"data row5 col6\" >III</td>\n", | |
" <td id=\"T_3e856_row5_col7\" class=\"data row5 col7\" >57</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row6_col0\" class=\"data row6 col0\" >59</td>\n", | |
" <td id=\"T_3e856_row6_col1\" class=\"data row6 col1\" >0</td>\n", | |
" <td id=\"T_3e856_row6_col2\" class=\"data row6 col2\" >yes</td>\n", | |
" <td id=\"T_3e856_row6_col3\" class=\"data row6 col3\" >Post</td>\n", | |
" <td id=\"T_3e856_row6_col4\" class=\"data row6 col4\" >2</td>\n", | |
" <td id=\"T_3e856_row6_col5\" class=\"data row6 col5\" >181</td>\n", | |
" <td id=\"T_3e856_row6_col6\" class=\"data row6 col6\" >II</td>\n", | |
" <td id=\"T_3e856_row6_col7\" class=\"data row6 col7\" >8</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row7_col0\" class=\"data row7 col0\" >65</td>\n", | |
" <td id=\"T_3e856_row7_col1\" class=\"data row7 col1\" >25</td>\n", | |
" <td id=\"T_3e856_row7_col2\" class=\"data row7 col2\" >no</td>\n", | |
" <td id=\"T_3e856_row7_col3\" class=\"data row7 col3\" >Post</td>\n", | |
" <td id=\"T_3e856_row7_col4\" class=\"data row7 col4\" >1</td>\n", | |
" <td id=\"T_3e856_row7_col5\" class=\"data row7 col5\" >192</td>\n", | |
" <td id=\"T_3e856_row7_col6\" class=\"data row7 col6\" >II</td>\n", | |
" <td id=\"T_3e856_row7_col7\" class=\"data row7 col7\" >16</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row8_col0\" class=\"data row8 col0\" >80</td>\n", | |
" <td id=\"T_3e856_row8_col1\" class=\"data row8 col1\" >59</td>\n", | |
" <td id=\"T_3e856_row8_col2\" class=\"data row8 col2\" >no</td>\n", | |
" <td id=\"T_3e856_row8_col3\" class=\"data row8 col3\" >Post</td>\n", | |
" <td id=\"T_3e856_row8_col4\" class=\"data row8 col4\" >30</td>\n", | |
" <td id=\"T_3e856_row8_col5\" class=\"data row8 col5\" >0</td>\n", | |
" <td id=\"T_3e856_row8_col6\" class=\"data row8 col6\" >II</td>\n", | |
" <td id=\"T_3e856_row8_col7\" class=\"data row8 col7\" >39</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_3e856_row9_col0\" class=\"data row9 col0\" >66</td>\n", | |
" <td id=\"T_3e856_row9_col1\" class=\"data row9 col1\" >3</td>\n", | |
" <td id=\"T_3e856_row9_col2\" class=\"data row9 col2\" >no</td>\n", | |
" <td id=\"T_3e856_row9_col3\" class=\"data row9 col3\" >Post</td>\n", | |
" <td id=\"T_3e856_row9_col4\" class=\"data row9 col4\" >7</td>\n", | |
" <td id=\"T_3e856_row9_col5\" class=\"data row9 col5\" >0</td>\n", | |
" <td id=\"T_3e856_row9_col6\" class=\"data row9 col6\" >II</td>\n", | |
" <td id=\"T_3e856_row9_col7\" class=\"data row9 col7\" >18</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zcMjHeAEb2n_" | |
}, | |
"source": [ | |
"Let's take a look at the output." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "h8ltRTa4_WOn", | |
"outputId": "6cc7e0ec-ff7d-43fd-df6e-78c2d1037257", | |
"scrolled": true, | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"y[:10]" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([( True, 1814.), ( True, 2018.), ( True, 712.), ( True, 1807.),\n", | |
" ( True, 772.), ( True, 448.), (False, 2172.), (False, 2161.),\n", | |
" ( True, 471.), (False, 2014.)],\n", | |
" dtype=[('cens', '?'), ('time', '<f8')])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "A5NXlNYob2oH" | |
}, | |
"source": [ | |
"For the output, scikit-survival uses a numpy nd array, so to show it we do a dataframe." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"df_y = pd.DataFrame(data={'time': y['time'].astype(int), 'event': y['cens']})\n", | |
"df_y[:10].style.hide(axis=\"index\").highlight_min('event', color='lightgreen')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 362 | |
}, | |
"id": "AfPvZcjJ-GwQ", | |
"outputId": "f5831c4a-7152-4268-ac31-4aeb2c9c71ac" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x78845c296230>" | |
], | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"#T_5ca91_row6_col1, #T_5ca91_row7_col1, #T_5ca91_row9_col1 {\n", | |
" background-color: lightgreen;\n", | |
"}\n", | |
"</style>\n", | |
"<table id=\"T_5ca91\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th id=\"T_5ca91_level0_col0\" class=\"col_heading level0 col0\" >time</th>\n", | |
" <th id=\"T_5ca91_level0_col1\" class=\"col_heading level0 col1\" >event</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row0_col0\" class=\"data row0 col0\" >1814</td>\n", | |
" <td id=\"T_5ca91_row0_col1\" class=\"data row0 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row1_col0\" class=\"data row1 col0\" >2018</td>\n", | |
" <td id=\"T_5ca91_row1_col1\" class=\"data row1 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row2_col0\" class=\"data row2 col0\" >712</td>\n", | |
" <td id=\"T_5ca91_row2_col1\" class=\"data row2 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row3_col0\" class=\"data row3 col0\" >1807</td>\n", | |
" <td id=\"T_5ca91_row3_col1\" class=\"data row3 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row4_col0\" class=\"data row4 col0\" >772</td>\n", | |
" <td id=\"T_5ca91_row4_col1\" class=\"data row4 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row5_col0\" class=\"data row5 col0\" >448</td>\n", | |
" <td id=\"T_5ca91_row5_col1\" class=\"data row5 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row6_col0\" class=\"data row6 col0\" >2172</td>\n", | |
" <td id=\"T_5ca91_row6_col1\" class=\"data row6 col1\" >False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row7_col0\" class=\"data row7 col0\" >2161</td>\n", | |
" <td id=\"T_5ca91_row7_col1\" class=\"data row7 col1\" >False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row8_col0\" class=\"data row8 col0\" >471</td>\n", | |
" <td id=\"T_5ca91_row8_col1\" class=\"data row8 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_5ca91_row9_col0\" class=\"data row9 col0\" >2014</td>\n", | |
" <td id=\"T_5ca91_row9_col1\" class=\"data row9 col1\" >False</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "xzW5x7ljb2oP" | |
}, | |
"source": [ | |
"One of the main challenges of survival analysis is **right censoring**, i.e., by the end of the study, the event of interest (for example, in medicine 'death of a patient' or in this dataset 'recurrence of cancer') has only occurred for a subset of the observations.\n", | |
"\n", | |
"The **right censoring** in this dataset is given by the column named 'event' and it's a variable which can take value 'True' if the patient had a recurrence of cancer or 'False' if the patient is recurrence free at the indicated time (right-censored samples)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "VmfAR7igb2oW" | |
}, | |
"source": [ | |
"Let's see how many right-censored samples do we have." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rzS8h1GG_o_A", | |
"outputId": "7b9af53e-91f8-48f0-93c1-7ee71dea4bb3", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"print(f'Number of samples: {len(df_y)}')\n", | |
"print(f'Number of right censored samples: {len(df_y.query(\"event == False\"))}')\n", | |
"print(f'Percentage of right censored samples: {100*len(df_y.query(\"event == False\"))/len(df_y):.1f}%')" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Number of samples: 686\n", | |
"Number of right censored samples: 387\n", | |
"Percentage of right censored samples: 56.4%\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "VtsENFsnQhZx" | |
}, | |
"source": [ | |
"There are 387 patients (56.4%) who were right censored (recurrence free) at the end of the study.\n", | |
"\n", | |
"Let's divide our dataset in training and test sets." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sklearn.preprocessing import OneHotEncoder\n", | |
"from sklearn.preprocessing import OrdinalEncoder" | |
], | |
"metadata": { | |
"id": "PV9SQ8LZ20BL" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X[\"horTh\"] = [1 if X[\"horTh\"].iloc[i] == 'yes' else 0 for i in range(X.shape[0])]" | |
], | |
"metadata": { | |
"id": "MxHGiw0E4-hP" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X[\"menostat\"] = [1 if X[\"menostat\"].iloc[i] == 'Post' else 0 for i in range(X.shape[0])]" | |
], | |
"metadata": { | |
"id": "P_R8Fr4a5JUt" | |
}, | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X[\"tgrade\"] = OrdinalEncoder(categories=[['I', 'II', 'III']]).fit_transform(X[[\"tgrade\"]])" | |
], | |
"metadata": { | |
"id": "b6ABtXNK3Dd9" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "duYhddUr_1nH", | |
"outputId": "8e91b087-2c84-41e8-bedd-80fc729a574f", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"\n", | |
"X_trn, X_test, y_trn, y_test = train_test_split(X, y, random_state=20)\n", | |
"\n", | |
"print(f'Number of training samples: {len(y_trn)}')\n", | |
"print(f'Number of test samples: {len(y_test)}')" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Number of training samples: 514\n", | |
"Number of test samples: 172\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "3VEOV-vWb2ow" | |
}, | |
"source": [ | |
"We divide the features/covariates into continuous and categorical." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X.dtypes" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "4OqMuC40mXDb", | |
"outputId": "ec0c6469-75e7-4262-e618-d327768bb17f" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"age float64\n", | |
"estrec float64\n", | |
"horTh int64\n", | |
"menostat int64\n", | |
"pnodes float64\n", | |
"progrec float64\n", | |
"tgrade float64\n", | |
"tsize float64\n", | |
"dtype: object" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 13 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "jDjgk9PWb2o3" | |
}, | |
"source": [ | |
"We use ordinal encoding for categorical features/covariates and standard scaling for continuous features/covariates." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "N_nodxlEb2o-" | |
}, | |
"source": [ | |
"# Baseline: Cox Proportional Hazards model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "KqdChwYJb2o_" | |
}, | |
"source": [ | |
"Cox Proportional Hazards model assumes that the log-hazard of a subject is a linear function of their $m$ static covariates/features $h_i, i\\in\\{1,\\ldots,m\\}$, and a population-level baseline hazard function $h_0(t)$ that changes over time:\n", | |
"\\begin{equation}\n", | |
"h(t|x)=h_0(t)\\exp\\left(\\sum_{i=1}^mh_i(x_i-\\bar{x_i})\\right).\n", | |
"\\end{equation}\n", | |
"\n", | |
"The term *proportional hazards* refers to the assumption of a constant relationship between the dependent variable and the regression coefficients." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "77YbwMKvAFHQ", | |
"outputId": "dbeb5e0a-6759-4792-91a9-a3a536a59699", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 74 | |
} | |
}, | |
"source": [ | |
"from sklearn.pipeline import make_pipeline\n", | |
"from sksurv.linear_model import CoxPHSurvivalAnalysis\n", | |
"from sksurv.metrics import concordance_index_censored\n", | |
"\n", | |
"cox = CoxPHSurvivalAnalysis()\n", | |
"cox.fit(X_trn, y_trn)" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"CoxPHSurvivalAnalysis()" | |
], | |
"text/html": [ | |
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>CoxPHSurvivalAnalysis()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">CoxPHSurvivalAnalysis</label><div class=\"sk-toggleable__content\"><pre>CoxPHSurvivalAnalysis()</pre></div></div></div></div></div>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Let's now attempt to quantify how a survival curve estimated on a training set performs on a test set.\n", | |
"\n", | |
"## Survival model evaluation using the Integrated Brier Score (IBS) and the Concordance Index (C-index)" | |
], | |
"metadata": { | |
"id": "I2wAzNXar551" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"The Brier score and the C-index are measures that **assess the quality of a predicted survival curve** on a finite data sample.\n", | |
"\n", | |
"- **The Brier score is a proper scoring rule**, meaning that an estimate of the survival curve has minimal Brier score if and only if it matches the true survival probabilities induced by the underlying data generating process. In that respect the **Brier score** assesses both the **calibration** and the **ranking power** of a survival probability estimator.\n", | |
"\n", | |
"- On the other hand, the **C-index** only assesses the **ranking power**: it is invariant to a monotonic transform of the survival probabilities. It only focus on the ability of a predictive survival model to identify which individual is likely to fail first out of any pair of two individuals.\n", | |
"\n", | |
"\n", | |
"\n", | |
"It is comprised between 0 and 1 (lower is better).\n", | |
"It answers the question \"how close to the real probabilities are our estimates?\"." | |
], | |
"metadata": { | |
"id": "Gap1YWH5sAA1" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"<summary>Mathematical formulation</summary>\n", | |
" \n", | |
"$$\\mathrm{BS}^c(t) = \\frac{1}{n} \\sum_{i=1}^n I(d_i \\leq t \\land \\delta_i = 1)\n", | |
" \\frac{(0 - \\hat{S}(t | \\mathbf{x}_i))^2}{\\hat{G}(d_i)} + I(d_i > t)\n", | |
" \\frac{(1 - \\hat{S}(t | \\mathbf{x}_i))^2}{\\hat{G}(t)}$$\n", | |
" \n", | |
"In the survival analysis context, the Brier Score can be seen as the Mean Squared Error (MSE) between our probability $\\hat{S}(t)$ and our target label $\\delta_i \\in {0, 1}$, weighted by the inverse probability of censoring $\\frac{1}{\\hat{G}(t)}$. In practice we estimate $\\hat{G}(t)$ using a variant of the Kaplan-Estimator with swapped event indicator.\n", | |
"\n", | |
"- When no event or censoring has happened at $t$ yet, i.e. $I(d_i > t)$, we penalize a low probability of survival with $(1 - \\hat{S}(t|\\mathbf{x}_i))^2$.\n", | |
"- Conversely, when an individual has experienced an event before $t$, i.e. $I(d_i \\leq t \\land \\delta_i = 1)$, we penalize a high probability of survival with $(0 - \\hat{S}(t|\\mathbf{x}_i))^2$." | |
], | |
"metadata": { | |
"id": "3bmwqNQisHup" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "CmP1ahaXsy_c" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], cox.predict(X_test))\n", | |
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "oVdrGM0NtWdR", | |
"outputId": "19ea15bc-9e45-4e8e-d0ab-ae3d590fa289" | |
}, | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"The c-index of Cox is given by 0.665\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"times = np.arange(365, 1826)" | |
], | |
"metadata": { | |
"id": "zoUFRn2yqLPF" | |
}, | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"survs = cox.predict_survival_function(X_test)" | |
], | |
"metadata": { | |
"id": "06i_zOUBqMUt" | |
}, | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"preds = np.asarray([[fn(t) for t in times] for fn in survs])" | |
], | |
"metadata": { | |
"id": "prIyWOtMqe7s" | |
}, | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sksurv.metrics import integrated_brier_score\n", | |
"\n", | |
"integrated_brier_score(y_trn, y_test, preds, times)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Qyvn1TexqkoO", | |
"outputId": "7042880a-70a3-40a6-940f-36a51a0fb246" | |
}, | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.17509710701666106" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 19 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install -q shap\n", | |
"import shap\n", | |
"explainer = shap.Explainer(cox.predict, X_trn)\n", | |
"shap_values = explainer(X_test[:100])\n", | |
"shap.plots.waterfall(shap_values[0])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 540 | |
}, | |
"id": "J03T1z80N8fg", | |
"outputId": "2bc340fd-dd5e-484a-d887-bf48eeed72b0" | |
}, | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 800x550 with 3 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X.describe().transpose().round(2).drop(columns=\"count\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 300 | |
}, | |
"id": "-s0v-Yx5N_3o", | |
"outputId": "d872a1e2-5c87-4cfb-803f-a3a5e846d11f" | |
}, | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" mean std min 25% 50% 75% max\n", | |
"age 53.05 10.12 21.0 46.0 53.0 61.00 80.0\n", | |
"estrec 96.25 153.08 0.0 8.0 36.0 114.00 1144.0\n", | |
"horTh 0.36 0.48 0.0 0.0 0.0 1.00 1.0\n", | |
"menostat 0.58 0.49 0.0 0.0 1.0 1.00 1.0\n", | |
"pnodes 5.01 5.48 1.0 1.0 3.0 7.00 51.0\n", | |
"progrec 110.00 202.33 0.0 7.0 32.5 131.75 2380.0\n", | |
"tgrade 1.12 0.58 0.0 1.0 1.0 1.00 2.0\n", | |
"tsize 29.33 14.30 3.0 20.0 25.0 35.00 120.0" | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-219003f2-8249-4967-887f-11a49065d107\" class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>mean</th>\n", | |
" <th>std</th>\n", | |
" <th>min</th>\n", | |
" <th>25%</th>\n", | |
" <th>50%</th>\n", | |
" <th>75%</th>\n", | |
" <th>max</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>age</th>\n", | |
" <td>53.05</td>\n", | |
" <td>10.12</td>\n", | |
" <td>21.0</td>\n", | |
" <td>46.0</td>\n", | |
" <td>53.0</td>\n", | |
" <td>61.00</td>\n", | |
" <td>80.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>estrec</th>\n", | |
" <td>96.25</td>\n", | |
" <td>153.08</td>\n", | |
" <td>0.0</td>\n", | |
" <td>8.0</td>\n", | |
" <td>36.0</td>\n", | |
" <td>114.00</td>\n", | |
" <td>1144.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>horTh</th>\n", | |
" <td>0.36</td>\n", | |
" <td>0.48</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1.00</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>menostat</th>\n", | |
" <td>0.58</td>\n", | |
" <td>0.49</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.00</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>pnodes</th>\n", | |
" <td>5.01</td>\n", | |
" <td>5.48</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>3.0</td>\n", | |
" <td>7.00</td>\n", | |
" <td>51.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>progrec</th>\n", | |
" <td>110.00</td>\n", | |
" <td>202.33</td>\n", | |
" <td>0.0</td>\n", | |
" <td>7.0</td>\n", | |
" <td>32.5</td>\n", | |
" <td>131.75</td>\n", | |
" <td>2380.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>tgrade</th>\n", | |
" <td>1.12</td>\n", | |
" <td>0.58</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.0</td>\n", | |
" <td>1.00</td>\n", | |
" <td>2.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>tsize</th>\n", | |
" <td>29.33</td>\n", | |
" <td>14.30</td>\n", | |
" <td>3.0</td>\n", | |
" <td>20.0</td>\n", | |
" <td>25.0</td>\n", | |
" <td>35.00</td>\n", | |
" <td>120.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <div class=\"colab-df-buttons\">\n", | |
"\n", | |
" <div class=\"colab-df-container\">\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-219003f2-8249-4967-887f-11a49065d107')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
"\n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n", | |
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
"\n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" .colab-df-buttons div {\n", | |
" margin-bottom: 4px;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-219003f2-8249-4967-887f-11a49065d107 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-219003f2-8249-4967-887f-11a49065d107');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
"\n", | |
"\n", | |
"<div id=\"df-e3611107-f6c5-45e6-ad1c-b101e98fa6aa\">\n", | |
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-e3611107-f6c5-45e6-ad1c-b101e98fa6aa')\"\n", | |
" title=\"Suggest charts.\"\n", | |
" style=\"display:none;\">\n", | |
"\n", | |
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <g>\n", | |
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n", | |
" </g>\n", | |
"</svg>\n", | |
" </button>\n", | |
"\n", | |
"<style>\n", | |
" .colab-df-quickchart {\n", | |
" --bg-color: #E8F0FE;\n", | |
" --fill-color: #1967D2;\n", | |
" --hover-bg-color: #E2EBFA;\n", | |
" --hover-fill-color: #174EA6;\n", | |
" --disabled-fill-color: #AAA;\n", | |
" --disabled-bg-color: #DDD;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-quickchart {\n", | |
" --bg-color: #3B4455;\n", | |
" --fill-color: #D2E3FC;\n", | |
" --hover-bg-color: #434B5C;\n", | |
" --hover-fill-color: #FFFFFF;\n", | |
" --disabled-bg-color: #3B4455;\n", | |
" --disabled-fill-color: #666;\n", | |
" }\n", | |
"\n", | |
" .colab-df-quickchart {\n", | |
" background-color: var(--bg-color);\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: var(--fill-color);\n", | |
" height: 32px;\n", | |
" padding: 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-quickchart:hover {\n", | |
" background-color: var(--hover-bg-color);\n", | |
" box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: var(--button-hover-fill-color);\n", | |
" }\n", | |
"\n", | |
" .colab-df-quickchart-complete:disabled,\n", | |
" .colab-df-quickchart-complete:disabled:hover {\n", | |
" background-color: var(--disabled-bg-color);\n", | |
" fill: var(--disabled-fill-color);\n", | |
" box-shadow: none;\n", | |
" }\n", | |
"\n", | |
" .colab-df-spinner {\n", | |
" border: 2px solid var(--fill-color);\n", | |
" border-color: transparent;\n", | |
" border-bottom-color: var(--fill-color);\n", | |
" animation:\n", | |
" spin 1s steps(1) infinite;\n", | |
" }\n", | |
"\n", | |
" @keyframes spin {\n", | |
" 0% {\n", | |
" border-color: transparent;\n", | |
" border-bottom-color: var(--fill-color);\n", | |
" border-left-color: var(--fill-color);\n", | |
" }\n", | |
" 20% {\n", | |
" border-color: transparent;\n", | |
" border-left-color: var(--fill-color);\n", | |
" border-top-color: var(--fill-color);\n", | |
" }\n", | |
" 30% {\n", | |
" border-color: transparent;\n", | |
" border-left-color: var(--fill-color);\n", | |
" border-top-color: var(--fill-color);\n", | |
" border-right-color: var(--fill-color);\n", | |
" }\n", | |
" 40% {\n", | |
" border-color: transparent;\n", | |
" border-right-color: var(--fill-color);\n", | |
" border-top-color: var(--fill-color);\n", | |
" }\n", | |
" 60% {\n", | |
" border-color: transparent;\n", | |
" border-right-color: var(--fill-color);\n", | |
" }\n", | |
" 80% {\n", | |
" border-color: transparent;\n", | |
" border-right-color: var(--fill-color);\n", | |
" border-bottom-color: var(--fill-color);\n", | |
" }\n", | |
" 90% {\n", | |
" border-color: transparent;\n", | |
" border-bottom-color: var(--fill-color);\n", | |
" }\n", | |
" }\n", | |
"</style>\n", | |
"\n", | |
" <script>\n", | |
" async function quickchart(key) {\n", | |
" const quickchartButtonEl =\n", | |
" document.querySelector('#' + key + ' button');\n", | |
" quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n", | |
" quickchartButtonEl.classList.add('colab-df-spinner');\n", | |
" try {\n", | |
" const charts = await google.colab.kernel.invokeFunction(\n", | |
" 'suggestCharts', [key], {});\n", | |
" } catch (error) {\n", | |
" console.error('Error during call to suggestCharts:', error);\n", | |
" }\n", | |
" quickchartButtonEl.classList.remove('colab-df-spinner');\n", | |
" quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n", | |
" }\n", | |
" (() => {\n", | |
" let quickchartButtonEl =\n", | |
" document.querySelector('#df-e3611107-f6c5-45e6-ad1c-b101e98fa6aa button');\n", | |
" quickchartButtonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
" })();\n", | |
" </script>\n", | |
"</div>\n", | |
" </div>\n", | |
" </div>\n" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 21 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"shap.plots.beeswarm(shap_values)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 453 | |
}, | |
"id": "0kbYd1wAODT9", | |
"outputId": "1c9aa6ea-f435-4c32-c143-ad0f6f71b61f" | |
}, | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 800x470 with 2 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZNHQ-bkWALNy", | |
"outputId": "3a63b8c0-5ce3-4c82-f22b-dfd4325c5fdd", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"from scipy.stats import reciprocal\n", | |
"from sklearn.model_selection import RandomizedSearchCV\n", | |
"\n", | |
"param_distributions = {\n", | |
" 'alpha': reciprocal(0.1, 100),\n", | |
"}\n", | |
"\n", | |
"model_random_search = RandomizedSearchCV(\n", | |
" cox, param_distributions=param_distributions, n_iter=50, n_jobs=-1, cv=3, random_state=42)\n", | |
"model_random_search.fit(X_trn, y_trn)\n", | |
"\n", | |
"print(\n", | |
" f\"The c-index of Cox using a {model_random_search.__class__.__name__} is \"\n", | |
" f\"{model_random_search.score(X_test, y_test):.3f}\")\n", | |
"print(\n", | |
" f\"The best set of parameters is: {model_random_search.best_params_}\"\n", | |
")" | |
], | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"The c-index of Cox using a RandomizedSearchCV is 0.660\n", | |
"The best set of parameters is: {'alpha': 39.67605077052987}\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "MrhAveCQAdxF", | |
"outputId": "75475cd8-1ec8-4df1-ec0a-2ff2d00038d4", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"alpha = model_random_search.best_params_['alpha']\n", | |
"cox_best = make_pipeline(CoxPHSurvivalAnalysis(alpha=alpha))\n", | |
"cox_best.fit(X_trn, y_trn)\n", | |
"\n", | |
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], cox_best.predict(X_test))\n", | |
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')" | |
], | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"The c-index of Cox is given by 0.660\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sksurv.ensemble import RandomSurvivalForest" | |
], | |
"metadata": { | |
"id": "s34_kkDKKzw5" | |
}, | |
"execution_count": 27, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"rsf = RandomSurvivalForest(\n", | |
" n_estimators=100, min_samples_leaf=15, n_jobs=-1, random_state=20\n", | |
")\n", | |
"rsf.fit(X_trn, y_trn)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 91 | |
}, | |
"id": "vU06jwE7Kuec", | |
"outputId": "e3cc5cf0-0655-46f6-91ea-d74ed12d6c5b" | |
}, | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"RandomSurvivalForest(min_samples_leaf=15, min_samples_split=10, n_jobs=-1,\n", | |
" random_state=20)" | |
], | |
"text/html": [ | |
"<style>#sk-container-id-2 {color: black;}#sk-container-id-2 pre{padding: 0;}#sk-container-id-2 div.sk-toggleable {background-color: white;}#sk-container-id-2 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-2 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-2 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-2 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-2 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-2 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-2 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-2 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-2 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-2 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-2 div.sk-item {position: relative;z-index: 1;}#sk-container-id-2 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-2 div.sk-item::before, #sk-container-id-2 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-2 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-2 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-2 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-2 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-2 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-2 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-2 div.sk-label-container {text-align: center;}#sk-container-id-2 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-2 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomSurvivalForest(min_samples_leaf=15, min_samples_split=10, n_jobs=-1,\n", | |
" random_state=20)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomSurvivalForest</label><div class=\"sk-toggleable__content\"><pre>RandomSurvivalForest(min_samples_leaf=15, min_samples_split=10, n_jobs=-1,\n", | |
" random_state=20)</pre></div></div></div></div></div>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 28 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"rsf.score(X_test, y_test)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "LJ0Tp63pKvMJ", | |
"outputId": "908b6916-2562-4533-ab15-853054aab6c8" | |
}, | |
"execution_count": 29, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.6716457023060797" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 29 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], rsf.predict(X_test))\n", | |
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Y1hNGO_QO2pW", | |
"outputId": "93786c36-160e-4b80-fd59-c779879fff39" | |
}, | |
"execution_count": 30, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"The c-index of Cox is given by 0.672\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"times = np.arange(365, 1826)" | |
], | |
"metadata": { | |
"id": "XHsIhuW7PHik" | |
}, | |
"execution_count": 31, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"survs = rsf.predict_survival_function(X_test)" | |
], | |
"metadata": { | |
"id": "iSZ0pa17PSJS" | |
}, | |
"execution_count": 32, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"preds = np.asarray([[fn(t) for t in times] for fn in survs])" | |
], | |
"metadata": { | |
"id": "np-god8mPW5W" | |
}, | |
"execution_count": 33, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"integrated_brier_score(y_trn, y_test, preds, times)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "U1avp5nfPaOS", | |
"outputId": "b1e50567-5ccd-40a5-9f19-b3aeb70881ee" | |
}, | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.17751179194838101" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 34 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"param_distributions = {\n", | |
" 'min_samples_leaf': [3, 7, 15],\n", | |
" 'max_depth': [3, 7, None]\n", | |
"}\n", | |
"\n", | |
"model_random_search = RandomizedSearchCV(\n", | |
" rsf, param_distributions=param_distributions, n_iter=50, n_jobs=-1, cv=3, random_state=42)\n", | |
"model_random_search.fit(X_trn, y_trn)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 134 | |
}, | |
"id": "QjiQ8TfBPfZg", | |
"outputId": "5438c05e-b20d-4cb7-f05d-b93740c08857" | |
}, | |
"execution_count": 35, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"The total space of parameters 9 is smaller than n_iter=50. Running 9 iterations. For exhaustive searches, use GridSearchCV.\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"RandomizedSearchCV(cv=3,\n", | |
" estimator=RandomSurvivalForest(min_samples_leaf=15,\n", | |
" min_samples_split=10,\n", | |
" n_jobs=-1, random_state=20),\n", | |
" n_iter=50, n_jobs=-1,\n", | |
" param_distributions={'max_depth': [3, 7, None],\n", | |
" 'min_samples_leaf': [3, 7, 15]},\n", | |
" random_state=42)" | |
], | |
"text/html": [ | |
"<style>#sk-container-id-3 {color: black;}#sk-container-id-3 pre{padding: 0;}#sk-container-id-3 div.sk-toggleable {background-color: white;}#sk-container-id-3 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-3 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-3 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-3 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-3 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-3 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-3 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-3 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-3 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-3 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-3 div.sk-item {position: relative;z-index: 1;}#sk-container-id-3 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-3 div.sk-item::before, #sk-container-id-3 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-3 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-3 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-3 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-3 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-3 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-3 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-3 div.sk-label-container {text-align: center;}#sk-container-id-3 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-3 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomizedSearchCV(cv=3,\n", | |
" estimator=RandomSurvivalForest(min_samples_leaf=15,\n", | |
" min_samples_split=10,\n", | |
" n_jobs=-1, random_state=20),\n", | |
" n_iter=50, n_jobs=-1,\n", | |
" param_distributions={'max_depth': [3, 7, None],\n", | |
" 'min_samples_leaf': [3, 7, 15]},\n", | |
" random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomizedSearchCV</label><div class=\"sk-toggleable__content\"><pre>RandomizedSearchCV(cv=3,\n", | |
" estimator=RandomSurvivalForest(min_samples_leaf=15,\n", | |
" min_samples_split=10,\n", | |
" n_jobs=-1, random_state=20),\n", | |
" n_iter=50, n_jobs=-1,\n", | |
" param_distributions={'max_depth': [3, 7, None],\n", | |
" 'min_samples_leaf': [3, 7, 15]},\n", | |
" random_state=42)</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: RandomSurvivalForest</label><div class=\"sk-toggleable__content\"><pre>RandomSurvivalForest(min_samples_leaf=15, min_samples_split=10, n_jobs=-1,\n", | |
" random_state=20)</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomSurvivalForest</label><div class=\"sk-toggleable__content\"><pre>RandomSurvivalForest(min_samples_leaf=15, min_samples_split=10, n_jobs=-1,\n", | |
" random_state=20)</pre></div></div></div></div></div></div></div></div></div></div>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 35 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model_random_search.score(X_test, y_test)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_X2klnTmQVXX", | |
"outputId": "b15bfb3d-130b-463a-831b-a82dca36fa6a" | |
}, | |
"execution_count": 36, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.6805555555555556" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 36 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(\n", | |
" f\"The best set of parameters is: {model_random_search.best_params_}\"\n", | |
")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "6Mam6JN2Qqh3", | |
"outputId": "6d5ba74e-d207-4fb6-9d9b-084774b76174" | |
}, | |
"execution_count": 38, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"The best set of parameters is: {'min_samples_leaf': 15, 'max_depth': 3}\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"rsf_best = RandomSurvivalForest(\n", | |
" n_estimators=100, min_samples_leaf=15, max_depth=3, n_jobs=-1, random_state=20\n", | |
")\n", | |
"rsf_best.fit(X_trn, y_trn)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 91 | |
}, | |
"id": "PbSNPjgQQruO", | |
"outputId": "c0218816-d654-48f5-c9f8-24e7ba37c443" | |
}, | |
"execution_count": 40, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"RandomSurvivalForest(max_depth=3, min_samples_leaf=15, n_jobs=-1,\n", | |
" random_state=20)" | |
], | |
"text/html": [ | |
"<style>#sk-container-id-5 {color: black;}#sk-container-id-5 pre{padding: 0;}#sk-container-id-5 div.sk-toggleable {background-color: white;}#sk-container-id-5 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-5 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-5 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-5 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-5 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-5 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-5 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-5 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-5 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-5 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-5 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-5 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-5 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-5 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-5 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-5 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-5 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-5 div.sk-item {position: relative;z-index: 1;}#sk-container-id-5 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-5 div.sk-item::before, #sk-container-id-5 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-5 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-5 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-5 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-5 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-5 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-5 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-5 div.sk-label-container {text-align: center;}#sk-container-id-5 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-5 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-5\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomSurvivalForest(max_depth=3, min_samples_leaf=15, n_jobs=-1,\n", | |
" random_state=20)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-7\" type=\"checkbox\" checked><label for=\"sk-estimator-id-7\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">RandomSurvivalForest</label><div class=\"sk-toggleable__content\"><pre>RandomSurvivalForest(max_depth=3, min_samples_leaf=15, n_jobs=-1,\n", | |
" random_state=20)</pre></div></div></div></div></div>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 40 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"times = np.arange(365, 1826)\n", | |
"survs = rsf_best.predict_survival_function(X_test)" | |
], | |
"metadata": { | |
"id": "iO5hU-exQZup" | |
}, | |
"execution_count": 41, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"preds = np.asarray([[fn(t) for t in times] for fn in survs])" | |
], | |
"metadata": { | |
"id": "nVQsQOocQhOI" | |
}, | |
"execution_count": 42, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"integrated_brier_score(y_trn, y_test, preds, times)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "rYRRmruYRCgq", | |
"outputId": "a1a4bdf8-27d9-4544-f7af-433429972883" | |
}, | |
"execution_count": 43, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.17850725604478493" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 43 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sksurv.ensemble import GradientBoostingSurvivalAnalysis" | |
], | |
"metadata": { | |
"id": "wTxXroEmRYMm" | |
}, | |
"execution_count": 44, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "wnfh4hsbRYw6" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment