Created
December 17, 2022 22:04
-
-
Save FlynnOConnell/648d1c810cde2bbf1d8d7ae5cd20473b to your computer and use it in GitHub Desktop.
Scoring utilities for neural network.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
# scores.py | |
Module (neuralnetwork): Structures to score and keep scores for evaluated data. | |
""" | |
from __future__ import division | |
import logging | |
import pickle | |
from dataclasses import dataclass | |
from typing import Optional, Iterable, Any | |
import numpy as np | |
import pandas as pd | |
from sklearn.metrics import classification_report | |
from graphs.plot import Plot | |
logger = logging.getLogger(__name__) | |
def save(save_file_path, team): | |
with open(save_file_path, "wb") as f: | |
pickle.dump(team, f) | |
def load(save_file_path): | |
with open(save_file_path, "rb") as f: | |
return pickle.load(f) | |
@dataclass | |
class Scoring(object): | |
def __init__( | |
self, | |
pred: np.ndarray, | |
true: np.ndarray, | |
desc: Optional[str] = "", | |
mat: bool = False, | |
) -> None: | |
""" | |
Class to manage scoring variables from fitted classifiers. | |
Parameters | |
---------- | |
pred : ndarray | |
Fitted model's "predicted" output. | |
true : ndarray | |
Descriptors of each predictable value. | |
mat : bool, optional | |
Whether to output a Confusion Matrix. The default is False. | |
Returns | |
------- | |
None. | |
""" | |
# Input variables | |
self.predicted: Iterable[Any] = pred | |
self.true: Iterable[Any] = true | |
self.classes: list = list(np.unique(self.predicted)) | |
self.descriptor: str = desc | |
self.report: pd.DataFrame = self.get_report() | |
if mat: | |
self.mat: Any = self.get_confusion_matrix() | |
if desc is None: | |
logging.info("No descriptor") | |
pass | |
def get_report(self) -> pd.DataFrame: | |
""" Get classification report""" | |
if self.descriptor: | |
assert self.descriptor in [ | |
"train", | |
"training", | |
"test", | |
"testing", | |
"eval", | |
"val", | |
"evaluate", | |
] | |
self.report = classification_report( | |
self.true, | |
self.predicted, | |
target_names=self.classes, | |
labels=self.classes, | |
output_dict=True, | |
) | |
report_df = pd.DataFrame(data=self.report).transpose() | |
return report_df | |
def get_confusion_matrix(self, caption: Optional[str] = "") -> object: | |
""" Get confusion matrix""" | |
mat = Plot.confusion_matrix( | |
y_true=self.true, | |
y_pred=self.predicted, | |
labels=self.classes, | |
caption=caption, | |
) | |
return mat |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment