Created
November 29, 2017 17:58
-
-
Save morganmcg1/ad401d3bf8b3054ba13099f42aa7c253 to your computer and use it in GitHub Desktop.
For multi class problems decide how the model performed in a binary situation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# for multi class problems decide whether the model was correct going long or short | |
def correct_quadrant_classifications(cls_labels, cm): | |
#Find nuumber of correct classifications | |
correct_clf = 0 | |
if len(cls_labels) % 2 == 0: | |
for i in range(0, len(cls_labels)): | |
correct_clf = correct_clf + cm[i][i] | |
# calculate the rate of correct predictions | |
correct_clf_rate = correct_clf/cm.sum() | |
#print('\n{:.0f} out of {:.0f} predictions were correct'.format(correct_clf, cm.sum())) | |
# Find how many predictions were in the correct quadrant | |
correct_quad = 0 | |
if len(cls_labels) % 2 == 0: | |
for i in range(0, int(len(cls_labels)/2)): | |
for j in range(0, int(len(cls_labels)/2)): | |
correct_quad = correct_quad + cm[i][j] | |
for i in range(int(len(cls_labels)/2) ,len(cls_labels)): | |
for j in range(int(len(cls_labels)/2) ,len(cls_labels)): | |
correct_quad = correct_quad + cm[i][j] | |
print('''\nTotal number of predictions in the correct quadrant was {:.0f} out of {:.0f}, | |
this corresponds to a {:.1f}% classification accuracy rate\n'''.format( | |
correct_quad, cm.sum(), correct_quad/cm.sum()*100)) | |
# calculate the rate of predictions in the correct quadrant, e.g. long or short | |
correct_quad_rate = correct_quad/cm.sum() | |
else: | |
print('cannot calculate correct quadrant classification on odd number of classes') | |
return correct_clf_rate, correct_quad_rate |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment