Skip to content

Instantly share code, notes, and snippets.

@jperl
Forked from Dref360/coordconv2d.py
Created May 14, 2019 21:32
Show Gist options
  • Save jperl/e5ee6f37d24fe944ef100d5e9175fca4 to your computer and use it in GitHub Desktop.
Save jperl/e5ee6f37d24fe944ef100d5e9175fca4 to your computer and use it in GitHub Desktop.
Un-scaled version of CoordConv2D
import keras.backend as K
import tensorflow as tf
from keras.layers import Layer
"""Not tested, I'll play around with GANs soon with it."""
class CoordConv2D(Layer):
def __init__(self, channel, kernel, padding='valid', **kwargs):
self.layer = Conv2D(channel, kernel, padding=padding)
self.name = 'CoordConv2D'
super(CoordConv2D, self).__init__(**kwargs)
def call(self, input):
input_shape = tf.unstack(K.shape(input))
if K.image_data_format() == 'channel_first':
bs, channel, w, h = input_shape
else:
bs, w, h, channel = input_shape
# Get indices
indices = tf.to_float(tf.where(K.ones([bs, w, h])))
canvas = K.reshape(indices, [bs, w, h, 3])[..., 1:]
# Normalize the canvas
canvas = canvas / tf.to_float(K.reshape([w, h], [1, 1, 1, 2]))
canvas = (canvas * 2) - 1
# If channel_first, we swap
if K.image_data_format() == 'channel_first':
canvas = K.swap_axes(canvas, [0, 3, 1, 2])
# Concatenate channel-wise
input = K.concatenate([input, canvas], -1)
return self.layer(input)
def compute_output_shape(self, input_shape):
return self.layer.compute_output_shape(input_shape)
from keras.layers import Input, Conv2D
import numpy as np
from keras import Model
inp = Input([32, 32, 3])
layer = CoordConv2D(63, 3, padding='same')
x = layer(inp)
mod = Model(inp, x)
mod.compile('sgd', 'mse')
res = mod.predict(np.ones([3, 32, 32, 3]))
print(res.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment