Last active
September 6, 2015 21:25
-
-
Save domluna/123c3075a4d48ff0ae76 to your computer and use it in GitHub Desktop.
convnet with cgt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import print_function, absolute_import | |
| import cgt | |
| from cgt import nn | |
| from cgt.distributions import categorical | |
| import numpy as np | |
| from load import load_mnist | |
| import time | |
| epochs = 10 | |
| batch_size = 128 | |
| Xtrain, Xtest, ytrain, ytest = load_mnist(onehot=False) | |
| # shuffle the data | |
| np.random.seed(42) | |
| sortinds = np.random.permutation(Xtrain.shape[0]) | |
| Xtrain = Xtrain[sortinds] | |
| ytrain = ytrain[sortinds] | |
| # reshape for convnet | |
| Xtrainimg = Xtrain.reshape(-1, 1, 28, 28) | |
| Xtestimg = Xtest.reshape(-1, 1, 28, 28) | |
| # Model: | |
| # Make it VGG-like | |
| # VGG nets have conv layers with 3x3 kernels and 1x1 padding & stride, max-pooling layers have 2x2 kernel and 2x2 strides. | |
| # | |
| # VGG is a large model so here well just do a small part of it. | |
| X = cgt.tensor4('X', fixed_shape=(None, 1, 28, 28)) | |
| y = cgt.vector('y', dtype='i8') | |
| conv1 = nn.rectify( | |
| nn.SpatialConvolution(1, 32, kernelshape=(3,3), stride=(1,1), pad=(1,1), weight_init=nn.IIDGaussian(std=.1))(X) | |
| ) | |
| pool1 = nn.max_pool_2d(conv1, kernelshape=(2,2), stride=(2,2)) | |
| conv2 = nn.rectify( | |
| nn.SpatialConvolution(32, 32, kernelshape=(3,3), stride=(1,1), pad=(1,1), weight_init=nn.IIDGaussian(std=.1))(pool1) | |
| ) | |
| pool2 = nn.max_pool_2d(conv2, kernelshape=(2,2), stride=(2,2)) | |
| d0, d1, d2, d3 = pool2.shape | |
| flat = pool2.reshape([d0, d1*d2*d3]) | |
| nfeats = cgt.infer_shape(flat)[1] | |
| probs = nn.softmax(nn.Affine(nfeats, 10)(flat)) | |
| cost = -categorical.loglik(y, probs).mean() | |
| y_preds = cgt.argmax(probs, axis=1) | |
| err = cgt.cast(cgt.not_equal(y, y_preds), cgt.floatX).mean() | |
| params = nn.get_parameters(cost) | |
| updates = nn.rmsprop(cost, params) # This time we'll do rmsprop | |
| # training function | |
| f = cgt.function(inputs=[X, y], outputs=[], updates=updates) | |
| # compute the cost and error | |
| cost_and_err = cgt.function(inputs=[X, y], outputs=[cost, err]) | |
| for i in xrange(epochs): | |
| t0 = time.time() | |
| for start in xrange(0, Xtrain.shape[0], batch_size): | |
| end = batch_size + start | |
| f(Xtrainimg[start:end], ytrain[start:end]) | |
| elapsed = time.time() - t0 | |
| costval, errval = cost_and_err(Xtestimg, ytest) | |
| print("Epoch {} took {}, test cost = {}, test error = {}".format(i, elapsed, costval, errval)) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import os | |
| datasets_dir = 'datasets/' | |
| def one_hot(x,n): | |
| if type(x) == list: | |
| x = np.array(x) | |
| x = x.flatten() | |
| o_h = np.zeros((len(x),n)) | |
| o_h[np.arange(len(x)),x] = 1 | |
| return o_h | |
| def load_mnist(ntrain=60000,ntest=10000,onehot=True): | |
| data_dir = os.path.join(datasets_dir,'mnist/') | |
| fd = open(os.path.join(data_dir,'train-images-idx3-ubyte')) | |
| loaded = np.fromfile(file=fd,dtype=np.uint8) | |
| trX = loaded[16:].reshape((60000,28*28)).astype(float) | |
| fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte')) | |
| loaded = np.fromfile(file=fd,dtype=np.uint8) | |
| trY = loaded[8:].reshape((60000)) | |
| fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte')) | |
| loaded = np.fromfile(file=fd,dtype=np.uint8) | |
| teX = loaded[16:].reshape((10000,28*28)).astype(float) | |
| fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte')) | |
| loaded = np.fromfile(file=fd,dtype=np.uint8) | |
| teY = loaded[8:].reshape((10000)) | |
| trX = trX/255. | |
| teX = teX/255. | |
| trX = trX[:ntrain] | |
| trY = trY[:ntrain] | |
| teX = teX[:ntest] | |
| teY = teY[:ntest] | |
| if onehot: | |
| trY = one_hot(trY, 10) | |
| teY = one_hot(teY, 10) | |
| else: | |
| trY = np.asarray(trY) | |
| teY = np.asarray(teY) | |
| return trX,teX,trY,teY |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment