Skip to content

Instantly share code, notes, and snippets.

@MohamedAliRashad
Created November 14, 2022 01:32
Show Gist options
  • Save MohamedAliRashad/e49769e7bb811b37ab0b7896ce83314e to your computer and use it in GitHub Desktop.
Save MohamedAliRashad/e49769e7bb811b37ab0b7896ce83314e to your computer and use it in GitHub Desktop.
Spatial Pyramid Pooling Pytorch
## Credits to https://github.com/revidee/pytorch-pyramid-pooling
## I just modified on it to make it simpler
class SpatialPyramidPooling(nn.Module):
def __init__(self, partition_size_list: list, pooling_mode: str = "max"):
"""Spatial Pyramid Pooling layer for PyTorch users
Parameters
----------
partition_size_list : list
List of vertical and horizontal partitions to image
pooling_mode : str, optional
operation of the pooling functionality ["max", "avg"], by default "max"
"""
super(SpatialPyramidPooling, self).__init__()
self.partition_size_list = partition_size_list
self.pooling_mode = pooling_mode
def forward(self, x):
batch_size = x.size(0)
H, W = int(x.size(2)), int(x.size(3))
for i in range(len(self.partition_size_list)):
h_kernel = int(math.ceil(H / self.partition_size_list[i]))
w_kernel = int(math.ceil(W / self.partition_size_list[i]))
w_pad1 = int(math.floor((w_kernel * self.partition_size_list[i] - W) / 2))
w_pad2 = int(math.ceil((w_kernel * self.partition_size_list[i] - W) / 2))
h_pad1 = int(math.floor((h_kernel * self.partition_size_list[i] - H) / 2))
h_pad2 = int(math.ceil((h_kernel * self.partition_size_list[i] - H) / 2))
assert w_pad1 + w_pad2 == (w_kernel * self.partition_size_list[i] - W) and h_pad1 + h_pad2 == (
h_kernel * self.partition_size_list[i] - H
)
padded_input = nn.functional.pad(
input=x, pad=[w_pad1, w_pad2, h_pad1, h_pad2], mode="constant", value=0
)
if self.pooling_mode == "max":
out = nn.functional.max_pool2d(
padded_input, (h_kernel, w_kernel), stride=(h_kernel, w_kernel), padding=(0, 0)
)
elif self.pooling_mode == "avg":
out = nn.functional.avg_pool2d(
padded_input, (h_kernel, w_kernel), stride=(h_kernel, w_kernel), padding=(0, 0)
)
else:
raise RuntimeError(f'Unknown pooling type: {self.pooling_mode}, please use "max" or "avg".')
if i == 0:
spp = out.view(batch_size, -1)
else:
spp = torch.cat((spp, out.view(batch_size, -1)), 1)
return spp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment