Last active
October 25, 2019 00:30
-
-
Save SleepProgger/3cb3608b629f807c65a83a61bb53420f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from keras.layers import Input, LeakyReLU, Conv2D | |
from keras.engine import InputSpec, Layer | |
from keras.models import Model as KerasModel | |
import keras.backend as K | |
import numpy as np | |
if K.backend() == "plaidml.keras.backend": | |
import plaidml | |
import plaidml.op | |
def pad(data, paddings, mode="CONSTANT", name=None, constant_value=0): | |
""" PlaidML Pad """ | |
if mode.upper() != "REFLECT": | |
raise NotImplementedError("pad only supports mode == 'REFLECT'") | |
if constant_value != 0: | |
raise NotImplementedError("pad does not support constant_value != 0") | |
return plaidml.op.reflection_padding(data, paddings) | |
else: | |
from tensorflow import pad | |
class ReflectionPadding2D(Layer): | |
"""Reflection-padding layer for 2D input (e.g. picture).""" | |
def __init__(self, stride=2, kernel_size=5, **kwargs): | |
''' | |
# Arguments | |
stride: stride of following convolution (2) | |
kernel_size: kernel size of following convolution (5,5) | |
''' | |
self.stride = stride | |
self.kernel_size = kernel_size | |
super(ReflectionPadding2D, self).__init__(**kwargs) | |
def build(self, input_shape): | |
self.input_spec = [InputSpec(shape=input_shape)] | |
super(ReflectionPadding2D, self).build(input_shape) | |
def compute_output_shape(self, input_shape): | |
""" If you are using "channels_last" configuration""" | |
input_shape = self.input_spec[0].shape | |
in_width, in_height = input_shape[2], input_shape[1] | |
kernel_width, kernel_height = self.kernel_size, self.kernel_size | |
if (in_height % self.stride) == 0: | |
padding_height = max(kernel_height - self.stride, 0) | |
else: | |
padding_height = max(kernel_height - (in_height % self.stride), 0) | |
if (in_width % self.stride) == 0: | |
padding_width = max(kernel_width - self.stride, 0) | |
else: | |
padding_width = max(kernel_width- (in_width % self.stride), 0) | |
return (input_shape[0], | |
input_shape[1] + padding_height, | |
input_shape[2] + padding_width, | |
input_shape[3]) | |
def call(self, x, mask=None): | |
input_shape = self.input_spec[0].shape | |
in_width, in_height = input_shape[2], input_shape[1] | |
kernel_width, kernel_height = self.kernel_size, self.kernel_size | |
if (in_height % self.stride) == 0: | |
padding_height = max(kernel_height - self.stride, 0) | |
else: | |
padding_height = max(kernel_height - (in_height % self.stride), 0) | |
if (in_width % self.stride) == 0: | |
padding_width = max(kernel_width - self.stride, 0) | |
else: | |
padding_width = max(kernel_width- (in_width % self.stride), 0) | |
padding_top = padding_height // 2 | |
padding_bot = padding_height - padding_top | |
padding_left = padding_width // 2 | |
padding_right = padding_width - padding_left | |
return pad(x, | |
[[0, 0], | |
[padding_top, padding_bot], | |
[padding_left, padding_right], | |
[0, 0]], | |
'REFLECT') | |
def get_config(self): | |
config = {'stride': self.stride, | |
'kernel_size': self.kernel_size} | |
base_config = super(ReflectionPadding2D, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
USE_REFLECTIVE_PADDING = True | |
INPUT_SHAPE = (64, 64, 3) | |
def conv(x, filters, kernel_size=5, strides=2, padding="same"): | |
if USE_REFLECTIVE_PADDING: | |
x = ReflectionPadding2D(stride=strides, kernel_size=kernel_size)(x) | |
padding = "valid" | |
x = Conv2D(filters, kernel_size=kernel_size, strides=strides, padding=padding)(x) | |
return x | |
if __name__ == '__main__': | |
x = inp = Input(INPUT_SHAPE) | |
x = conv(x, 128) | |
x = conv(x, 256) | |
model = KerasModel(inp, x) | |
model.summary() | |
train_x = np.ones((8,) + INPUT_SHAPE) | |
train_y = np.ones((8,) + tuple(K.int_shape(x)[1:])) | |
model.compile("Adam", "mse") | |
for i in range(20): | |
print("Train batch %i" % i) | |
model.train_on_batch(train_x, train_y) | |
exit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment