Skip to content

Instantly share code, notes, and snippets.

@malnakli
Created November 11, 2019 00:27
Show Gist options
  • Save malnakli/6882357f4834337f9f9e6b95aa97e7ab to your computer and use it in GitHub Desktop.
Save malnakli/6882357f4834337f9f9e6b95aa97e7ab to your computer and use it in GitHub Desktop.
import torch.nn as nn
import math
def conv2d_out_shape(width, height, Conv2d):
"""
return (C , W , H)
C: channels
W: Width
H: Height
"""
# taken from:
# https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d
_h = math.floor(((
height
+ (2 * Conv2d.padding[0])
- Conv2d.dilation[0] * (Conv2d.kernel_size[0] - 1)
- 1
) / Conv2d.stride[0]) + 1)
_w = math.floor(((
width
+ (2 * Conv2d.padding[1])
- Conv2d.dilation[1] * (Conv2d.kernel_size[1] - 1)
- 1
) / Conv2d.stride[1]) + 1)
return (Conv2d.out_channels, _w, _h)
# example
print(conv2d_out_shape(28, 28, nn.Conv2d(3, 32, 3, 2, 1)))
# output (32, 14, 14)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment