Last active
August 29, 2015 14:19
-
-
Save dniku/7b67e4db822f0e3211b8 to your computer and use it in GitHub Desktop.
Dumping MNIST digits as images in a separate folder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import division, print_function | |
| import os | |
| import cv2 | |
| import numpy as np | |
| from sklearn.datasets import fetch_mldata | |
| from sklearn.cross_validation import train_test_split | |
| def load_digits_full(data_home=None): | |
| if not data_home: | |
| data_home = os.getenv('DATA_HOME') | |
| if not data_home: | |
| data_home = '.' | |
| mnist = fetch_mldata('MNIST original', data_home=data_home) | |
| images = mnist.data.reshape((-1, 28, 28)) | |
| labels = mnist.target.astype(np.uint8) | |
| return images, labels | |
| def preprocess(image): | |
| return image | |
| # target_size = (15, 15) | |
| # crop = image[4:24, 4:24] | |
| # resize = cv2.resize(crop, target_size, interpolation=cv2.INTER_LINEAR) | |
| # return resize | |
| if __name__ == '__main__': | |
| root_dir = 'T:\\' | |
| dataset_name = 'mnist_28x28' | |
| dataset_dir = os.path.join(root_dir, dataset_name) | |
| if not os.path.exists(dataset_dir): | |
| os.makedirs(dataset_dir) | |
| images, labels = load_digits_full() | |
| print(images[0].dtype) | |
| exit() | |
| images = images.reshape(-1, 28*28) | |
| images_train, images_test, labels_train, labels_test = train_test_split(images, labels, train_size=60000, random_state=37) | |
| images_train = images_train.reshape(-1, 28, 28) | |
| images_test = images_test.reshape(-1, 28, 28) | |
| for images_dataset, labels_dataset, dataset_label in [(images_train, labels_train, 'train'), (images_test, labels_test, 'test')]: | |
| label_dir = os.path.join(dataset_dir, dataset_label) | |
| cnt = [0] * 10 | |
| if not os.path.exists(label_dir): | |
| os.makedirs(label_dir) | |
| for image, label in zip(images_dataset, labels_dataset): | |
| prep = preprocess(image) | |
| imname = '%d_%04d.png' % (label, cnt[label]) | |
| impath = os.path.join(label_dir, imname) | |
| cv2.imwrite(impath, prep) | |
| cnt[label] += 1 | |
| index_path = os.path.join(dataset_dir, '%s_index.txt' % dataset_label) | |
| with open(index_path, 'w') as f: | |
| for fname in os.listdir(label_dir): | |
| # path = os.path.join(dataset_name, dataset_label, fname) | |
| path = '/'.join((dataset_name, dataset_label, fname)) | |
| label = fname[0] | |
| f.write('%s %s\n' % (path, label)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment