Skip to content

Instantly share code, notes, and snippets.

@pranjalAI
Created October 17, 2020 17:27
Show Gist options
  • Save pranjalAI/31a6d14f4c9264cca4bb12e0d6c2f5ef to your computer and use it in GitHub Desktop.
Save pranjalAI/31a6d14f4c9264cca4bb12e0d6c2f5ef to your computer and use it in GitHub Desktop.
class ResidualSep(nn.Module):
def __init__(self, channels, dilation=1):
super().__init__()
self.blocks = nn.Sequential(
nn.ReLU(),
nn.ReflectionPad2d(dilation),
nn.Conv2d(channels, channels, kernel_size=3, stride=1,
padding=0, dilation=dilation,
groups=channels, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, kernel_size=1, stride=1,
padding=0, bias=False),
nn.BatchNorm2d(channels)
)
def forward(self, x):
return x + self.blocks(x)
class ResidualHourglass(nn.Module):
def __init__(self, channels, mult=0.5):
super().__init__()
hidden_channels = int(channels * mult)
self.blocks = nn.Sequential(
nn.ReLU(),
# Downsample
nn.ReflectionPad2d(1),
nn.Conv2d(channels, hidden_channels, kernel_size=3, stride=2,
padding=0, dilation=1,
groups=1, bias=False),
nn.BatchNorm2d(hidden_channels),
# Bottleneck
ResidualSep(channels=hidden_channels, dilation=1),
ResidualSep(channels=hidden_channels, dilation=2),
ResidualSep(channels=hidden_channels, dilation=1),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(hidden_channels, channels, kernel_size=3, stride=1,
padding=0, dilation=1,
groups=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True),
# Upsample
nn.ConvTranspose2d(channels, channels, kernel_size=2, stride=2,
padding=0, groups=1, bias=True),
nn.BatchNorm2d(channels)
)
def forward(self, x):
return x + self.blocks(x)
class TransformerNet(torch.nn.Module):
def __init__(self, width=8):
super().__init__()
self.blocks = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(3, width, kernel_size=3, stride=1, padding=0, bias=False),
nn.BatchNorm2d(width, affine=True),
ResidualHourglass(channels=width),
ResidualHourglass(channels=width),
ResidualSep(channels=width, dilation=1),
nn.ReLU(inplace=True),
nn.Conv2d(width, 3, kernel_size=3, stride=1, padding=1, bias=True)
)
# Normalization
self.blocks[1].weight.data /= 127.5
self.blocks[-1].weight.data *= 127.5 / 8
self.blocks[-1].bias.data.fill_(127.5)
def forward(self, x):
return self.blocks(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment