Created
August 27, 2021 08:29
-
-
Save alonsosilvaallende/6633f66be36d99b261d7a00f36a896f3 to your computer and use it in GitHub Desktop.
RSF-IPEC.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "RSF-IPEC.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"mount_file_id": "1jGHxL2vL5kRTMX0kMVV9jZWDT1eXkzpH", | |
"authorship_tag": "ABX9TyNw4c8ph1CcUJZyR4i4R6vR", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/6633f66be36d99b261d7a00f36a896f3/rsf-ipec.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LnChyaqIicM1" | |
}, | |
"source": [ | |
"!cp drive/MyDrive/Time-to-Event-Analysis/util.py ." | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TZJRW0UOjBAr" | |
}, | |
"source": [ | |
"!cp drive/MyDrive/Time-to-Event-Analysis/npsurvival_models.py ." | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "i9ZYUFdZs8Mh", | |
"outputId": "9afca690-dff1-479d-e633-01c041d57cd8" | |
}, | |
"source": [ | |
"!pip install --quiet --upgrade pip\n", | |
"!pip uninstall --yes --quiet osqp\n", | |
"!pip install --quiet -U scikit-survival" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n", | |
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n", | |
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", | |
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25herror\n", | |
"\u001b[33mWARNING: Discarding https://files.pythonhosted.org/packages/91/f0/047ce90bb831ab34ca287d1d23f0c61b6546cd89494566898c0e17516990/scikit-survival-0.15.0.post0.tar.gz#sha256=572c3ac6818a9d0944fc4b8176eb948051654de857e28419ecc5060bcc6fbf37 (from https://pypi.org/simple/scikit-survival/) (requires-python:>=3.7). Command errored out with exit status 1: /usr/bin/python3 /usr/local/lib/python3.7/dist-packages/pip/_vendor/pep517/in_process/_in_process.py get_requires_for_build_wheel /tmp/tmpmgszihjl Check the logs for full command output.\u001b[0m\n", | |
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "92CA1kgflbYo", | |
"outputId": "57868e52-bfca-4bbd-eb62-f9f34ad08366" | |
}, | |
"source": [ | |
"!pip install --quiet lifelines" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3EO7M5I2kGPq" | |
}, | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import matplotlib.pyplot as plt" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "t1dBhvvDkQ9e" | |
}, | |
"source": [ | |
"from sksurv.datasets import load_gbsg2\n", | |
"\n", | |
"X, y = load_gbsg2()" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wxS5EV7jkX33" | |
}, | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"\n", | |
"X_trn, X_test, y_trn, y_test = train_test_split(X, y, random_state=42)" | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kLOBNaj-kfDx" | |
}, | |
"source": [ | |
"scaling_cols = [c for c in X.columns if X[c].dtype.kind in ['i', 'f']]\n", | |
"cat_cols = [c for c in X.columns if X[c].dtype.kind not in ['i', 'f']]" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3_K8sDKXkg8C" | |
}, | |
"source": [ | |
"from sklearn.compose import ColumnTransformer\n", | |
"from sklearn.preprocessing import OrdinalEncoder\n", | |
"from sklearn.preprocessing import StandardScaler\n", | |
"\n", | |
"preprocessor = ColumnTransformer(\n", | |
" [('cat-preprocessor', OrdinalEncoder(), cat_cols),\n", | |
" ('standard-scaler', StandardScaler(), scaling_cols)],\n", | |
" remainder='passthrough', sparse_threshold=0)" | |
], | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "-fsY8FH9tcRh" | |
}, | |
"source": [ | |
"# Concordance index computation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "xnpbv82skp69", | |
"outputId": "6a5206b9-0138-4254-a80c-00f914378e57" | |
}, | |
"source": [ | |
"from sklearn.pipeline import make_pipeline\n", | |
"from sksurv.ensemble import RandomSurvivalForest\n", | |
"from sksurv.metrics import concordance_index_censored\n", | |
"\n", | |
"rsf = make_pipeline(preprocessor, RandomSurvivalForest(random_state=42))\n", | |
"rsf.fit(X_trn, y_trn)\n", | |
"print(f\"Concordance index: {concordance_index_censored(y_test['cens'], y_test['time'], rsf.predict(X_test))[0]:.3f}\")" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Concordance index: 0.658\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "X8YMW3E9tYZB" | |
}, | |
"source": [ | |
"# IPEC score computation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "z8c7LOuXkxDy" | |
}, | |
"source": [ | |
"from util import compute_IPEC_scores" | |
], | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "b-PZWT1Fqa0Z" | |
}, | |
"source": [ | |
"rsf_survfunc_test = rsf.predict_survival_function(X_test, return_array=False)" | |
], | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "lxxyRFumlX5E", | |
"outputId": "1835263b-b387-4ddc-8516-7f2937e75518" | |
}, | |
"source": [ | |
"y_trn_ip = np.array([[i,j] for i, j in zip(y_trn['time'], y_trn['cens'])])\n", | |
"y_test_ip = np.array([[i,j] for i, j in zip(y_test['time'], y_test['cens'])])\n", | |
"times = np.concatenate((np.array([0]), rsf_survfunc_test[0].x))\n", | |
"rsf_survfunc_y = \\\n", | |
"[np.concatenate((np.array([1]), rsf_survfunc_test[i].y)) for i in range(len(rsf_survfunc_test))]\n", | |
"tau = [times[-1]]\n", | |
"\n", | |
"print(f\"IPEC score: {compute_IPEC_scores(y_trn_ip, y_test_ip, times, rsf_survfunc_y, tau)[tau[0]]/tau[0]:.4f}\")" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"IPEC score: 0.3027\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment