Skip to content

Instantly share code, notes, and snippets.

@skaae
Created January 7, 2015 20:48
Show Gist options
  • Save skaae/b1ed48fe958b6c3d1041 to your computer and use it in GitHub Desktop.
Save skaae/b1ed48fe958b6c3d1041 to your computer and use it in GitHub Desktop.
confusion matrix
--[[ A Confusion Matrix class
Example:
conf = optim.ConfusionMatrix( {'cat','dog','person'} ) -- new matrix
conf:zero() -- reset matrix
for i = 1,N do
conf:add( neuralnet:forward(sample), label ) -- accumulate errors
end
print(conf) -- print matrix
image.display(conf:render()) -- render matrix
]]
--local ConfusionMatrix = torch.class('optim.MyConfusionMatrix')
local ConfusionMatrix = torch.class('ConfusionMatrix')
function ConfusionMatrix:__init(nclasses, classes)
if type(nclasses) == 'table' then
classes = nclasses
nclasses = #classes
end
self.mat = torch.FloatTensor(nclasses,nclasses):zero()
self.valids = torch.FloatTensor(nclasses):zero()
self.unionvalids = torch.FloatTensor(nclasses):zero()
self.nclasses = nclasses
self.totalValid = 0
self.averageValid = 0
if classes then
self.classes = classes
else
local c = {}
for i = 1,nclasses do c[i] = tostring(i) end
self.classes = c
end
end
function ConfusionMatrix:add(prediction, target)
if type(prediction) == 'number' then
-- comparing numbers
self.mat[target][prediction] = self.mat[target][prediction] + 1
elseif type(target) == 'number' then
-- prediction is a vector, then target assumed to be an index
self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses)
self.prediction_1d:copy(prediction)
local _,prediction = self.prediction_1d:max(1)
self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1
else
-- both prediction and target are vectors
self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses)
self.prediction_1d:copy(prediction)
self.target_1d = self.target_1d or torch.FloatTensor(self.nclasses)
self.target_1d:copy(target)
local _,prediction = self.prediction_1d:max(1)
local _,target = self.target_1d:max(1)
self.mat[target[1]][prediction[1]] = self.mat[target[1]][prediction[1]] + 1
end
end
function ConfusionMatrix:batchAdd(predictions, targets)
local preds, targs, __
if predictions:dim() == 1 then
-- predictions is a vector of classes
preds = predictions
elseif predictions:dim() == 2 then
-- prediction is a matrix of class likelihoods
if predictions:size(2) == 1 then
-- or prediction just needs flattening
preds = predictions:select(2,1)
else
__,preds = predictions:max(2)
preds:resize(preds:size(1))
end
else
error("predictions has invalid number of dimensions")
end
if targets:dim() == 1 then
-- targets is a vector of classes
targs = targets
elseif targets:dim() == 2 then
-- targets is a matrix of one-hot rows
if targets:size(2) == 1 then
-- or targets just needs flattening
targs = targets:select(2,1)
else
__,targs = targets:max(2)
targs:resize(targs:size(1))
end
else
error("targets has invalid number of dimensions")
end
--loop over each pair of indices
for i = 1,preds:size(1) do
self.mat[targs[i]][preds[i]] = self.mat[targs[i]][preds[i]] + 1
end
end
function ConfusionMatrix:zero()
self.mat:zero()
self.valids:zero()
self.unionvalids:zero()
self.totalValid = 0
self.averageValid = 0
end
local function isNaN(number)
return number ~= number
end
local function remNaN(x,self)
for i = 1, self.nclasses do
if isNaN(x[{1,i}]) then
x[{1,i}] = 0
end
end
return x
end
local function getErrors(self)
local tp = torch.diag(self.mat):resize(1,self.nclasses )
local fn = (torch.sum(self.mat,2)-torch.diag(self.mat)):t()
local fp = torch.sum(self.mat,1)-torch.diag(self.mat)
local tn = torch.Tensor(1,self.nclasses):fill(torch.sum(self.mat)):typeAs(tp) - tp - fn - fp
return tp, tn, fp, fn
end
function ConfusionMatrix:getConfusion()
return getErrors(self)
end
function ConfusionMatrix:printscore(name,mytitle)
local score,class_app,class,val
if name == "sensitivity" then
score = self:sensitivity()
elseif name == 'specificity' then
score = self:specificity()
elseif name == 'positivePredictiveValue' then
score = self:positivePredictiveValue()
elseif name == 'negativePredictiveValue' then
score = self:negativePredictiveValue()
elseif name == 'falsePositiveRate' then
score = self:falsePositiveRate()
elseif name == 'falseDiscoveryRate' then
score = self:falseDiscoveryRate()
elseif name == 'classAccuracy' then
score = self:classAccuracy()
elseif name == 'F1' then
score = self:F1()
elseif name == 'matthewsCorrelation' then
score = self:matthewsCorrelation()
else
print("Unknown error type")
error()
end
if mytitle then
name = mytitle..": "..name
end
local ln = "|"
local ls = "|"
for i = 1,self.nclasses do
val = string.format("%.4f", score[{1,i}])
class = self.classes[i]
class_app = math.max(1,4-math.floor(#class / 2))
class = string.rep(" ",class_app)..class..string.rep(" ",class_app+1-#class%2)
ln = ln..class.."|"
ls =ls.." "..val
if (#ls+1) < #ln then
ls = ls .. string.rep(" ",#ln-#ls-1)
end
ls = ls .."|"
end
local line = string.rep("-",#ln)
ln = ln.."\n"..line.."\n"..ls
print(line)
print(string.rep(" ",math.min(0,math.floor(#ls/2)-math.floor(#name/2) ))..name)
print(line)
print(ln)
print(line)
end
function ConfusionMatrix:accuracy()
local tp, tn, fp, fn = getErrors(self)
return tp:sum() / self.mat:sum()
end
function ConfusionMatrix:matthewsCorrelation()
local tp, tn, fp, fn = getErrors(self)
local numerator = torch.cmul(tp,tn) - torch.cmul(fp,fn)
local denominator = torch.sqrt((tp+fp):cmul(tp+fn):cmul(tn+fp):cmul(tn+fn))
local mcc = torch.cdiv(numerator,denominator)
local mcc = remNaN(mcc,self)
return mcc
end
function ConfusionMatrix:sensitivity()
local tp, tn, fp, fn = getErrors(self)
local res = torch.cdiv(tp, tp + fn )
local res = remNaN(res,self)
return res -- TP / (TP + FN)
end
function ConfusionMatrix:specificity()
local tp, tn, fp, fn = getErrors(self)
local res = torch.cdiv(tn, tn + fp) -- TN / (TN + FP)
local res = remNaN(res,self)
return res -- TP / (TP + FN)
end
function ConfusionMatrix:positivePredictiveValue()
local tp, tn, fp, fn = getErrors(self)
local res = torch.cdiv(tp, tp + fp ) -- TP / (TP + FP)
local res = remNaN(res,self)
return res -- TP / (TP + FN)
end
function ConfusionMatrix:negativePredictiveValue()
local tp, tn, fp, fn = getErrors(self)
local res = torch.cdiv(tn, tn + fn ) -- TN / (TN + FN)
local res = remNaN(res,self)
return res -- TP / (TP + FN)
end
function ConfusionMatrix:falsePositiveRate()
local tp, tn, fp, fn = getErrors(self)
local res = torch.cdiv(fp, fp + tn) -- FP / (FP + TN)
local res = remNaN(res,self)
return res -- TP / (TP + FN)
end
function ConfusionMatrix:falseDiscoveryRate()
local tp, tn, fp, fn = getErrors(self)
local res = torch.cdiv(fp, tp + fp) -- FP / (TP + FP)
local res = remNaN(res,self)
return res -- TP / (TP + FN)
end
function ConfusionMatrix:classAccuracy()
local tp, tn, fp, fn = getErrors(self)
local res = torch.cdiv(tp + tn, tp + tn + fp + fn) -- (TP + FN) / (TN + TP + FN + FP)
local res = remNaN(res,self)
return res -- TP / (TP + FN)
end
function ConfusionMatrix:F1()
local tp, tn, fp, fn = getErrors(self)
local res = torch.cdiv(tp * 2, tp * 2 + fp + fn) -- (2*TP)/(TP*2+fp+fn)
local res = remNaN(res,self)
return res -- TP / (TP + FN)
end
function ConfusionMatrix:updateValids()
local total = 0
for t = 1,self.nclasses do
self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum()
self.unionvalids[t] = self.mat[t][t] / (self.mat:select(1,t):sum()+self.mat:select(2,t):sum()-self.mat[t][t])
total = total + self.mat[t][t]
end
self.totalValid = total / self.mat:sum()
self.averageValid = 0
self.averageUnionValid = 0
local nvalids = 0
local nunionvalids = 0
for t = 1,self.nclasses do
if not isNaN(self.valids[t]) then
self.averageValid = self.averageValid + self.valids[t]
nvalids = nvalids + 1
end
if not isNaN(self.valids[t]) and not isNaN(self.unionvalids[t]) then
self.averageUnionValid = self.averageUnionValid + self.unionvalids[t]
nunionvalids = nunionvalids + 1
end
end
self.averageValid = self.averageValid / nvalids
self.averageUnionValid = self.averageUnionValid / nunionvalids
end
function ConfusionMatrix:__tostring__()
self:updateValids()
local str = {'ConfusionMatrix:\n'}
local nclasses = self.nclasses
table.insert(str, '[')
for t = 1,nclasses do
local pclass = self.valids[t] * 100
pclass = string.format('%2.3f', pclass)
if t == 1 then
table.insert(str, '[')
else
table.insert(str, ' [')
end
for p = 1,nclasses do
table.insert(str, string.format('%8d', self.mat[t][p]))
end
if self.classes and self.classes[1] then
if t == nclasses then
table.insert(str, ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n')
else
table.insert(str, '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n')
end
else
if t == nclasses then
table.insert(str, ']] ' .. pclass .. '% \n')
else
table.insert(str, '] ' .. pclass .. '% \n')
end
end
end
table.insert(str, ' + average row correct: ' .. (self.averageValid*100) .. '% \n')
table.insert(str, ' + average rowUcol correct (VOC measure): ' .. (self.averageUnionValid*100) .. '% \n')
table.insert(str, ' + global correct: ' .. (self.totalValid*100) .. '%')
return table.concat(str)
end
function ConfusionMatrix:render(sortmode, display, block, legendwidth)
-- args
local confusion = self.mat
local classes = self.classes
local sortmode = sortmode or 'score' -- 'score' or 'occurrence'
local block = block or 25
local legendwidth = legendwidth or 200
local display = display or false
-- legends
local legend = {
['score'] = 'Confusion matrix [sorted by scores, global accuracy = %0.3f%%, per-class accuracy = %0.3f%%]',
['occurrence'] = 'Confusiong matrix [sorted by occurences, accuracy = %0.3f%%, per-class accuracy = %0.3f%%]'
}
-- parse matrix / normalize / count scores
local diag = torch.FloatTensor(#classes)
local freqs = torch.FloatTensor(#classes)
local unconf = confusion
local confusion = confusion:clone()
local corrects = 0
local total = 0
for target = 1,#classes do
freqs[target] = confusion[target]:sum()
corrects = corrects + confusion[target][target]
total = total + freqs[target]
confusion[target]:div( math.max(confusion[target]:sum(),1) )
diag[target] = confusion[target][target]
end
-- accuracies
local accuracy = corrects / total * 100
local perclass = 0
local total = 0
for target = 1,#classes do
if confusion[target]:sum() > 0 then
perclass = perclass + diag[target]
total = total + 1
end
end
perclass = perclass / total * 100
freqs:div(unconf:sum())
-- sort matrix
if sortmode == 'score' then
_,order = torch.sort(diag,1,true)
elseif sortmode == 'occurrence' then
_,order = torch.sort(freqs,1,true)
else
error('sort mode must be one of: score | occurrence')
end
-- render matrix
local render = torch.zeros(#classes*block, #classes*block)
for target = 1,#classes do
for prediction = 1,#classes do
render[{ { (target-1)*block+1,target*block }, { (prediction-1)*block+1,prediction*block } }] = confusion[order[target]][order[prediction]]
end
end
-- add grid
for target = 1,#classes do
render[{ {target*block},{} }] = 0.1
render[{ {},{target*block} }] = 0.1
end
-- create rendering
require 'image'
require 'qtwidget'
require 'qttorch'
local win1 = qtwidget.newimage( (#render)[2]+legendwidth, (#render)[1] )
image.display{image=render, win=win1}
-- add legend
for i in ipairs(classes) do
-- background cell
win1:setcolor{r=0,g=0,b=0}
win1:rectangle((#render)[2],(i-1)*block,legendwidth,block)
win1:fill()
-- %
win1:setfont(qt.QFont{serif=false, size=fontsize})
local gscale = freqs[order[i]]/freqs:max()*0.9+0.1 --3/4
win1:setcolor{r=gscale*0.5+0.2,g=gscale*0.5+0.2,b=gscale*0.8+0.2}
win1:moveto((#render)[2]+10,i*block-block/3)
win1:show(string.format('[%2.2f%% labels]',math.floor(freqs[order[i]]*10000+0.5)/100))
-- legend
win1:setfont(qt.QFont{serif=false, size=fontsize})
local gscale = diag[order[i]]*0.8+0.2
win1:setcolor{r=gscale,g=gscale,b=gscale}
win1:moveto(120+(#render)[2]+10,i*block-block/3)
win1:show(classes[order[i]])
for j in ipairs(classes) do
-- scores
local score = confusion[order[j]][order[i]]
local gscale = (1-score)*(score*0.8+0.2)
win1:setcolor{r=gscale,g=gscale,b=gscale}
win1:moveto((i-1)*block+block/5,(j-1)*block+block*2/3)
win1:show(string.format('%02.0f',math.floor(score*100+0.5)))
end
end
-- generate tensor
local t = win1:image():toTensor()
-- display
if display then
image.display{image=t, legend=string.format(legend[sortmode],accuracy,perclass)}
end
-- return rendering
return t
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment