Created
July 27, 2018 15:40
-
-
Save muayyad-alsadi/39506bdf906568fcdeb7b1a8f6aeba36 to your computer and use it in GitHub Desktop.
tensorflow 2D random pooling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import math | |
PADDING_VALID, PADDING_SAME = 'VALID', 'SAME' | |
FORMAT_NHWC, FORMAT_NCHW = 'NHWC', 'NCHW' | |
# TODO: support N=? | |
def rnd_pooling2d(inputs, ksize=2, strides=2, padding=PADDING_VALID, data_format=FORMAT_NHWC, name=None): | |
''' | |
similar to `tf.nn.max_pool` and `tf.layers.max_pooling2d` | |
but pick random from pool window instead of max | |
''' | |
if data_format != FORMAT_NHWC: raise NotImplementedError('only data_format=NHWC is implemented') | |
if padding != PADDING_VALID and padding != PADDING_SAME: raise NotImplementedError('padding should be either VALID or SAME') | |
# short cut | |
k = ksize | |
s = strides | |
# sizes | |
N, H, W, C = map(int, inputs.shape) | |
CH, CW = (H, W) if padding==PADDING_SAME else (H-k+1, W-k+1) | |
# corner points | |
c = np.array([ [ [ [ (i, h, w, j) for j in range(C) ] for w in range(0, CW, s) ] for h in range(0, CH, s) ] for i in range(N)]) | |
ch, cw = c.shape[1], c.shape[2] | |
tc = tf.constant(c, dtype=tf.int32) | |
# random offset from corner, same shape as center point, contains redundant axis for C | |
tr = tf.random_uniform((N, ch, cw, 2), 0, k, tf.int32) | |
# repeat each C times, and make [y, x] into [0, y, x, 0] | |
tr = tf.transpose(tf.stack([tr for i in range(C)]), [1, 2, 3, 0, 4]) | |
tr = tf.pad(tr, [[0,0], [0,0], [0,0], [0,0], [ 1, 1 ]]) | |
# index of randomly shifted center point | |
if padding==PADDING_VALID: | |
ix = tc+tr | |
else: | |
# max points | |
m = np.array([ [ [ [ (N-1, H-1, W-1, C-1) for j in range(C) ] for w in range(0, CW, s) ] for h in range(0, CH, s) ] for i in range(N)]) | |
tm = tf.constant(m, dtype=tf.int32) | |
ix = tf.minimum(tc+tr, tm) | |
return tf.gather_nd(inputs, ix, name=name) | |
if __name__ == '__main__': | |
from PIL import Image | |
session = tf.Session() | |
with open('test.png', 'rb') as f: | |
raw_data=f.read() | |
img1=tf.image.decode_image(raw_data) | |
img2=tf.image.adjust_saturation(img1, 0.0) | |
img1, img2 = session.run([img1, img2]) | |
inputs = tf.stack([img1, img2]) | |
print session.run(inputs).shape | |
out=session.run(rnd_pooling2d(inputs, 3, 3, padding='VALID')) | |
print out[0].shape | |
out1=Image.fromarray(np.array(out[0])) | |
out1.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment