Skip to content

Instantly share code, notes, and snippets.

@ltiao
Created January 16, 2019 14:37
Show Gist options
  • Save ltiao/86f7f3fdd3c59bed81abc0844a2a8d1c to your computer and use it in GitHub Desktop.
Save ltiao/86f7f3fdd3c59bed81abc0844a2a8d1c to your computer and use it in GitHub Desktop.
Recursive generator function inside a closure. Either the most beautiful piece of code I've written. Or the most contrived and convoluted. There is no in between.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import StratifiedShuffleSplit"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"random_state = 42"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"X, y = load_iris(return_X_y=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(150, 4)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"StratifiedShuffleSplit(n_splits=1, random_state=0, test_size=0.4,\n",
" train_size=None)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sss = StratifiedShuffleSplit(n_splits=1, test_size=0.4, random_state=0)\n",
"sss"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([128, 101, 52, 28, 94, 38, 76, 141, 148, 75, 126, 87, 69,\n",
" 45, 8, 115, 4, 127, 79, 84, 108, 82, 140, 59, 131, 10,\n",
" 22, 97, 13, 95, 63, 135, 33, 15, 56, 105, 16, 27, 32,\n",
" 78, 104, 26, 92, 60, 41, 58, 119, 93, 112, 11, 146, 72,\n",
" 83, 116, 62, 91, 120, 48, 57, 7, 133, 106, 31, 132, 80,\n",
" 73, 66, 111, 107, 20, 30, 25, 42, 14, 70, 138, 35, 137,\n",
" 2, 18, 124, 122, 74, 143, 43, 117, 29, 125, 96, 34]),\n",
" array([121, 109, 36, 144, 1, 9, 39, 147, 98, 89, 23, 149, 118,\n",
" 44, 61, 100, 65, 37, 113, 142, 64, 24, 145, 46, 99, 53,\n",
" 102, 19, 54, 139, 40, 130, 71, 86, 110, 47, 136, 51, 81,\n",
" 123, 50, 49, 68, 103, 129, 85, 88, 0, 17, 6, 3, 134,\n",
" 90, 21, 5, 55, 114, 12, 67, 77]))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_index, test_index = next(sss.split(X, y))\n",
"train_index, test_index"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(90, 4)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[train_index].shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(60, 4)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[test_index].shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0, 1, 2]), array([20, 20, 20]))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.unique(y[test_index], return_counts=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0, 1, 2]), array([30, 30, 30]))"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.unique(y[train_index], return_counts=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def recursive_stratified_shuffle_split(sizes, random_state=None):\n",
"\n",
" head, *tail = sizes\n",
" sss = StratifiedShuffleSplit(n_splits=1, test_size=head, random_state=random_state)\n",
"\n",
" def split(X, y):\n",
"\n",
" a_index, b_index = next(sss.split(X, y))\n",
"\n",
" yield a_index\n",
"\n",
" if tail:\n",
"\n",
" split_tail = recursive_stratified_shuffle_split(sizes=tail, random_state=random_state)\n",
" \n",
" for ind in split_tail(X[b_index], y[b_index]):\n",
" \n",
" yield b_index[ind]\n",
"\n",
" else:\n",
"\n",
" yield b_index\n",
" \n",
" return split"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# first split 70/80 and split the remainder 60/20\n",
"split = recursive_stratified_shuffle_split(sizes=[80, 20])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[array([139, 138, 7, 34, 109, 128, 24, 132, 76, 96, 22, 101, 83,\n",
" 140, 146, 46, 67, 8, 61, 44, 88, 85, 1, 9, 35, 74,\n",
" 145, 0, 65, 6, 57, 136, 73, 4, 54, 43, 69, 55, 75,\n",
" 131, 99, 60, 18, 79, 125, 5, 111, 63, 12, 149, 13, 89,\n",
" 106, 25, 122, 113, 119, 49, 80, 11, 59, 52, 115, 142, 38,\n",
" 45, 20, 118, 130, 123]),\n",
" array([ 64, 40, 15, 114, 36, 124, 50, 2, 107, 53, 141, 30, 87,\n",
" 62, 17, 39, 134, 105, 19, 70, 66, 42, 129, 116, 86, 37,\n",
" 21, 94, 72, 41, 71, 84, 68, 110, 148, 82, 98, 137, 31,\n",
" 48, 47, 102, 127, 23, 133, 27, 51, 95, 121, 77, 120, 32,\n",
" 104, 16, 58, 147, 33, 103, 92, 135]),\n",
" array([ 56, 14, 112, 143, 93, 26, 108, 78, 144, 100, 117, 29, 126,\n",
" 97, 28, 10, 81, 3, 90, 91])]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(split(X, y))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(70, 4), (60, 4), (20, 4)]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[X[index].shape for index in split(X, y)]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(array([0, 1, 2]), array([24, 23, 23])),\n",
" (array([0, 1, 2]), array([20, 20, 20])),\n",
" (array([0, 1, 2]), array([6, 7, 7]))]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[np.unique(y[index], return_counts=True) for index in split(X, y)]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# first split 40/60 and split the remainder 50/50\n",
"split = recursive_stratified_shuffle_split(sizes=[0.4, 0.5])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(90, 4), (30, 4), (30, 4)]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[X[index].shape for index in split(X, y)]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(array([0, 1, 2]), array([30, 30, 30])),\n",
" (array([0, 1, 2]), array([10, 10, 10])),\n",
" (array([0, 1, 2]), array([10, 10, 10]))]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[np.unique(y[index], return_counts=True) for index in split(X, y)]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment