Skip to content

Instantly share code, notes, and snippets.

@suryabhupa
Created July 30, 2016 00:47
Show Gist options
  • Save suryabhupa/df7c8dc6b63bbfec4f37b2d67332c304 to your computer and use it in GitHub Desktop.
Save suryabhupa/df7c8dc6b63bbfec4f37b2d67332c304 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
"""
Usage example employing Lasagne and TT-layer on the MNIST dataset.
This is a simplified version of
https://github.com/Lasagne/Lasagne/blob/master/examples/mnist.py
with using the TT-layer.
"""
from __future__ import print_function
import sys
import os
import time
import numpy as np
np.set_printoptions(threshold=np.nan)
import theano
import theano.tensor as T
import lasagne
from ttlayer import TTLayer
np.random.seed(1234)
size = 256
def load_dataset():
X = np.random.rand(size)
X_train = np.array([X]*1)
return X_train
def build_simple_mlp(input_var=None):
l_in = lasagne.layers.InputLayer(shape=(None, size), input_var=input_var)
# Another 16-unit layer:
l_hid1 = lasagne.layers.DenseLayer(
l_in, num_units=size, W = np.random.rand(size, size),
nonlinearity=lasagne.nonlinearities.identity)
return l_hid1
def full_main(num_epochs=500):
np.random.seed(1234)
X_train = load_dataset()
input_var = T.matrix('inputs')
lhid_1 = build_simple_mlp(input_var)
output = lasagne.layers.get_output(lhid_1)
print_fn = theano.function([input_var], output)
res = print_fn(X_train).reshape(size)
print("\nFULL:\n", res)
return res
def build_mlp(input_var=None):
l_in = lasagne.layers.InputLayer(shape=(None, size), input_var=input_var)
l_hid1 = TTLayer(l_in, tt_input_shape=[4, 4, 4, 4], tt_output_shape=[4, 4, 4, 4],
tt_ranks=[1, 10, 100, 10, 1], nonlinearity=lasagne.nonlinearities.identity)
return l_hid1
def main(num_epochs=500):
np.random.seed(1234)
X_train = load_dataset()
input_var = T.matrix('inputs')
lhid_1 = build_mlp(input_var)
output = lasagne.layers.get_output(lhid_1)
print_fn = theano.function([input_var], output)
res = print_fn(X_train).reshape(size)
print("\nTT:\n", res)
return res
if __name__ == '__main__':
np.random.seed(1234)
X_tt = main()
X_fc = full_main()
print('\nNORM:', np.linalg.norm(X_tt - X_fc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment