Skip to content

Instantly share code, notes, and snippets.

@jdhao
Last active October 25, 2022 07:41
Show Gist options
  • Save jdhao/d96300624b6aaee50a253df671b47b5c to your computer and use it in GitHub Desktop.
Save jdhao/d96300624b6aaee50a253df671b47b5c to your computer and use it in GitHub Desktop.
A naive implementation just for illustrating how forward and backward pass of max-pooling layer in CNN works
import numpy as np
import numba
class MaxPooling(object):
def __init__(self, X, kernel_size=(2,2), stride=(2,2)):
if len(X.shape) != 4:
raise ValueError("Input must have be a tensor of shape N*C*H*W!")
self.X = X
self.input_shape = self.X.shape
N, C, H, W = self.input_shape
self.kernel_size_h, self.kernel_size_w = kernel_size
self.stride_h, self.stride_w = stride
self.out_h = np.floor((H-self.kernel_size_h)/self.stride_h) + 1
self.out_w = np.floor((W-self.kernel_size_w)/self.stride_w) + 1
self.out_h = int(self.out_h)
self.out_w = int(self.out_w)
self.out = np.empty((N, C, self.out_h, self.out_w))
# for each output feature map, store the corresponding index in the original feature map
# in an 2*out_h*out_w tensor. For each element in output feature map, store the index in
# the same position in the max_index tensor. First element denote row index, second element
# denote column index
self.max_index = np.empty((N, C, 2, self.out_h, self.out_w), dtype=np.int16)
@numba.jit
def forward(self):
N, C, _, _ = self.X.shape
for n in range(N):
for c in range(C):
for h in range(self.out_h):
for w in range(self.out_w):
h_start = h*self.stride_h
h_end = h_start + self.kernel_size_h
w_start = w*self.stride_w
w_end = w_start + self.kernel_size_w
self.out[n, c, h, w] = np.max(self.X[n, c, h_start:h_end, w_start:w_end])
scalar_ind = np.argmax(self.X[n, c, h_start:h_end, w_start:w_end])
# ind is in (row_ind, col_ind) format
ind = np.unravel_index(scalar_ind, (self.kernel_size_h, self.kernel_size_w))
# real index of maximum element in the local region
real_ind = (ind[0]+h_start, ind[1]+w_start)
# store this real index in two part
self.max_index[n, c, 0, h, w] = real_ind[0]
self.max_index[n, c, 1, h, w] = real_ind[1]
return self.out
@numba.jit
def backward(self, in_grad):
if in_grad.shape != self.out.shape:
raise ValueError("in_grad should have shape {}, but instead "
"we get {}".format(self.out.shape, in_grad.shape))
out_grad = np.zeros_like(self.X)
N, C,_, _ = self.X.shape
for n in range(N):
for c in range(C):
for h in range(self.out_h):
for w in range(self.out_w):
temp = np.zeros_like(self.X)
# each in_grad element backprop to corresponding position in out_grad
# according to max_index recorded in the forward pass
temp[n, c, self.max_index[n,c,0,h,w],self.max_index[n,c,1,h,w]] = in_grad[n,c,h,w]
# accumulate all the backproped gradient together
out_grad += temp
return out_grad        
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment