Skip to content

Instantly share code, notes, and snippets.

@alonsosilvaallende
Created September 14, 2021 12:42
Show Gist options
  • Save alonsosilvaallende/cd9ec5ae0f6e8b83991606bd8e0f5848 to your computer and use it in GitHub Desktop.
Save alonsosilvaallende/cd9ec5ae0f6e8b83991606bd8e0f5848 to your computer and use it in GitHub Desktop.
Split_vs_Bootstrap.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Split_vs_Bootstrap.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMVtV3J6yM7jfaezPh3JqQX",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/alonsosilvaallende/cd9ec5ae0f6e8b83991606bd8e0f5848/split_vs_bootstrap.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "cEwJ4mw3MRq8"
},
"source": [
"!pip install -q scikit-survival"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "q8tsUA-yMY8i"
},
"source": [
"from sksurv.datasets import load_gbsg2\n",
"\n",
"X, y = load_gbsg2()"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2vQfLWpMMcSp"
},
"source": [
"scaling_cols = [c for c in X.columns if X[c].dtype.kind in ['i', 'f']]\n",
"cat_cols = [c for c in X.columns if X[c].dtype.kind not in [\"i\", \"f\"]]"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "yzZePe3PM1NT"
},
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import OrdinalEncoder\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"preprocessor = ColumnTransformer(\n",
" [('cat-preprocessor', OrdinalEncoder(), cat_cols),\n",
" ('standard-scaler', StandardScaler(), scaling_cols)],\n",
" remainder='passthrough', sparse_threshold=0)"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jdmci5AmTpFs"
},
"source": [
"import numpy as np\n",
"\n",
"splits_seeds = np.random.RandomState(42).permutation(1000)[:10]"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BOQcJ9wTM9Bn"
},
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sksurv.ensemble import RandomSurvivalForest\n",
"from sksurv.metrics import concordance_index_censored\n",
"\n",
"ci_splits = []\n",
"for seed in splits_seeds:\n",
" X_trn, X_test, y_trn, y_test = train_test_split(X, y, random_state=seed)\n",
" X_trn = preprocessor.fit_transform(X_trn)\n",
" X_test = preprocessor.transform(X_test)\n",
" rsf = RandomSurvivalForest()\n",
" rsf.fit(X_trn, y_trn)\n",
" ci_rsf = concordance_index_censored(y_test[\"cens\"], y_test[\"time\"], rsf.predict(X_test))\n",
" ci_splits.append(ci_rsf[0])"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qGGkvh54O8ip",
"outputId": "8aba1b17-f1ff-45f4-cbc5-6626ecdc7f6d"
},
"source": [
"np.mean(ci_splits)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.6946017467642212"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "czk3WPk6UwAK"
},
"source": [
"# Bootstrap"
]
},
{
"cell_type": "code",
"metadata": {
"id": "E4BEIHWlQAPD"
},
"source": [
"X_trn, X_test, y_trn, y_test = train_test_split(X, y, random_state=42)\n",
"X_trn = preprocessor.fit_transform(X_trn)\n",
"X_test = preprocessor.transform(X_test)"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TSMfhe--QDYy",
"outputId": "4c56e32b-f3a4-4e97-d1f8-43d3b4479870"
},
"source": [
"rsf = RandomSurvivalForest(random_state=42)\n",
"rsf.fit(X_trn, y_trn)"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomSurvivalForest(bootstrap=True, max_depth=None, max_features='auto',\n",
" max_leaf_nodes=None, max_samples=None, min_samples_leaf=3,\n",
" min_samples_split=6, min_weight_fraction_leaf=0.0,\n",
" n_estimators=100, n_jobs=None, oob_score=False,\n",
" random_state=42, verbose=0, warm_start=False)"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "XFnfFkrcQfSJ"
},
"source": [
"bootstrap_seeds = np.random.RandomState(0).permutation(1000)[:10]"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "YjHlvji6PNkg"
},
"source": [
"bootstrap_indexes = \\\n",
"[np.random.RandomState(_seed).choice(X_test.shape[0], X_test.shape[0], replace=True) for _seed in bootstrap_seeds]"
],
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "s49YobPnShiZ"
},
"source": [
"ci_bootstrap = \\\n",
"[concordance_index_censored(y_test[bootstrap_indexes[0]][\"cens\"], \n",
" y_test[bootstrap_indexes[0]]['time'], \n",
" rsf.predict(X_test[bootstrap_indexes[0]]))[0] for i in range(len(bootstrap_indexes))]"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IfrKoPHfTWij",
"outputId": "5df21d2a-3808-46ec-b874-a762b10e7221"
},
"source": [
"np.mean(ci_bootstrap)"
],
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.6869005705960908"
]
},
"metadata": {},
"execution_count": 13
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment