Created
June 22, 2018 16:56
-
-
Save post2web/a92be14008646a3d10b4183c8d35375f to your computer and use it in GitHub Desktop.
max mean pooling with numpy for 2d and 3d data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from https://stackoverflow.com/questions/42463172/how-to-perform-max-mean-pooling-on-a-2d-array-using-numpy | |
import numpy as np | |
def asStride(arr,sub_shape,stride): | |
'''Get a strided sub-matrices view of an ndarray. | |
See also skimage.util.shape.view_as_windows() | |
''' | |
s0,s1=arr.strides[:2] | |
m1,n1=arr.shape[:2] | |
m2,n2=sub_shape | |
view_shape=(1+(m1-m2)//stride[0],1+(n1-n2)//stride[1],m2,n2)+arr.shape[2:] | |
strides=(stride[0]*s0,stride[1]*s1,s0,s1)+arr.strides[2:] | |
subs=np.lib.stride_tricks.as_strided(arr,view_shape,strides=strides) | |
return subs | |
def poolingOverlap(mat,ksize,stride=None,method='max',pad=False): | |
'''Overlapping pooling on 2D or 3D data. | |
<mat>: ndarray, input array to pool. | |
<ksize>: tuple of 2, kernel size in (ky, kx). | |
<stride>: tuple of 2 or None, stride of pooling window. | |
If None, same as <ksize> (non-overlapping pooling). | |
<method>: str, 'max for max-pooling, | |
'mean' for mean-pooling. | |
<pad>: bool, pad <mat> or not. If no pad, output has size | |
(n-f)//s+1, n being <mat> size, f being kernel size, s stride. | |
if pad, output has size ceil(n/s). | |
Return <result>: pooled matrix. | |
''' | |
m, n = mat.shape[:2] | |
ky,kx=ksize | |
if stride is None: | |
stride=(ky,kx) | |
sy,sx=stride | |
_ceil=lambda x,y: int(np.ceil(x/float(y))) | |
if pad: | |
ny=_ceil(m,sy) | |
nx=_ceil(n,sx) | |
size=((ny-1)*sy+ky, (nx-1)*sx+kx) + mat.shape[2:] | |
mat_pad=np.full(size,np.nan) | |
mat_pad[:m,:n,...]=mat | |
else: | |
mat_pad=mat[:(m-ky)//sy*sy+ky, :(n-kx)//sx*sx+kx, ...] | |
view=asStride(mat_pad,ksize,stride) | |
if method=='max': | |
result=np.nanmax(view,axis=(2,3)) | |
else: | |
result=np.nanmean(view,axis=(2,3)) | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
very nice 👍