Created
August 28, 2019 20:54
-
-
Save hereismari/85891b30eb9804a7e81e75a238fb8e84 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def maxpool2d(a_sh, kernel_size: int = 1, stride: int = 1, padding: int = 0, | |
dilation: int = 1, ceil_mode=False) | |
"""Applies a 2D max pooling over an input signal composed of several input planes. | |
This interface is similar to torch.nn.MaxPool2D. | |
Args: | |
kernel_size: the size of the window to take a max over | |
stride: the stride of the window | |
padding: implicit zero padding to be added on both sides | |
dilation: a parameter that controls the stride of elements in the window | |
ceil_mode: when True, will use ceil instead of floor to compute the output shape | |
""" | |
assert len(a_sh.shape) == 4 | |
# Change to tuple if not one | |
stride = torch.nn.modules.utils._pair(stride) | |
padding = torch.nn.modules.utils._pair(padding) | |
dilation = torch.nn.modules.utils._pair(dilation) | |
# Extract a few useful values | |
bh_in, ch_in, h_in, w_in = input.shape | |
# ########## Calculate output shapes ############### | |
round_op = math.ceil if ceil_model else math.floor | |
h_out = round_op((h_in + 2 * h_padding - h_dilation * (h_kernel_size - 1) - 1)/h_stride + 1)) | |
w_out = round_op((h_in + 2 * w_padding - w_dilation * (w_kernel_size - 1) - 1)/w_stride + 1)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment