Skip to content

Instantly share code, notes, and snippets.

@titu1994
Last active October 5, 2016 06:59
Show Gist options
  • Save titu1994/abe4bb62da23e5d8c1b3d23ca477f232 to your computer and use it in GitHub Desktop.
Save titu1994/abe4bb62da23e5d8c1b3d23ca477f232 to your computer and use it in GitHub Desktop.
from keras.engine.topology import Layer
from keras import backend as K
class Normalize(Layer):
'''
Custom layer to subtract the outputs of previous layer by 120,
to normalize the inputs to the VGG and GAN networks.
'''
def __init__(self, type="vgg", value=120, **kwargs):
super(Normalize, self).__init__(**kwargs)
self.type = type
self.value = value
def build(self, input_shape):
pass
def call(self, x, mask=None):
if self.type == "gan":
return x / self.value
else:
if K.backend() == "theano":
import theano.tensor as T
T.set_subtensor(x[:, 0, :, :], x[:, 0, :, :] - 103.939)
T.set_subtensor(x[:, 1, :, :], x[:, 1, :, :] - 116.779)
T.set_subtensor(x[:, 2, :, :], x[:, 2, :, :] - 123.680)
else:
# No exact substitute for set_subtensor in tensorflow
# So we subtract an approximate value
x = x - self.value
return x
def get_output_shape_for(self, input_shape):
return input_shape
from keras import backend as K
from keras.regularizers import ActivityRegularizer
dummy_loss_val = K.variable(0.0)
# Dummy loss function which simply returns 0
# This is because we will be training the network using regularizers.
def dummy_loss(y_true, y_pred):
return dummy_loss_val
class ContentVGGRegularizer(ActivityRegularizer):
""" Johnson et al 2015 https://arxiv.org/abs/1603.08155 """
def __init__(self, weight=1.0):
super(ContentVGGRegularizer, self).__init__()
self.weight = weight
self.uses_learning_phase = False
def __call__(self, loss):
batch_size = K.shape(self.layer.output)[0] // 2
generated = self.layer.output[:batch_size] # Generated by network features
content = self.layer.output[batch_size:] # True X input features
batch, filters, width, height = K.shape(generated)
loss += self.weight * K.sum(K.square(content - generated)) / (width * height)
return loss
def get_config(self):
return {'name' : self.__class__.__name__,
'weight' : self.weight}
class AdversarialLossRegularizer(ActivityRegularizer):
def __init__(self, weight=1e-3):
super(AdversarialLossRegularizer, self).__init__()
self.weight = weight
self.uses_learning_phase = False
def __call__(self, loss):
gan_outputs = self.layer.output
loss += self.weight * K.sum(-K.log(gan_outputs))
return loss
def get_config(self):
return {'name' : self.__class__.__name__,
'weight' : self.weight}
class TVRegularizer(ActivityRegularizer):
""" Enforces smoothness in image output. """
def __init__(self, img_width, img_height, weight=2e-8):
super(TVRegularizer, self).__init__()
self.img_width = img_width
self.img_height = img_height
self.weight = weight
self.uses_learning_phase = False
def __call__(self, loss):
x = self.layer.output
assert K.ndim(x) == 4
if K.image_dim_ordering() == 'th':
a = K.square(x[:, :, :self.img_width - 1, :self.img_height - 1] - x[:, :, 1:, :self.img_height - 1])
b = K.square(x[:, :, :self.img_width - 1, :self.img_height - 1] - x[:, :, :self.img_width - 1, 1:])
else:
a = K.square(x[:, :self.img_width - 1, :self.img_height - 1, :] - x[:, 1:, :self.img_height - 1, :])
b = K.square(x[:, :self.img_width - 1, :self.img_height - 1, :] - x[:, :self.img_width - 1, 1:, :])
loss += self.weight * K.mean(K.sum(K.pow(a + b, 1.25)))
return loss
def get_config(self):
return {'name' : self.__class__.__name__,
'img_width' : self.img_width,
'img_height' : self.img_height,
'weight' : self.weight}
@titu1994
Copy link
Author

titu1994 commented Oct 2, 2016

Full Network architecture:

srgan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment