Created
November 16, 2017 17:31
-
-
Save alexattia/dfab3c3dbe0f647b8c761e1e932b9961 to your computer and use it in GitHub Desktop.
Image classification using CNN features and linear SVM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function feature_vector = feature_vector_from_cnn(net, names) | |
feature_vector = []; | |
for i = 1:length(names) | |
s = sprintf('../../TD2/practical-category-recognition-2015a/data/images/%s.jpg' , names{i}); | |
im = imread(s); | |
im_ = single(im) ; % note: 255 range | |
im_ = imresize(im_, net.meta.normalization.imageSize(1:2)) ; | |
im_ = bsxfun(@minus,im_,net.meta.normalization.averageImage) ; | |
res = vl_simplenn(net, im_) ; | |
output_fc7 = res(19).x; | |
d = size(output_fc7); | |
output_fc7 = reshape(output_fc7, [1, d(3)]); | |
feature_vector = [feature_vector; output_fc7]; | |
end | |
feature_vector = reshape(feature_vector, fliplr(size(feature_vector))); | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
% load data | |
encoding = 'bovw' ; | |
% category = 'motorbike' ; | |
category = 'aeroplane' ; | |
% category = 'person' ; | |
pos = load(['../../TD2/practical-category-recognition-2015a/data/' category '_train_' encoding '.mat']) ; | |
neg = load(['../../TD2/practical-category-recognition-2015a/data/background_train_' encoding '.mat']) ; | |
names = {pos.names{:}, neg.names{:}}; | |
labels = [ones(1,numel(pos.names)), - ones(1,numel(neg.names))] ; | |
% load a pretrained cnn | |
net = load('data/imagenet-vgg-f.mat') ; | |
vl_simplenn_display(net) ; | |
% create feature vector for each positive image of the category | |
pos.histograms = feature_vector_from_cnn(net, pos.names); | |
neg.histograms = feature_vector_from_cnn(net, neg.names); | |
fprintf("Training images has been passed through the CNN\n"); | |
histograms = [pos.histograms, neg.histograms] ; | |
% L2 normalization | |
histograms = bsxfun(@times, histograms, 1./sqrt(sum(histograms.^2,1))) ; | |
clear pos, clear neg; | |
% Test data | |
pos = load(['../../TD2/practical-category-recognition-2015a/data/' category '_val_' encoding '.mat']) ; | |
neg = load(['../../TD2/practical-category-recognition-2015a/data/background_val_' encoding '.mat']) ; | |
testNames = {pos.names{:}, neg.names{:}}; | |
testLabels = [ones(1,numel(pos.names)), - ones(1,numel(neg.names))] ; | |
pos.histograms = feature_vector_from_cnn(net, pos.names); | |
neg.histograms = feature_vector_from_cnn(net, neg.names); | |
fprintf("Testing images has been passed through the CNN\n"); | |
testHistograms = [pos.histograms, neg.histograms] ; | |
% L2 normalization | |
testHistograms = bsxfun(@times, testHistograms, 1./sqrt(sum(testHistograms.^2,1))) ; | |
clear pos, clear neg; | |
% Train the linear SVM | |
C = 1 ; | |
[w, bias] = trainLinearSVM(histograms, labels, C) ; | |
fprintf("SVM trained\n"); | |
% Evaluate the scores on the training data | |
scores = w' * histograms + bias ; | |
testScores = w' * testHistograms + bias ; | |
% Visualize the ranked list of images | |
figure(3) ; clf ; set(3,'name','Ranked test images (subset)') ; | |
displayRankedImageList(testNames, testScores) ; | |
figure(4) ; clf ; set(4,'name','Precision-recall on test data') ; | |
vl_pr(testLabels, testScores) ; | |
% Print results | |
[drop,drop,info] = vl_pr(testLabels, testScores) ; | |
fprintf('Test AP: %.2f\n', info.auc) ; | |
[drop,perm] = sort(testScores,'descend') ; | |
fprintf('Correctly retrieved in the top 36: %d\n', sum(testLabels(perm(1:36)) > 0)) ; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment