Last active
November 4, 2022 13:21
-
-
Save loretoparisi/41b918add11893d761d0ec12a3a4e1aa to your computer and use it in GitHub Desktop.
Calculate FastText Classifier Confusion Matrix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/local/bin/python3 | |
# @author cpuhrsch https://github.com/cpuhrsch | |
# @author Loreto Parisi [email protected] | |
import argparse | |
import numpy as np | |
from sklearn.metrics import confusion_matrix | |
def parse_labels(path): | |
with open(path, 'r') as f: | |
return np.array(list(map(lambda x: x[9:], f.read().split()))) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Display confusion matrix.') | |
parser.add_argument('test', help='Path to test labels') | |
parser.add_argument('predict', help='Path to predictions') | |
args = parser.parse_args() | |
test_labels = parse_labels(args.test) | |
pred_labels = parse_labels(args.predict) | |
eq = test_labels == pred_labels | |
print("Accuracy: " + str(eq.sum() / len(test_labels))) | |
print(confusion_matrix(test_labels, pred_labels)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Usage example:
Supposed that our test set must be normalized since it has string labels with no prefix (while FastText has a
__label__
default prefix:and you get