Created
March 9, 2018 02:36
-
-
Save Kautenja/3bd6c96cc01890e0538e62c086300bbc to your computer and use it in GitHub Desktop.
a pandas wrapper for the sklearn confusion_matrix method
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A simple method wrapping around `confusion_matrix`.""" | |
def confusion_dataframe(*args, **kwargs): | |
""" | |
Generate an sklearn confusion matrix as a DataFrame. | |
Args: | |
*args: positional args for `confusion_matrix` | |
**kwargs: keyword args for `confusion_matrix` | |
Returns: a confusion matrix formatted into a pandas DataFrame | |
""" | |
import pandas as pd | |
from sklearn.metrics import confusion_matrix | |
# generate the confusion matrix with the given arguments | |
mat = confusion_matrix(*args, **kwargs) | |
# convert the confusion matrix to a pandas dataframe | |
df = pd.DataFrame(mat, columns=['N', 'P'], index=["N", "P"]) | |
# label the x,y indexes | |
df.columns.name = 'Prediction' | |
df.index.name = 'Actual' | |
return df | |
__all__ = ['confusion_dataframe'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment