Created
June 7, 2023 14:59
-
-
Save alonsosilvaallende/ada27894eca947da7f2ccd706463d0de to your computer and use it in GitHub Desktop.
Cox_PH_and_RSF-colab.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Cox_PH_and_RSF-colab.ipynb", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.2" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/ada27894eca947da7f2ccd706463d0de/cox_ph_and_rsf-colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UN3PoUTSb2nT" | |
}, | |
"source": [ | |
"The objective of this post is to compare two models (Cox proportional hazards model and Random survival forest model) to estimate the survival probability given a set of features/covariables.\n", | |
"\n", | |
">[\"Experimental Comparison of Semi-parametric, Parametric, and Machine Learning Models for Time-to-Event Analysis Through the Concordance Index,\"](https://arxiv.org/abs/2003.08820)\n", | |
"Camila Fernandez, Chung Shue Chen, Pierre Gaillard, Alonso Silva\n", | |
"\n", | |
"To perform this analysis we will use [scikit-learn](https://scikit-learn.org/) and [scikit-survival](https://pypi.org/project/scikit-survival/). Finally, we will use [eli5](https://eli5.readthedocs.io/en/latest/index.html) to study feature importances (computed with permutation importance)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pzqGFx_Bb_6A" | |
}, | |
"source": [ | |
"!pip install -q scikit-survival" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VoQVEI5p_rga" | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import matplotlib.pyplot as plt" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SnT6e_JPb2ns" | |
}, | |
"source": [ | |
"We first download a dataset from scikit-survival." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "D0xxNWzI-N3j" | |
}, | |
"source": [ | |
"from sksurv.datasets import load_gbsg2\n", | |
"\n", | |
"X, y = load_gbsg2()" | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "IBTo4q_Hb2n0" | |
}, | |
"source": [ | |
"## An example: German Breast Cancer Study Group 2 (gbcsg)\n", | |
"\n", | |
"This dataset contains the following 8 features/covariables:\n", | |
"\n", | |
"- age: age (in years), \n", | |
"- estrec: estrogen receptor (in fmol), \n", | |
"- horTh: hormonal therapy (yes or no), \n", | |
"- menostat: menopausal status (premenopausal or postmenopausal),\n", | |
"- pnodes: number of positive nodes, \n", | |
"- progrec: progesterone receptor (in fmol), \n", | |
"- tgrade: tumor grade (I < II < III), \n", | |
"- tsize: tumor size (in mm).\n", | |
"\n", | |
"and the two outputs: \n", | |
"\n", | |
"- recurrence free time (in days),\n", | |
"- censoring indicator (0 - censored, 1 - event).\n", | |
"\n", | |
"The dataset has 686 samples and 8 features/covariables.\n", | |
"\n", | |
"\n", | |
"**References**\n", | |
"\n", | |
"M. Schumacher, G. Basert, H. Bojar, K. Huebner, M. Olschewski, W. Sauerbrei, C. Schmoor, C. Beyerle, R.L.A. Neumann and H.F. Rauschecker for the German Breast Cancer Study Group (1994), [Randomized 2 x 2 trial evaluating hormonal treatment and the duration of chemotherapy in node-positive breast cancer patients](https://www.ncbi.nlm.nih.gov/pubmed/7931478). Journal of Clinical Oncology, 12, 2086–2093." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "kAsZ72YYb2n3" | |
}, | |
"source": [ | |
"Let's take a look at the features/covariates." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_OlmlI6g-X43", | |
"outputId": "f3b97327-9b8a-491b-e36a-00887b6b2f9d", | |
"scrolled": true, | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 363 | |
} | |
}, | |
"source": [ | |
"cols = [\"age\", \"estrec\", \"pnodes\", \"progrec\", \"tsize\"]\n", | |
"formatdict = {}\n", | |
"for col in cols: formatdict[col] = \"{:,.0f}\"\n", | |
"X.head(10).style.hide(axis=\"index\").format(formatdict)" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7f6efea2a9b0>" | |
], | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"</style>\n", | |
"<table id=\"T_b723e\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th id=\"T_b723e_level0_col0\" class=\"col_heading level0 col0\" >age</th>\n", | |
" <th id=\"T_b723e_level0_col1\" class=\"col_heading level0 col1\" >estrec</th>\n", | |
" <th id=\"T_b723e_level0_col2\" class=\"col_heading level0 col2\" >horTh</th>\n", | |
" <th id=\"T_b723e_level0_col3\" class=\"col_heading level0 col3\" >menostat</th>\n", | |
" <th id=\"T_b723e_level0_col4\" class=\"col_heading level0 col4\" >pnodes</th>\n", | |
" <th id=\"T_b723e_level0_col5\" class=\"col_heading level0 col5\" >progrec</th>\n", | |
" <th id=\"T_b723e_level0_col6\" class=\"col_heading level0 col6\" >tgrade</th>\n", | |
" <th id=\"T_b723e_level0_col7\" class=\"col_heading level0 col7\" >tsize</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row0_col0\" class=\"data row0 col0\" >70</td>\n", | |
" <td id=\"T_b723e_row0_col1\" class=\"data row0 col1\" >66</td>\n", | |
" <td id=\"T_b723e_row0_col2\" class=\"data row0 col2\" >no</td>\n", | |
" <td id=\"T_b723e_row0_col3\" class=\"data row0 col3\" >Post</td>\n", | |
" <td id=\"T_b723e_row0_col4\" class=\"data row0 col4\" >3</td>\n", | |
" <td id=\"T_b723e_row0_col5\" class=\"data row0 col5\" >48</td>\n", | |
" <td id=\"T_b723e_row0_col6\" class=\"data row0 col6\" >II</td>\n", | |
" <td id=\"T_b723e_row0_col7\" class=\"data row0 col7\" >21</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row1_col0\" class=\"data row1 col0\" >56</td>\n", | |
" <td id=\"T_b723e_row1_col1\" class=\"data row1 col1\" >77</td>\n", | |
" <td id=\"T_b723e_row1_col2\" class=\"data row1 col2\" >yes</td>\n", | |
" <td id=\"T_b723e_row1_col3\" class=\"data row1 col3\" >Post</td>\n", | |
" <td id=\"T_b723e_row1_col4\" class=\"data row1 col4\" >7</td>\n", | |
" <td id=\"T_b723e_row1_col5\" class=\"data row1 col5\" >61</td>\n", | |
" <td id=\"T_b723e_row1_col6\" class=\"data row1 col6\" >II</td>\n", | |
" <td id=\"T_b723e_row1_col7\" class=\"data row1 col7\" >12</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row2_col0\" class=\"data row2 col0\" >58</td>\n", | |
" <td id=\"T_b723e_row2_col1\" class=\"data row2 col1\" >271</td>\n", | |
" <td id=\"T_b723e_row2_col2\" class=\"data row2 col2\" >yes</td>\n", | |
" <td id=\"T_b723e_row2_col3\" class=\"data row2 col3\" >Post</td>\n", | |
" <td id=\"T_b723e_row2_col4\" class=\"data row2 col4\" >9</td>\n", | |
" <td id=\"T_b723e_row2_col5\" class=\"data row2 col5\" >52</td>\n", | |
" <td id=\"T_b723e_row2_col6\" class=\"data row2 col6\" >II</td>\n", | |
" <td id=\"T_b723e_row2_col7\" class=\"data row2 col7\" >35</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row3_col0\" class=\"data row3 col0\" >59</td>\n", | |
" <td id=\"T_b723e_row3_col1\" class=\"data row3 col1\" >29</td>\n", | |
" <td id=\"T_b723e_row3_col2\" class=\"data row3 col2\" >yes</td>\n", | |
" <td id=\"T_b723e_row3_col3\" class=\"data row3 col3\" >Post</td>\n", | |
" <td id=\"T_b723e_row3_col4\" class=\"data row3 col4\" >4</td>\n", | |
" <td id=\"T_b723e_row3_col5\" class=\"data row3 col5\" >60</td>\n", | |
" <td id=\"T_b723e_row3_col6\" class=\"data row3 col6\" >II</td>\n", | |
" <td id=\"T_b723e_row3_col7\" class=\"data row3 col7\" >17</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row4_col0\" class=\"data row4 col0\" >73</td>\n", | |
" <td id=\"T_b723e_row4_col1\" class=\"data row4 col1\" >65</td>\n", | |
" <td id=\"T_b723e_row4_col2\" class=\"data row4 col2\" >no</td>\n", | |
" <td id=\"T_b723e_row4_col3\" class=\"data row4 col3\" >Post</td>\n", | |
" <td id=\"T_b723e_row4_col4\" class=\"data row4 col4\" >1</td>\n", | |
" <td id=\"T_b723e_row4_col5\" class=\"data row4 col5\" >26</td>\n", | |
" <td id=\"T_b723e_row4_col6\" class=\"data row4 col6\" >II</td>\n", | |
" <td id=\"T_b723e_row4_col7\" class=\"data row4 col7\" >35</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row5_col0\" class=\"data row5 col0\" >32</td>\n", | |
" <td id=\"T_b723e_row5_col1\" class=\"data row5 col1\" >13</td>\n", | |
" <td id=\"T_b723e_row5_col2\" class=\"data row5 col2\" >no</td>\n", | |
" <td id=\"T_b723e_row5_col3\" class=\"data row5 col3\" >Pre</td>\n", | |
" <td id=\"T_b723e_row5_col4\" class=\"data row5 col4\" >24</td>\n", | |
" <td id=\"T_b723e_row5_col5\" class=\"data row5 col5\" >0</td>\n", | |
" <td id=\"T_b723e_row5_col6\" class=\"data row5 col6\" >III</td>\n", | |
" <td id=\"T_b723e_row5_col7\" class=\"data row5 col7\" >57</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row6_col0\" class=\"data row6 col0\" >59</td>\n", | |
" <td id=\"T_b723e_row6_col1\" class=\"data row6 col1\" >0</td>\n", | |
" <td id=\"T_b723e_row6_col2\" class=\"data row6 col2\" >yes</td>\n", | |
" <td id=\"T_b723e_row6_col3\" class=\"data row6 col3\" >Post</td>\n", | |
" <td id=\"T_b723e_row6_col4\" class=\"data row6 col4\" >2</td>\n", | |
" <td id=\"T_b723e_row6_col5\" class=\"data row6 col5\" >181</td>\n", | |
" <td id=\"T_b723e_row6_col6\" class=\"data row6 col6\" >II</td>\n", | |
" <td id=\"T_b723e_row6_col7\" class=\"data row6 col7\" >8</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row7_col0\" class=\"data row7 col0\" >65</td>\n", | |
" <td id=\"T_b723e_row7_col1\" class=\"data row7 col1\" >25</td>\n", | |
" <td id=\"T_b723e_row7_col2\" class=\"data row7 col2\" >no</td>\n", | |
" <td id=\"T_b723e_row7_col3\" class=\"data row7 col3\" >Post</td>\n", | |
" <td id=\"T_b723e_row7_col4\" class=\"data row7 col4\" >1</td>\n", | |
" <td id=\"T_b723e_row7_col5\" class=\"data row7 col5\" >192</td>\n", | |
" <td id=\"T_b723e_row7_col6\" class=\"data row7 col6\" >II</td>\n", | |
" <td id=\"T_b723e_row7_col7\" class=\"data row7 col7\" >16</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row8_col0\" class=\"data row8 col0\" >80</td>\n", | |
" <td id=\"T_b723e_row8_col1\" class=\"data row8 col1\" >59</td>\n", | |
" <td id=\"T_b723e_row8_col2\" class=\"data row8 col2\" >no</td>\n", | |
" <td id=\"T_b723e_row8_col3\" class=\"data row8 col3\" >Post</td>\n", | |
" <td id=\"T_b723e_row8_col4\" class=\"data row8 col4\" >30</td>\n", | |
" <td id=\"T_b723e_row8_col5\" class=\"data row8 col5\" >0</td>\n", | |
" <td id=\"T_b723e_row8_col6\" class=\"data row8 col6\" >II</td>\n", | |
" <td id=\"T_b723e_row8_col7\" class=\"data row8 col7\" >39</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_b723e_row9_col0\" class=\"data row9 col0\" >66</td>\n", | |
" <td id=\"T_b723e_row9_col1\" class=\"data row9 col1\" >3</td>\n", | |
" <td id=\"T_b723e_row9_col2\" class=\"data row9 col2\" >no</td>\n", | |
" <td id=\"T_b723e_row9_col3\" class=\"data row9 col3\" >Post</td>\n", | |
" <td id=\"T_b723e_row9_col4\" class=\"data row9 col4\" >7</td>\n", | |
" <td id=\"T_b723e_row9_col5\" class=\"data row9 col5\" >0</td>\n", | |
" <td id=\"T_b723e_row9_col6\" class=\"data row9 col6\" >II</td>\n", | |
" <td id=\"T_b723e_row9_col7\" class=\"data row9 col7\" >18</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zcMjHeAEb2n_" | |
}, | |
"source": [ | |
"Let's take a look at the output." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "h8ltRTa4_WOn", | |
"outputId": "74cce024-2643-4713-b117-fc47f9b672c2", | |
"scrolled": true, | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"y[:10]" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([( True, 1814.), ( True, 2018.), ( True, 712.), ( True, 1807.),\n", | |
" ( True, 772.), ( True, 448.), (False, 2172.), (False, 2161.),\n", | |
" ( True, 471.), (False, 2014.)],\n", | |
" dtype=[('cens', '?'), ('time', '<f8')])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "A5NXlNYob2oH" | |
}, | |
"source": [ | |
"For the output, scikit-survival uses a numpy nd array, so to show it we do a dataframe." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"df_y = pd.DataFrame(data={'time': y['time'].astype(int), 'event': y['cens']})\n", | |
"df_y[:10].style.hide(axis=\"index\").highlight_min('event', color='lightgreen')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 363 | |
}, | |
"id": "AfPvZcjJ-GwQ", | |
"outputId": "290f3b50-93fd-47df-9a47-389a7036c7e8" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7f6ec8b32770>" | |
], | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"#T_bcba8_row6_col1, #T_bcba8_row7_col1, #T_bcba8_row9_col1 {\n", | |
" background-color: lightgreen;\n", | |
"}\n", | |
"</style>\n", | |
"<table id=\"T_bcba8\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th id=\"T_bcba8_level0_col0\" class=\"col_heading level0 col0\" >time</th>\n", | |
" <th id=\"T_bcba8_level0_col1\" class=\"col_heading level0 col1\" >event</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row0_col0\" class=\"data row0 col0\" >1814</td>\n", | |
" <td id=\"T_bcba8_row0_col1\" class=\"data row0 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row1_col0\" class=\"data row1 col0\" >2018</td>\n", | |
" <td id=\"T_bcba8_row1_col1\" class=\"data row1 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row2_col0\" class=\"data row2 col0\" >712</td>\n", | |
" <td id=\"T_bcba8_row2_col1\" class=\"data row2 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row3_col0\" class=\"data row3 col0\" >1807</td>\n", | |
" <td id=\"T_bcba8_row3_col1\" class=\"data row3 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row4_col0\" class=\"data row4 col0\" >772</td>\n", | |
" <td id=\"T_bcba8_row4_col1\" class=\"data row4 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row5_col0\" class=\"data row5 col0\" >448</td>\n", | |
" <td id=\"T_bcba8_row5_col1\" class=\"data row5 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row6_col0\" class=\"data row6 col0\" >2172</td>\n", | |
" <td id=\"T_bcba8_row6_col1\" class=\"data row6 col1\" >False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row7_col0\" class=\"data row7 col0\" >2161</td>\n", | |
" <td id=\"T_bcba8_row7_col1\" class=\"data row7 col1\" >False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row8_col0\" class=\"data row8 col0\" >471</td>\n", | |
" <td id=\"T_bcba8_row8_col1\" class=\"data row8 col1\" >True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td id=\"T_bcba8_row9_col0\" class=\"data row9 col0\" >2014</td>\n", | |
" <td id=\"T_bcba8_row9_col1\" class=\"data row9 col1\" >False</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "xzW5x7ljb2oP" | |
}, | |
"source": [ | |
"One of the main challenges of survival analysis is **right censoring**, i.e., by the end of the study, the event of interest (for example, in medicine 'death of a patient' or in this dataset 'recurrence of cancer') has only occurred for a subset of the observations.\n", | |
"\n", | |
"The **right censoring** in this dataset is given by the column named 'event' and it's a variable which can take value 'True' if the patient had a recurrence of cancer or 'False' if the patient is recurrence free at the indicated time (right-censored samples)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "VmfAR7igb2oW" | |
}, | |
"source": [ | |
"Let's see how many right-censored samples do we have." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rzS8h1GG_o_A", | |
"outputId": "3fc3c587-1978-45c5-abaa-6eed2bd5af6a", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"print(f'Number of samples: {len(df_y)}')\n", | |
"print(f'Number of right censored samples: {len(df_y.query(\"event == False\"))}')\n", | |
"print(f'Percentage of right censored samples: {100*len(df_y.query(\"event == False\"))/len(df_y):.1f}%')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Number of samples: 686\n", | |
"Number of right censored samples: 387\n", | |
"Percentage of right censored samples: 56.4%\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "VtsENFsnQhZx" | |
}, | |
"source": [ | |
"There are 387 patients (56.4%) who were right censored (recurrence free) at the end of the study.\n", | |
"\n", | |
"Let's divide our dataset in training and test sets." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "duYhddUr_1nH", | |
"outputId": "17d1a3fb-2793-446e-cdd6-3a866d24fbc3", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"\n", | |
"X_trn, X_test, y_trn, y_test = train_test_split(X, y, random_state=42)\n", | |
"\n", | |
"print(f'Number of training samples: {len(y_trn)}')\n", | |
"print(f'Number of test samples: {len(y_test)}')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Number of training samples: 514\n", | |
"Number of test samples: 172\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "3VEOV-vWb2ow" | |
}, | |
"source": [ | |
"We divide the features/covariates into continuous and categorical." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "R1UVhukR_8VA" | |
}, | |
"source": [ | |
"scaling_cols = [c for c in X.columns if X[c].dtype.kind in ['i', 'f']]\n", | |
"cat_cols = [c for c in X.columns if X[c].dtype.kind not in [\"i\", \"f\"]]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "jDjgk9PWb2o3" | |
}, | |
"source": [ | |
"We use ordinal encoding for categorical features/covariates and standard scaling for continuous features/covariates." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "axfu8RQeACQx" | |
}, | |
"source": [ | |
"from sklearn.compose import ColumnTransformer\n", | |
"from sklearn.preprocessing import OrdinalEncoder\n", | |
"from sklearn.preprocessing import StandardScaler\n", | |
"\n", | |
"preprocessor = ColumnTransformer(\n", | |
" [('cat-preprocessor', OrdinalEncoder(), cat_cols),\n", | |
" ('standard-scaler', StandardScaler(), scaling_cols)],\n", | |
" remainder='passthrough', sparse_threshold=0)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "N_nodxlEb2o-" | |
}, | |
"source": [ | |
"# Baseline: Cox Proportional Hazards model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "KqdChwYJb2o_" | |
}, | |
"source": [ | |
"Cox Proportional Hazards model assumes that the log-hazard of a subject is a linear function of their $m$ static covariates/features $h_i, i\\in\\{1,\\ldots,m\\}$, and a population-level baseline hazard function $h_0(t)$ that changes over time:\n", | |
"\\begin{equation}\n", | |
"h(t|x)=h_0(t)\\exp\\left(\\sum_{i=1}^mh_i(x_i-\\bar{x_i})\\right).\n", | |
"\\end{equation}\n", | |
"\n", | |
"The term *proportional hazards* refers to the assumption of a constant relationship between the dependent variable and the regression coefficients." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "77YbwMKvAFHQ", | |
"outputId": "f5a6349f-b86e-49df-9a2a-9ae06d610fc0", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"from sklearn.pipeline import make_pipeline\n", | |
"from sksurv.linear_model import CoxPHSurvivalAnalysis\n", | |
"from sksurv.metrics import concordance_index_censored\n", | |
"\n", | |
"cox = make_pipeline(preprocessor, CoxPHSurvivalAnalysis())\n", | |
"cox.fit(X_trn, y_trn)\n", | |
"\n", | |
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], cox.predict(X_test))\n", | |
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"The c-index of Cox is given by 0.635\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZNHQ-bkWALNy", | |
"outputId": "8683c357-0633-4f9f-8096-9bf13e2d8321", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"from scipy.stats import reciprocal\n", | |
"from sklearn.model_selection import RandomizedSearchCV\n", | |
"\n", | |
"param_distributions = {\n", | |
" 'coxphsurvivalanalysis__alpha': reciprocal(0.1, 100),\n", | |
"}\n", | |
"\n", | |
"model_random_search = RandomizedSearchCV(\n", | |
" cox, param_distributions=param_distributions, n_iter=50, n_jobs=-1, cv=3, random_state=42)\n", | |
"model_random_search.fit(X_trn, y_trn)\n", | |
"\n", | |
"print(\n", | |
" f\"The c-index of Cox using a {model_random_search.__class__.__name__} is \"\n", | |
" f\"{model_random_search.score(X_test, y_test):.3f}\")\n", | |
"print(\n", | |
" f\"The best set of parameters is: {model_random_search.best_params_}\"\n", | |
")" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"The c-index of Cox using a RandomizedSearchCV is 0.646\n", | |
"The best set of parameters is: {'coxphsurvivalanalysis__alpha': 31.428808908401084}\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "MrhAveCQAdxF", | |
"outputId": "0f791b44-21ac-4546-e7d0-8553c757af6e", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"alpha = model_random_search.best_params_['coxphsurvivalanalysis__alpha']\n", | |
"cox_best = make_pipeline(preprocessor, CoxPHSurvivalAnalysis(alpha=alpha))\n", | |
"cox_best.fit(X_trn, y_trn)\n", | |
"\n", | |
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], cox_best.predict(X_test))\n", | |
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"The c-index of Cox is given by 0.646\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WArdzxG0cZ5d", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "05cf000a-579c-42d5-aaeb-f53d3e68a6c8" | |
}, | |
"source": [ | |
"!pip install -q eli5" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[?25l\r\u001b[K |███ | 10 kB 24.3 MB/s eta 0:00:01\r\u001b[K |██████▏ | 20 kB 13.4 MB/s eta 0:00:01\r\u001b[K |█████████▎ | 30 kB 10.5 MB/s eta 0:00:01\r\u001b[K |████████████▍ | 40 kB 9.1 MB/s eta 0:00:01\r\u001b[K |███████████████▌ | 51 kB 5.3 MB/s eta 0:00:01\r\u001b[K |██████████████████▌ | 61 kB 5.9 MB/s eta 0:00:01\r\u001b[K |█████████████████████▋ | 71 kB 5.7 MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 81 kB 6.4 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▉ | 92 kB 5.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 102 kB 5.4 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 106 kB 5.4 MB/s \n", | |
"\u001b[?25h" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Oy3whfxOAlzN" | |
}, | |
"source": [ | |
"from eli5.sklearn import PermutationImportance" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "j2wH_sxgBX1-" | |
}, | |
"source": [ | |
"perm = PermutationImportance(\n", | |
" cox_best.steps[-1][1], n_iter=100, random_state=42).fit(preprocessor.fit_transform(X_trn),y_trn)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1Rdepqh18UDV" | |
}, | |
"source": [ | |
"data = perm.results_\n", | |
"data = pd.DataFrame(data, columns=X_trn.columns)\n", | |
"meds = data.median()\n", | |
"meds = meds.sort_values(ascending=False)\n", | |
"data = data[meds.index]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "PaPRs7hR9QL5", | |
"outputId": "149f6323-0025-4cee-cb8c-bc443a8420f2", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 444 | |
} | |
}, | |
"source": [ | |
"fig, ax = plt.subplots(figsize=(10,7))\n", | |
"data.boxplot(ax=ax)\n", | |
"ax.set_title('Feature Importances')\n", | |
"plt.show()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 720x504 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "IjDH0p1dBuyV", | |
"outputId": "1f8b09f3-52ce-4b4a-9297-813d798e343f", | |
"scrolled": true, | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 173 | |
} | |
}, | |
"source": [ | |
"import eli5\n", | |
"eli5.show_weights(perm, feature_names = X.columns.tolist())" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <style>\n", | |
" table.eli5-weights tr:hover {\n", | |
" filter: brightness(85%);\n", | |
" }\n", | |
"</style>\n", | |
"\n", | |
"\n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
" <table class=\"eli5-weights eli5-feature-importances\" style=\"border-collapse: collapse; border: none; margin-top: 0em; table-layout: auto;\">\n", | |
" <thead>\n", | |
" <tr style=\"border: none;\">\n", | |
" <th style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">Weight</th>\n", | |
" <th style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">Feature</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" \n", | |
" <tr style=\"background-color: hsl(120, 100.00%, 80.00%); border: none;\">\n", | |
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n", | |
" 0.0703\n", | |
" \n", | |
" ± 0.0298\n", | |
" \n", | |
" </td>\n", | |
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n", | |
" tgrade\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr style=\"background-color: hsl(120, 100.00%, 81.86%); border: none;\">\n", | |
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n", | |
" 0.0611\n", | |
" \n", | |
" ± 0.0198\n", | |
" \n", | |
" </td>\n", | |
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n", | |
" progrec\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr style=\"background-color: hsl(120, 100.00%, 92.67%); border: none;\">\n", | |
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n", | |
" 0.0167\n", | |
" \n", | |
" ± 0.0121\n", | |
" \n", | |
" </td>\n", | |
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n", | |
" horTh\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr style=\"background-color: hsl(120, 100.00%, 93.29%); border: none;\">\n", | |
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n", | |
" 0.0148\n", | |
" \n", | |
" ± 0.0111\n", | |
" \n", | |
" </td>\n", | |
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n", | |
" menostat\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr style=\"background-color: hsl(120, 100.00%, 96.19%); border: none;\">\n", | |
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n", | |
" 0.0066\n", | |
" \n", | |
" ± 0.0092\n", | |
" \n", | |
" </td>\n", | |
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n", | |
" age\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr style=\"background-color: hsl(120, 100.00%, 96.47%); border: none;\">\n", | |
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n", | |
" 0.0059\n", | |
" \n", | |
" ± 0.0074\n", | |
" \n", | |
" </td>\n", | |
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n", | |
" tsize\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr style=\"background-color: hsl(120, 100.00%, 97.64%); border: none;\">\n", | |
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n", | |
" 0.0033\n", | |
" \n", | |
" ± 0.0045\n", | |
" \n", | |
" </td>\n", | |
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n", | |
" pnodes\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
" <tr style=\"background-color: hsl(120, 100.00%, 99.29%); border: none;\">\n", | |
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n", | |
" 0.0006\n", | |
" \n", | |
" ± 0.0018\n", | |
" \n", | |
" </td>\n", | |
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n", | |
" estrec\n", | |
" </td>\n", | |
" </tr>\n", | |
" \n", | |
" \n", | |
" </tbody>\n", | |
"</table>\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
" \n", | |
"\n", | |
"\n", | |
"\n" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 24 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment