-
-
Save jperl/e5ee6f37d24fe944ef100d5e9175fca4 to your computer and use it in GitHub Desktop.
Un-scaled version of CoordConv2D
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras.backend as K | |
import tensorflow as tf | |
from keras.layers import Layer | |
"""Not tested, I'll play around with GANs soon with it.""" | |
class CoordConv2D(Layer): | |
def __init__(self, channel, kernel, padding='valid', **kwargs): | |
self.layer = Conv2D(channel, kernel, padding=padding) | |
self.name = 'CoordConv2D' | |
super(CoordConv2D, self).__init__(**kwargs) | |
def call(self, input): | |
input_shape = tf.unstack(K.shape(input)) | |
if K.image_data_format() == 'channel_first': | |
bs, channel, w, h = input_shape | |
else: | |
bs, w, h, channel = input_shape | |
# Get indices | |
indices = tf.to_float(tf.where(K.ones([bs, w, h]))) | |
canvas = K.reshape(indices, [bs, w, h, 3])[..., 1:] | |
# Normalize the canvas | |
canvas = canvas / tf.to_float(K.reshape([w, h], [1, 1, 1, 2])) | |
canvas = (canvas * 2) - 1 | |
# If channel_first, we swap | |
if K.image_data_format() == 'channel_first': | |
canvas = K.swap_axes(canvas, [0, 3, 1, 2]) | |
# Concatenate channel-wise | |
input = K.concatenate([input, canvas], -1) | |
return self.layer(input) | |
def compute_output_shape(self, input_shape): | |
return self.layer.compute_output_shape(input_shape) | |
from keras.layers import Input, Conv2D | |
import numpy as np | |
from keras import Model | |
inp = Input([32, 32, 3]) | |
layer = CoordConv2D(63, 3, padding='same') | |
x = layer(inp) | |
mod = Model(inp, x) | |
mod.compile('sgd', 'mse') | |
res = mod.predict(np.ones([3, 32, 32, 3])) | |
print(res.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment