Skip to content

Instantly share code, notes, and snippets.

@joao-timescale
Last active August 21, 2017 15:01
Show Gist options
  • Save joao-timescale/9e6a662a3361ab96f3787a42a320ef1a to your computer and use it in GitHub Desktop.
Save joao-timescale/9e6a662a3361ab96f3787a42a320ef1a to your computer and use it in GitHub Desktop.
Test spatial transformer gradients
import numpy as np
import theano
import theano.printing
import theano.tensor as T
from theano.tests import unittest_tools as utt
from theano.tensor.nnet.abstract_spatialtf import (AbstractSpatialTransformerGradIOp, AbstractSpatialTransformerGradTOp)
from theano.tensor.nnet.spatialtf import (spatialtf, spatialtf_cpu)
from theano.gpuarray import dnn
from theano.compile.debugmode import DebugMode
num_images = 1
num_channels = 1
height = 4
width = 4
scale_height = 1
scale_width = 1
# img = np.random.random((num_images, num_channels, height, width)).astype(theano.config.floatX)
img = np.asarray([[[[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 10., 11., 12.],
[13., 14., 15., 16.]]]], dtype=theano.config.floatX)
transform = [[1, 0, 0],
[0, 1, 0]]
theta = np.asarray(num_images * [transform], dtype=theano.config.floatX)
t_inp = T.tensor4('img')
t_theta = T.tensor3('theta')
t_dy = T.tensor4('dy')
dy = np.random.random(img.shape).astype(theano.config.floatX)
out = spatialtf(t_inp, t_theta, scale_height, scale_width)
out_mean = T.mean(out)
mean_gi = T.grad(out_mean, [t_inp])
mean_gt = T.grad(out_mean, [t_theta])
# Use DebugMode to test AbstractSpatialTransformerOp.debug_perform
gi_fn = theano.function([t_inp, t_theta], mean_gi, mode='DebugMode')
# apply_nodes = gi_fn.maker.fgraph.apply_nodes
gi_out = gi_fn(img, theta)
# print(gi_out)
gt_fn = theano.function([t_inp, t_theta], mean_gt, mode='DebugMode')
gt_out = gt_fn(img, theta)
print(gt_out)
def grad_functor(inputs, theta):
sptf_out = spatialtf(inputs, theta, scale_height, scale_width)
return sptf_out
utt.verify_grad(grad_functor, [img, theta])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment