Skip to content

Instantly share code, notes, and snippets.

@twmht
Last active July 21, 2016 15:53
Show Gist options
  • Select an option

  • Save twmht/d9ee32d9a14ffd8145ebb46e02549ea5 to your computer and use it in GitHub Desktop.

Select an option

Save twmht/d9ee32d9a14ffd8145ebb46e02549ea5 to your computer and use it in GitHub Desktop.
much faster
import os
import struct
from array import array
import numpy as np
class MNIST(object):
def __init__(self, path='.'):
self.path = path
self.test_img_fname = 't10k-images-idx3-ubyte'
self.test_lbl_fname = 't10k-labels-idx1-ubyte'
self.train_img_fname = 'train-images-idx3-ubyte'
self.train_lbl_fname = 'train-labels-idx1-ubyte'
self.test_images = []
self.test_labels = []
self.train_images = []
self.train_labels = []
def load_testing(self):
ims, labels = self.load(os.path.join(self.path, self.test_img_fname),
os.path.join(self.path, self.test_lbl_fname))
self.test_images = ims
self.test_labels = labels
return ims, labels
def load_training(self):
return self.load(os.path.join(self.path, self.train_img_fname),
os.path.join(self.path, self.train_lbl_fname))
@classmethod
def load(cls, path_img, path_lbl):
with open(path_lbl, 'rb') as file:
magic, size = struct.unpack(">II", file.read(8))
if magic != 2049:
raise ValueError('Magic number mismatch, expected 2049,'
'got {}'.format(magic))
labels = array("B", file.read())
with open(path_img, 'rb') as file:
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
if magic != 2051:
raise ValueError('Magic number mismatch, expected 2051,'
'got {}'.format(magic))
for i in xrange(size):
print i
image_data = array("B", file.read(rows * cols))
yield (image_data, labels[i])
@classmethod
def display(cls, img, width=28, threshold=200):
render = ''
for i in range(len(img)):
if i % width == 0:
render += '\n'
if img[i] > threshold:
render += '@'
else:
render += '.'
return render
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment