Created
February 12, 2015 08:53
-
-
Save nubela/fa3439a47a378da8aced to your computer and use it in GitHub Desktop.
MNIST/Pylearn2 Simple 2-layered MNIST runner
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import struct | |
import numpy as np | |
from pylearn2.utils import serial | |
import theano | |
def read_mnist_images(): | |
path = "mnist/" | |
label_path = path + 't10k-images-idx3-ubyte' | |
dtype = 'float32' | |
with open(label_path, 'rb') as f: | |
magic, number, rows, cols = struct.unpack('>iiii', f.read(16)) | |
array = np.fromfile(f, dtype='uint8').reshape((number, rows, cols)) | |
dtype = np.dtype(dtype) | |
array = array.astype(dtype) | |
return array | |
def get_testing_input(): | |
axes = ['b', 0, 1, 'c'] | |
topo_view = read_mnist_images() | |
m, r, c = topo_view.shape | |
topo_view = topo_view.reshape(m, r, c, 1) | |
default = ('b', 0, 1, 'c') | |
return topo_view.transpose(*[default.index(axis) for axis in axes]) | |
def get_testing_labels(): | |
path = "mnist/" | |
label_path = path + 't10k-labels-idx1-ubyte' | |
with open(label_path, 'rb') as f: | |
_, number = struct.unpack('>ii', | |
f.read(8)) # gotta do this line first to move the point to AFTER the first 8 bytes | |
label_array = np.fromfile(f, dtype='uint8') | |
label_array = np.atleast_2d(label_array).transpose() | |
return label_array | |
if __name__ == "__main__": | |
#load model from file | |
neural_network_model = serial.load("./model.saved") | |
# test trained model | |
inputs = get_testing_input() | |
targets = get_testing_labels() | |
sample_input = [inputs[0].ravel()] #flatten the array | |
#make it into an input that the model can accept | |
sample_input = np.array(sample_input) | |
dtype = np.dtype("float32") | |
sample_input = sample_input.astype(dtype) | |
print neural_network_model.fprop(theano.shared(sample_input)).eval() #the first input should resovle to 7. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment