Skip to content

Instantly share code, notes, and snippets.

@el3ment
Created July 28, 2018 01:24
Show Gist options
  • Save el3ment/a00d7ca33248721bb31d616757052982 to your computer and use it in GitHub Desktop.
Save el3ment/a00d7ca33248721bb31d616757052982 to your computer and use it in GitHub Desktop.
t = model.swizzle(model.swizzle(y)).clone()
t[0, :, 0, 0] = torch.arange(0, 16)
def even_odd(x, dim=1):
evens, odds = x.view(x.size()), x.view(x.size())
odd_offset = np.asscalar(np.prod(x.size()[dim + 1:]))
shape, stride = list(x.size()), list(x.stride())
shape[dim] = shape[dim] // 2
stride[dim] *= 2
evens.set_(x.storage(), x.storage_offset(), shape, stride)
odds.set_(x.storage(), x.storage_offset() + odd_offset, shape, stride)
return evens, odds
t = torch.randn([128, 64, 64, 64])
evens, odds = even_odd(t)
evens[0, :, 0, 0], odds[0, :, 0, 0]
with torch.autograd.profiler.profile(use_cuda=True) as prof:
for i in range(10):
evens, odds = even_odd(t)
print(prof.key_averages().table('cpu_time_total'))
with torch.autograd.profiler.profile(use_cuda=True) as prof:
for i in range(10):
tevens = t[:,0::2,:,:]
todds = t[:,1::2,:,:]
print(prof.key_averages().table('cpu_time_total'))
print('even error', torch.abs(tevens - evens).sum())
print('odd error', torch.abs(todds - odds).sum())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment