Last active
August 6, 2020 15:42
-
-
Save alberduris/06f5095ddbc293501d65c8d2741899f3 to your computer and use it in GitHub Desktop.
[SkLearn TrainTestSplit OneHot Behaviour] #JupyterNotebook #CodeSnippet #@todo: BUG: `stratify=data_holder.Y_data` instead of `stratify=data_holder.Y_data` #@bug: ¿strange behaviour? If stratify is one_hot then changes w.r.t string or int encoded #@bug: See SkLearn_TrainTestSplit_OneHot_Behaviour.ipynb #Others
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Sklearn Train/Test split random test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"from sklearn.preprocessing import LabelEncoder\n", | |
"from keras.utils import np_utils\n", | |
"\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ids = np.arange(20)\n", | |
"ids_2 = np.random.normal(size=(20,))\n", | |
"\n", | |
"classes = ['label1', 'label2', 'label3']\n", | |
"\n", | |
"labels = np.random.choice(classes, size=20)\n", | |
"\n", | |
"label_encoder = LabelEncoder()\n", | |
"\n", | |
"encoded_labels = label_encoder.fit_transform(labels)\n", | |
"\n", | |
"one_hot_labels = np_utils.to_categorical(encoded_labels)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Labels: ['label3' 'label3' 'label2' 'label3' 'label2' 'label2' 'label3' 'label3'\n", | |
" 'label3' 'label1' 'label2' 'label3' 'label1' 'label1' 'label2' 'label1'\n", | |
" 'label2' 'label3' 'label1' 'label2']\n", | |
"Encoded labels: [2 2 1 2 1 1 2 2 2 0 1 2 0 0 1 0 1 2 0 1]\n", | |
"OneHot labels: \n", | |
"[[0. 0. 1.]\n", | |
" [0. 0. 1.]\n", | |
" [0. 1. 0.]\n", | |
" [0. 0. 1.]\n", | |
" [0. 1. 0.]\n", | |
" [0. 1. 0.]\n", | |
" [0. 0. 1.]\n", | |
" [0. 0. 1.]\n", | |
" [0. 0. 1.]\n", | |
" [1. 0. 0.]\n", | |
" [0. 1. 0.]\n", | |
" [0. 0. 1.]\n", | |
" [1. 0. 0.]\n", | |
" [1. 0. 0.]\n", | |
" [0. 1. 0.]\n", | |
" [1. 0. 0.]\n", | |
" [0. 1. 0.]\n", | |
" [0. 0. 1.]\n", | |
" [1. 0. 0.]\n", | |
" [0. 1. 0.]]\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Labels: {}'.format(labels))\n", | |
"print('Encoded labels: {}'.format(encoded_labels))\n", | |
"print('OneHot labels: \\n{}'.format(one_hot_labels))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([ 6, 5, 7, 11, 12, 3, 18, 4, 15, 10, 8, 1, 14, 16]),\n", | |
" array([ 2, 13, 0, 17, 9, 19]))" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tr_ids, te_ids, tr_labels, te_labels = train_test_split(ids, labels, stratify=labels, test_size=0.3, random_state=2019)\n", | |
"tr_ids, te_ids" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([ 6, 5, 7, 11, 12, 3, 18, 4, 15, 10, 8, 1, 14, 16]),\n", | |
" array([ 2, 13, 0, 17, 9, 19]))" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tr_ids, te_ids, tr_labels, te_labels = train_test_split(ids, encoded_labels, stratify=encoded_labels, test_size=0.3, random_state=2019)\n", | |
"tr_ids, te_ids" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([19, 2, 15, 18, 5, 7, 11, 1, 17, 13, 10, 4, 8, 6]),\n", | |
" array([ 3, 9, 12, 0, 16, 14]))" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tr_ids, te_ids, tr_labels, te_labels = train_test_split(ids, one_hot_labels, stratify=one_hot_labels, test_size=0.3, random_state=2019)\n", | |
"tr_ids, te_ids" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([[0., 0., 1.],\n", | |
" [1., 0., 0.],\n", | |
" [1., 0., 0.],\n", | |
" [0., 0., 1.],\n", | |
" [0., 1., 0.],\n", | |
" [0., 1., 0.]], dtype=float32),\n", | |
" array(['label3', 'label1', 'label1', 'label3', 'label2', 'label2'],\n", | |
" dtype='<U6'))" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"te_labels, label_encoder.classes_[np.argmax(te_labels, 1)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment