Created
February 16, 2022 14:12
-
-
Save alonsosilvaallende/89954323b9acb178f2fed7cbbe56391b to your computer and use it in GitHub Desktop.
Untitled25.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Untitled25.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyNlFIjwbBLXg6WyofJxLUVg", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/89954323b9acb178f2fed7cbbe56391b/untitled25.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "ALAV8_c8tRjL" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install -q scikit-survival" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import sksurv\n", | |
"sksurv.__version__" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
}, | |
"id": "eK0n3so81idr", | |
"outputId": "2de00a47-f4c0-4dda-edb0-ba284f46baf0" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
}, | |
"text/plain": [ | |
"'0.17.0'" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 2 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import statsmodels.api as sm\n", | |
"pharmacoSmoking = sm.datasets.get_rdataset(\"pharmacoSmoking\", \"asaur\")\n", | |
"data = pharmacoSmoking.data\n", | |
"data = data.drop(columns=[\"id\",\"ageGroup2\",\"ageGroup4\"]) # Drop redundant information and ids\n", | |
"data.head(3)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 178 | |
}, | |
"id": "tToD8R-dtden", | |
"outputId": "e90ddeea-a1be-4357-d536-95dafd136a3c" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", | |
" import pandas.util.testing as tm\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div id=\"df-b99d3495-b817-4d4a-9a7c-e84c4a164452\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>ttr</th>\n", | |
" <th>relapse</th>\n", | |
" <th>grp</th>\n", | |
" <th>age</th>\n", | |
" <th>gender</th>\n", | |
" <th>race</th>\n", | |
" <th>employment</th>\n", | |
" <th>yearsSmoking</th>\n", | |
" <th>levelSmoking</th>\n", | |
" <th>priorAttempts</th>\n", | |
" <th>longestNoSmoke</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>182</td>\n", | |
" <td>0</td>\n", | |
" <td>patchOnly</td>\n", | |
" <td>36</td>\n", | |
" <td>Male</td>\n", | |
" <td>white</td>\n", | |
" <td>ft</td>\n", | |
" <td>26</td>\n", | |
" <td>heavy</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>14</td>\n", | |
" <td>1</td>\n", | |
" <td>patchOnly</td>\n", | |
" <td>41</td>\n", | |
" <td>Male</td>\n", | |
" <td>white</td>\n", | |
" <td>other</td>\n", | |
" <td>27</td>\n", | |
" <td>heavy</td>\n", | |
" <td>3</td>\n", | |
" <td>90</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>5</td>\n", | |
" <td>1</td>\n", | |
" <td>combination</td>\n", | |
" <td>25</td>\n", | |
" <td>Female</td>\n", | |
" <td>white</td>\n", | |
" <td>other</td>\n", | |
" <td>12</td>\n", | |
" <td>heavy</td>\n", | |
" <td>3</td>\n", | |
" <td>21</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b99d3495-b817-4d4a-9a7c-e84c4a164452')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-b99d3495-b817-4d4a-9a7c-e84c4a164452 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-b99d3495-b817-4d4a-9a7c-e84c4a164452');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
" ttr relapse grp ... levelSmoking priorAttempts longestNoSmoke\n", | |
"0 182 0 patchOnly ... heavy 0 0\n", | |
"1 14 1 patchOnly ... heavy 3 90\n", | |
"2 5 1 combination ... heavy 3 21\n", | |
"\n", | |
"[3 rows x 11 columns]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sksurv.datasets import get_x_y\n", | |
"\n", | |
"X, y = get_x_y(data, attr_labels=[\"relapse\", \"ttr\"], pos_label=True)" | |
], | |
"metadata": { | |
"id": "6Fj-HZgftnG0" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for c in X.columns:\n", | |
" if X[c].dtype.kind not in ['i', 'f']:\n", | |
" X[c] = X[c].astype(\"category\")" | |
], | |
"metadata": { | |
"id": "evjHOcuutqh0" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"from sklearn.preprocessing import StandardScaler\n", | |
"from sksurv.preprocessing import OneHotEncoder\n", | |
"from sksurv.ensemble import RandomSurvivalForest\n", | |
"from sksurv.metrics import integrated_brier_score\n", | |
"\n", | |
"def get_ibs(seed, times):\n", | |
" X_trn, X_val, y_trn, y_val = train_test_split(X, y, random_state=seed)\n", | |
" print(f\"Minimum training time: {y_trn['ttr'].min()}\")\n", | |
" print(f\"Maximum training time: {y_trn['ttr'].max()}\")\n", | |
" print(f\"Minimum validation time: {y_val['ttr'].min()}\")\n", | |
" print(f\"Maximum validation time: {y_val['ttr'].max()}\")\n", | |
" enc = OneHotEncoder()\n", | |
" scaler = StandardScaler()\n", | |
" X_trn = enc.fit_transform(X_trn)\n", | |
" X_trn = pd.DataFrame(scaler.fit_transform(X_trn), columns=X_trn.columns)\n", | |
" X_val = enc.transform(X_val)\n", | |
" X_val = pd.DataFrame(scaler.transform(X_val), columns=X_val.columns)\n", | |
" rsf = RandomSurvivalForest(random_state=42)\n", | |
" rsf.fit(X_trn, y_trn)\n", | |
" survs = rsf.predict_survival_function(X_val)\n", | |
" preds = np.asarray([[fn(t) for t in times] for fn in survs])\n", | |
" return integrated_brier_score(y_trn, y_val, preds, times)" | |
], | |
"metadata": { | |
"id": "ZVoepAnVtutM" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"get_ibs(0, np.arange(0,170))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Hx6LHbg_wmSd", | |
"outputId": "c3bab9ab-13e4-4bab-cc7a-c49419335d6a" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Minimum training time: 0.0\n", | |
"Maximum training time: 182.0\n", | |
"Minimum validation time: 0.0\n", | |
"Maximum validation time: 182.0\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.21497683508486282" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"get_ibs(0,np.arange(0,180))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 380 | |
}, | |
"id": "FbaIV83zuhLT", | |
"outputId": "a38d5ac7-5421-49ff-aac8-8d84a831a048" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Minimum training time: 0.0\n", | |
"Maximum training time: 182.0\n", | |
"Minimum validation time: 0.0\n", | |
"Maximum validation time: 182.0\n" | |
] | |
}, | |
{ | |
"output_type": "error", | |
"ename": "ValueError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-8-8702b8f1e249>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_ibs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m180\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-6-70c030061f69>\u001b[0m in \u001b[0;36mget_ibs\u001b[0;34m(seed, times)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mrsf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_trn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_trn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0msurvs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrsf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_survival_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msurvs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mintegrated_brier_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_trn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-6-70c030061f69>\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mrsf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_trn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_trn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0msurvs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrsf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_survival_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msurvs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mintegrated_brier_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_trn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-6-70c030061f69>\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mrsf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_trn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_trn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0msurvs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrsf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_survival_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msurvs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mintegrated_brier_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_trn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/sksurv/functions.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 66\u001b[0m raise ValueError(\n\u001b[0;32m---> 67\u001b[0;31m \"x must be within [%f; %f]\" % (self.x[0], self.x[-1]))\n\u001b[0m\u001b[1;32m 68\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msearchsorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mside\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'left'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mnot_exact\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mValueError\u001b[0m: x must be within [0.000000; 170.000000]" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"get_ibs(1, np.arange(0,170))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 380 | |
}, | |
"id": "pOA1saaA0cic", | |
"outputId": "ea029fe2-2104-478c-cefe-ff2aeb89cb40" | |
}, | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Minimum training time: 0.0\n", | |
"Maximum training time: 182.0\n", | |
"Minimum validation time: 4.0\n", | |
"Maximum validation time: 182.0\n" | |
] | |
}, | |
{ | |
"output_type": "error", | |
"ename": "ValueError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-9-e5a48495f626>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_ibs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m170\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-6-70c030061f69>\u001b[0m in \u001b[0;36mget_ibs\u001b[0;34m(seed, times)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0msurvs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrsf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_survival_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_val\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mpreds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msurvs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mintegrated_brier_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_trn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/sksurv/metrics.py\u001b[0m in \u001b[0;36mintegrated_brier_score\u001b[0;34m(survival_train, survival_test, estimate, times)\u001b[0m\n\u001b[1;32m 734\u001b[0m \"\"\"\n\u001b[1;32m 735\u001b[0m \u001b[0;31m# Computing the brier scores\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 736\u001b[0;31m \u001b[0mtimes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbrier_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbrier_score\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msurvival_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msurvival_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 737\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 738\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/sksurv/metrics.py\u001b[0m in \u001b[0;36mbrier_score\u001b[0;34m(survival_train, survival_test, estimate, times)\u001b[0m\n\u001b[1;32m 613\u001b[0m \"\"\"\n\u001b[1;32m 614\u001b[0m \u001b[0mtest_event\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_y_survival\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msurvival_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 615\u001b[0;31m \u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_estimate_2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_time\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 616\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mestimate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 617\u001b[0m \u001b[0mestimate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mestimate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/sksurv/metrics.py\u001b[0m in \u001b[0;36m_check_estimate_2d\u001b[0;34m(estimate, test_time, time_points)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_check_estimate_2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_time\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_points\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0mestimate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mensure_2d\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallow_nd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0mtime_points\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_times\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_time\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_points\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0mcheck_consistent_length\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_time\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mestimate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/sksurv/metrics.py\u001b[0m in \u001b[0;36m_check_times\u001b[0;34m(test_time, times)\u001b[0m\n\u001b[1;32m 70\u001b[0m raise ValueError(\n\u001b[1;32m 71\u001b[0m 'all times must be within follow-up time of test data: [{}; {}['.format(\n\u001b[0;32m---> 72\u001b[0;31m test_time.min(), test_time.max()))\n\u001b[0m\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mValueError\u001b[0m: all times must be within follow-up time of test data: [4.0; 182.0[" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment