Skip to content

Instantly share code, notes, and snippets.

@post2web
Created June 22, 2018 16:56
Show Gist options
  • Save post2web/a92be14008646a3d10b4183c8d35375f to your computer and use it in GitHub Desktop.
Save post2web/a92be14008646a3d10b4183c8d35375f to your computer and use it in GitHub Desktop.
max mean pooling with numpy for 2d and 3d data
# from https://stackoverflow.com/questions/42463172/how-to-perform-max-mean-pooling-on-a-2d-array-using-numpy
import numpy as np
def asStride(arr,sub_shape,stride):
'''Get a strided sub-matrices view of an ndarray.
See also skimage.util.shape.view_as_windows()
'''
s0,s1=arr.strides[:2]
m1,n1=arr.shape[:2]
m2,n2=sub_shape
view_shape=(1+(m1-m2)//stride[0],1+(n1-n2)//stride[1],m2,n2)+arr.shape[2:]
strides=(stride[0]*s0,stride[1]*s1,s0,s1)+arr.strides[2:]
subs=np.lib.stride_tricks.as_strided(arr,view_shape,strides=strides)
return subs
def poolingOverlap(mat,ksize,stride=None,method='max',pad=False):
'''Overlapping pooling on 2D or 3D data.
<mat>: ndarray, input array to pool.
<ksize>: tuple of 2, kernel size in (ky, kx).
<stride>: tuple of 2 or None, stride of pooling window.
If None, same as <ksize> (non-overlapping pooling).
<method>: str, 'max for max-pooling,
'mean' for mean-pooling.
<pad>: bool, pad <mat> or not. If no pad, output has size
(n-f)//s+1, n being <mat> size, f being kernel size, s stride.
if pad, output has size ceil(n/s).
Return <result>: pooled matrix.
'''
m, n = mat.shape[:2]
ky,kx=ksize
if stride is None:
stride=(ky,kx)
sy,sx=stride
_ceil=lambda x,y: int(np.ceil(x/float(y)))
if pad:
ny=_ceil(m,sy)
nx=_ceil(n,sx)
size=((ny-1)*sy+ky, (nx-1)*sx+kx) + mat.shape[2:]
mat_pad=np.full(size,np.nan)
mat_pad[:m,:n,...]=mat
else:
mat_pad=mat[:(m-ky)//sy*sy+ky, :(n-kx)//sx*sx+kx, ...]
view=asStride(mat_pad,ksize,stride)
if method=='max':
result=np.nanmax(view,axis=(2,3))
else:
result=np.nanmean(view,axis=(2,3))
return result
@ManuelZierl
Copy link

very nice 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment