Skip to content

Instantly share code, notes, and snippets.

@jnothman
Last active December 17, 2015 20:48
Show Gist options
  • Select an option

  • Save jnothman/5670070 to your computer and use it in GitHub Desktop.

Select an option

Save jnothman/5670070 to your computer and use it in GitHub Desktop.
import fractions
import numpy as np
from sklearn.base import MetaEstimatorMixin, BaseEstimator
class DiscretelyWeightedTrainer(BaseEstimator, MetaEstimatorMixin):
"""Replicates samples according to `sample_weight` to fit an estimator
This effectively provides a `sample_weight` parameter to the `fit` method
of estimators otherwise lacking it. It transforms the weights into integers
and then replicates samples according to those integral weights.
"""
def __init__(self, sub_estimator, resolution='min_gap'):
self.sub_estimator = sub_estimator
self.resolution = resolution
def fit(self, X, y=None, sample_weight=None):
if sample_weight is None:
self.sub_estimator.fit(X, y)
return self
if len(sample_weight) != X.shape[0]:
raise ValueError('Different number of samples in X and sample_weight')
sample_weight = np.asarray(sample_weight)
sample_weight = np.round(sample_weight /
self._calc_resolution(sample_weight))
sample_weight = self._factorize(sample_weight.astype(int))
X = np.repeat(X, sample_weight, axis=0)
if y is not None:
y = np.repeat(y, sample_weight, axis=0)
self.sub_estimator.fit(X, y)
return self
def _calc_resolution(self, sample_weight):
# TODO: allow for fractions of these
resolution = self.resolution
if resolution == 'std':
return np.std(sample_weight)
elif resolution == 'min_gap':
gaps = np.diff(np.unique(np.hstack([[0], sample_weight])))
return np.min(gaps)
return resolution
@staticmethod
def _factorize(weight):
gcd = weight[0]
for w in np.unique(weight):
if gcd == 1:
return weight
gcd = fractions.gcd(w, gcd)
return weight // gcd
if __name__ == '__main__':
class TestingEstimator(BaseEstimator):
def fit(self, X, y):
print(X, y)
est = DiscretelyWeightedTrainer(TestingEstimator(), resolution='min_gap')
est.fit(np.arange(20).reshape(4,5), sample_weight=[5,3,3,1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment