Created
February 28, 2022 14:25
-
-
Save torridgristle/cbc46cc94b8af7190d22dc0be3ab9a64 to your computer and use it in GitHub Desktop.
Sobel and Farid edge detection modules for PyTorch. Option for using Scharr kernel instead of Sobel is enabled by default and has better rotational symmetry.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class Sobel(nn.Module): | |
def __init__(self,structure=False,scharr=True, padding_mode='reflect'): | |
super().__init__() | |
self.structure = structure | |
self.padding_mode = padding_mode | |
if scharr == True: | |
self.kernel = torch.outer(torch.tensor([3., 10., 3.]) / 16,torch.tensor([1., 0., -1.]),).reshape(1,1,3,3).to(device) | |
else: | |
self.kernel = torch.tensor([[1.0, 2.0, 1.0],[0.00, 0.0, 0.00],[-1.0, -2.0, -1.0]]).reshape(1,1,3,3).to(device) | |
def forward(self, x): | |
x_pad = F.pad(x,[1,1,1,1],self.padding_mode) | |
x_x = F.conv2d(x_pad, self.kernel.expand(x.shape[1],1,3,3), groups=x.shape[1]) | |
x_y = F.conv2d(x_pad, self.kernel.permute(0,1,3,2).expand(x.shape[1],1,3,3), groups=x.shape[1]) | |
if self.structure == True: | |
x = torch.cat([x_x*x_x.abs(),x_y*x_y.abs(),x_x*x_y],1) | |
return x | |
else: | |
x = torch.cat([x_x,x_y,],1) | |
return x | |
class Farid(nn.Module): | |
def __init__(self, padding_mode='reflect'): | |
super().__init__() | |
self.padding_mode = padding_mode | |
p = torch.tensor([[0.0376593171958126, 0.249153396177344, 0.426374573253687, | |
0.249153396177344, 0.0376593171958126]]) | |
d1 = torch.tensor([[0.109603762960254, 0.276690988455557, 0, -0.276690988455557, | |
-0.109603762960254]]) | |
self.kernel = (d1.T * p).unsqueeze(0).unsqueeze(0).to(device) | |
def forward(self, x): | |
x_pad = F.pad(x,[2,2,2,2],self.padding_mode) | |
x_x = F.conv2d(x_pad, self.kernel.expand(x.shape[1],-1,-1,-1), groups=x.shape[1]) | |
x_y = F.conv2d(x_pad, self.kernel.permute(0,1,3,2).expand(x.shape[1],-1,-1,-1), groups=x.shape[1]) | |
x = torch.cat([x_x,x_y,],1) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment