Skip to content

Instantly share code, notes, and snippets.

@aksimhal
Forked from wrgr/pr_eval_func.m
Last active August 29, 2015 14:15
Show Gist options
  • Select an option

  • Save aksimhal/0307d8e47f5e9269564d to your computer and use it in GitHub Desktop.

Select an option

Save aksimhal/0307d8e47f5e9269564d to your computer and use it in GitHub Desktop.
function metrics = pr_eval(detectVol, truthVol)
% W. Gray Roncal - 02.12.2015
% Feedback welcome and encouraged. Let's make this better!
% Other options (size filters, morphological priors, etc. can be added.
% z = rand(100,100);
% t = z > 0.8;
%
tic
%% params
count = 1;
c = 0;
pad = [0, 0, 0]; %this is important for edges - set according to your needs
minSize2D = 0; % size limit - can do more complex things too
overlap = 1; % pixels required to overlap to count
thresholdVals = [0:0.05:1]; %choose this based on your sweeping needs
maxCount = length(thresholdVals);
detectVol = detectVol(pad(1)+1:end-pad(1),pad(2)+1:end-pad(2), pad(3)+1:end-pad(3));
truthVol = truthVol(pad(1)+1:end-pad(1),pad(2)+1:end-pad(2), pad(3)+1:end-pad(3));
truthObj = bwconncomp(truthVol,18);
for threshold = thresholdVals
c = c + 1;
fprintf('NOW PROCESSING SEARCH %d of %d...\n', c, maxCount)
% POST PROCESSING
temp_prob = detectVol;
temp_prob(temp_prob >= threshold) = 1;
temp_prob(temp_prob < 1) = 0;
% Check 2D size limits first
cc = bwconncomp(temp_prob,4);
% Apply object size filter
for jj = 1:cc.NumObjects
if length(cc.PixelIdxList{jj}) < minSize2D
temp_prob(cc.PixelIdxList{jj}) = 0;
end
end
% re-run connected components
detectcc = bwconncomp(temp_prob,18);
detectMtx = labelmatrix(detectcc);
% POST PROCESSING
stats2 = regionprops(detectcc,'PixelList','Area','Centroid','PixelIdxList');
fprintf('Number Synapses detected: %d\n',length(stats2));
% 3D metrics
TP = 0; FP = 0; FN = 0; TP2 = 0;
for j = 1:truthObj.NumObjects
temp = detectMtx(truthObj.PixelIdxList{j});
if sum(temp > 0) >= overlap
TP = TP + 1;
% TODO something fancier
% any detected objects can only be used
% once, so remove them here.
% This does not penalize (or reward) fragmented
% detections
detectIdxUsed = unique(temp);
detectIdxUsed(detectIdxUsed == 0) = [];
for jjj = 1:length(detectIdxUsed)
detectMtx(detectcc.PixelIdxList{detectIdxUsed(jjj)}) = 0;
end
else
FN = FN + 1;
end
end
for j = 1:detectcc.NumObjects
temp = truthVol(detectcc.PixelIdxList{j});
%sum(temp>0)
if sum(temp > 0) >= overlap
%TP = TP + 1; %don't do this again, because already
% considered above
TP2 = TP2 + 1;
else
FP = FP + 1;
end
end
metrics.precision(count) = TP./(TP+FP);
metrics.recall(count) = TP./(TP+FN);
metrics.thresh(count) = threshold;
fprintf('precision: %f recall: %f threshold: %f \n',metrics.precision(count),metrics.recall(count), threshold);
count = count + 1;
end
% figure, plot(mm.recall, mm.precision,'o'), grid on
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment