Created
December 17, 2022 22:03
-
-
Save FlynnOConnell/d1f98ef449bb3abd667e502cfc73e6b3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
# _validate.py | |
Module (nn_utils): | |
Validation functions for SVM. Determine which steps have occured, | |
Prevent steps from occuring before previous attributes have been set | |
""" | |
import numpy as np | |
class _validate: | |
@staticmethod | |
def _full_predict(pred, true): | |
assert (np.unique(pred)) == (np.unique(true)) | |
@staticmethod | |
def _validate_shape(x, y): | |
assert x.shape[0] == y.shape[0] | |
@staticmethod | |
def _check_fitted(model): | |
if not hasattr(model, "fit"): | |
raise TypeError(f"{model} is not an estimator instance.") | |
@staticmethod | |
def _num_classes(classes): | |
if len(np.unique(classes)) == 2: | |
print("Estimator class type: Multivariate") | |
return "binary" | |
elif len(np.unique(classes)) > 2: | |
print("Estimator class type: Multivariate") | |
return "multivariate" | |
else: | |
raise AttributeError( | |
"Number of classes do not match binary or multivariate" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment