Skip to content

Instantly share code, notes, and snippets.

@zhenghaoz
Created October 14, 2018 02:45
Show Gist options
  • Save zhenghaoz/12f1dde5f5427d551859ffc35f29447c to your computer and use it in GitHub Desktop.
Save zhenghaoz/12f1dde5f5427d551859ffc35f29447c to your computer and use it in GitHub Desktop.
KNN for MNIST
# KNN for MNIST
# Author: Zhenghao Zhang <[email protected]>
using Statistics
using MLDatasets
# KNN
function knn(test_x, train_x, train_y, k)
num = size(test_x)[1]
pred_y = zeros(num)
# Iterate over samples in test set
for i = 1:num
# Print progress
print("\rProgress: ", i, "/", num)
# The distances to all samples in train set
distances = reshape(sum(test_x[i:i,:] .⊻ train_x, dims=2), :)
# Find k nearest neighbors
neighbors = sortperm(distances)[1:k]
# Find the most number
numbers = train_y[neighbors]
candidates = unique(numbers)
pred_y[i] = candidates[argmax([sum(numbers.==i) for i in candidates])]
end
# Print newline
println()
# Return predictions
pred_y
end
# Load dataset
train_x, train_y = MNIST.traindata()
test_x, test_y = MNIST.testdata()
# Reshape dataset
train_x = transpose(reshape(train_x, 28*28, :))
test_x = transpose(reshape(test_x, 28*28, :))
# Binarization
threshold = 0.5
test_x = (test_x .> threshold)
train_x = (train_x .> threshold)
# Predict with K=3
pred_y = knn(test_x, train_x, train_y, 3)
# Error rate: 4.21%
println("Error rate: ", mean(pred_y .!= test_y)*100, "%")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment