Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active August 18, 2021 10:07
Show Gist options
  • Save vadimkantorov/c9b82a32d983971cab455f49e14b0c0e to your computer and use it in GitHub Desktop.
Save vadimkantorov/c9b82a32d983971cab455f49e14b0c0e to your computer and use it in GitHub Desktop.
Image gradient filters in PyTorch
# https://scikit-image.org/docs/dev/api/skimage.filters.html
import torch
import torch.nn.functional as F
def sobel_filter() -> '2133':
flipped_sobel_x = torch.tensor([
[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]
])
return torch.stack([flipped_sobel_x, flipped_sobel_x.t()]).unsqueeze(1)
def scharr_filter() -> '2133':
flipped_scharr_x = torch.tensor([
[-3, 0, 3 ],
[10, 0, 10],
[-3, 0, 3 ]
])
return torch.stack([flipped_scharr_x, flipped_scharr_x.t()]).unsqueeze(1)
def image_gradients(img : 'BCHW', kernel, mode = None) -> 'BC2HW':
components = F.conv2d(img.flatten(end_dim = -3).unsqueeze(1), kernel.to(dtype = img.dtype, device = img.device), padding = 1).unflatten(0, img.shape[:-2])
dx, dy = components.unbind(dim = -3)
magnitude = lambda dx, dy: (dy ** 2 + dx ** 2) ** 0.5
angle = lambda dx, dy: torch.atan2(dy, dx)
if mode == 'magnitude': return magnitude(dx, dy)
if mode == 'angle': return angle(dx, dy)
if mode == 'magnitude_angle': return torch.stack([magnitude(dx, dy), angle(dx, dy)], dim = -3)
return components
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment