Last active
November 13, 2019 19:15
-
-
Save N-McA/9bd3a81d3062340e4affaaaaad332107 to your computer and use it in GitHub Desktop.
Concatenates the (x, y) coordinate normalised to 0-1 to each spatial location in the image. Allows a network to learn spatial bias. Has been explored in at least one paper, "An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution" https://arxiv.org/abs/1807.03247
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras.backend as kb | |
from keras.layers import Layer | |
def _kb_linspace(num): | |
num = kb.cast(num, kb.floatx()) | |
return kb.arange(0, num, dtype=kb.floatx()) / (num - 1) | |
def _kb_grid_coords(width, height): | |
w, h = width, height | |
coords = kb.stack( | |
[ | |
kb.reshape(kb.tile(kb.expand_dims(_kb_linspace(num=w), -1), [1, h]), [-1]), | |
kb.tile(_kb_linspace(num=h), [w]), | |
], | |
axis=-1, | |
) | |
coords = kb.reshape(coords, [w, h, 2]) | |
return coords | |
class ConcatSpatialCoordinate(Layer): | |
def __init__(self, **kwargs): | |
"""Concatenates the (x, y) coordinate normalised to 0-1 to each spatial location in the image. | |
Allows a network to learn spatial bias. Has been explored in at least one paper, | |
"An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution" | |
https://arxiv.org/abs/1807.03247 | |
Improves performance where spatial bias is appropriate. | |
Works with dynamic shapes. | |
# Example | |
```python | |
x_input = Input([None, None, 1]) | |
x = ConcatSpatialCoordinate()(x_input) | |
model = Model(x_input, x) | |
output = model.predict(np.zeros([1, 3, 3, 1])) | |
spatial_features = output[0, :, :, -2:] | |
assert np.all(spatial_features[0, 0] == [0, 0]) | |
assert np.all(spatial_features[-1, -1] == [1, 1]) | |
assert np.all(spatial_features[0, -1] == [0, 1]) | |
# Because this example was 3x3, cordinates are [0, 0.5, 1], so | |
assert np.all(spatial_features[1, 1] == [0.5, 0.5]) | |
``` | |
""" | |
if kb.image_data_format() != 'channels_last': | |
raise Exception(( | |
"Only compatible with" | |
" kb.image_data_format() == 'channels_last'")) | |
super(ConcatSpatialCoordinate, self).__init__(**kwargs) | |
def build(self, input_shape): | |
super(ConcatSpatialCoordinate, self).build(input_shape) | |
def call(self, x): | |
dynamic_input_shape = kb.shape(x) | |
w = dynamic_input_shape[-3] | |
h = dynamic_input_shape[-2] | |
bias = _kb_grid_coords(width=w, height=h) | |
return kb.concatenate([x, kb.expand_dims(bias, 0)], axis=-1) | |
def compute_output_shape(self, input_shape): | |
batch_size, w, h, channels = input_shape | |
return (batch_size, w, h, channels + 2) | |
def test_ConcatSpatialCoordinate(): | |
import numpy as np | |
from keras.layers import Input | |
from keras.models import Model | |
x_input = Input([None, None, 1]) | |
x = ConcatSpatialCoordinate()(x_input) | |
model = Model(x_input, x) | |
output = model.predict(np.zeros([1, 3, 3, 1])) | |
spatial_features = output[0, :, :, -2:] | |
# The following are always true. | |
assert np.all(spatial_features[0, 0] == [0, 0]) | |
assert np.all(spatial_features[-1, -1] == [1, 1]) | |
assert np.all(spatial_features[0, -1] == [0, 1]) | |
# Because this example was 3x3, cordinates are [0, 0.5, 1], so | |
assert np.all(spatial_features[1, 1] == [0.5, 0.5]) | |
if __name__ == '__main__': | |
test_ConcatSpatialCoordinate() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for sharing!
For visitors that stopped by this page, below fork contains a minor update to method
call
to account for the batch size while tiling the coordinate layers to the original input. For example, if the input is 16x128x128x1, the output from ConcatSpatialCoordinate would be 16x128x128x3.https://gist.github.com/pangyuteng/8f4f7c09b490e1baaef852d07105db77