Last active
March 22, 2017 20:38
-
-
Save casperkaae/97962bba200dc04e40e0b43eb1138504 to your computer and use it in GitHub Desktop.
tests.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from torch.autograd import Variable | |
def dynamic_avg_pooling_test(): | |
from functions.dynamic_avg_pooling import dynamic_avg_pooling | |
batch_size = 7 #crashes if below 8 | |
nc = 32 | |
f = Variable(torch.rand(batch_size,nc,10), requires_grad=True) | |
z = Variable(torch.rand(batch_size,nc,10), requires_grad=True) | |
hinit = Variable(torch.rand(batch_size,nc), requires_grad=True) | |
h = dynamic_avg_pooling(f, z, hinit) | |
f_expect = Variable(f.data.clone(), requires_grad=True) | |
z_expect = Variable(z.data.clone(), requires_grad=True) | |
hinit_expect = Variable(hinit.data.clone(), requires_grad=True) | |
h_list = [] | |
h_prev = hinit_expect | |
for t in range(10): | |
ht = h_prev*f_expect[:,:,t] + z_expect[:,:,t] | |
h_list.append(ht.resize(batch_size,nc,1)) | |
h_prev = ht | |
h_expect = torch.cat(h_list,2) | |
assert np.all(np.isclose( h.data.numpy(), h_expect.data.numpy() )) | |
loss_expect = h_expect.sum() | |
loss_expect.backward() | |
loss = h.sum() | |
loss.backward() | |
assert np.all(np.isclose( z.grad.data.numpy(), z_expect.grad.data.numpy() ) ) | |
assert np.all(np.isclose( hinit.grad.data.numpy(), hinit_expect.grad.data.numpy() ) ) | |
assert np.all(np.isclose( f.grad.data.numpy(), f_expect.grad.data.numpy() ) ) | |
if torch.cuda.is_available(): | |
print('testing GPU') | |
f_gpu = Variable(f.data.cuda(), requires_grad=True) | |
z_gpu = Variable(z.data.cuda(), requires_grad=True) | |
hinit_gpu = Variable(hinit.data.cuda(), requires_grad=True) | |
h_gpu = dynamic_avg_pooling(f_gpu, z_gpu, hinit_gpu) | |
assert np.all(np.isclose( h.data.numpy(), h_gpu.data.cpu().numpy() )) | |
loss = h_gpu.sum() | |
loss.backward() | |
assert np.all(np.isclose( z.grad.data.numpy(), z_gpu.grad.cpu().data.numpy() ) ) | |
assert np.all(np.isclose( hinit.grad.data.numpy(), hinit_gpu.grad.cpu().data.numpy() ) ) | |
assert np.all(np.isclose( f.grad.data.numpy(), f_gpu.grad.cpu().data.numpy() ) ) | |
print('dynamic_avg_pooling_test PASSED!') | |
if __name__ == "__main__": | |
dynamic_avg_pooling_test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment