Skip to content

Instantly share code, notes, and snippets.

@jdthorpe
Last active September 27, 2019 20:08
Show Gist options
  • Save jdthorpe/8edfe931c7f6b5d83ba68ca4e71d88a6 to your computer and use it in GitHub Desktop.
Save jdthorpe/8edfe931c7f6b5d83ba68ca4e71d88a6 to your computer and use it in GitHub Desktop.
Shim for sklearn.impute.SimpleImputer

When imputing using the sklearn.impute.SimpleImputer with the option strategy="most_frequent" calling the .fit() method takes an obserdly long time. This happens scipy.stats.mode is rediculously inefficient for string variables. For example, imputing a single feature with half a million values takes ~15 minutes without the shim and with this shim it takes less than one milliseond.

Usage:

import SimpleImputerShim  # thats it!
import warnings
import numpy as np
import sklearn.impute._base
from scipy import stats
def _most_frequent(array, extra_value, n_repeat):
"""Compute the most frequent value in a 1d array extended with
[extra_value] * n_repeat, where extra_value is assumed to be not part
of the array."""
# Compute the most frequent value in array only
if array.size > 0:
if array.dtype == "O":
values, counts = np.unique(array, return_counts=True)
most_frequent_count = counts.max()
most_frequent_value = values[np.where(counts == counts.max())[0][0]]
else:
with warnings.catch_warnings():
# stats.mode raises a warning when input array contains objects due
# to incapacity to detect NaNs. Irrelevant here since input array
# has already been NaN-masked.
warnings.simplefilter("ignore", RuntimeWarning)
mode = stats.mode(array)
most_frequent_value = mode[0][0]
most_frequent_count = mode[1][0]
else:
most_frequent_value = 0
most_frequent_count = 0
# Compare to array + [extra_value] * n_repeat
if most_frequent_count == 0 and n_repeat == 0:
return np.nan
elif most_frequent_count < n_repeat:
return extra_value
elif most_frequent_count > n_repeat:
return most_frequent_value
elif most_frequent_count == n_repeat:
# Ties the breaks. Copy the behaviour of scipy.stats.mode
if most_frequent_value < extra_value:
return most_frequent_value
else:
return extra_value
sklearn.impute._base._most_frequent = _most_frequent
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment