Last active
September 22, 2021 08:45
-
-
Save alonsosilvaallende/9b2569d1ba05ba3ef0c3edde2982247b to your computer and use it in GitHub Desktop.
Camila/Simulations/Untitled-Copy1.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:32.702545Z", | |
"end_time": "2021-09-22T08:41:34.187759Z" | |
}, | |
"trusted": true | |
}, | |
"id": "6e0bc87c", | |
"cell_type": "code", | |
"source": "%load_ext autoreload\n%autoreload 2\n%matplotlib inline", | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:34.190431Z", | |
"end_time": "2021-09-22T08:41:34.888932Z" | |
}, | |
"trusted": true | |
}, | |
"id": "e18e3b09", | |
"cell_type": "code", | |
"source": "import numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:34.891548Z", | |
"end_time": "2021-09-22T08:41:35.902744Z" | |
}, | |
"trusted": true | |
}, | |
"id": "2f3e8872", | |
"cell_type": "code", | |
"source": "from sksurv.datasets import load_gbsg2\n\nX, y = load_gbsg2()", | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:35.908803Z", | |
"end_time": "2021-09-22T08:41:35.975551Z" | |
}, | |
"trusted": true | |
}, | |
"id": "a81b4a6e", | |
"cell_type": "code", | |
"source": "scaling_cols = [c for c in X.columns if X[c].dtype.kind in ['i', 'f']]\ncat_cols = [c for c in X.columns if X[c].dtype.kind not in ['i', 'f']]", | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:35.978038Z", | |
"end_time": "2021-09-22T08:41:36.067754Z" | |
}, | |
"trusted": true | |
}, | |
"id": "5fe58065", | |
"cell_type": "code", | |
"source": "from sklearn.compose import ColumnTransformer\nfrom sklearn.preprocessing import OrdinalEncoder\nfrom sklearn.preprocessing import StandardScaler\n\npreprocessor = ColumnTransformer(\n [('cat-preprocessor', OrdinalEncoder(), cat_cols),\n ('standard-scaler', StandardScaler(), scaling_cols)],\n remainder='passthrough', sparse_threshold=0)", | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:36.070960Z", | |
"end_time": "2021-09-22T08:41:36.137212Z" | |
}, | |
"trusted": true | |
}, | |
"id": "edf80455", | |
"cell_type": "code", | |
"source": "param_grid_cph = {\n 'alpha': (0.01, 0.1, 0.5),\n}\n\nparam_grid_rsf = {\n 'max_features': (\"sqrt\", 0.5, 1),\n 'min_samples_leaf': (1,3,5),\n 'n_estimators': (50, 100, 200),\n 'max_depth': (3,5,7,10),\n}\n\nparam_grid_gbs = {\n 'learning_rate': (0.05, 0.1, 0.15),\n 'max_features': (\"sqrt\", 0.5, 1),\n 'min_samples_leaf': (1,3,5),\n 'n_estimators': (50, 100, 200),\n 'subsample': (0.7,0.9,1),\n 'max_depth': (3,5,7,10),\n}\n\nparam_distributions = [param_grid_cph, param_grid_rsf, param_grid_gbs]", | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:36.139749Z", | |
"end_time": "2021-09-22T08:41:36.221693Z" | |
}, | |
"trusted": true | |
}, | |
"id": "8fd06a66", | |
"cell_type": "code", | |
"source": "seeds = np.random.RandomState(0).permutation(1000)[:3]", | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:36.225358Z", | |
"end_time": "2021-09-22T08:41:36.310144Z" | |
}, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def ipec(model, X_val, y_trn, y_val):\n rsf_survfunc_test = model.predict_survival_function(X_val)\n y_trn_ip = np.array([[i,j] for i, j in zip(y_trn['time'], y_trn['cens'])])\n y_test_ip = np.array([[i,j] for i, j in zip(y_val['time'], y_val['cens'])])\n times = np.concatenate((np.array([0]), rsf_survfunc_test[0].x))\n rsf_survfunc_y = \\\n [np.concatenate((np.array([1]), rsf_survfunc_test[i].y)) for i in range(len(rsf_survfunc_test))]\n tau = [times[-1]] \n ipec_model_trn = compute_IPEC_scores(y_trn_ip, y_test_ip, times, rsf_survfunc_y, tau)[tau[0]]/tau[0]\n ipec_model_val = compute_IPEC_scores(y_trn_ip, y_test_ip, times, rsf_survfunc_y, tau)[tau[0]]/tau[0]\n return ipec_model_trn, ipec_model_val", | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:36.312518Z", | |
"end_time": "2021-09-22T08:41:36.371074Z" | |
}, | |
"trusted": true | |
}, | |
"id": "e7a6caec", | |
"cell_type": "code", | |
"source": "from itertools import product\n\ndef my_product(inp):\n return (dict(zip(inp.keys(), values)) for values in product(*inp.values()))", | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:36.377285Z", | |
"end_time": "2021-09-22T08:41:36.439441Z" | |
}, | |
"trusted": true | |
}, | |
"id": "b5f3a4ec", | |
"cell_type": "code", | |
"source": "def my_custom_function(model, candidate_params, best_value, best_params):\n rfr = model(**candidate_params)\n score_rfr = []\n for train_idx, test_idx in kf.split(X_trn):\n X_train, X_test = X_trn.iloc[train_idx], X_trn.iloc[test_idx]\n y_train, y_test = y_trn[train_idx], y_trn[test_idx]\n rfr.fit(X_train, y_train)\n ipec_model_trn, ipec_model_val = ipec(rfr, X_test, y_train, y_test)\n score_rfr.append(ipec_model_val)\n if np.mean(score_rfr) < best_value:\n best_value = np.mean(score_rfr)\n best_params = candidate_params\n return best_value, best_params", | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:41:36.442099Z", | |
"end_time": "2021-09-22T08:42:54.884963Z" | |
}, | |
"scrolled": true, | |
"trusted": true | |
}, | |
"id": "ecc0e02d", | |
"cell_type": "code", | |
"source": "from sklearn.model_selection import train_test_split\nfrom sklearn.model_selection import RandomizedSearchCV\nfrom sklearn.model_selection import KFold\nfrom sksurv.linear_model import CoxPHSurvivalAnalysis\nfrom sksurv.ensemble import RandomSurvivalForest\nfrom sksurv.ensemble import GradientBoostingSurvivalAnalysis\nfrom lifelines.utils import concordance_index\nfrom util import compute_IPEC_scores\n\ndf_ci = pd.DataFrame(columns=['Model','Train/Test','Default/Best','Score'], index=None)\ndf_ipec = pd.DataFrame(columns=['Model','Train/Test','Default/Best','Score'], index=None)\nfor seed in seeds:\n X_trn, X_val, y_trn, y_val = train_test_split(X, y, random_state=seed)\n X_trn = pd.DataFrame(preprocessor.fit_transform(X_trn))\n X_val = pd.DataFrame(preprocessor.transform(X_val))\n for i, modelo in enumerate([CoxPHSurvivalAnalysis, RandomSurvivalForest, GradientBoostingSurvivalAnalysis]):\n model = modelo(alpha=.1) if i == 0 else modelo(random_state=42)\n model.fit(X_trn, y_trn)\n # Concordance index\n ci_model_trn = concordance_index(y_trn['time'], -model.predict(X_trn), y_trn['cens'])\n ci_model_val = concordance_index(y_val['time'], -model.predict(X_val), y_val['cens'])\n df_ci = df_ci.append({'Model':f'{model}','Train/Test':'Train',\n 'Default/Best':'Default','Score':ci_model_trn}, ignore_index=True)\n df_ci = df_ci.append({'Model':f'{model}','Train/Test':'Test',\n 'Default/Best':'Default','Score':ci_model_val}, ignore_index=True)\n # IPEC score\n ipec_model_trn, ipec_model_val = ipec(model, X_val, y_trn, y_val)\n df_ipec = df_ipec.append({'Model':f'{model}','Train/Test':'Train',\n 'Default/Best':'Default','Score':ipec_model_trn}, ignore_index=True)\n df_ipec = df_ipec.append({'Model':f'{model}','Train/Test':'Test',\n 'Default/Best':'Default','Score':ipec_model_val}, ignore_index=True)\n # Hyperparameter optimization concordance index\n model = modelo(alpha=.1) if i == 0 else modelo(random_state=42)\n rs_model = RandomizedSearchCV(model, param_distributions=param_distributions[i], n_jobs=-1, cv=2, n_iter=3)\n rs_model.fit(X_trn, y_trn)\n ci_rs_model_trn = concordance_index(y_trn['time'], -rs_model.predict(X_trn), y_trn['cens'])\n ci_rs_model_val = concordance_index(y_val['time'], -rs_model.predict(X_val), y_val['cens'])\n df_ci = df_ci.append({'Model':f'{model}','Train/Test':'Train',\n 'Default/Best':'Best','Score':ci_rs_model_trn}, ignore_index=True)\n df_ci = df_ci.append({'Model':f'{model}','Train/Test':'Test',\n 'Default/Best':'Best','Score':ci_rs_model_val}, ignore_index=True)\n # Hyperparameter optimization IPEC score\n kf = KFold(n_splits=3)\n kf.get_n_splits(X_trn)\n best_value = np.inf\n best_params = {}\n k = np.prod([len(param_distributions[i][key]) for key in param_distributions[i].keys()])\n if k <= 50:\n # GridSearchCV\n for params in my_product(param_distributions[i]):\n best_value, best_params = my_custom_function(modelo, params, best_value, best_params)\n else:\n # RandomSearchCV\n for _seed in range(5):\n params = {f'{key}':param_distributions[i][key][np.random.RandomState(_seed).randint(len(param_distributions[i][key]))] for key in param_distributions[i].keys()}\n best_value, best_params = my_custom_function(modelo,params, best_value, best_params)\n rsf = modelo(**best_params)\n rsf.fit(X_trn, y_trn)\n ipec_model_trn, ipec_model_val = ipec(rsf, X_val, y_trn, y_val)\n df_ipec = df_ipec.append({'Model':f'{model}','Train/Test':'Train',\n 'Default/Best':'Best','Score':ipec_model_trn}, ignore_index=True)\n df_ipec = df_ipec.append({'Model':f'{model}','Train/Test':'Test',\n 'Default/Best':'Best','Score':ipec_model_val}, ignore_index=True)", | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:42:54.887711Z", | |
"end_time": "2021-09-22T08:42:55.152963Z" | |
}, | |
"trusted": true | |
}, | |
"id": "cbd7af2c", | |
"cell_type": "code", | |
"source": "df_ci[df_ci['Train/Test']=='Test'].groupby(['Model', 'Default/Best']).mean()", | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 12, | |
"data": { | |
"text/plain": " Score\nModel Default/Best \nCoxPHSurvivalAnalysis(alpha=0.1) Best 0.657254\n Default 0.657220\nGradientBoostingSurvivalAnalysis(random_state=42) Best 0.672136\n Default 0.684195\nRandomSurvivalForest(random_state=42) Best 0.678548\n Default 0.683404", | |
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th></th>\n <th>Score</th>\n </tr>\n <tr>\n <th>Model</th>\n <th>Default/Best</th>\n <th></th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th rowspan=\"2\" valign=\"top\">CoxPHSurvivalAnalysis(alpha=0.1)</th>\n <th>Best</th>\n <td>0.657254</td>\n </tr>\n <tr>\n <th>Default</th>\n <td>0.657220</td>\n </tr>\n <tr>\n <th rowspan=\"2\" valign=\"top\">GradientBoostingSurvivalAnalysis(random_state=42)</th>\n <th>Best</th>\n <td>0.672136</td>\n </tr>\n <tr>\n <th>Default</th>\n <td>0.684195</td>\n </tr>\n <tr>\n <th rowspan=\"2\" valign=\"top\">RandomSurvivalForest(random_state=42)</th>\n <th>Best</th>\n <td>0.678548</td>\n </tr>\n <tr>\n <th>Default</th>\n <td>0.683404</td>\n </tr>\n </tbody>\n</table>\n</div>" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"start_time": "2021-09-22T08:42:55.156686Z", | |
"end_time": "2021-09-22T08:42:55.288332Z" | |
}, | |
"trusted": true | |
}, | |
"id": "526e63bb", | |
"cell_type": "code", | |
"source": "df_ipec[df_ipec['Train/Test']=='Test'].groupby(['Model', 'Default/Best']).mean()", | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 13, | |
"data": { | |
"text/plain": " Score\nModel Default/Best \nCoxPHSurvivalAnalysis(alpha=0.1) Best 0.472445\n Default 0.472865\nGradientBoostingSurvivalAnalysis(random_state=42) Best 0.355065\n Default 0.423120\nRandomSurvivalForest(random_state=42) Best 0.238962\n Default 0.253419", | |
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th></th>\n <th>Score</th>\n </tr>\n <tr>\n <th>Model</th>\n <th>Default/Best</th>\n <th></th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th rowspan=\"2\" valign=\"top\">CoxPHSurvivalAnalysis(alpha=0.1)</th>\n <th>Best</th>\n <td>0.472445</td>\n </tr>\n <tr>\n <th>Default</th>\n <td>0.472865</td>\n </tr>\n <tr>\n <th rowspan=\"2\" valign=\"top\">GradientBoostingSurvivalAnalysis(random_state=42)</th>\n <th>Best</th>\n <td>0.355065</td>\n </tr>\n <tr>\n <th>Default</th>\n <td>0.423120</td>\n </tr>\n <tr>\n <th rowspan=\"2\" valign=\"top\">RandomSurvivalForest(random_state=42)</th>\n <th>Best</th>\n <td>0.238962</td>\n </tr>\n <tr>\n <th>Default</th>\n <td>0.253419</td>\n </tr>\n </tbody>\n</table>\n</div>" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"id": "54dcee40", | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.9.7", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"gist": { | |
"id": "9b2569d1ba05ba3ef0c3edde2982247b", | |
"data": { | |
"description": "Camila/Simulations/Untitled-Copy1.ipynb", | |
"public": true | |
} | |
}, | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/9b2569d1ba05ba3ef0c3edde2982247b" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment