Created
November 14, 2022 01:32
-
-
Save MohamedAliRashad/e49769e7bb811b37ab0b7896ce83314e to your computer and use it in GitHub Desktop.
Spatial Pyramid Pooling Pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Credits to https://github.com/revidee/pytorch-pyramid-pooling | |
## I just modified on it to make it simpler | |
class SpatialPyramidPooling(nn.Module): | |
def __init__(self, partition_size_list: list, pooling_mode: str = "max"): | |
"""Spatial Pyramid Pooling layer for PyTorch users | |
Parameters | |
---------- | |
partition_size_list : list | |
List of vertical and horizontal partitions to image | |
pooling_mode : str, optional | |
operation of the pooling functionality ["max", "avg"], by default "max" | |
""" | |
super(SpatialPyramidPooling, self).__init__() | |
self.partition_size_list = partition_size_list | |
self.pooling_mode = pooling_mode | |
def forward(self, x): | |
batch_size = x.size(0) | |
H, W = int(x.size(2)), int(x.size(3)) | |
for i in range(len(self.partition_size_list)): | |
h_kernel = int(math.ceil(H / self.partition_size_list[i])) | |
w_kernel = int(math.ceil(W / self.partition_size_list[i])) | |
w_pad1 = int(math.floor((w_kernel * self.partition_size_list[i] - W) / 2)) | |
w_pad2 = int(math.ceil((w_kernel * self.partition_size_list[i] - W) / 2)) | |
h_pad1 = int(math.floor((h_kernel * self.partition_size_list[i] - H) / 2)) | |
h_pad2 = int(math.ceil((h_kernel * self.partition_size_list[i] - H) / 2)) | |
assert w_pad1 + w_pad2 == (w_kernel * self.partition_size_list[i] - W) and h_pad1 + h_pad2 == ( | |
h_kernel * self.partition_size_list[i] - H | |
) | |
padded_input = nn.functional.pad( | |
input=x, pad=[w_pad1, w_pad2, h_pad1, h_pad2], mode="constant", value=0 | |
) | |
if self.pooling_mode == "max": | |
out = nn.functional.max_pool2d( | |
padded_input, (h_kernel, w_kernel), stride=(h_kernel, w_kernel), padding=(0, 0) | |
) | |
elif self.pooling_mode == "avg": | |
out = nn.functional.avg_pool2d( | |
padded_input, (h_kernel, w_kernel), stride=(h_kernel, w_kernel), padding=(0, 0) | |
) | |
else: | |
raise RuntimeError(f'Unknown pooling type: {self.pooling_mode}, please use "max" or "avg".') | |
if i == 0: | |
spp = out.view(batch_size, -1) | |
else: | |
spp = torch.cat((spp, out.view(batch_size, -1)), 1) | |
return spp |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment