Last active
January 11, 2016 13:14
-
-
Save danstowell/192ad65527965086693d to your computer and use it in GitHub Desktop.
Test to show Lasagne/Theano automagically doing one-hot upsampling to reverse maxpooling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numpy import float32 | |
import theano | |
import theano.tensor as T | |
import lasagne | |
###################################################### | |
# Test of reversing maxpooling | |
# create a network that applies transformation -> maxpool -> untransformation -> unmaxpool | |
# (the transformations are to satisfy myself that it's not happening by accident) | |
input_var = T.tensor4('X') | |
datashape = (2,3,4,5) | |
network = lasagne.layers.InputLayer(datashape, input_var) | |
network = lasagne.layers.ExpressionLayer(network, lambda x: (x - 1) * 0.5) | |
network = lasagne.layers.MaxPool2DLayer(network, pool_size=3, stride=2) | |
maxpool_layer = network # store a pointer to this one | |
network = lasagne.layers.ExpressionLayer(network, lambda x: (x * 2) + 1) | |
network = lasagne.layers.InverseLayer(network, maxpool_layer) | |
# create a dataset | |
input = np.zeros(datashape, dtype=float32) | |
datashape_flat = np.product(datashape) | |
for _ in range(10): | |
input.flat[np.random.randint(datashape_flat)] = np.random.randint(5) * 2 + 1 # odd numbers, we can be sure will be floating-point identical despite our transformations | |
# process the dataset, storing the latents and the output | |
output_fn = theano.function([input_var], lasagne.layers.get_output(network)) | |
latents_fn = theano.function([input_var], lasagne.layers.get_output(maxpool_layer)) | |
output = output_fn(input) | |
latents = latents_fn(input) | |
# print the latents, the output, and whether they're identical | |
print("input:") | |
print input | |
print("latents (shape %s)" % (str(latents.shape))) | |
#print latents | |
print("output:") | |
print output | |
#print("error (input-output):") | |
#print input-output | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment