Skip to content

Instantly share code, notes, and snippets.

@dniku
Last active August 29, 2015 14:19
Show Gist options
  • Select an option

  • Save dniku/7b67e4db822f0e3211b8 to your computer and use it in GitHub Desktop.

Select an option

Save dniku/7b67e4db822f0e3211b8 to your computer and use it in GitHub Desktop.
Dumping MNIST digits as images in a separate folder
from __future__ import division, print_function
import os
import cv2
import numpy as np
from sklearn.datasets import fetch_mldata
from sklearn.cross_validation import train_test_split
def load_digits_full(data_home=None):
if not data_home:
data_home = os.getenv('DATA_HOME')
if not data_home:
data_home = '.'
mnist = fetch_mldata('MNIST original', data_home=data_home)
images = mnist.data.reshape((-1, 28, 28))
labels = mnist.target.astype(np.uint8)
return images, labels
def preprocess(image):
return image
# target_size = (15, 15)
# crop = image[4:24, 4:24]
# resize = cv2.resize(crop, target_size, interpolation=cv2.INTER_LINEAR)
# return resize
if __name__ == '__main__':
root_dir = 'T:\\'
dataset_name = 'mnist_28x28'
dataset_dir = os.path.join(root_dir, dataset_name)
if not os.path.exists(dataset_dir):
os.makedirs(dataset_dir)
images, labels = load_digits_full()
print(images[0].dtype)
exit()
images = images.reshape(-1, 28*28)
images_train, images_test, labels_train, labels_test = train_test_split(images, labels, train_size=60000, random_state=37)
images_train = images_train.reshape(-1, 28, 28)
images_test = images_test.reshape(-1, 28, 28)
for images_dataset, labels_dataset, dataset_label in [(images_train, labels_train, 'train'), (images_test, labels_test, 'test')]:
label_dir = os.path.join(dataset_dir, dataset_label)
cnt = [0] * 10
if not os.path.exists(label_dir):
os.makedirs(label_dir)
for image, label in zip(images_dataset, labels_dataset):
prep = preprocess(image)
imname = '%d_%04d.png' % (label, cnt[label])
impath = os.path.join(label_dir, imname)
cv2.imwrite(impath, prep)
cnt[label] += 1
index_path = os.path.join(dataset_dir, '%s_index.txt' % dataset_label)
with open(index_path, 'w') as f:
for fname in os.listdir(label_dir):
# path = os.path.join(dataset_name, dataset_label, fname)
path = '/'.join((dataset_name, dataset_label, fname))
label = fname[0]
f.write('%s %s\n' % (path, label))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment