Skip to content

Instantly share code, notes, and snippets.

View Tathagatd96's full-sized avatar

Tathagat Dasgupta Tathagatd96

View GitHub Profile
###Compiling the training and testing functions
train_fn = theano.function([input_var, target_var], loss, updates=updates)
test_prediction = lasagne.layers.get_output(net['out'], deterministic=True)
test_loss = lasagne.objectives.categorical_crossentropy(test_prediction,
target_var)
test_loss = test_loss.mean()
test_acc = T.mean(T.eq(T.argmax(test_prediction, axis=1), target_var),
dtype=theano.config.floatX)
#Get the update rule for Stochastic Gradient Descent
params = lasagne.layers.get_all_params(net['out'], trainable=True)
updates=lasagne.updates.adam(loss,params)
###Defining the cost function and the update rule
#Define hyperparameters. These could also be symbolic variables
lr = 1e-2
weight_decay = 1e-5
#Loss function: mean cross-entropy
prediction = lasagne.layers.get_output(net['out'])
loss = lasagne.objectives.categorical_crossentropy(prediction, target_var)
loss = loss.mean()
#Input layer:
net['data'] = lasagne.layers.InputLayer(data_size, input_var=input_var)
#Convolution + Pooling + Normalization
net['conv1'] = lasagne.layers.Conv2DLayer(net['data'], num_filters=6, filter_size=3)
net['pool1'] = lasagne.layers.Pool2DLayer(net['conv1'], pool_size=2)
net['conv2'] = lasagne.layers.Conv2DLayer(net['pool1'], num_filters=10, filter_size=4)
net['pool2'] = lasagne.layers.Pool2DLayer(net['conv2'], pool_size=2)
#Conv Net Structure
batch_size=100
output_size=10
data_size=(None,1,28,28)
input_var = T.tensor4(name='inputs')
target_var =T.ivector(name='targets')
net = {}
def load_dataset():
url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz'
filename = 'mnist.pkl.gz'
if not os.path.exists(filename):
print("Downloading MNIST dataset...")
urlretrieve(url, filename)
with gzip.open(filename, 'rb') as f:
data = pickle.load(f)
X_train, y_train = data[0]
X_val, y_val = data[1]
Classifier Accuracy:
0.834886817577
'God is love'=>soc.religion.christian
'OpenGL on the GPU is fast'=>comp.graphics
(2257, 35788)
4690
(2257, 35788)