Created
February 1, 2010 21:07
-
-
Save dwf/292033 to your computer and use it in GitHub Desktop.
A naive, still-sort-of-inefficient k-NN implementation in idiomatic ("vectorized") MATLAB.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function bestclass = knn(train_data, labels, example, k); | |
%kNN-- do k-nearest neighbours classification | |
% | |
% BESTCLASS = knn(TRAIN_DATA, LABELS, EXAMPLE, K) | |
% | |
% Takes TRAIN_DATA, a D x N matrix containing N training examples of dimension | |
% D; LABELS, an N-vector of the (positive integer) classes assigned to each | |
% column of TRAIN_DATA; EXAMPLE, a D-vector consisting of the example we | |
% are trying to classify; and K, the number of neighbours to use in | |
% classifying. | |
% | |
% Returns BESTCLASS, the predicted class of the test point. | |
% | |
% The K-nearest neighbour algorithm works exactly like the one-nearest | |
% neighbour algorithm (which chooses the class containing the example that is | |
% has minimum Euclidean distance to the test example) but instead of using only | |
% the closest neighbour it takes the K closest points and computes the | |
% majority vote. See http://en.wikipedia.org/wiki/KNN for more details. | |
% | |
% Compute the distances to each of the N training examples by duplicating | |
% the test vector using REPMAT, subtracting, elementwise squaring with .^2, | |
% and sum() to get the sums of each column, then sort them. | |
% | |
% NOTE: calling sort with two output arguments as I've done returns a vector | |
% of the sorted distances as the first output argument, and a vector of | |
% indices for the original positions of the sorted elements. i.e. ind(5) | |
% contains the index that element 5 of the sorted array originally appeared | |
% at in the unsorted array. | |
% | |
% By David Warde-Farley -- user AT cs dot toronto dot edu (user = dwf) | |
% Redistributable under the terms of the 3-clause BSD license | |
% (see http://www.opensource.org/licenses/bsd-license.php for details) | |
[val, ind] = sort(sum((repmat(example,1,size(train_data,2)) - train_data).^2)); | |
% Create a vector to store the number of examples observed for each class | |
% among the K neighbours. | |
counts = zeros(max(labels),1); | |
% Loop through the k closest neighbours | |
for neighbour = 1:k, | |
% Get the label of the current neighbour from the LABELS vector | |
% and increment its count. | |
counts(labels(ind(neighbour))) = counts(labels(ind(neighbour))) + 1; | |
end; | |
% Take the class with the highest count (throw away the actual count | |
% but store the index in BESTCLASS). | |
[junk, bestclass] = max(counts); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment