Skip to content

Instantly share code, notes, and snippets.

@gabrieldernbach
Last active June 6, 2022 17:40
Show Gist options
  • Save gabrieldernbach/f4310c6f1283ae0e19387da0c6dd25ec to your computer and use it in GitHub Desktop.
Save gabrieldernbach/f4310c6f1283ae0e19387da0c6dd25ec to your computer and use it in GitHub Desktop.
Solving mnist, fast and short
from torchvision.datasets import MNIST
import numpy as np
def data(train):
mnist = MNIST(root='.', download=True, train=train)
X = mnist.data.numpy().reshape(-1, 784) / 255
y = mnist.targets.numpy()
return X, y
# the fastest (nearest centroid classifier)
X, y = data(train=True)
centroids = np.stack([X[y==c].mean(0) for c in np.unique(y)])
X, y = data(train=False)
dist = ((X[..., None] - centroids.T[None, ...])**2).mean(1)
print("nearest centroid acc", (dist.argmin(-1) == y).mean())
# the shortest (nearest neighbor classifier)
from sklearn.neighbors import KNeighborsClassifier
KNeighborsClassifier().fit(*data(train=True)).score(*data(train=False))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment