Created
August 20, 2019 18:50
-
-
Save arose13/362f5a1ca95ce475ef45a9b61e18c5ea to your computer and use it in GitHub Desktop.
Computing the mean of a particular model, conditional on some categorical variable
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from sklearn.base import BaseEstimator, RegressorMixin | |
from sklearn.preprocessing import OneHotEncoder | |
from sklearn.exceptions import NotFittedError | |
class StratifiedDummyRegressor(BaseEstimator, RegressorMixin): | |
""" | |
An extremely scalable dummy regression model for computing the mean for each group specified by a column. | |
Single core 3.4Ghz | |
(+1e8 rows, +1e4 cardinality) in < 1 minute | |
""" | |
def __init__(self, stratified_col=None): | |
self.stratified_col = stratified_col # type: str | |
self.preprocessor = OneHotEncoder(categories='auto') | |
self.coef_, self.results_ = 2*[None] | |
def _solve_means(self, x_one_hot, y): | |
""" | |
Sparse solution for m that minimizes the difference from y = Xm | |
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsqr.html | |
""" | |
from scipy.sparse.linalg import lsqr | |
self.coef_, *self.results_ = lsqr(x_one_hot, y) | |
self.results_ = {k: v for k, v in zip( | |
['istop', 'itn', 'r1norm', 'r2norm', 'anorm', 'acond', 'arnorm', 'xnorm'], | |
self.results_ | |
)} | |
def fit(self, X, y): | |
if isinstance(X, pd.DataFrame): | |
x = self.preprocessor.fit_transform(X[[self.stratified_col]].values) | |
self._solve_means(x, y) | |
else: | |
raise NotImplementedError('X must be a pd.DataFrame with named cols') | |
return self | |
def predict(self, X): | |
if self.coef_ is None: | |
raise NotFittedError('Call fit() first') | |
x = self.preprocessor.transform(X[[self.stratified_col]].values) | |
return x @ self.coef_ | |
def score(self, X, y, sample_weight=None): | |
from sklearn.metrics import r2_score | |
return r2_score(y, self.predict(X)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment