Skip to content

Instantly share code, notes, and snippets.

@alexattia
Created November 16, 2017 17:31
Show Gist options
  • Save alexattia/dfab3c3dbe0f647b8c761e1e932b9961 to your computer and use it in GitHub Desktop.
Save alexattia/dfab3c3dbe0f647b8c761e1e932b9961 to your computer and use it in GitHub Desktop.
Image classification using CNN features and linear SVM
function feature_vector = feature_vector_from_cnn(net, names)
feature_vector = [];
for i = 1:length(names)
s = sprintf('../../TD2/practical-category-recognition-2015a/data/images/%s.jpg' , names{i});
im = imread(s);
im_ = single(im) ; % note: 255 range
im_ = imresize(im_, net.meta.normalization.imageSize(1:2)) ;
im_ = bsxfun(@minus,im_,net.meta.normalization.averageImage) ;
res = vl_simplenn(net, im_) ;
output_fc7 = res(19).x;
d = size(output_fc7);
output_fc7 = reshape(output_fc7, [1, d(3)]);
feature_vector = [feature_vector; output_fc7];
end
feature_vector = reshape(feature_vector, fliplr(size(feature_vector)));
end
% load data
encoding = 'bovw' ;
% category = 'motorbike' ;
category = 'aeroplane' ;
% category = 'person' ;
pos = load(['../../TD2/practical-category-recognition-2015a/data/' category '_train_' encoding '.mat']) ;
neg = load(['../../TD2/practical-category-recognition-2015a/data/background_train_' encoding '.mat']) ;
names = {pos.names{:}, neg.names{:}};
labels = [ones(1,numel(pos.names)), - ones(1,numel(neg.names))] ;
% load a pretrained cnn
net = load('data/imagenet-vgg-f.mat') ;
vl_simplenn_display(net) ;
% create feature vector for each positive image of the category
pos.histograms = feature_vector_from_cnn(net, pos.names);
neg.histograms = feature_vector_from_cnn(net, neg.names);
fprintf("Training images has been passed through the CNN\n");
histograms = [pos.histograms, neg.histograms] ;
% L2 normalization
histograms = bsxfun(@times, histograms, 1./sqrt(sum(histograms.^2,1))) ;
clear pos, clear neg;
% Test data
pos = load(['../../TD2/practical-category-recognition-2015a/data/' category '_val_' encoding '.mat']) ;
neg = load(['../../TD2/practical-category-recognition-2015a/data/background_val_' encoding '.mat']) ;
testNames = {pos.names{:}, neg.names{:}};
testLabels = [ones(1,numel(pos.names)), - ones(1,numel(neg.names))] ;
pos.histograms = feature_vector_from_cnn(net, pos.names);
neg.histograms = feature_vector_from_cnn(net, neg.names);
fprintf("Testing images has been passed through the CNN\n");
testHistograms = [pos.histograms, neg.histograms] ;
% L2 normalization
testHistograms = bsxfun(@times, testHistograms, 1./sqrt(sum(testHistograms.^2,1))) ;
clear pos, clear neg;
% Train the linear SVM
C = 1 ;
[w, bias] = trainLinearSVM(histograms, labels, C) ;
fprintf("SVM trained\n");
% Evaluate the scores on the training data
scores = w' * histograms + bias ;
testScores = w' * testHistograms + bias ;
% Visualize the ranked list of images
figure(3) ; clf ; set(3,'name','Ranked test images (subset)') ;
displayRankedImageList(testNames, testScores) ;
figure(4) ; clf ; set(4,'name','Precision-recall on test data') ;
vl_pr(testLabels, testScores) ;
% Print results
[drop,drop,info] = vl_pr(testLabels, testScores) ;
fprintf('Test AP: %.2f\n', info.auc) ;
[drop,perm] = sort(testScores,'descend') ;
fprintf('Correctly retrieved in the top 36: %d\n', sum(testLabels(perm(1:36)) > 0)) ;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment