Created
August 22, 2018 15:10
-
-
Save elena-roff/d17f257ee41686e4a3d49b4e0df4d3b9 to your computer and use it in GitHub Desktop.
Regression: performance statistics + plots
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics import r2_score, mean_squared_error | |
import matplotlib.lines as mlines | |
import matplotlib.transforms as mtransforms | |
def stats(value, pred): | |
res = [ | |
{ | |
'metric': 'R2', | |
'value' : r2_score(value, pred) | |
} | |
] | |
res.append( | |
{ | |
'metric': 'RMSE', | |
'value' : sc.sqrt(mean_squared_error(value, pred)) | |
} | |
) | |
res.append( | |
{ | |
'metric': 'N', | |
'value' : len(value) | |
} | |
) | |
res = pd.DataFrame(res) | |
return res | |
def ba_plot(x, y, title="", x_lab="sell_price lead", y_lab='sell_price difference', s=2): | |
fig, ax = plt.subplots() | |
ax.scatter(x=x, y=y, s=s) | |
ax.axhline(sc.mean(y), xmax=ax.get_xlim()[1], linestyle='--', color='black') | |
ax.axhline(sc.mean(y)+sc.std(y), xmax=ax.get_xlim()[1], linestyle='--', color='red') | |
ax.axhline(sc.mean(y)-sc.std(y), xmax=ax.get_xlim()[1], linestyle='--', color='red') | |
ax.set_title(title) | |
ax.set_xlabel(x_lab) | |
ax.set_ylabel(y_lab) | |
met = [ | |
{ | |
'ba_metric': 'mean diff', 'value': sc.mean(y) | |
}, | |
{ | |
'ba_metric': 'std diff', 'value': sc.std(y) | |
} | |
] | |
met = pd.DataFrame(met) | |
display(met) | |
plt.show() | |
def bland_altman(true_values, pred_values): | |
x = true_values | |
y = true_values - pred_values | |
y_rel = (true_values-pred_values) / true_values | |
return pd.DataFrame( | |
{ | |
'x': x,'y_abs': y, 'y_rel': y_rel, | |
} | |
) | |
def corr_plot(value, pred, x_lab="", y_lab="", title="", s=2): | |
fig, ax = plt.subplots() | |
ax.scatter(x=value, y=pred, s=s) | |
ax.set_title(title) | |
ax.set_xlabel(x_lab) | |
ax.set_ylabel(y_lab) | |
plt.show() | |
def analysis(df, target='sell_price', pred=None): | |
metrics = stats(df[target].values, df[pred].values) | |
display(metrics) | |
corr_plot(df[target].values, df[pred].values, title='correlation plot', y_lab=pred, x_lab=target) | |
ba = bland_altman(df[target], df[pred]) | |
ba_plot(ba['x'], ba['y_abs'], title="absolute values BA", y_lab="abs. {}-{} difference".format(target, pred)) | |
ba_plot(ba['x'], ba['y_rel'], title="relative values BA", y_lab="rel. {}-{} difference".format(target, pred)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment