Skip to content

Instantly share code, notes, and snippets.

@danstowell
Created January 11, 2016 12:00
Show Gist options
  • Save danstowell/91a2303e9362933bba72 to your computer and use it in GitHub Desktop.
Save danstowell/91a2303e9362933bba72 to your computer and use it in GitHub Desktop.
A version of https://gist.github.com/danstowell/192ad65527965086693d exhibiting buggy shape behaviour?
import numpy as np
from numpy import float32
import theano
import theano.tensor as T
import lasagne
######################################################
# Test of reversing maxpooling
# create a network that applies transformation -> maxpool -> untransformation -> unmaxpool
# (the transformations are to satisfy myself that it's not happening by accident)
input_var = T.tensor4('X')
datashape = (2,3,2,5)
datashape = (16,1,132,160)
network = lasagne.layers.InputLayer(datashape, input_var)
network = lasagne.layers.ExpressionLayer(network, lambda x: (x - 1) * 0.5)
network = lasagne.layers.MaxPool1DLayer(network, pool_size=3, stride=2)
maxpool_layer = network # store a pointer to this one
network = lasagne.layers.ExpressionLayer(network, lambda x: (x * 2) + 1)
network = lasagne.layers.InverseLayer(network, maxpool_layer)
# create a dataset
input = np.zeros(datashape, dtype=float32)
datashape_flat = np.product(datashape)
for _ in range(10):
input.flat[np.random.randint(datashape_flat)] = np.random.randint(5) * 2 + 1 # odd numbers, we can be sure will be floating-point identical despite our transformations
# process the dataset, storing the latents and the output
output_fn = theano.function([input_var], lasagne.layers.get_output(network))
latents_fn = theano.function([input_var], lasagne.layers.get_output(maxpool_layer))
output = output_fn(input)
latents = latents_fn(input)
# print the latents, the output, and whether they're identical
print("input (shape %s)" % (str(input.shape)))
#print input
print(" shape of latents is claimed to be %s" % str(lasagne.layers.get_output_shape(maxpool_layer)))
print("latents (shape %s)" % (str(latents.shape)))
#print latents
print("output (shape %s)" % (str(output.shape)))
#print output
#print("error (input-output):")
#print input-output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment