Created
May 12, 2015 07:19
-
-
Save skaae/88d9b381933877de828d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class ConfusionMatrix: | |
""" | |
Simple confusion matrix class | |
row is the true class, column is the predicted class | |
""" | |
def __init__(self, n_classes, class_names=None): | |
self.n_classes = n_classes | |
if class_names is None: | |
self.class_names = map(str, range(n_classes)) | |
else: | |
self.class_names = class_names | |
# find max class_name and pad | |
max_len = max(map(len, self.class_names)) | |
self.max_len = max_len | |
for idx,name in enumerate(self.class_names): | |
if len(self.class_names) < max_len: | |
self.class_names[idx] = name + " "*(max_len-len(name)) | |
self.mat = np.zeros((n_classes,n_classes),dtype='int') | |
def __str__(self): | |
# calucate row and column sums | |
col_sum = np.sum(self.mat, axis=1) | |
row_sum = np.sum(self.mat, axis=0) | |
s = [] | |
mat_str = self.mat.__str__() | |
mat_str = mat_str.replace('[','').replace(']','').split('\n') | |
for idx,row in enumerate(mat_str): | |
if idx == 0: | |
pad = " " | |
else: | |
pad = "" | |
class_name = self.class_names[idx] | |
class_name = " " + class_name + " |" | |
row_str = class_name + pad + row | |
row_str += " |" + str(col_sum[idx]) | |
s.append(row_str) | |
row_sum = [(self.max_len+4)*" "+" ".join(map(str, row_sum))] | |
hline = [(1+self.max_len)*" "+"-"*len(row_sum[0])] | |
s = hline + s + hline + row_sum | |
# add linebreaks | |
s_out = [line+'\n' for line in s] | |
return "".join(s_out) | |
def batchAdd(self,y_true,y_pred): | |
assert y_true.shape == y_pred.shape | |
assert len(y_true) == len(y_pred) | |
assert max(y_true) < self.n_classes | |
assert max(y_pred) < self.n_classes | |
y_true = y_true.flatten() | |
y_pred = y_pred.flatten() | |
for i in range(len(y_true)): | |
self.mat[y_true[i],y_pred[i]] += 1 | |
def batchAddMask(self, y_true,y_pred,mask): | |
assert y_true.shape == y_pred.shape | |
assert y_true.shape == mask.shape | |
assert mask.dtype == np.bool, "performance will be wrong if this is ints" | |
y_true_masked = y_true[mask] | |
y_pred_masked = y_pred[mask] | |
self.batchAdd(y_true_masked, y_pred_masked) | |
def setFullDataset(self, y_true, mask): | |
self.y_true_full = y_true | |
self.mask_full = mask | |
def checkMatrix(self, y_true_full, mask_full): | |
assert mask_full.dtype == np.bool | |
assert y_true_full.shape == mask_full.shape | |
y_true_full = y_true_full[mask_full] | |
# calculate number of samples added to true classes | |
# sum over rows | |
n_per_classes_mat = np.sum(self.mat,axis=1) | |
n_classes = int(np.max(y_true_full) + 1) | |
n_per_classes_data = [] | |
for c in range(n_classes): | |
n_per_classes_data.append(np.sum(y_true_full == c)) | |
n_per_classes_data = np.array(n_per_classes_data) | |
assert all(n_per_classes_data == n_per_classes_mat) | |
print "" | |
def zero(self): | |
self.mat.fill(0) | |
def getErrors(self): | |
""" | |
Calculate differetn error types | |
:return: vetors of true postives (tp) false negatives (fn), false positives (fp) and true negatives (tn) | |
pos 0 is first class, pos 1 is second class etc. | |
""" | |
tp = np.asarray(np.diag(self.mat).flatten(),dtype='float') | |
fn = np.asarray(np.sum(self.mat, axis=1).flatten(),dtype='float') - tp | |
fp = np.asarray(np.sum(self.mat, axis=0).flatten(),dtype='float') - tp | |
tn = np.asarray(np.sum(self.mat)*np.ones(self.n_classes).flatten(),dtype='float') - tp - fn - fp | |
return tp,fn,fp,tn | |
def accuracy(self): | |
""" | |
Calculates global accuracy | |
:return: accuracy | |
:example: >>> conf = ConfusionMatrix(3) | |
>>> conf.batchAdd([0,0,1],[0,0,2]) | |
>>> print conf.accuracy() | |
""" | |
tp, _, _, _ = self.getErrors() | |
n_samples = np.sum(self.mat) | |
return np.sum(tp) / n_samples | |
def sensitivity(self): | |
tp, tn, fp, fn = self.getErrors() | |
res = tp / (tp + fn) | |
res = res[~np.isnan(res)] | |
return res | |
def specificity(self): | |
tp, tn, fp, fn = self.getErrors() | |
res = tn / (tn + fp) | |
res = res[~np.isnan(res)] | |
return res | |
def positivePredictiveValue(self): | |
tp, tn, fp, fn = self.getErrors() | |
res = tp / (tp + fp) | |
res = res[~np.isnan(res)] | |
return res | |
def negativePredictiveValue(self): | |
tp, tn, fp, fn = self.getErrors() | |
res = tn / (tn + fn) | |
res = res[~np.isnan(res)] | |
return res | |
def falsePositiveRate(self): | |
tp, tn, fp, fn = self.getErrors() | |
res = fp / (fp + tn) | |
res = res[~np.isnan(res)] | |
return res | |
def falseDiscoveryRate(self): | |
tp, tn, fp, fn = self.getErrors() | |
res = fp / (tp + fp) | |
res = res[~np.isnan(res)] | |
return res | |
def F1(self): | |
tp, tn, fp, fn = self.getErrors() | |
res = (2*tp) / (2*tp + fp + fn) | |
res = res[~np.isnan(res)] | |
return res | |
def matthewsCorrelation(self): | |
tp, tn, fp, fn = self.getErrors() | |
numerator = tp*tn - fp*fn | |
denominator = np.sqrt((tp + fp)*(tp + fn)*(tn + fp)*(tn + fn)) | |
res = numerator / denominator | |
res = res[~np.isnan(res)] | |
return res | |
def getMat(self): | |
return self.mat | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment