Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save brianspiering/3263100c410c649285ca9ab58a59a0e6 to your computer and use it in GitHub Desktop.
Save brianspiering/3263100c410c649285ca9ab58a59a0e6 to your computer and use it in GitHub Desktop.
Filter rows from both X y in scikit-learn
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"toc": true
},
"source": [
"<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
"<div class=\"toc\"><ul class=\"toc-item\"><li><span><a href=\"#Problem-Statement\" data-toc-modified-id=\"Problem-Statement-1\">Problem Statement</a></span></li><li><span><a href=\"#Current-State-of-the-Art\" data-toc-modified-id=\"Current-State-of-the-Art-2\">Current State of the Art</a></span></li><li><span><a href=\"#Working-Code-Solution\" data-toc-modified-id=\"Working-Code-Solution-3\">Working Code Solution</a></span></li><li><span><a href=\"#Sources-of-Inspiration\" data-toc-modified-id=\"Sources-of-Inspiration-4\">Sources of Inspiration</a></span></li></ul></div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Problem Statement\n",
"-----\n",
"\n",
"Is there a way to apply a transformer to both X and y in the pipeline at the same time? \n",
"\n",
"For example, is it possible to remove rows based on a threshold value for a certain column I would have to remove those rows in both X and y at the same index?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Current State of the Art\n",
"-----\n",
"\n",
"The scikit-learn core team is working on it\n",
"> As noted elsewhere, transformers that change the number of samples are not currently supported\n",
"\n",
"Source: https://github.com/scikit-learn/scikit-learn/issues/3855\n",
"\n",
"It turns out that imblearn has already solved this with its [imblearn.pipe.Pipeline](https://imbalanced-learn.org/stable/references/generated/imblearn.pipeline.Pipeline.html#imblearn.pipeline.Pipeline)\n",
"\n",
"> the number of samples can vary during training, which usually is a limitation of the current scikit-learn pipeline."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Working Code Solution\n",
"-----"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
"reset -fs"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from sklearn.base import TransformerMixin, BaseEstimator\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"from imblearn import FunctionSampler\n",
"from imblearn.pipeline import Pipeline as imbPipeline"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
"class Debug(BaseEstimator, TransformerMixin):\n",
" \"Allow introspection of transformation in middle of a pipeline\"\n",
"\n",
" def transform(self, X):\n",
" self.shape = X.shape\n",
" return X\n",
"\n",
" def fit(self, X, y=None, **fit_params):\n",
" return self"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
"def filter_data(X, y):\n",
" \"Filter data - remove from both X and y.\"\n",
" \n",
" # Hard code dropping last row\n",
" return X[:-1], y[:-1]\n",
"\n",
"# All functions that perform compuation should have tests\n",
"X = np.array([[0, 0], \n",
" [1, 1],\n",
" [0, 1]])\n",
" \n",
"y = np.array([0, \n",
" 1,\n",
" 0])\n",
"X_filtered, y_filtered = filter_data(X, y) \n",
"np.testing.assert_array_equal(X_filtered,\n",
" np.array([[0, 0],\n",
" [1, 1]]))\n",
"np.testing.assert_array_equal(y_filtered,\n",
" np.array([0, 1])) "
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
"# Toy data\n",
"X = np.array([[0, 0], \n",
" [1, 1],\n",
" [0, 1]])\n",
" \n",
"y = np.array([0, \n",
" 1,\n",
" 0])"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {},
"outputs": [],
"source": [
"pipe = imbPipeline([('filter', FunctionSampler(func=filter_data)),\n",
" (\"debug\", Debug()),\n",
" ('clf', DecisionTreeClassifier()),\n",
" ])\n",
"pipe.fit(X, y)\n",
"\n",
"# Test that the filter function removed a row\n",
"assert pipe.named_steps[\"debug\"].shape == (2, 2)"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(steps=[('filter',\n",
" FunctionSampler(func=<function outlier_rejection at 0x7f90a6673790>)),\n",
" ('clf', DecisionTreeClassifier())])"
]
},
"execution_count": 127,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Hard coded filtering is useful for debuggin / testing\n",
"# A model is far more interesting / useful\n",
"from sklearn.datasets import make_classification\n",
"from sklearn.ensemble import IsolationForest\n",
"\n",
"X, y = make_classification(n_features=2, n_redundant=0, n_informative=2,\n",
" random_state=1, n_clusters_per_class=1)\n",
"\n",
"def outlier_rejection(X, y):\n",
" \"Drop based on a model\"\n",
" model = IsolationForest(max_samples=100, contamination=0.4)\n",
" model.fit(X)\n",
" y_pred = model.predict(X)\n",
" return X[y_pred == 1], y[y_pred == 1]\n",
"\n",
"pipe = imbPipeline([('filter', FunctionSampler(func=outlier_rejection)),\n",
" ('clf', DecisionTreeClassifier()),\n",
" ])\n",
"pipe.fit(X, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sources of Inspiration\n",
"----\n",
"\n",
"- https://stackoverflow.com/questions/48743032/get-intermediate-data-state-in-scikit-learn-pipeline\n",
"- https://stackoverflow.com/questions/62819600/detect-and-remove-outliers-as-step-of-a-pipeline\n",
"- https://imbalanced-learn.org/stable/auto_examples/applications/plot_outlier_rejections.html\n",
"- https://datascience.stackexchange.com/questions/57924/difference-between-sklearn-make-pipeline-and-imblearn-make-pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<br>\n",
"<br> \n",
"<br>\n",
"\n",
"----"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": false,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": true,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment