Last active
December 17, 2015 20:48
-
-
Save jnothman/5670070 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import fractions | |
| import numpy as np | |
| from sklearn.base import MetaEstimatorMixin, BaseEstimator | |
| class DiscretelyWeightedTrainer(BaseEstimator, MetaEstimatorMixin): | |
| """Replicates samples according to `sample_weight` to fit an estimator | |
| This effectively provides a `sample_weight` parameter to the `fit` method | |
| of estimators otherwise lacking it. It transforms the weights into integers | |
| and then replicates samples according to those integral weights. | |
| """ | |
| def __init__(self, sub_estimator, resolution='min_gap'): | |
| self.sub_estimator = sub_estimator | |
| self.resolution = resolution | |
| def fit(self, X, y=None, sample_weight=None): | |
| if sample_weight is None: | |
| self.sub_estimator.fit(X, y) | |
| return self | |
| if len(sample_weight) != X.shape[0]: | |
| raise ValueError('Different number of samples in X and sample_weight') | |
| sample_weight = np.asarray(sample_weight) | |
| sample_weight = np.round(sample_weight / | |
| self._calc_resolution(sample_weight)) | |
| sample_weight = self._factorize(sample_weight.astype(int)) | |
| X = np.repeat(X, sample_weight, axis=0) | |
| if y is not None: | |
| y = np.repeat(y, sample_weight, axis=0) | |
| self.sub_estimator.fit(X, y) | |
| return self | |
| def _calc_resolution(self, sample_weight): | |
| # TODO: allow for fractions of these | |
| resolution = self.resolution | |
| if resolution == 'std': | |
| return np.std(sample_weight) | |
| elif resolution == 'min_gap': | |
| gaps = np.diff(np.unique(np.hstack([[0], sample_weight]))) | |
| return np.min(gaps) | |
| return resolution | |
| @staticmethod | |
| def _factorize(weight): | |
| gcd = weight[0] | |
| for w in np.unique(weight): | |
| if gcd == 1: | |
| return weight | |
| gcd = fractions.gcd(w, gcd) | |
| return weight // gcd | |
| if __name__ == '__main__': | |
| class TestingEstimator(BaseEstimator): | |
| def fit(self, X, y): | |
| print(X, y) | |
| est = DiscretelyWeightedTrainer(TestingEstimator(), resolution='min_gap') | |
| est.fit(np.arange(20).reshape(4,5), sample_weight=[5,3,3,1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment