Skip to content

Instantly share code, notes, and snippets.

@pranjalAI
Created October 27, 2020 10:34
Show Gist options
  • Save pranjalAI/d7adc76c34c8a9ce493b75a41180a91a to your computer and use it in GitHub Desktop.
Save pranjalAI/d7adc76c34c8a9ce493b75a41180a91a to your computer and use it in GitHub Desktop.
class TransformerNet(torch.nn.Module):
def __init__(self, width=8):
super().__init__()
self.blocks = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(3, width, kernel_size=3, stride=1, padding=0, bias=False),
nn.BatchNorm2d(width, affine=True),
ResidualHourglass(channels=width),
ResidualHourglass(channels=width),
ResidualSep(channels=width, dilation=1),
nn.ReLU(inplace=True),
nn.Conv2d(width, 3, kernel_size=3, stride=1, padding=1, bias=True)
)
# Normalization
self.blocks[1].weight.data /= 127.5
self.blocks[-1].weight.data *= 127.5 / 8
self.blocks[-1].bias.data.fill_(127.5)
def forward(self, x):
return self.blocks(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment