Last active
August 21, 2017 15:01
-
-
Save joao-timescale/9e6a662a3361ab96f3787a42a320ef1a to your computer and use it in GitHub Desktop.
Test spatial transformer gradients
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import theano | |
import theano.printing | |
import theano.tensor as T | |
from theano.tests import unittest_tools as utt | |
from theano.tensor.nnet.abstract_spatialtf import (AbstractSpatialTransformerGradIOp, AbstractSpatialTransformerGradTOp) | |
from theano.tensor.nnet.spatialtf import (spatialtf, spatialtf_cpu) | |
from theano.gpuarray import dnn | |
from theano.compile.debugmode import DebugMode | |
num_images = 1 | |
num_channels = 1 | |
height = 4 | |
width = 4 | |
scale_height = 1 | |
scale_width = 1 | |
# img = np.random.random((num_images, num_channels, height, width)).astype(theano.config.floatX) | |
img = np.asarray([[[[1., 2., 3., 4.], | |
[5., 6., 7., 8.], | |
[9., 10., 11., 12.], | |
[13., 14., 15., 16.]]]], dtype=theano.config.floatX) | |
transform = [[1, 0, 0], | |
[0, 1, 0]] | |
theta = np.asarray(num_images * [transform], dtype=theano.config.floatX) | |
t_inp = T.tensor4('img') | |
t_theta = T.tensor3('theta') | |
t_dy = T.tensor4('dy') | |
dy = np.random.random(img.shape).astype(theano.config.floatX) | |
out = spatialtf(t_inp, t_theta, scale_height, scale_width) | |
out_mean = T.mean(out) | |
mean_gi = T.grad(out_mean, [t_inp]) | |
mean_gt = T.grad(out_mean, [t_theta]) | |
# Use DebugMode to test AbstractSpatialTransformerOp.debug_perform | |
gi_fn = theano.function([t_inp, t_theta], mean_gi, mode='DebugMode') | |
# apply_nodes = gi_fn.maker.fgraph.apply_nodes | |
gi_out = gi_fn(img, theta) | |
# print(gi_out) | |
gt_fn = theano.function([t_inp, t_theta], mean_gt, mode='DebugMode') | |
gt_out = gt_fn(img, theta) | |
print(gt_out) | |
def grad_functor(inputs, theta): | |
sptf_out = spatialtf(inputs, theta, scale_height, scale_width) | |
return sptf_out | |
utt.verify_grad(grad_functor, [img, theta]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment