Skip to content

Instantly share code, notes, and snippets.

@pbloem
Last active November 14, 2025 13:24
Show Gist options
  • Select an option

  • Save pbloem/bd8348d58251872d9ca10de4816945e4 to your computer and use it in GitHub Desktop.

Select an option

Save pbloem/bd8348d58251872d9ca10de4816945e4 to your computer and use it in GitHub Desktop.
# -- assignment 1 --
import numpy as np
from urllib import request
import gzip
import pickle
import os
def load_synth(num_train=60_000, num_val=10_000, seed=0):
"""
Load some very basic synthetic data that should be easy to classify. Two features, so that we can plot the
decision boundary (which is an ellipse in the feature space).
:param num_train: Number of training instances
:param num_val: Number of test/validation instances
:param num_features: Number of features per instance
:return: Two tuples and an integer: (xtrain, ytrain), (xval, yval), num_cls. The first contains a matrix of training
data with 2 features as a numpy floating point array, and the corresponding classification labels as a numpy
integer array. The second contains the test/validation data in the same format. The last integer contains the
number of classes (this is always 2 for this function).
"""
np.random.seed(seed)
THRESHOLD = 0.6
quad = np.asarray([[1, -0.05], [1, .4]])
ntotal = num_train + num_val
x = np.random.randn(ntotal, 2)
# compute the quadratic form
q = np.einsum('bf, fk, bk -> b', x, quad, x)
y = (q > THRESHOLD).astype(int)
return (x[:num_train, :], y[:num_train]), (x[num_train:, :], y[num_train:]), 2
def load_mnist(final=False, flatten=True, shuffle_seed=0):
"""
Load the MNIST data.
:param final: If true, return the canonical test/train split. If false, split some validation data from the training
data and keep the test data hidden.
:param flatten: If true, each instance is flattened into a vector, so that the data is returns as a matrix with 768
columns. If false, the data is returned as a 3-tensor preserving each image as a matrix.
:param shuffle_seed If >= 0, the data is shuffled. This keeps the canonical test/train split, but shuffles each
internally before splitting off a validation set. The given number is used as a seed. Note that the original data
is _not_ shuffled, but ordered by writer. This means that there will be a distribution shift between train and val
if the data is not shuffled.
:return: Two tuples and an integer: (xtrain, ytrain), (xval, yval), num_cls. The first contains a matrix of training
data and the corresponding classification labels as a numpy integer array. The second contains the test/validation
data in the same format. The last integer contains the number of classes (this is always 2 for this function).
"""
if not os.path.isfile('mnist.pkl'):
init()
xtrain, ytrain, xtest, ytest = load()
xtl, xsl = xtrain.shape[0], xtest.shape[0]
if flatten:
xtrain = xtrain.reshape(xtl, -1)
xtest = xtest.reshape(xsl, -1)
if shuffle_seed >= 0:
rng = np.random.default_rng(shuffle_seed)
p = rng.permutation(xtrain.shape[0])
xtrain, ytrain = xtrain[p], ytrain[p]
p = rng.permutation(xtest.shape[0])
xtest, ytest = xtest[p], ytest[p]
if not final: # return the flattened images
return (xtrain[:-5000], ytrain[:-5000]), (xtrain[-5000:], ytrain[-5000:]), 10
return (xtrain, ytrain), (xtest, ytest), 10
# Numpy-only MNIST loader. Courtesy of Hyeonseok Jung
# https://github.com/hsjeong5/MNIST-for-Numpy
filename = [
["training_images","train-images-idx3-ubyte.gz"],
["test_images","t10k-images-idx3-ubyte.gz"],
["training_labels","train-labels-idx1-ubyte.gz"],
["test_labels","t10k-labels-idx1-ubyte.gz"]
]
def download_mnist():
base_url = "https://peterbloem.nl/files/mnist/" # "http://yann.lecun.com/exdb/mnist/"
for name in filename:
print("Downloading "+name[1]+"...")
request.urlretrieve(base_url+name[1], name[1])
print("Download complete.")
def save_mnist():
mnist = {}
for name in filename[:2]:
with gzip.open(name[1], 'rb') as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28)
for name in filename[-2:]:
with gzip.open(name[1], 'rb') as f:
mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
with open("mnist.pkl", 'wb') as f:
pickle.dump(mnist,f)
print("Save complete.")
def init():
download_mnist()
save_mnist()
def load():
with open("mnist.pkl",'rb') as f:
mnist = pickle.load(f)
return mnist["training_images"], mnist["training_labels"], mnist["test_images"], mnist["test_labels"]
@drasgo
Copy link

drasgo commented Oct 27, 2020

os module is used but not imported

@pbloem
Copy link
Author

pbloem commented Oct 27, 2020

My bad. Thanks for the pointer.

@b6nrb56g9p-ship-it
Copy link

The link to the dataset is dead, I made a new working code with this dataset: https://www.kaggle.com/datasets/hojjatk/mnist-dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment