Skip to content

Instantly share code, notes, and snippets.

@alonsosilvaallende
Created November 30, 2022 17:22
Show Gist options
  • Save alonsosilvaallende/2a88abda7dfa28cd17b37b80cbaf95f9 to your computer and use it in GitHub Desktop.
Save alonsosilvaallende/2a88abda7dfa28cd17b37b80cbaf95f9 to your computer and use it in GitHub Desktop.
BSC_project.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNIWt9YZ/k7c6Y4xUNbVSYo",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/2a88abda7dfa28cd17b37b80cbaf95f9/bsc_project.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "lou1mDgsb4d6"
},
"outputs": [],
"source": [
"!pip install -q scikit-survival"
]
},
{
"cell_type": "code",
"source": [
"!pip install -q eli5"
],
"metadata": {
"id": "jjqvp6PC1PVT"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install -q shap"
],
"metadata": {
"id": "FSZYrbwn6AWz"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt"
],
"metadata": {
"id": "45lQix8lcA2y"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sksurv.datasets import load_gbsg2\n",
"\n",
"X, y = load_gbsg2()\n",
"X.head(3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 143
},
"id": "zZHRTFkdcIM4",
"outputId": "f78f4767-6d2c-4f99-c954-93f96c069183"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" age estrec horTh menostat pnodes progrec tgrade tsize\n",
"0 70.0 66.0 no Post 3.0 48.0 II 21.0\n",
"1 56.0 77.0 yes Post 7.0 61.0 II 12.0\n",
"2 58.0 271.0 yes Post 9.0 52.0 II 35.0"
],
"text/html": [
"\n",
" <div id=\"df-d17a4e67-f59d-49d6-ae58-db72708f751b\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>estrec</th>\n",
" <th>horTh</th>\n",
" <th>menostat</th>\n",
" <th>pnodes</th>\n",
" <th>progrec</th>\n",
" <th>tgrade</th>\n",
" <th>tsize</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>70.0</td>\n",
" <td>66.0</td>\n",
" <td>no</td>\n",
" <td>Post</td>\n",
" <td>3.0</td>\n",
" <td>48.0</td>\n",
" <td>II</td>\n",
" <td>21.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>56.0</td>\n",
" <td>77.0</td>\n",
" <td>yes</td>\n",
" <td>Post</td>\n",
" <td>7.0</td>\n",
" <td>61.0</td>\n",
" <td>II</td>\n",
" <td>12.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>58.0</td>\n",
" <td>271.0</td>\n",
" <td>yes</td>\n",
" <td>Post</td>\n",
" <td>9.0</td>\n",
" <td>52.0</td>\n",
" <td>II</td>\n",
" <td>35.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-d17a4e67-f59d-49d6-ae58-db72708f751b')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-d17a4e67-f59d-49d6-ae58-db72708f751b button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-d17a4e67-f59d-49d6-ae58-db72708f751b');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"y[:3]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ep87Os_IcLI6",
"outputId": "324b232b-be85-4c06-c885-f0e962a37dea"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([( True, 1814.), ( True, 2018.), ( True, 712.)],\n",
" dtype=[('cens', '?'), ('time', '<f8')])"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"print(f\"% censorship: {100*(1-np.mean(y['cens'])):.2f}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IHYpMoqFcPpA",
"outputId": "93c07c5a-ed43-4054-f864-342e9ed2b744"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"% censorship: 56.41\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X_trn, X_test, y_trn, y_test = train_test_split(X, y, random_state=42)"
],
"metadata": {
"id": "c1aVxKXJccAr"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"X.dtypes"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "koZphqbdcl5q",
"outputId": "83b85507-5c74-4fba-c40d-e48c4722ae9a"
},
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"age float64\n",
"estrec float64\n",
"horTh category\n",
"menostat category\n",
"pnodes float64\n",
"progrec float64\n",
"tgrade category\n",
"tsize float64\n",
"dtype: object"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"source": [
"def my_score(est, X_trn=X_trn, y_trn=y_trn, X_test=X_test, y_test=y_test):\n",
" est.fit(X_trn, y_trn)\n",
" ci = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], est.predict(X_test))\n",
" survs = est.predict_survival_function(X_test)\n",
" times = np.arange(365, 1826)\n",
" preds = np.asarray([[fn(t) for t in times] for fn in survs])\n",
" ibs = integrated_brier_score(y_trn, y_test, preds, times)\n",
" print(f'Concordance index: {ci[0]}')\n",
" print(f'Integrated brier score: {ibs}')"
],
"metadata": {
"id": "Ngpbzw5OdeMU"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"scaling_cols = [c for c in X.columns if X[c].dtype.kind in ['i', 'f']]\n",
"cat_cols = [c for c in X.columns if X[c].dtype.kind not in ['i', 'f']]"
],
"metadata": {
"id": "BFZSrnvYcqGt"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import OrdinalEncoder\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"preprocessor = ColumnTransformer(\n",
" [('cat-preprocessor', OrdinalEncoder(), cat_cols),\n",
" ('standard-scaler', StandardScaler(), scaling_cols)],\n",
" remainder='passthrough', sparse_threshold=0)"
],
"metadata": {
"id": "FwVbgZavcunZ"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sksurv.linear_model import CoxPHSurvivalAnalysis\n",
"from sksurv.metrics import concordance_index_censored\n",
"from sksurv.metrics import integrated_brier_score\n",
"\n",
"cph = make_pipeline(preprocessor, CoxPHSurvivalAnalysis())\n",
"my_score(cph)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "i0mE9J_XcyAl",
"outputId": "02e04bc4-d9c9-4ee6-c93e-92c54c5f7520"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Concordance index: 0.634554233894712\n",
"Integrated brier score: 0.20009274951127964\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"param_grid = {\n",
" 'coxphsurvivalanalysis__alpha': 10**np.linspace(-1,2,5),\n",
"}\n",
"cph_gs = GridSearchCV(\n",
" cph, param_grid=param_grid, n_jobs=-1, cv=3)\n",
"cph_gs.fit(X_trn, y_trn)\n",
"cph_best = make_pipeline(preprocessor, CoxPHSurvivalAnalysis(alpha=cph_gs.best_params_['coxphsurvivalanalysis__alpha']))\n",
"cph_best.fit(X_trn, y_trn)\n",
"my_score(cph_best)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "l3hsTAtne71k",
"outputId": "ecd3a9dc-0600-4f08-b04c-1f6368e729df"
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Concordance index: 0.6430338004946414\n",
"Integrated brier score: 0.19743464033866906\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sksurv.ensemble import RandomSurvivalForest\n",
"\n",
"rsf = make_pipeline(preprocessor, RandomSurvivalForest(random_state=42))\n",
"rsf.fit(X_trn, y_trn)\n",
"my_score(rsf)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hlj--y94u-as",
"outputId": "51da1c55-127b-4869-e42b-256699bad9ce"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Concordance index: 0.6579908138028501\n",
"Integrated brier score: 0.1984125909087399\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"rsf.get_params()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "B_dFO8mYvOOm",
"outputId": "ff6f4935-cd86-44cf-917a-5e2e0e240f6b"
},
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'memory': None,\n",
" 'steps': [('columntransformer',\n",
" ColumnTransformer(remainder='passthrough', sparse_threshold=0,\n",
" transformers=[('cat-preprocessor', OrdinalEncoder(),\n",
" ['horTh', 'menostat', 'tgrade']),\n",
" ('standard-scaler', StandardScaler(),\n",
" ['age', 'estrec', 'pnodes', 'progrec',\n",
" 'tsize'])])),\n",
" ('randomsurvivalforest', RandomSurvivalForest(random_state=42))],\n",
" 'verbose': False,\n",
" 'columntransformer': ColumnTransformer(remainder='passthrough', sparse_threshold=0,\n",
" transformers=[('cat-preprocessor', OrdinalEncoder(),\n",
" ['horTh', 'menostat', 'tgrade']),\n",
" ('standard-scaler', StandardScaler(),\n",
" ['age', 'estrec', 'pnodes', 'progrec',\n",
" 'tsize'])]),\n",
" 'randomsurvivalforest': RandomSurvivalForest(random_state=42),\n",
" 'columntransformer__n_jobs': None,\n",
" 'columntransformer__remainder': 'passthrough',\n",
" 'columntransformer__sparse_threshold': 0,\n",
" 'columntransformer__transformer_weights': None,\n",
" 'columntransformer__transformers': [('cat-preprocessor',\n",
" OrdinalEncoder(),\n",
" ['horTh', 'menostat', 'tgrade']),\n",
" ('standard-scaler',\n",
" StandardScaler(),\n",
" ['age', 'estrec', 'pnodes', 'progrec', 'tsize'])],\n",
" 'columntransformer__verbose': False,\n",
" 'columntransformer__verbose_feature_names_out': True,\n",
" 'columntransformer__cat-preprocessor': OrdinalEncoder(),\n",
" 'columntransformer__standard-scaler': StandardScaler(),\n",
" 'columntransformer__cat-preprocessor__categories': 'auto',\n",
" 'columntransformer__cat-preprocessor__dtype': numpy.float64,\n",
" 'columntransformer__cat-preprocessor__handle_unknown': 'error',\n",
" 'columntransformer__cat-preprocessor__unknown_value': None,\n",
" 'columntransformer__standard-scaler__copy': True,\n",
" 'columntransformer__standard-scaler__with_mean': True,\n",
" 'columntransformer__standard-scaler__with_std': True,\n",
" 'randomsurvivalforest__bootstrap': True,\n",
" 'randomsurvivalforest__max_depth': None,\n",
" 'randomsurvivalforest__max_features': 'auto',\n",
" 'randomsurvivalforest__max_leaf_nodes': None,\n",
" 'randomsurvivalforest__max_samples': None,\n",
" 'randomsurvivalforest__min_samples_leaf': 3,\n",
" 'randomsurvivalforest__min_samples_split': 6,\n",
" 'randomsurvivalforest__min_weight_fraction_leaf': 0.0,\n",
" 'randomsurvivalforest__n_estimators': 100,\n",
" 'randomsurvivalforest__n_jobs': None,\n",
" 'randomsurvivalforest__oob_score': False,\n",
" 'randomsurvivalforest__random_state': 42,\n",
" 'randomsurvivalforest__verbose': 0,\n",
" 'randomsurvivalforest__warm_start': False}"
]
},
"metadata": {},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"source": [
"param_grid_rsf = {\n",
" 'randomsurvivalforest__max_features': np.arange(3, 8),\n",
" 'randomsurvivalforest__max_depth': [5,10,None],\n",
" 'randomsurvivalforest__min_samples_leaf': [1,3,5],\n",
"}\n",
"\n",
"rsf_gs = GridSearchCV(\n",
" rsf, param_grid=param_grid_rsf, cv=3)\n",
"rsf_gs.fit(X_trn, y_trn)\n",
"rsf_best = make_pipeline(preprocessor, RandomSurvivalForest(max_depth= rsf_gs.best_params_['randomsurvivalforest__max_depth'],\n",
" max_features= rsf_gs.best_params_['randomsurvivalforest__max_features'],\n",
" min_samples_leaf= rsf_gs.best_params_['randomsurvivalforest__min_samples_leaf']))\n",
"rsf_best.fit(X_trn, y_trn)\n",
"my_score(rsf_best)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hjy0rYCyvchH",
"outputId": "5af669d2-eaa9-4181-8ed2-51c59f54c869"
},
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Concordance index: 0.6247791779531269\n",
"Integrated brier score: 0.20951841008965424\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sksurv.ensemble import GradientBoostingSurvivalAnalysis\n",
"\n",
"gbc = make_pipeline(preprocessor, GradientBoostingSurvivalAnalysis(random_state=42))\n",
"gbc.fit(X_trn, y_trn)\n",
"my_score(gbc)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QBNY5WXjxHMb",
"outputId": "6d7d8ab0-32dc-488a-e187-f674f6aa5824"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Concordance index: 0.6410905664821576\n",
"Integrated brier score: 0.19772913228546876\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"param_grid = {\n",
" 'gradientboostingsurvivalanalysis__learning_rate': 10**np.linspace(-1,0,3),\n",
" 'gradientboostingsurvivalanalysis__max_depth': np.arange(3, 10, 2),\n",
" 'gradientboostingsurvivalanalysis__min_samples_leaf': np.arange(2,6,2),\n",
"}\n",
"\n",
"gbc_gs = GridSearchCV(gbc, param_grid=param_grid, cv=3, n_jobs=-1)\n",
"gbc_gs.fit(X_trn, y_trn)\n",
"gbc_best = make_pipeline(preprocessor, GradientBoostingSurvivalAnalysis(learning_rate= gbc_gs.best_params_['gradientboostingsurvivalanalysis__learning_rate'],\n",
" max_depth= gbc_gs.best_params_['gradientboostingsurvivalanalysis__max_depth'],\n",
" min_samples_leaf= gbc_gs.best_params_['gradientboostingsurvivalanalysis__min_samples_leaf']))\n",
"gbc_best.fit(X_trn, y_trn)\n",
"my_score(gbc_best)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AuKdh16qxvqo",
"outputId": "7488dccc-8bf0-4de2-aada-875d130be4cf"
},
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/joblib/externals/loky/process_executor.py:703: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n",
" \"timeout or by a memory leak.\", UserWarning\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Concordance index: 0.6283123307030974\n",
"Integrated brier score: 0.20318280691517193\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from eli5.sklearn import PermutationImportance"
],
"metadata": {
"id": "EcS_gT830kw5"
},
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"source": [
"perm = PermutationImportance(\n",
" rsf.steps[-1][1], n_iter=100, random_state=42).fit(preprocessor.fit_transform(X_test),y_test)"
],
"metadata": {
"id": "NYZzSt_V0lN7"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import eli5\n",
"eli5.show_weights(perm, feature_names = X.columns.tolist())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
},
"id": "F9p0vI4t0nvn",
"outputId": "0387a255-9f51-48cb-fab3-98c95a72acb1"
},
"execution_count": 22,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
" <style>\n",
" table.eli5-weights tr:hover {\n",
" filter: brightness(85%);\n",
" }\n",
"</style>\n",
"\n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
" <table class=\"eli5-weights eli5-feature-importances\" style=\"border-collapse: collapse; border: none; margin-top: 0em; table-layout: auto;\">\n",
" <thead>\n",
" <tr style=\"border: none;\">\n",
" <th style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">Weight</th>\n",
" <th style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">Feature</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 80.00%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0886\n",
" \n",
" &plusmn; 0.0426\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" progrec\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 92.58%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0215\n",
" \n",
" &plusmn; 0.0417\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" tgrade\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 95.39%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0109\n",
" \n",
" &plusmn; 0.0197\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" menostat\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 95.69%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0099\n",
" \n",
" &plusmn; 0.0111\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" age\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 96.31%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0079\n",
" \n",
" &plusmn; 0.0112\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" horTh\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(120, 100.00%, 96.36%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" 0.0078\n",
" \n",
" &plusmn; 0.0197\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" pnodes\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(0, 100.00%, 97.96%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" -0.0034\n",
" \n",
" &plusmn; 0.0065\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" estrec\n",
" </td>\n",
" </tr>\n",
" \n",
" <tr style=\"background-color: hsl(0, 100.00%, 95.52%); border: none;\">\n",
" <td style=\"padding: 0 1em 0 0.5em; text-align: right; border: none;\">\n",
" -0.0105\n",
" \n",
" &plusmn; 0.0186\n",
" \n",
" </td>\n",
" <td style=\"padding: 0 0.5em 0 0.5em; text-align: left; border: none;\">\n",
" tsize\n",
" </td>\n",
" </tr>\n",
" \n",
" \n",
" </tbody>\n",
"</table>\n",
" \n",
"\n",
" \n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
" \n",
"\n",
"\n",
"\n"
]
},
"metadata": {},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"source": [
"data = perm.results_\n",
"data = pd.DataFrame(data, columns=X_trn.columns)\n",
"meds = data.median()\n",
"meds = meds.sort_values(ascending=False)\n",
"data = data[meds.index]\n",
"fig, ax = plt.subplots()\n",
"data.boxplot(ax=ax)\n",
"ax.set_title('Permutation Importances')\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"id": "dYKh6YRq1afe",
"outputId": "f3d77340-e10c-4045-febc-c600b617fd7c"
},
"execution_count": 23,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.feature_selection import SelectFromModel\n",
"\n",
"sel = SelectFromModel(perm, max_features=6, prefit=True)\n",
"X_trn_trans = sel.transform(X_trn)\n",
"X_test_trans = sel.transform(X_test)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-nyBcGgW1d7I",
"outputId": "38fee0f6-3142-4603-f9eb-ca61b85517d1"
},
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/sklearn/base.py:444: UserWarning: X has feature names, but SelectFromModel was fitted without feature names\n",
" f\"X has feature names, but {self.__class__.__name__} was fitted without\"\n",
"/usr/local/lib/python3.7/dist-packages/sklearn/base.py:444: UserWarning: X has feature names, but SelectFromModel was fitted without feature names\n",
" f\"X has feature names, but {self.__class__.__name__} was fitted without\"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"X.columns[sel.get_support()]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0zejgr1b2uOo",
"outputId": "b148c566-87ca-424b-8895-56eae2cb7e22"
},
"execution_count": 25,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Index(['progrec', 'tgrade'], dtype='object')"
]
},
"metadata": {},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"source": [
"scaling_cols = [c for c in X.columns[sel.get_support()] if X[c].dtype.kind in ['i', 'f']]\n",
"cat_cols = [c for c in X.columns[sel.get_support()] if X[c].dtype.kind not in ['i', 'f']]"
],
"metadata": {
"id": "v0X44JiO35mj"
},
"execution_count": 26,
"outputs": []
},
{
"cell_type": "code",
"source": [
"preprocessor = ColumnTransformer(\n",
" [('cat-preprocessor', OrdinalEncoder(), cat_cols),\n",
" ('standard-scaler', StandardScaler(), scaling_cols)],\n",
" remainder='passthrough', sparse_threshold=0)"
],
"metadata": {
"id": "USgbwaUT4eRj"
},
"execution_count": 27,
"outputs": []
},
{
"cell_type": "code",
"source": [
"rsf = make_pipeline(preprocessor, RandomSurvivalForest(random_state=42))\n",
"rsf.fit(pd.DataFrame(X_trn_trans, columns=X.columns[sel.get_support()]), y_trn)\n",
"my_score(rsf, pd.DataFrame(X_trn_trans, columns=X.columns[sel.get_support()]), y_trn, pd.DataFrame(X_test_trans, columns=X.columns[sel.get_support()]), y_test)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5iF9aF6J1lYD",
"outputId": "2da908fa-b772-486e-ce68-6c8b764338ef"
},
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Concordance index: 0.557237074549523\n",
"Integrated brier score: 0.2407429762535294\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import shap"
],
"metadata": {
"id": "kzfBU7ei6abi"
},
"execution_count": 29,
"outputs": []
},
{
"cell_type": "code",
"source": [
"cat_cols = [c for c in X.columns if X[c].dtype.kind not in ['i', 'f'] and c != \"tgrade\"]"
],
"metadata": {
"id": "TnwCnwm1XBsR"
},
"execution_count": 30,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"preprocessor = ColumnTransformer(\n",
" [('cat-preprocessor', OneHotEncoder(drop='first'), cat_cols)],\n",
" remainder='passthrough', sparse_threshold=0)"
],
"metadata": {
"id": "AZF1KRX8XKEf"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"grade_str = X[\"tgrade\"].astype(object).values[:, None]\n",
"grade_num = OrdinalEncoder(categories=[[\"I\", \"II\", \"III\"]]).fit_transform(grade_str)\n",
"X_no_grade = X.drop(columns=\"tgrade\")\n",
"Xt = pd.DataFrame(preprocessor.fit_transform(X_no_grade), columns=preprocessor.get_feature_names_out())\n",
"Xt[\"tgrade\"] = grade_num"
],
"metadata": {
"id": "yfgfqi5CUca_"
},
"execution_count": 32,
"outputs": []
},
{
"cell_type": "code",
"source": [
"X_trn1, X_test1, y_trn1, y_test1 = train_test_split(Xt, y, random_state=42)"
],
"metadata": {
"id": "SOGfyRo9PzbU"
},
"execution_count": 33,
"outputs": []
},
{
"cell_type": "code",
"source": [
"X100 = shap.utils.sample(X_trn1, 100)"
],
"metadata": {
"id": "BwuErm8yQEhC"
},
"execution_count": 34,
"outputs": []
},
{
"cell_type": "code",
"source": [
"rsf = RandomSurvivalForest(random_state=42)\n",
"rsf.fit(X_trn1, y_trn1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "C2OumfIFQMYU",
"outputId": "14b94a03-3b66-4ea2-9dd8-c0c289a14e98"
},
"execution_count": 35,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomSurvivalForest(random_state=42)"
]
},
"metadata": {},
"execution_count": 35
}
]
},
{
"cell_type": "code",
"source": [
"# compute the SHAP values\n",
"explainer = shap.Explainer(rsf.predict, X100)"
],
"metadata": {
"id": "rtZhm9uEQRyV"
},
"execution_count": 36,
"outputs": []
},
{
"cell_type": "code",
"source": [
"shap_values = explainer(X_test1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "blTnZQxZNm1-",
"outputId": "43afc93f-eecb-49f9-8f0f-2c73f06ed3bd"
},
"execution_count": 37,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Exact explainer: 173it [05:47, 2.06s/it]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"shap.plots.waterfall(shap_values[0])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 389
},
"id": "IPiTWws_Li-z",
"outputId": "d3ebe870-58a7-4b6d-d6e7-877e0da29293"
},
"execution_count": 38,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 576x396 with 3 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"shap.plots.beeswarm(shap_values)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 324
},
"id": "NhlkwF3yLxMW",
"outputId": "ba109969-ed30-4d38-b0f9-d3ac10415ea3"
},
"execution_count": 39,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 576x338.4 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment