Created
August 26, 2019 20:25
-
-
Save prrao87/37700b16faf1bd361c2c323d86127fe8 to your computer and use it in GitHub Desktop.
Base utilities class for all classifiers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from sklearn.metrics import f1_score, accuracy_score | |
class Base: | |
"""Base class that houses common utilities for reading in test data | |
and calculating model accuracy and F1 scores. | |
""" | |
def __init__(self) -> None: | |
pass | |
def read_data(self, fname: str, lower_case: bool=False, | |
colnames=['truth', 'text']) -> pd.DataFrame: | |
"Read in test data into a Pandas DataFrame" | |
df = pd.read_csv(fname, sep='\t', header=None, names=colnames) | |
df['truth'] = df['truth'].str.replace('__label__', '') | |
# Categorical data type for truth labels | |
df['truth'] = df['truth'].astype(int).astype('category') | |
# Optional lowercase for test data (if model was trained on lowercased text) | |
if lower_case: | |
df['text'] = df['text'].str.lower() | |
return df | |
def accuracy(self, df: pd.DataFrame) -> None: | |
"Prediction accuracy (percentage) and F1 score" | |
acc = accuracy_score(df['truth'], df['pred'])*100 | |
f1 = f1_score(df['truth'], df['pred'], average='macro') | |
print("Accuracy: {}\nMacro F1-score: {}".format(acc, f1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment