Created
December 28, 2020 19:37
-
-
Save patrickmineault/21b8d78f423ac8ea4b006f9ec1a1a1a7 to your computer and use it in GitHub Desktop.
Downsample a stack of 2d images in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def downsample_2d(X, sz): | |
""" | |
Downsamples a stack of square images. | |
Args: | |
X: a stack of images (batch, channels, ny, ny). | |
sz: the desired size of images. | |
Returns: | |
The downsampled images, a tensor of shape (batch, channel, sz, sz) | |
""" | |
kernel = torch.tensor([[.25, .5, .25], | |
[.5, 1, .5], | |
[.25, .5, .25]], device=X.device).reshape(1, 1, 3, 3) | |
kernel = kernel.repeat((X.shape[1], 1, 1, 1)) | |
while sz < X.shape[-1] / 2: | |
# Downsample by a factor 2 with smoothing | |
mask = torch.ones(1, *X.shape[1:]) | |
mask = F.conv2d(mask, kernel, groups=X.shape[1], stride=2, padding=1) | |
X = F.conv2d(X, kernel, groups=X.shape[1], stride=2, padding=1) | |
# Normalize the edges and corners. | |
X = X = X / mask | |
return F.interpolate(X, size=sz, mode='bilinear') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment