Skip to content

Instantly share code, notes, and snippets.

@casperkaae
Last active March 22, 2017 20:38
Show Gist options
  • Save casperkaae/97962bba200dc04e40e0b43eb1138504 to your computer and use it in GitHub Desktop.
Save casperkaae/97962bba200dc04e40e0b43eb1138504 to your computer and use it in GitHub Desktop.
tests.py
import numpy as np
import torch
from torch.autograd import Variable
def dynamic_avg_pooling_test():
from functions.dynamic_avg_pooling import dynamic_avg_pooling
batch_size = 7 #crashes if below 8
nc = 32
f = Variable(torch.rand(batch_size,nc,10), requires_grad=True)
z = Variable(torch.rand(batch_size,nc,10), requires_grad=True)
hinit = Variable(torch.rand(batch_size,nc), requires_grad=True)
h = dynamic_avg_pooling(f, z, hinit)
f_expect = Variable(f.data.clone(), requires_grad=True)
z_expect = Variable(z.data.clone(), requires_grad=True)
hinit_expect = Variable(hinit.data.clone(), requires_grad=True)
h_list = []
h_prev = hinit_expect
for t in range(10):
ht = h_prev*f_expect[:,:,t] + z_expect[:,:,t]
h_list.append(ht.resize(batch_size,nc,1))
h_prev = ht
h_expect = torch.cat(h_list,2)
assert np.all(np.isclose( h.data.numpy(), h_expect.data.numpy() ))
loss_expect = h_expect.sum()
loss_expect.backward()
loss = h.sum()
loss.backward()
assert np.all(np.isclose( z.grad.data.numpy(), z_expect.grad.data.numpy() ) )
assert np.all(np.isclose( hinit.grad.data.numpy(), hinit_expect.grad.data.numpy() ) )
assert np.all(np.isclose( f.grad.data.numpy(), f_expect.grad.data.numpy() ) )
if torch.cuda.is_available():
print('testing GPU')
f_gpu = Variable(f.data.cuda(), requires_grad=True)
z_gpu = Variable(z.data.cuda(), requires_grad=True)
hinit_gpu = Variable(hinit.data.cuda(), requires_grad=True)
h_gpu = dynamic_avg_pooling(f_gpu, z_gpu, hinit_gpu)
assert np.all(np.isclose( h.data.numpy(), h_gpu.data.cpu().numpy() ))
loss = h_gpu.sum()
loss.backward()
assert np.all(np.isclose( z.grad.data.numpy(), z_gpu.grad.cpu().data.numpy() ) )
assert np.all(np.isclose( hinit.grad.data.numpy(), hinit_gpu.grad.cpu().data.numpy() ) )
assert np.all(np.isclose( f.grad.data.numpy(), f_gpu.grad.cpu().data.numpy() ) )
print('dynamic_avg_pooling_test PASSED!')
if __name__ == "__main__":
dynamic_avg_pooling_test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment