Skip to content

Instantly share code, notes, and snippets.

@erykml
Created June 3, 2019 16:35
Show Gist options
  • Select an option

  • Save erykml/8e266c871b8a79433ce4ad4aba31a0b3 to your computer and use it in GitHub Desktop.

Select an option

Save erykml/8e266c871b8a79433ce4ad4aba31a0b3 to your computer and use it in GitHub Desktop.
%matplotlib inline
%config InlineBackend.figure_format ='retina'
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.stats.api as sms
sns.set_style('darkgrid')
sns.mpl.rcParams['figure.figsize'] = (15.0, 9.0)
def linearity_test(model, y):
'''
Function for visually inspecting the assumption of linearity in a linear regression model.
It plots observed vs. predicted values and residuals vs. predicted values.
Args:
* model - fitted OLS model from statsmodels
* y - observed values
'''
fitted_vals = model.predict()
resids = model.resid
fig, ax = plt.subplots(1,2)
sns.regplot(x=fitted_vals, y=y, lowess=True, ax=ax[0], line_kws={'color': 'red'})
ax[0].set_title('Observed vs. Predicted Values', fontsize=16)
ax[0].set(xlabel='Predicted', ylabel='Observed')
sns.regplot(x=fitted_vals, y=resids, lowess=True, ax=ax[1], line_kws={'color': 'red'})
ax[1].set_title('Residuals vs. Predicted Values', fontsize=16)
ax[1].set(xlabel='Predicted', ylabel='Residuals')
linearity_test(lin_reg, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment