This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
###Compiling the training and testing functions | |
train_fn = theano.function([input_var, target_var], loss, updates=updates) | |
test_prediction = lasagne.layers.get_output(net['out'], deterministic=True) | |
test_loss = lasagne.objectives.categorical_crossentropy(test_prediction, | |
target_var) | |
test_loss = test_loss.mean() | |
test_acc = T.mean(T.eq(T.argmax(test_prediction, axis=1), target_var), | |
dtype=theano.config.floatX) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Get the update rule for Stochastic Gradient Descent | |
params = lasagne.layers.get_all_params(net['out'], trainable=True) | |
updates=lasagne.updates.adam(loss,params) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
###Defining the cost function and the update rule | |
#Define hyperparameters. These could also be symbolic variables | |
lr = 1e-2 | |
weight_decay = 1e-5 | |
#Loss function: mean cross-entropy | |
prediction = lasagne.layers.get_output(net['out']) | |
loss = lasagne.objectives.categorical_crossentropy(prediction, target_var) | |
loss = loss.mean() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Input layer: | |
net['data'] = lasagne.layers.InputLayer(data_size, input_var=input_var) | |
#Convolution + Pooling + Normalization | |
net['conv1'] = lasagne.layers.Conv2DLayer(net['data'], num_filters=6, filter_size=3) | |
net['pool1'] = lasagne.layers.Pool2DLayer(net['conv1'], pool_size=2) | |
net['conv2'] = lasagne.layers.Conv2DLayer(net['pool1'], num_filters=10, filter_size=4) | |
net['pool2'] = lasagne.layers.Pool2DLayer(net['conv2'], pool_size=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Conv Net Structure | |
batch_size=100 | |
output_size=10 | |
data_size=(None,1,28,28) | |
input_var = T.tensor4(name='inputs') | |
target_var =T.ivector(name='targets') | |
net = {} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_dataset(): | |
url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz' | |
filename = 'mnist.pkl.gz' | |
if not os.path.exists(filename): | |
print("Downloading MNIST dataset...") | |
urlretrieve(url, filename) | |
with gzip.open(filename, 'rb') as f: | |
data = pickle.load(f) | |
X_train, y_train = data[0] | |
X_val, y_val = data[1] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
SVM Accuracy: | |
0.912782956059 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Classifier Accuracy: | |
0.834886817577 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'God is love'=>soc.religion.christian | |
'OpenGL on the GPU is fast'=>comp.graphics |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(2257, 35788) | |
4690 | |
(2257, 35788) |