Last active
July 18, 2021 12:35
-
-
Save GaelVaroquaux/047d13d738d89ddcd4bc297edcd53233 to your computer and use it in GitHub Desktop.
Linear deconfounding in a fit-transform API
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A scikit-learn like transformer to remove a confounding effect on X. | |
""" | |
from sklearn.base import BaseEstimator, TransformerMixin, clone | |
from sklearn.linear_model import LinearRegression | |
import numpy as np | |
class DeConfounder(BaseEstimator, TransformerMixin): | |
""" A transformer removing the effect of y on X. | |
""" | |
def __init__(self, confound_model=LinearRegression()): | |
self.confound_model = confound_model | |
def fit(self, X, y): | |
if y.ndim == 1: | |
y = y[:, np.newaxis] | |
confound_model = clone(self.confound_model) | |
confound_model.fit(y, X) | |
self.confound_model_ = confound_model | |
return self | |
def transform(self, X, y): | |
if y.ndim == 1: | |
y = y[:, np.newaxis] | |
X_confounds = self.confound_model_.predict(y) | |
return X - X_confounds | |
def test_deconfounder(): | |
rng = np.random.RandomState(0) | |
# An in-sample test | |
X = rng.normal(size=(100, 10)) | |
y = rng.normal(size=100) | |
deconfounder = DeConfounder() | |
deconfounder.fit(X, y) | |
X_clean = deconfounder.transform(X, y) | |
# Check that X_clean is indeed orthogonal to y | |
np.testing.assert_almost_equal(X_clean.T.dot(y), 0) | |
# An out-of-sample test | |
# Generate data where X is a linear function of y | |
y = rng.normal(size=100) | |
coef = rng.normal(size=10) | |
X = coef * y[:, np.newaxis] | |
X_train = X[:-10] | |
y_train = y[:-10] | |
deconfounder.fit(X_train, y_train) | |
X_clean = deconfounder.transform(X, y) | |
# Check that X_clean is indeed orthogonal to y | |
np.testing.assert_almost_equal(X_clean.T.dot(y), 0) | |
if __name__ == '__main__': | |
test_deconfounder() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment