Created
June 18, 2018 12:49
-
-
Save jeffbaumes/83bd453ed70af48352d4329136ad4a84 to your computer and use it in GitHub Desktop.
Pulling out seaborn regression plotter methods into utility functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def fit_regression(x, y, n_boot=1000, units=None, ci=95, order=1, logistic=False, lowess=False, robust=False, logx=False): | |
"""Fit the regression model.""" | |
# Create the grid for the regression | |
x_min, x_max = [x.min(), x.max()] | |
grid = np.linspace(x_min, x_max, 100) | |
# Fit the regression | |
if order > 1: | |
yhat, yhat_boots = fit_poly(x, y, ci, grid, order, n_boot, units) | |
elif logistic: | |
from statsmodels.genmod.generalized_linear_model import GLM | |
from statsmodels.genmod.families import Binomial | |
yhat, yhat_boots = fit_statsmodels(x, y, ci, grid, GLM, nboot, units, family=Binomial()) | |
elif lowess: | |
# ci = None | |
grid, yhat = fit_lowess(x, y) | |
elif robust: | |
from statsmodels.robust.robust_linear_model import RLM | |
yhat, yhat_boots = fit_statsmodels(x, y, ci, grid, RLM, nboot, units) | |
elif logx: | |
yhat, yhat_boots = fit_logx(x, y, ci, grid, n_boot, units) | |
else: | |
yhat, yhat_boots = fit_fast(x, y, ci, grid, n_boot, units) | |
# Compute the confidence interval at each grid point | |
if ci is None: | |
err_bands = None | |
else: | |
err_bands = utils.ci(yhat_boots, ci, axis=0) | |
return grid, yhat, err_bands | |
def fit_fast(x, y, ci, grid, n_boot, units): | |
"""Low-level regression and prediction using linear algebra.""" | |
def reg_func(_x, _y): | |
return np.linalg.pinv(_x).dot(_y) | |
X, y = np.c_[np.ones(len(x)), x], y | |
grid = np.c_[np.ones(len(grid)), grid] | |
yhat = grid.dot(reg_func(X, y)) | |
if ci is None: | |
return yhat, None | |
beta_boots = algo.bootstrap(X, y, func=reg_func, | |
n_boot=n_boot, units=units).T | |
yhat_boots = grid.dot(beta_boots).T | |
return yhat, yhat_boots | |
def fit_poly(x, y, ci, grid, order, n_boot, units): | |
"""Regression using numpy polyfit for higher-order trends.""" | |
def reg_func(_x, _y): | |
return np.polyval(np.polyfit(_x, _y, order), grid) | |
x, y = x, y | |
yhat = reg_func(x, y) | |
if ci is None: | |
return yhat, None | |
yhat_boots = algo.bootstrap(x, y, func=reg_func, | |
n_boot=n_boot, units=units) | |
return yhat, yhat_boots | |
def fit_statsmodels(x, y, ci, grid, model, n_boot, units, **kwargs): | |
"""More general regression function using statsmodels objects.""" | |
import statsmodels.genmod.generalized_linear_model as glm | |
X, y = np.c_[np.ones(len(x)), x], y | |
grid = np.c_[np.ones(len(grid)), grid] | |
def reg_func(_x, _y): | |
try: | |
yhat = model(_y, _x, **kwargs).fit().predict(grid) | |
except glm.PerfectSeparationError: | |
yhat = np.empty(len(grid)) | |
yhat.fill(np.nan) | |
return yhat | |
yhat = reg_func(X, y) | |
if ci is None: | |
return yhat, None | |
yhat_boots = algo.bootstrap(X, y, func=reg_func, | |
n_boot=n_boot, units=units) | |
return yhat, yhat_boots | |
def fit_lowess(x, y): | |
"""Fit a locally-weighted regression, which returns its own grid.""" | |
from statsmodels.nonparametric.smoothers_lowess import lowess | |
grid, yhat = lowess(y, x).T | |
return grid, yhat | |
def fit_logx(x, y, ci, grid, n_boot, units): | |
"""Fit the model in log-space.""" | |
X, y = np.c_[np.ones(len(x)), x], y | |
grid = np.c_[np.ones(len(grid)), np.log(grid)] | |
def reg_func(_x, _y): | |
_x = np.c_[_x[:, 0], np.log(_x[:, 1])] | |
return np.linalg.pinv(_x).dot(_y) | |
yhat = grid.dot(reg_func(X, y)) | |
if ci is None: | |
return yhat, None | |
beta_boots = algo.bootstrap(X, y, func=reg_func, | |
n_boot=n_boot, units=units).T | |
yhat_boots = grid.dot(beta_boots).T | |
return yhat, yhat_boots |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment