Skip to content

Instantly share code, notes, and snippets.

@alonsosilvaallende
Last active October 16, 2023 08:20
Show Gist options
  • Save alonsosilvaallende/684ed7326e5c1192ad339fc8e59e04f9 to your computer and use it in GitHub Desktop.
Save alonsosilvaallende/684ed7326e5c1192ad339fc8e59e04f9 to your computer and use it in GitHub Desktop.
Cox_PH_and_RSF-colab.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Cox_PH_and_RSF-colab.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/684ed7326e5c1192ad339fc8e59e04f9/cox_ph_and_rsf-colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UN3PoUTSb2nT"
},
"source": [
"The objective of this post is to compare two models (Cox proportional hazards model and Random survival forest model) to estimate the survival probability given a set of features/covariables.\n",
"\n",
">[\"Experimental Comparison of Semi-parametric, Parametric, and Machine Learning Models for Time-to-Event Analysis Through the Concordance Index,\"](https://arxiv.org/abs/2003.08820)\n",
"Camila Fernandez, Chung Shue Chen, Pierre Gaillard, Alonso Silva\n",
"\n",
"To perform this analysis we will use [scikit-learn](https://scikit-learn.org/) and [scikit-survival](https://pypi.org/project/scikit-survival/). Finally, we will use [eli5](https://eli5.readthedocs.io/en/latest/index.html) to study feature importances (computed with permutation importance)."
]
},
{
"cell_type": "code",
"metadata": {
"id": "pzqGFx_Bb_6A"
},
"source": [
"!pip install -q scikit-survival"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VoQVEI5p_rga"
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "SnT6e_JPb2ns"
},
"source": [
"We first download a dataset from scikit-survival."
]
},
{
"cell_type": "code",
"metadata": {
"id": "D0xxNWzI-N3j"
},
"source": [
"from sksurv.datasets import load_gbsg2\n",
"\n",
"X, y = load_gbsg2()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "IBTo4q_Hb2n0"
},
"source": [
"## An example: German Breast Cancer Study Group 2 (gbcsg)\n",
"\n",
"This dataset contains the following 8 features/covariables:\n",
"\n",
"- age: age (in years),\n",
"- estrec: estrogen receptor (in fmol),\n",
"- horTh: hormonal therapy (yes or no),\n",
"- menostat: menopausal status (premenopausal or postmenopausal),\n",
"- pnodes: number of positive nodes,\n",
"- progrec: progesterone receptor (in fmol),\n",
"- tgrade: tumor grade (I < II < III),\n",
"- tsize: tumor size (in mm).\n",
"\n",
"and the two outputs:\n",
"\n",
"- recurrence free time (in days),\n",
"- censoring indicator (0 - censored, 1 - event).\n",
"\n",
"The dataset has 686 samples and 8 features/covariables.\n",
"\n",
"\n",
"**References**\n",
"\n",
"M. Schumacher, G. Basert, H. Bojar, K. Huebner, M. Olschewski, W. Sauerbrei, C. Schmoor, C. Beyerle, R.L.A. Neumann and H.F. Rauschecker for the German Breast Cancer Study Group (1994), [Randomized 2 x 2 trial evaluating hormonal treatment and the duration of chemotherapy in node-positive breast cancer patients](https://www.ncbi.nlm.nih.gov/pubmed/7931478). Journal of Clinical Oncology, 12, 2086–2093."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kAsZ72YYb2n3"
},
"source": [
"Let's take a look at the features/covariates."
]
},
{
"cell_type": "code",
"metadata": {
"id": "_OlmlI6g-X43",
"outputId": "f3b97327-9b8a-491b-e36a-00887b6b2f9d",
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
}
},
"source": [
"cols = [\"age\", \"estrec\", \"pnodes\", \"progrec\", \"tsize\"]\n",
"formatdict = {}\n",
"for col in cols: formatdict[col] = \"{:,.0f}\"\n",
"X.head(10).style.hide(axis=\"index\").format(formatdict)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f6efea2a9b0>"
],
"text/html": [
"<style type=\"text/css\">\n",
"</style>\n",
"<table id=\"T_b723e\" class=\"dataframe\">\n",
" <thead>\n",
" <tr>\n",
" <th id=\"T_b723e_level0_col0\" class=\"col_heading level0 col0\" >age</th>\n",
" <th id=\"T_b723e_level0_col1\" class=\"col_heading level0 col1\" >estrec</th>\n",
" <th id=\"T_b723e_level0_col2\" class=\"col_heading level0 col2\" >horTh</th>\n",
" <th id=\"T_b723e_level0_col3\" class=\"col_heading level0 col3\" >menostat</th>\n",
" <th id=\"T_b723e_level0_col4\" class=\"col_heading level0 col4\" >pnodes</th>\n",
" <th id=\"T_b723e_level0_col5\" class=\"col_heading level0 col5\" >progrec</th>\n",
" <th id=\"T_b723e_level0_col6\" class=\"col_heading level0 col6\" >tgrade</th>\n",
" <th id=\"T_b723e_level0_col7\" class=\"col_heading level0 col7\" >tsize</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_b723e_row0_col0\" class=\"data row0 col0\" >70</td>\n",
" <td id=\"T_b723e_row0_col1\" class=\"data row0 col1\" >66</td>\n",
" <td id=\"T_b723e_row0_col2\" class=\"data row0 col2\" >no</td>\n",
" <td id=\"T_b723e_row0_col3\" class=\"data row0 col3\" >Post</td>\n",
" <td id=\"T_b723e_row0_col4\" class=\"data row0 col4\" >3</td>\n",
" <td id=\"T_b723e_row0_col5\" class=\"data row0 col5\" >48</td>\n",
" <td id=\"T_b723e_row0_col6\" class=\"data row0 col6\" >II</td>\n",
" <td id=\"T_b723e_row0_col7\" class=\"data row0 col7\" >21</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b723e_row1_col0\" class=\"data row1 col0\" >56</td>\n",
" <td id=\"T_b723e_row1_col1\" class=\"data row1 col1\" >77</td>\n",
" <td id=\"T_b723e_row1_col2\" class=\"data row1 col2\" >yes</td>\n",
" <td id=\"T_b723e_row1_col3\" class=\"data row1 col3\" >Post</td>\n",
" <td id=\"T_b723e_row1_col4\" class=\"data row1 col4\" >7</td>\n",
" <td id=\"T_b723e_row1_col5\" class=\"data row1 col5\" >61</td>\n",
" <td id=\"T_b723e_row1_col6\" class=\"data row1 col6\" >II</td>\n",
" <td id=\"T_b723e_row1_col7\" class=\"data row1 col7\" >12</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b723e_row2_col0\" class=\"data row2 col0\" >58</td>\n",
" <td id=\"T_b723e_row2_col1\" class=\"data row2 col1\" >271</td>\n",
" <td id=\"T_b723e_row2_col2\" class=\"data row2 col2\" >yes</td>\n",
" <td id=\"T_b723e_row2_col3\" class=\"data row2 col3\" >Post</td>\n",
" <td id=\"T_b723e_row2_col4\" class=\"data row2 col4\" >9</td>\n",
" <td id=\"T_b723e_row2_col5\" class=\"data row2 col5\" >52</td>\n",
" <td id=\"T_b723e_row2_col6\" class=\"data row2 col6\" >II</td>\n",
" <td id=\"T_b723e_row2_col7\" class=\"data row2 col7\" >35</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b723e_row3_col0\" class=\"data row3 col0\" >59</td>\n",
" <td id=\"T_b723e_row3_col1\" class=\"data row3 col1\" >29</td>\n",
" <td id=\"T_b723e_row3_col2\" class=\"data row3 col2\" >yes</td>\n",
" <td id=\"T_b723e_row3_col3\" class=\"data row3 col3\" >Post</td>\n",
" <td id=\"T_b723e_row3_col4\" class=\"data row3 col4\" >4</td>\n",
" <td id=\"T_b723e_row3_col5\" class=\"data row3 col5\" >60</td>\n",
" <td id=\"T_b723e_row3_col6\" class=\"data row3 col6\" >II</td>\n",
" <td id=\"T_b723e_row3_col7\" class=\"data row3 col7\" >17</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b723e_row4_col0\" class=\"data row4 col0\" >73</td>\n",
" <td id=\"T_b723e_row4_col1\" class=\"data row4 col1\" >65</td>\n",
" <td id=\"T_b723e_row4_col2\" class=\"data row4 col2\" >no</td>\n",
" <td id=\"T_b723e_row4_col3\" class=\"data row4 col3\" >Post</td>\n",
" <td id=\"T_b723e_row4_col4\" class=\"data row4 col4\" >1</td>\n",
" <td id=\"T_b723e_row4_col5\" class=\"data row4 col5\" >26</td>\n",
" <td id=\"T_b723e_row4_col6\" class=\"data row4 col6\" >II</td>\n",
" <td id=\"T_b723e_row4_col7\" class=\"data row4 col7\" >35</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b723e_row5_col0\" class=\"data row5 col0\" >32</td>\n",
" <td id=\"T_b723e_row5_col1\" class=\"data row5 col1\" >13</td>\n",
" <td id=\"T_b723e_row5_col2\" class=\"data row5 col2\" >no</td>\n",
" <td id=\"T_b723e_row5_col3\" class=\"data row5 col3\" >Pre</td>\n",
" <td id=\"T_b723e_row5_col4\" class=\"data row5 col4\" >24</td>\n",
" <td id=\"T_b723e_row5_col5\" class=\"data row5 col5\" >0</td>\n",
" <td id=\"T_b723e_row5_col6\" class=\"data row5 col6\" >III</td>\n",
" <td id=\"T_b723e_row5_col7\" class=\"data row5 col7\" >57</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b723e_row6_col0\" class=\"data row6 col0\" >59</td>\n",
" <td id=\"T_b723e_row6_col1\" class=\"data row6 col1\" >0</td>\n",
" <td id=\"T_b723e_row6_col2\" class=\"data row6 col2\" >yes</td>\n",
" <td id=\"T_b723e_row6_col3\" class=\"data row6 col3\" >Post</td>\n",
" <td id=\"T_b723e_row6_col4\" class=\"data row6 col4\" >2</td>\n",
" <td id=\"T_b723e_row6_col5\" class=\"data row6 col5\" >181</td>\n",
" <td id=\"T_b723e_row6_col6\" class=\"data row6 col6\" >II</td>\n",
" <td id=\"T_b723e_row6_col7\" class=\"data row6 col7\" >8</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b723e_row7_col0\" class=\"data row7 col0\" >65</td>\n",
" <td id=\"T_b723e_row7_col1\" class=\"data row7 col1\" >25</td>\n",
" <td id=\"T_b723e_row7_col2\" class=\"data row7 col2\" >no</td>\n",
" <td id=\"T_b723e_row7_col3\" class=\"data row7 col3\" >Post</td>\n",
" <td id=\"T_b723e_row7_col4\" class=\"data row7 col4\" >1</td>\n",
" <td id=\"T_b723e_row7_col5\" class=\"data row7 col5\" >192</td>\n",
" <td id=\"T_b723e_row7_col6\" class=\"data row7 col6\" >II</td>\n",
" <td id=\"T_b723e_row7_col7\" class=\"data row7 col7\" >16</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b723e_row8_col0\" class=\"data row8 col0\" >80</td>\n",
" <td id=\"T_b723e_row8_col1\" class=\"data row8 col1\" >59</td>\n",
" <td id=\"T_b723e_row8_col2\" class=\"data row8 col2\" >no</td>\n",
" <td id=\"T_b723e_row8_col3\" class=\"data row8 col3\" >Post</td>\n",
" <td id=\"T_b723e_row8_col4\" class=\"data row8 col4\" >30</td>\n",
" <td id=\"T_b723e_row8_col5\" class=\"data row8 col5\" >0</td>\n",
" <td id=\"T_b723e_row8_col6\" class=\"data row8 col6\" >II</td>\n",
" <td id=\"T_b723e_row8_col7\" class=\"data row8 col7\" >39</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_b723e_row9_col0\" class=\"data row9 col0\" >66</td>\n",
" <td id=\"T_b723e_row9_col1\" class=\"data row9 col1\" >3</td>\n",
" <td id=\"T_b723e_row9_col2\" class=\"data row9 col2\" >no</td>\n",
" <td id=\"T_b723e_row9_col3\" class=\"data row9 col3\" >Post</td>\n",
" <td id=\"T_b723e_row9_col4\" class=\"data row9 col4\" >7</td>\n",
" <td id=\"T_b723e_row9_col5\" class=\"data row9 col5\" >0</td>\n",
" <td id=\"T_b723e_row9_col6\" class=\"data row9 col6\" >II</td>\n",
" <td id=\"T_b723e_row9_col7\" class=\"data row9 col7\" >18</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zcMjHeAEb2n_"
},
"source": [
"Let's take a look at the output."
]
},
{
"cell_type": "code",
"metadata": {
"id": "h8ltRTa4_WOn",
"outputId": "9e762f57-bcd9-4961-f2d0-4a445ad69d85",
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"y[:10]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([( True, 1814.), ( True, 2018.), ( True, 712.), ( True, 1807.),\n",
" ( True, 772.), ( True, 448.), (False, 2172.), (False, 2161.),\n",
" ( True, 471.), (False, 2014.)],\n",
" dtype=[('cens', '?'), ('time', '<f8')])"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A5NXlNYob2oH"
},
"source": [
"For the output, scikit-survival uses a numpy nd array, so to show it we do a dataframe."
]
},
{
"cell_type": "code",
"source": [
"df_y = pd.DataFrame(data={'time': y['time'].astype(int), 'event': y['cens']})\n",
"df_y[:10].style.hide_index().highlight_min('event', color='lightgreen')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
},
"id": "AfPvZcjJ-GwQ",
"outputId": "b8561a10-5167-4901-b454-0c7e6be0136e"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_63853_row6_col1, #T_63853_row7_col1, #T_63853_row9_col1 {\n",
" background-color: lightgreen;\n",
"}\n",
"</style>\n",
"<table id=\"T_63853_\" class=\"dataframe\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"col_heading level0 col0\" >time</th>\n",
" <th class=\"col_heading level0 col1\" >event</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td id=\"T_63853_row0_col0\" class=\"data row0 col0\" >1814</td>\n",
" <td id=\"T_63853_row0_col1\" class=\"data row0 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_63853_row1_col0\" class=\"data row1 col0\" >2018</td>\n",
" <td id=\"T_63853_row1_col1\" class=\"data row1 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_63853_row2_col0\" class=\"data row2 col0\" >712</td>\n",
" <td id=\"T_63853_row2_col1\" class=\"data row2 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_63853_row3_col0\" class=\"data row3 col0\" >1807</td>\n",
" <td id=\"T_63853_row3_col1\" class=\"data row3 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_63853_row4_col0\" class=\"data row4 col0\" >772</td>\n",
" <td id=\"T_63853_row4_col1\" class=\"data row4 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_63853_row5_col0\" class=\"data row5 col0\" >448</td>\n",
" <td id=\"T_63853_row5_col1\" class=\"data row5 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_63853_row6_col0\" class=\"data row6 col0\" >2172</td>\n",
" <td id=\"T_63853_row6_col1\" class=\"data row6 col1\" >False</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_63853_row7_col0\" class=\"data row7 col0\" >2161</td>\n",
" <td id=\"T_63853_row7_col1\" class=\"data row7 col1\" >False</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_63853_row8_col0\" class=\"data row8 col0\" >471</td>\n",
" <td id=\"T_63853_row8_col1\" class=\"data row8 col1\" >True</td>\n",
" </tr>\n",
" <tr>\n",
" <td id=\"T_63853_row9_col0\" class=\"data row9 col0\" >2014</td>\n",
" <td id=\"T_63853_row9_col1\" class=\"data row9 col1\" >False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x7f8de02b0510>"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xzW5x7ljb2oP"
},
"source": [
"One of the main challenges of survival analysis is **right censoring**, i.e., by the end of the study, the event of interest (for example, in medicine 'death of a patient' or in this dataset 'recurrence of cancer') has only occurred for a subset of the observations.\n",
"\n",
"The **right censoring** in this dataset is given by the column named 'event' and it's a variable which can take value 'True' if the patient had a recurrence of cancer or 'False' if the patient is recurrence free at the indicated time (right-censored samples)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VmfAR7igb2oW"
},
"source": [
"Let's see how many right-censored samples do we have."
]
},
{
"cell_type": "code",
"metadata": {
"id": "rzS8h1GG_o_A",
"outputId": "3fc3c587-1978-45c5-abaa-6eed2bd5af6a",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"print(f'Number of samples: {len(df_y)}')\n",
"print(f'Number of right censored samples: {len(df_y.query(\"event == False\"))}')\n",
"print(f'Percentage of right censored samples: {100*len(df_y.query(\"event == False\"))/len(df_y):.1f}%')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of samples: 686\n",
"Number of right censored samples: 387\n",
"Percentage of right censored samples: 56.4%\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VtsENFsnQhZx"
},
"source": [
"There are 387 patients (56.4%) who were right censored (recurrence free) at the end of the study.\n",
"\n",
"Let's divide our dataset in training and test sets."
]
},
{
"cell_type": "code",
"metadata": {
"id": "duYhddUr_1nH",
"outputId": "17d1a3fb-2793-446e-cdd6-3a866d24fbc3",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_trn, X_test, y_trn, y_test = train_test_split(X, y, random_state=42)\n",
"\n",
"print(f'Number of training samples: {len(y_trn)}')\n",
"print(f'Number of test samples: {len(y_test)}')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of training samples: 514\n",
"Number of test samples: 172\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3VEOV-vWb2ow"
},
"source": [
"We divide the features/covariates into continuous and categorical."
]
},
{
"cell_type": "code",
"metadata": {
"id": "R1UVhukR_8VA"
},
"source": [
"scaling_cols = [c for c in X.columns if X[c].dtype.kind in ['i', 'f']]\n",
"cat_cols = [c for c in X.columns if X[c].dtype.kind not in [\"i\", \"f\"]]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "jDjgk9PWb2o3"
},
"source": [
"We use ordinal encoding for categorical features/covariates and standard scaling for continuous features/covariates."
]
},
{
"cell_type": "code",
"metadata": {
"id": "axfu8RQeACQx"
},
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import OrdinalEncoder\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"preprocessor = ColumnTransformer(\n",
" [('cat-preprocessor', OrdinalEncoder(), cat_cols),\n",
" ('standard-scaler', StandardScaler(), scaling_cols)],\n",
" remainder='passthrough', sparse_threshold=0)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "N_nodxlEb2o-"
},
"source": [
"# Baseline: Cox Proportional Hazards model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KqdChwYJb2o_"
},
"source": [
"Cox Proportional Hazards model assumes that the log-hazard of a subject is a linear function of their $m$ static covariates/features $h_i, i\\in\\{1,\\ldots,m\\}$, and a population-level baseline hazard function $h_0(t)$ that changes over time:\n",
"\\begin{equation}\n",
"h(t|x)=h_0(t)\\exp\\left(\\sum_{i=1}^mh_i(x_i-\\bar{x_i})\\right).\n",
"\\end{equation}\n",
"\n",
"The term *proportional hazards* refers to the assumption of a constant relationship between the dependent variable and the regression coefficients."
]
},
{
"cell_type": "code",
"metadata": {
"id": "77YbwMKvAFHQ",
"outputId": "f5a6349f-b86e-49df-9a2a-9ae06d610fc0",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sksurv.linear_model import CoxPHSurvivalAnalysis\n",
"from sksurv.metrics import concordance_index_censored\n",
"\n",
"cox = make_pipeline(preprocessor, CoxPHSurvivalAnalysis())\n",
"cox.fit(X_trn, y_trn)\n",
"\n",
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], cox.predict(X_test))\n",
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The c-index of Cox is given by 0.635\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZNHQ-bkWALNy",
"outputId": "8683c357-0633-4f9f-8096-9bf13e2d8321",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"from scipy.stats import reciprocal\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"\n",
"param_distributions = {\n",
" 'coxphsurvivalanalysis__alpha': reciprocal(0.1, 100),\n",
"}\n",
"\n",
"model_random_search = RandomizedSearchCV(\n",
" cox, param_distributions=param_distributions, n_iter=50, n_jobs=-1, cv=3, random_state=42)\n",
"model_random_search.fit(X_trn, y_trn)\n",
"\n",
"print(\n",
" f\"The c-index of Cox using a {model_random_search.__class__.__name__} is \"\n",
" f\"{model_random_search.score(X_test, y_test):.3f}\")\n",
"print(\n",
" f\"The best set of parameters is: {model_random_search.best_params_}\"\n",
")"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The c-index of Cox using a RandomizedSearchCV is 0.646\n",
"The best set of parameters is: {'coxphsurvivalanalysis__alpha': 31.428808908401084}\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "MrhAveCQAdxF",
"outputId": "0f791b44-21ac-4546-e7d0-8553c757af6e",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"alpha = model_random_search.best_params_['coxphsurvivalanalysis__alpha']\n",
"cox_best = make_pipeline(preprocessor, CoxPHSurvivalAnalysis(alpha=alpha))\n",
"cox_best.fit(X_trn, y_trn)\n",
"\n",
"ci_cox = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], cox_best.predict(X_test))\n",
"print(f'The c-index of Cox is given by {ci_cox[0]:.3f}')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The c-index of Cox is given by 0.646\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WArdzxG0cZ5d",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "05cf000a-579c-42d5-aaeb-f53d3e68a6c8"
},
"source": [
"!pip install -q eli5"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[?25l\r\u001b[K |███ | 10 kB 24.3 MB/s eta 0:00:01\r\u001b[K |██████▏ | 20 kB 13.4 MB/s eta 0:00:01\r\u001b[K |█████████▎ | 30 kB 10.5 MB/s eta 0:00:01\r\u001b[K |████████████▍ | 40 kB 9.1 MB/s eta 0:00:01\r\u001b[K |███████████████▌ | 51 kB 5.3 MB/s eta 0:00:01\r\u001b[K |██████████████████▌ | 61 kB 5.9 MB/s eta 0:00:01\r\u001b[K |█████████████████████▋ | 71 kB 5.7 MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 81 kB 6.4 MB/s eta 0:00:01\r\u001b[K |███████████████████████████▉ | 92 kB 5.0 MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 102 kB 5.4 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 106 kB 5.4 MB/s \n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Oy3whfxOAlzN"
},
"source": [
"from eli5.sklearn import PermutationImportance"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "j2wH_sxgBX1-"
},
"source": [
"perm = PermutationImportance(\n",
" cox_best.steps[-1][1], n_iter=100, random_state=42).fit(preprocessor.fit_transform(X_trn),y_trn)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1Rdepqh18UDV"
},
"source": [
"data = perm.results_\n",
"data = pd.DataFrame(data, columns=X_trn.columns)\n",
"meds = data.median()\n",
"meds = meds.sort_values(ascending=False)\n",
"data = data[meds.index]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PaPRs7hR9QL5",
"outputId": "149f6323-0025-4cee-cb8c-bc443a8420f2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 444
}
},
"source": [
"fig, ax = plt.subplots(figsize=(10,7))\n",
"data.boxplot(ax=ax)\n",
"ax.set_title('Feature Importances')\n",
"plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x504 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "IjDH0p1dBuyV",
"outputId": "1f8b09f3-52ce-4b4a-9297-813d798e343f",
"scrolled": true,
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
}
},
"source": [
"import eli5\n",
"eli5.show_weights(perm, feature_names = X.columns.tolist())"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
" <style>\n",
" table.eli5-weights tr:hover {\n",
" filter: brightness(85%);\n",
" }\n",
"</style>\n",
"\n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
" <table class=\"eli5-weights eli5-feature-importances\" style=\"border-collapse: collapse; border: none; margin-top: 0em; table-layout: auto;\">\n",
" <thead>\n",
" <tr style=\"border: none;\">\n",
" <th style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">Weight</th>\n",
" <th style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">Feature</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 80.00%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0703\n",
" \n",
" &plusmn; 0.0298\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" tgrade\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 81.86%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0611\n",
" \n",
" &plusmn; 0.0198\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" progrec\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 92.67%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0167\n",
" \n",
" &plusmn; 0.0121\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" horTh\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 93.29%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0148\n",
" \n",
" &plusmn; 0.0111\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" menostat\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 96.19%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0066\n",
" \n",
" &plusmn; 0.0092\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" age\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 96.47%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0059\n",
" \n",
" &plusmn; 0.0074\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" tsize\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 97.64%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0033\n",
" \n",
" &plusmn; 0.0045\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" pnodes\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 99.29%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0006\n",
" \n",
" &plusmn; 0.0018\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" estrec\n",
" </td>\n",
" </tr>\n",
" \n",
" \n",
" </tbody>\n",
"</table>\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
"\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"execution_count": 24
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment