Skip to content

Instantly share code, notes, and snippets.

@josephbima
Created September 18, 2020 22:00
Show Gist options
  • Select an option

  • Save josephbima/5fa918abab7793a1f0e6b181a1fce89c to your computer and use it in GitHub Desktop.

Select an option

Save josephbima/5fa918abab7793a1f0e6b181a1fce89c to your computer and use it in GitHub Desktop.
total_accuracy = 0.0
total_precision = 0.0
total_recall = 0.0
# Iterate over the cv and fit the decision tree using the training set
# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html
for i, (train_index, test_index) in enumerate(cv.split(X, Y)):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = Y[train_index], Y[test_index]
tree = DecisionTreeClassifier(criterion="entropy", max_depth=3)
print("Fold {} : Training decision tree classifier over {} points...".format(i, len(y_train)))
sys.stdout.flush()
tree.fit(X_train, y_train)
print("Evaluating classifier over {} points...".format(len(y_test)))
# predict the labels on the test data
y_pred = tree.predict(X_test)
# show the comparison between the predicted and ground-truth labels
conf = confusion_matrix(y_test, y_pred, labels=[0,1])
accuracy = np.sum(np.diag(conf)) / float(np.sum(conf))
precision = np.nan_to_num(np.diag(conf) / np.sum(conf, axis=1).astype(float))
recall = np.nan_to_num(np.diag(conf) / np.sum(conf, axis=0).astype(float))
total_accuracy += accuracy
total_precision += precision
total_recall += recall
print("The average accuracy is {}".format(total_accuracy/10.0))
print("The average precision is {}".format(total_precision/10.0))
print("The average recall is {}".format(total_recall/10.0))
# Set this to the best model we found, trained on all the data:
best_classifier = RandomForestClassifier(n_estimators=100)
best_classifier.fit(X,Y)
export_graphviz(tree, out_file='tree-random.dot', feature_names = feature_names)
classifier_filename='classifier.pickle'
print("Saving best classifier")
with open(classifier_filename, 'wb') as f:
pickle.dump(best_classifier, f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment