Skip to content

Instantly share code, notes, and snippets.

@michelkana
Created April 27, 2021 07:13
Show Gist options
  • Save michelkana/4b12a2d16b09cb0d41116356262464a2 to your computer and use it in GitHub Desktop.
Save michelkana/4b12a2d16b09cb0d41116356262464a2 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
def running_predictions(prediction_dataset, targets):
n_trees = prediction_dataset.shape[1]
running_percent_1s = np.cumsum(prediction_dataset, axis=1)/np.arange(1,n_trees+1)
running_conclusions = running_percent_1s > 0.5
running_correctnesss = running_conclusions == targets.reshape(-1,1)
return np.mean(running_correctnesss, axis=0)
def plot_bagging_predictions(bagging_size, bagging_train, y_train, bagging_test, y_test,
sm_best_tree_accuracy_test, sm_overfit_accuracy_test, title):
bagging_accuracy_train = running_predictions(bagging_train, y_train)
bagging_accuracy_test = running_predictions(bagging_test, y_test)
fig, ax = plt.subplots(1,1, figsize=(15,5))
bagging_sizes = range(1, bagging_size+1)
ax.plot(bagging_sizes, bagging_accuracy_train, '-o', label='Bagging accuracy on training set', alpha=0.9)
ax.plot(bagging_sizes, bagging_accuracy_test, '-x', label='Bagging accuracy on test set', alpha=0.9)
ax.axhline(y=sm_best_tree_accuracy_test, label='Best single tree accuracy', alpha=0.9)
ax.axhline(y=sm_overfit_accuracy_test, label='Overfit single tree accuracy', c='r', alpha=0.9)
ax.set_title(title, fontsize=16)
ax.set_xlabel('Number of bootstraps', fontsize=14)
ax.set_ylabel('Accuracy', fontsize=14)
ax.set_xticks(bagging_sizes)
ax.legend()
plot_bagging_predictions(bagging_size, bagging_train, y_train, bagging_test, y_test,
sm_best_tree_accuracy_test, sm_overfit_accuracy_test,
'Bagging ensemble performance')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment