Skip to content

Instantly share code, notes, and snippets.

@pranjalAI
Created October 27, 2020 10:08
Show Gist options
  • Save pranjalAI/1602f6036f3a3a9fb8fad4407a5feaef to your computer and use it in GitHub Desktop.
Save pranjalAI/1602f6036f3a3a9fb8fad4407a5feaef to your computer and use it in GitHub Desktop.
class ResidualSep(nn.Module):
def __init__(self, channels, dilation=1):
super().__init__()
self.blocks = nn.Sequential(
nn.ReLU(),
nn.ReflectionPad2d(dilation),
nn.Conv2d(channels, channels, kernel_size=3, stride=1,
padding=0, dilation=dilation,
groups=channels, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, kernel_size=1, stride=1,
padding=0, bias=False),
nn.BatchNorm2d(channels)
)
def forward(self, x):
return x + self.blocks(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment