Created
March 5, 2021 06:26
-
-
Save louity/c6b0c91810c9957f57c56c952323b29e to your computer and use it in GitHub Desktop.
K-nearest-neighbors on CIFAR-10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 58.4 % accuracy with K-nearest-neighbor classifier on CIFAR. | |
# Images are whitened and normalized | |
import pickle | |
import numpy as np | |
import os | |
from sklearn.neighbors import KNeighborsClassifier | |
def compute_whitening_op(X, reg=0.1): | |
X = X.astype('float64') | |
mean = X.mean(axis=0, keepdims=True) | |
X_ = X - mean | |
covariance = 1.0/ X_.shape[0] * X_.T.dot(X_) | |
(E, V) = np.linalg.eig(covariance) | |
print(f'Eigvals min {E.min()} max {E.max()}') | |
sqrt_zca_eigs = np.sqrt(E+reg) | |
inv_sqrt_zca_eigs = np.diag(np.power(sqrt_zca_eigs, -1)) | |
whitening_op = V.dot(inv_sqrt_zca_eigs).dot(V.T).astype('float32') | |
return mean.astype('float32'), whitening_op | |
data_dir = './cifar-10-batches-py' | |
if not os.path.exists(data_dir): | |
print('You need to Download CIFAR 10 at the url "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" and to extract it') | |
exit() | |
test_filename = os.path.join(data_dir, 'test_batch') | |
train_batch_filenames = [os.path.join(data_dir, f'data_batch_{i}') for i in range(1, 6)] | |
X_train, y_train = [], [] | |
for train_batch_filename in train_batch_filenames: | |
with open(train_batch_filename, 'rb') as fo: | |
train_data = pickle.load(fo, encoding='bytes') | |
X_train.append(train_data[b'data']) | |
y_train.append(train_data[b'labels']) | |
X_train, y_train = np.concatenate(X_train), np.concatenate(y_train) | |
X_train = X_train.astype('float32') / 255.0 | |
reg = 0.1 | |
mean, whitening_op = compute_whitening_op(X_train, reg) | |
# flip aug | |
X_train_flip = X_train.reshape((-1, 3, 32, 32))[:,:,:,::-1].reshape(X_train.shape) | |
X_train = np.concatenate([X_train, X_train_flip]) | |
y_train = np.tile(y_train, 2) | |
with open(test_filename, 'rb') as fo: | |
test_data = pickle.load(fo, encoding='bytes') | |
X_test = test_data[b'data'] | |
y_test = test_data[b'labels'] | |
X_test, y_test = np.array(X_test), np.array(y_test) | |
X_test = X_test.astype('float32') / 255.0 | |
print(X_test.shape, y_test.shape) | |
k = 50 | |
neigh = KNeighborsClassifier(n_neighbors=k, n_jobs=-1, weights='distance') | |
X_train_whit = (X_train - mean).dot(whitening_op) | |
X_train_whit /= np.linalg.norm(X_train_whit, axis=1, keepdims=True) + 1e-9 | |
neigh.fit(X_train_whit, y_train) | |
X_test_whit = (X_test - mean).dot(whitening_op) | |
X_test_whit /= np.linalg.norm(X_test_whit, axis=1, keepdims=True) + 1e-9 | |
y_test_pred = neigh.predict(X_test_whit) | |
acc = (y_test_pred == y_test).mean() | |
print(f'k={k}, whit_reg={reg}, accuracy: {100*acc:.2f} %') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment