Skip to content

Instantly share code, notes, and snippets.

@t-ae
Last active March 23, 2017 07:39
Show Gist options
  • Save t-ae/10f5009790e9e571d5bbaa573e7f346d to your computer and use it in GitHub Desktop.
Save t-ae/10f5009790e9e571d5bbaa573e7f346d to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import numpy as np
from keras.models import Sequential
from keras.layers import InputLayer
from keras.backend import image_dim_ordering, set_image_dim_ordering
from pixel_shuffler import PixelShuffler
batch_size = 6
in_height = 4
in_width = 3
out_channels = 5
rh = 3
rw = 4
print("image dim ordering:", image_dim_ordering())
if image_dim_ordering() == 'tf':
model = Sequential([
InputLayer([in_height, in_width, out_channels * rh * rw]),
PixelShuffler((rh, rw))
])
x = np.arange(batch_size * in_height * in_width * out_channels * rh * rw) \
.reshape([batch_size, in_height, in_width, out_channels * rh * rw])
y = model.predict(x)
# test
for b in range(batch_size):
for h in range(in_height*rh):
for w in range(in_width*rw):
for k in range(out_channels):
_k = out_channels * rw * (h % rh) + out_channels * (w % rw) + k
_h = int(np.floor(h / float(rh)))
_w = int(np.floor(w / float(rw)))
assert y[b, h, w, k] == x[b, _h, _w, _k]
elif image_dim_ordering() == 'th':
model = Sequential([
InputLayer([out_channels * rh * rw, in_height, in_width]),
PixelShuffler((rh, rw))
])
x = np.arange(batch_size * in_height * in_width * out_channels * rh * rw) \
.reshape([batch_size, out_channels * rh * rw, in_height, in_width])
y = model.predict(x)
# test
for b in range(batch_size):
for h in range(in_height * rh):
for w in range(in_width * rw):
for k in range(out_channels):
_k = out_channels * rw * (h % rh) + out_channels * (w % rw) + k
_h = int(np.floor(h / float(rh)))
_w = int(np.floor(w / float(rw)))
assert y[b, k, h, w] == x[b, _k, _h, _w]
else:
raise Exception("neither tf nor th")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment