Skip to content

Instantly share code, notes, and snippets.

@J3698
Created March 22, 2021 20:28
Show Gist options
  • Save J3698/211ad78855edfd940f91b1fe76d5ee6e to your computer and use it in GitHub Desktop.
Save J3698/211ad78855edfd940f91b1fe76d5ee6e to your computer and use it in GitHub Desktop.
for i, layer in enumerate(features):
if isinstance(layer, nn.MaxPool2d):
features[i] = nn.Upsample(scale_factor = (2, 2), mode = 'nearest')
elif isinstance(layer, nn.Conv2d):
conv2d = nn.Conv2d(layer.out_channels, layer.in_channels, \
kernel_size = layer.kernel_size, stride = layer.stride, \
padding = layer.padding, padding_mode = 'reflect')
with torch.no_grad():
conv2d.weight[...] = layer.weight.transpose(0, 1)
features[i] = conv2d
elif isinstance(layer, nn.ReLU):
layer.inplace = False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment