Created
June 29, 2016 02:18
-
-
Save matsub/206a1dac75093d74d8ae2ab9c5a2ae35 to your computer and use it in GitHub Desktop.
A parser for MNIST handwritten digits dataset. see http://yann.lecun.com/exdb/mnist/.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import os | |
import struct | |
class Image: | |
def __init__(self, dir='./'): | |
self.train_files = { | |
'images': os.path.join(dir, 'train-images-idx3-ubyte'), | |
'labels': os.path.join(dir, 'train-labels-idx1-ubyte') | |
} | |
self.test_files = { | |
'images': os.path.join(dir, 't10k-images-idx3-ubyte'), | |
'labels': os.path.join(dir, 't10k-labels-idx1-ubyte') | |
} | |
@property | |
def train(self): | |
path = self.train_files | |
return self._get_dataset(path) | |
@property | |
def test(self): | |
path = self.test_files | |
return self._get_dataset(path) | |
def _get_dataset(self, path): | |
images = self._load_images(path['images']) | |
labels = self._load_labels(path['labels']) | |
for image, label in zip(images, labels): | |
yield image, label | |
def _load_images(self, fname): | |
f = open(fname, 'rb') | |
header = struct.unpack('>4i', f.read(16)) | |
magic, size, width, height = header | |
if magic != 2051: | |
raise RuntimeError("'%s' is not an MNIST image set." % fname) | |
chunk = width * height | |
for _ in range(size): | |
img = struct.unpack('>%dB' % chunk, f.read(chunk)) | |
yield img, width, height | |
f.close() | |
def _load_labels(self, fname): | |
f = open(fname, 'rb') | |
header = struct.unpack('>2i', f.read(8)) | |
magic, size = header | |
if magic != 2049: | |
raise RuntimeError("'%s' is not an MNIST label set." % fname) | |
for label in struct.unpack('>%dB' % size, f.read()): | |
yield label | |
f.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mnist import Image | |
dataset = Image('path/to/MNIST-Dataset') | |
for (img, width, height), label in dataset.train: | |
print(img, width, height, label) | |
for (img, width, height), label in dataset.test: | |
print(img, width, height, label) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment