Skip to content

Instantly share code, notes, and snippets.

@eugeneyan
Created February 21, 2021 19:15
Show Gist options
  • Save eugeneyan/00c3200abf15a73fb32cab472f16a2a8 to your computer and use it in GitHub Desktop.
Save eugeneyan/00c3200abf15a73fb32cab472f16a2a8 to your computer and use it in GitHub Desktop.
Test RandomForest accuracy increase
def test_dt_increase_acc(dummy_titanic):
X_train, y_train, X_test, y_test = dummy_titanic
acc_list = []
auc_list = []
for num_trees in [1, 3, 7, 15]:
rf = RandomForest(num_trees=num_trees, depth_limit=7, col_subsampling=0.7, row_subsampling=0.7)
rf.fit(X_train, y_train)
pred = rf.predict(X_test)
pred_binary = np.round(pred)
acc_list.append(accuracy_score(y_test, pred_binary))
auc_list.append(roc_auc_score(y_test, pred))
assert sorted(acc_list) == acc_list, 'Accuracy should increase as number of trees increases.'
assert sorted(auc_list) == auc_list, 'AUC ROC should increase as number of trees increases.'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment