Skip to content

Instantly share code, notes, and snippets.

@ducha-aiki
Created March 5, 2018 12:36
Show Gist options
  • Save ducha-aiki/9f457f80298c52aa65819fe235f9cec1 to your computer and use it in GitHub Desktop.
Save ducha-aiki/9f457f80298c52aa65819fe235f9cec1 to your computer and use it in GitHub Desktop.
Batched version of grid sampling for saving memory
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def batched_grid_apply(img, grid, batch_size):
n_patches = len(grid)
if n_patches > batch_size:
bs = batch_size
n_batches = n_patches / bs + 1
for batch_idx in range(n_batches):
st = batch_idx * bs
if batch_idx == n_batches - 1:
if (batch_idx + 1) * bs > n_patches:
end = n_patches
else:
end = (batch_idx + 1) * bs
else:
end = (batch_idx + 1) * bs
if st >= end:
continue
if batch_idx == 0:
first_batch_out = F.grid_sample(img.expand(end - st, img.size(1), img.size(2), img.size(3)), grid[st:end, :,:,:])# kwargs)
out_size = torch.Size([n_patches] + list(first_batch_out.size()[1:]))
out = Variable(torch.zeros(out_size));
if img.is_cuda:
out = out.cuda()
out[st:end] = first_batch_out
else:
out[st:end,:,:] = F.grid_sample(img.expand(end - st, img.size(1), img.size(2), img.size(3)), grid[st:end, :,:,:])
return out
else:
return F.grid_sample(img.expand(grid.size(0), img.size(1), img.size(2), img.size(3)), grid)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment