Created
July 21, 2021 06:46
-
-
Save angeligareta/b42785185ee245e846455cf2d6b343ff to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "stratified_sampling.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "SPVHS92rqQ7L" | |
}, | |
"source": [ | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"from sklearn import datasets\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"\n", | |
"def get_dataset_partitions_pd(df, train_split=0.8, val_split=0.1, test_split=0.1, target_variable=None):\n", | |
" assert (train_split + test_split + val_split) == 1\n", | |
" \n", | |
" # Only allows for equal validation and test splits\n", | |
" assert val_split == test_split \n", | |
"\n", | |
" # Shuffle\n", | |
" df_sample = df.sample(frac=1, random_state=12)\n", | |
"\n", | |
" # Specify seed to always have the same split distribution between runs\n", | |
" # If target variable is provided, generate stratified sets\n", | |
" if target_variable is not None:\n", | |
" grouped_df = df_sample.groupby(target_variable)\n", | |
" arr_list = [np.split(g, [int(train_split * len(g)), int((1 - val_split) * len(g))]) for i, g in grouped_df]\n", | |
"\n", | |
" train_ds = pd.concat([t[0] for t in arr_list])\n", | |
" val_ds = pd.concat([t[1] for t in arr_list])\n", | |
" test_ds = pd.concat([v[2] for v in arr_list])\n", | |
"\n", | |
" else:\n", | |
" indices_or_sections = [int(train_split * len(df)), int((1 - val_split) * len(df))]\n", | |
" train_ds, val_ds, test_ds = np.split(df_sample, indices_or_sections)\n", | |
" \n", | |
" return train_ds, val_ds, test_ds" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_HFvDin4n4Xd", | |
"outputId": "03aebdc2-d739-4380-e8e6-370181d1a9a9" | |
}, | |
"source": [ | |
"dataset = datasets.load_iris()\n", | |
"X = pd.DataFrame(dataset.data)\n", | |
"y = pd.DataFrame(dataset.target)\n", | |
"print(f'Distribution in original set: \\n{y[0].value_counts().sort_index() / len(y)}')" | |
], | |
"execution_count": 72, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Distribution in original set: \n", | |
"0 0.333333\n", | |
"1 0.333333\n", | |
"2 0.333333\n", | |
"Name: 0, dtype: float64\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "AeUTn588tsWK", | |
"outputId": "bea9e0eb-69e6-4953-d9c9-24831849bab6" | |
}, | |
"source": [ | |
"train_ds, val_ds, test_ds = get_dataset_partitions_pd(y)\n", | |
"print(f'Distribution in training set: \\n{train_ds[0].value_counts().sort_index() / len(train_ds)}\\n\\n'+\n", | |
" f'Distribution in validation set: \\n{val_ds[0].value_counts().sort_index() / len(val_ds)}\\n\\n'+\n", | |
" f'Distribution in testing set: \\n{test_ds[0].value_counts().sort_index() / len(test_ds)}')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Distribution in training set: \n", | |
"0 0.341667\n", | |
"1 0.358333\n", | |
"2 0.300000\n", | |
"Name: 0, dtype: float64\n", | |
"\n", | |
"Distribution in validation set: \n", | |
"0 0.333333\n", | |
"1 0.266667\n", | |
"2 0.400000\n", | |
"Name: 0, dtype: float64\n", | |
"\n", | |
"Distribution in testing set: \n", | |
"0 0.266667\n", | |
"1 0.200000\n", | |
"2 0.533333\n", | |
"Name: 0, dtype: float64\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Rd0RAVJBo3XU", | |
"outputId": "84760724-16b6-424d-a5bc-f0ba99f5f0aa" | |
}, | |
"source": [ | |
"train_ds, val_ds, test_ds = get_dataset_partitions_pd(y, target_variable=0)\n", | |
"print(f'Distribution in training set: \\n{train_ds[0].value_counts().sort_index() / len(train_ds)}\\n\\n'+\n", | |
" f'Distribution in validation set: \\n{val_ds[0].value_counts().sort_index() / len(val_ds)}\\n\\n'+\n", | |
" f'Distribution in testing set: \\n{test_ds[0].value_counts().sort_index() / len(test_ds)}')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Distribution in training set: \n", | |
"0 0.333333\n", | |
"1 0.333333\n", | |
"2 0.333333\n", | |
"Name: 0, dtype: float64\n", | |
"\n", | |
"Distribution in validation set: \n", | |
"0 0.333333\n", | |
"1 0.333333\n", | |
"2 0.333333\n", | |
"Name: 0, dtype: float64\n", | |
"\n", | |
"Distribution in testing set: \n", | |
"0 0.333333\n", | |
"1 0.333333\n", | |
"2 0.333333\n", | |
"Name: 0, dtype: float64\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment