Created
September 18, 2020 22:00
-
-
Save josephbima/5fa918abab7793a1f0e6b181a1fce89c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| total_accuracy = 0.0 | |
| total_precision = 0.0 | |
| total_recall = 0.0 | |
| # Iterate over the cv and fit the decision tree using the training set | |
| # https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html | |
| for i, (train_index, test_index) in enumerate(cv.split(X, Y)): | |
| X_train, X_test = X[train_index], X[test_index] | |
| y_train, y_test = Y[train_index], Y[test_index] | |
| tree = DecisionTreeClassifier(criterion="entropy", max_depth=3) | |
| print("Fold {} : Training decision tree classifier over {} points...".format(i, len(y_train))) | |
| sys.stdout.flush() | |
| tree.fit(X_train, y_train) | |
| print("Evaluating classifier over {} points...".format(len(y_test))) | |
| # predict the labels on the test data | |
| y_pred = tree.predict(X_test) | |
| # show the comparison between the predicted and ground-truth labels | |
| conf = confusion_matrix(y_test, y_pred, labels=[0,1]) | |
| accuracy = np.sum(np.diag(conf)) / float(np.sum(conf)) | |
| precision = np.nan_to_num(np.diag(conf) / np.sum(conf, axis=1).astype(float)) | |
| recall = np.nan_to_num(np.diag(conf) / np.sum(conf, axis=0).astype(float)) | |
| total_accuracy += accuracy | |
| total_precision += precision | |
| total_recall += recall | |
| print("The average accuracy is {}".format(total_accuracy/10.0)) | |
| print("The average precision is {}".format(total_precision/10.0)) | |
| print("The average recall is {}".format(total_recall/10.0)) | |
| # Set this to the best model we found, trained on all the data: | |
| best_classifier = RandomForestClassifier(n_estimators=100) | |
| best_classifier.fit(X,Y) | |
| export_graphviz(tree, out_file='tree-random.dot', feature_names = feature_names) | |
| classifier_filename='classifier.pickle' | |
| print("Saving best classifier") | |
| with open(classifier_filename, 'wb') as f: | |
| pickle.dump(best_classifier, f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment