Last active
February 11, 2020 14:40
-
-
Save Dref360/b330e75cb121c03a0066d9587a7bfee5 to your computer and use it in GitHub Desktop.
Un-scaled version of CoordConv2D
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras.backend as K | |
import tensorflow as tf | |
from tensorflow.keras.layers import Layer | |
"""Not tested, I'll play around with GANs soon with it.""" | |
from tensorflow.keras.layers import Conv2D | |
import numpy as np | |
class CoordConv2D(Layer): | |
def __init__(self, channel, kernel, padding='valid', **kwargs): | |
self.layer = Conv2D(channel, kernel, padding=padding) | |
super(CoordConv2D, self).__init__(**kwargs) | |
def call(self, inputs, **kwargs): | |
indices = tf.ones_like(inputs) | |
if K.image_data_format() == 'channel_first': | |
# bs, channel, w, h | |
indices = indices[:, 0, ...] | |
else: | |
#bs, w, h, channel = input_shape | |
indices = indices[..., 0] | |
# Get indices | |
bs, w, h = [tf.shape(indices)[k] for k in range(3)] | |
indices = K.cast(tf.where(indices), tf.float32) | |
canvas = K.reshape(indices, [bs, w, h, 3])[..., 1:] | |
# Normalize the canvas | |
canvas = canvas / tf.cast(K.reshape([w, h], [1, 1, 1, 2]), tf.float32) | |
canvas = (canvas * 2) - 1 | |
# If channel_first, we swap | |
if K.image_data_format() == 'channel_first': | |
canvas = K.permute_dimensions(canvas, [0, 3, 1, 2]) | |
# Concatenate channel-wise | |
input = K.concatenate([inputs, canvas], -1) | |
return self.layer(input) | |
def compute_output_shape(self, input_shape): | |
return self.layer.compute_output_shape(input_shape) | |
class CustomModel(tf.keras.Model): | |
def __init__(self): | |
super().__init__() | |
self.l = CoordConv2D(63, 3, padding='same') | |
def call(self, inputs): | |
x = self.l(inputs) | |
return x | |
def main(): | |
mod = CustomModel() | |
mod.compile('sgd', 'mse') | |
mod.run_eagerly = True | |
res = mod.predict(np.ones([3, 32, 32, 3])) | |
print(res.shape) | |
if __name__ == '__main__': | |
main() |
Hi Thomas,
I updated the gist to TF2.
Thank you!
Thanks!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@Dref360, I've tried modifying the above for Tensorflow 2.1. (see my fork). Any chance you understand what is causing the error that results when one runs it?