Skip to content

Instantly share code, notes, and snippets.

@muayyad-alsadi
Created July 27, 2018 15:40
Show Gist options
  • Save muayyad-alsadi/39506bdf906568fcdeb7b1a8f6aeba36 to your computer and use it in GitHub Desktop.
Save muayyad-alsadi/39506bdf906568fcdeb7b1a8f6aeba36 to your computer and use it in GitHub Desktop.
tensorflow 2D random pooling
import tensorflow as tf
import numpy as np
import math
PADDING_VALID, PADDING_SAME = 'VALID', 'SAME'
FORMAT_NHWC, FORMAT_NCHW = 'NHWC', 'NCHW'
# TODO: support N=?
def rnd_pooling2d(inputs, ksize=2, strides=2, padding=PADDING_VALID, data_format=FORMAT_NHWC, name=None):
'''
similar to `tf.nn.max_pool` and `tf.layers.max_pooling2d`
but pick random from pool window instead of max
'''
if data_format != FORMAT_NHWC: raise NotImplementedError('only data_format=NHWC is implemented')
if padding != PADDING_VALID and padding != PADDING_SAME: raise NotImplementedError('padding should be either VALID or SAME')
# short cut
k = ksize
s = strides
# sizes
N, H, W, C = map(int, inputs.shape)
CH, CW = (H, W) if padding==PADDING_SAME else (H-k+1, W-k+1)
# corner points
c = np.array([ [ [ [ (i, h, w, j) for j in range(C) ] for w in range(0, CW, s) ] for h in range(0, CH, s) ] for i in range(N)])
ch, cw = c.shape[1], c.shape[2]
tc = tf.constant(c, dtype=tf.int32)
# random offset from corner, same shape as center point, contains redundant axis for C
tr = tf.random_uniform((N, ch, cw, 2), 0, k, tf.int32)
# repeat each C times, and make [y, x] into [0, y, x, 0]
tr = tf.transpose(tf.stack([tr for i in range(C)]), [1, 2, 3, 0, 4])
tr = tf.pad(tr, [[0,0], [0,0], [0,0], [0,0], [ 1, 1 ]])
# index of randomly shifted center point
if padding==PADDING_VALID:
ix = tc+tr
else:
# max points
m = np.array([ [ [ [ (N-1, H-1, W-1, C-1) for j in range(C) ] for w in range(0, CW, s) ] for h in range(0, CH, s) ] for i in range(N)])
tm = tf.constant(m, dtype=tf.int32)
ix = tf.minimum(tc+tr, tm)
return tf.gather_nd(inputs, ix, name=name)
if __name__ == '__main__':
from PIL import Image
session = tf.Session()
with open('test.png', 'rb') as f:
raw_data=f.read()
img1=tf.image.decode_image(raw_data)
img2=tf.image.adjust_saturation(img1, 0.0)
img1, img2 = session.run([img1, img2])
inputs = tf.stack([img1, img2])
print session.run(inputs).shape
out=session.run(rnd_pooling2d(inputs, 3, 3, padding='VALID'))
print out[0].shape
out1=Image.fromarray(np.array(out[0]))
out1.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment