Created
July 9, 2016 20:29
-
-
Save EderSantana/3cebf581aeb2aec896e77e25635994ba to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["KERAS_BACKEND"] = "tensorflow" | |
import tensorflow as tf | |
from keras.engine import Layer, InputSpec | |
from keras import backend as K, regularizers, constraints, initializations, activations | |
class Deconv2D(Layer): | |
def __init__(self, nb_filter, nb_row, nb_col, | |
init='glorot_uniform', activation='linear', weights=None, | |
border_mode='valid', subsample=(1, 1), dim_ordering='tf', | |
W_regularizer=None, b_regularizer=None, activity_regularizer=None, | |
W_constraint=None, b_constraint=None, **kwargs): | |
if border_mode not in {'valid', 'same'}: | |
raise Exception('Invalid border mode for Convolution2D:', border_mode) | |
self.nb_filter = nb_filter | |
self.nb_row = nb_row | |
self.nb_col = nb_col | |
self.init = initializations.get(init, dim_ordering=dim_ordering) | |
self.activation = activations.get(activation) | |
assert border_mode in {'valid', 'same'}, 'border_mode must be in {valid, same}' | |
self.border_mode = border_mode | |
self.subsample = tuple(subsample) | |
assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}' | |
self.dim_ordering = dim_ordering | |
self.W_regularizer = regularizers.get(W_regularizer) | |
self.b_regularizer = regularizers.get(b_regularizer) | |
self.activity_regularizer = regularizers.get(activity_regularizer) | |
self.W_constraint = constraints.get(W_constraint) | |
self.b_constraint = constraints.get(b_constraint) | |
self.input_spec = [InputSpec(ndim=4)] | |
self.initial_weights = weights | |
super(Deconv2D, self).__init__(**kwargs) | |
def build(self, input_shape): | |
if self.dim_ordering == 'th': | |
stack_size = input_shape[1] | |
self.W_shape = (self.nb_filter, stack_size, self.nb_row, self.nb_col) | |
elif self.dim_ordering == 'tf': | |
stack_size = input_shape[3] | |
self.W_shape = (self.nb_row, self.nb_col, self.nb_filter, stack_size) | |
else: | |
raise Exception('Invalid dim_ordering: ' + self.dim_ordering) | |
self.W = self.init(self.W_shape, name='{}/w'.format(self.name)) | |
self.b = K.zeros((self.nb_filter,), name='{}/biases'.format(self.name)) | |
self.trainable_weights = [self.W, self.b] | |
self.regularizers = [] | |
if self.W_regularizer: | |
self.W_regularizer.set_param(self.W) | |
self.regularizers.append(self.W_regularizer) | |
if self.b_regularizer: | |
self.b_regularizer.set_param(self.b) | |
self.regularizers.append(self.b_regularizer) | |
if self.activity_regularizer: | |
self.activity_regularizer.set_layer(self) | |
self.regularizers.append(self.activity_regularizer) | |
self.constraints = {} | |
if self.W_constraint: | |
self.constraints[self.W] = self.W_constraint | |
if self.b_constraint: | |
self.constraints[self.b] = self.b_constraint | |
if self.initial_weights is not None: | |
self.set_weights(self.initial_weights) | |
del self.initial_weights | |
def get_output_shape_for(self, input_shape): | |
if self.dim_ordering == 'th': | |
rows = input_shape[2] | |
cols = input_shape[3] | |
elif self.dim_ordering == 'tf': | |
rows = input_shape[1] | |
cols = input_shape[2] | |
else: | |
raise Exception('Invalid dim_ordering: ' + self.dim_ordering) | |
rows = rows * self.subsample[0] | |
cols = cols * self.subsample[1] | |
if self.dim_ordering == 'th': | |
return (input_shape[0], self.nb_filter, rows, cols) | |
elif self.dim_ordering == 'tf': | |
return (input_shape[0], rows, cols, self.nb_filter) | |
else: | |
raise Exception('Invalid dim_ordering: ' + self.dim_ordering) | |
def call(self, x, mask=None): | |
output_shape = self.get_output_shape_for(x.get_shape().as_list()) | |
deconv_out = tf.nn.conv2d_transpose( | |
x, self.W, output_shape=output_shape, strides=[1, self.subsample[0], self.subsample[1], 1]) | |
if self.dim_ordering == 'th': | |
output = deconv_out + K.reshape(self.b, (1, self.nb_filter, 1, 1)) | |
elif self.dim_ordering == 'tf': | |
output = deconv_out + K.reshape(self.b, (1, 1, 1, self.nb_filter)) | |
else: | |
raise Exception('Invalid dim_ordering: ' + self.dim_ordering) | |
output = self.activation(output) | |
return output | |
def get_config(self): | |
config = {'nb_filter': self.nb_filter, | |
'nb_row': self.nb_row, | |
'nb_col': self.nb_col, | |
'init': self.init.__name__, | |
'activation': self.activation.__name__, | |
'border_mode': self.border_mode, | |
'subsample': self.subsample, | |
'dim_ordering': self.dim_ordering, | |
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None, | |
'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None, | |
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, | |
'W_constraint': self.W_constraint.get_config() if self.W_constraint else None, | |
'b_constraint': self.b_constraint.get_config() if self.b_constraint else None} | |
base_config = super(Deconv2D, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
hi @MikeAmy check this commit keras-team/keras#3251
it should work for both theano and tensorflow