Last active
March 21, 2022 03:52
-
-
Save ortsed/51ba28668fa6e3db6cbb652bd30a6efd to your computer and use it in GitHub Desktop.
Sklearn Model Summary
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_summary(model, X, y, columns=[]): | |
""" | |
Takes a sklearn model and outputs basic stats, | |
based on input features (X) and target (y) | |
""" | |
import pandas as pd | |
from scipy import stats | |
import numpy as np | |
lm = model | |
params = np.append(lm.intercept_,lm.coef_) | |
predictions = lm.predict(X) | |
def r_adjusted(r_squared, X, y): | |
""" | |
Outputs the R^2 adjusted value for a sklearn linear regression model | |
based on r_squared and the X and y of the data sets being used | |
""" | |
return 1 - (1-r_squared)*(len(y)-1)/(len(y)-X.shape[1]-1) | |
r_squared = model.score(X, y) | |
#from sklearn.feature_selection import f_regression | |
#f_values, p_values = f_regression(X, y) # for each coefficient | |
print("R^2: %s" % r_squared) | |
print("R^2 Adjusted: %s" % r_adjusted(r_squared, X, y)) | |
if hasattr(model, "aic"): print("AIC: %s" % model.aic) | |
if hasattr(model, "bic"): print("BIC: %s" % model.bic) | |
newX = pd.DataFrame({"Constant":np.ones(len(X))}).join(pd.DataFrame(X)) | |
MSE = (sum((y-predictions)**2))/(len(newX)-len(newX.columns)) | |
# Note if you don't want to use a DataFrame replace the two lines above with | |
# newX = np.append(np.ones((len(X),1)), X, axis=1) | |
# MSE = (sum((y-predictions)**2))/(len(newX)-len(newX[0])) | |
var_b = MSE*(np.linalg.inv(np.dot(newX.T,newX)).diagonal()) | |
sd_b = np.sqrt(var_b) | |
ts_b = params/ sd_b | |
p_values =[2*(1-stats.t.cdf(np.abs(i),(len(newX)-1))) for i in ts_b] | |
sd_b = np.round(sd_b,3) | |
ts_b = np.round(ts_b,3) | |
p_values = np.round(p_values,3) | |
params = np.round(params,4) | |
myDF3 = pd.DataFrame() | |
myDF3["Coefficients"],myDF3["Standard Errors"],myDF3["t values"],myDF3["Probabilites"] = [params,sd_b,ts_b,p_values] | |
if any(columns): myDF3.index = ["Intercept"] + list(columns) | |
print(myDF3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment