Last active
March 13, 2019 20:21
-
-
Save sandeepnmenon/023495c69a877980535a7bf2da1389fa to your computer and use it in GitHub Desktop.
Gated autoencoder model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_gated_connections(gatePercentageFactor,inputLayer): | |
gateFactor = Input(tensor = K.variable([gatePercentageFactor])) | |
fractionG = Lambda(lambda x: x[0]*x[1])([inputLayer,gateFactor]) | |
complement = Lambda(lambda x: x[0] - x[1])([inputLayer,fractionG]) | |
return gateFactor,fractionG,complement | |
#x is conv layer | |
#y is de-conv layer | |
#gf is gating factor | |
#fg is fractional input from gate | |
#c is complement ie remaining fraction from the gate | |
#jt joining tensor of convolution layer and previous de-conv layer | |
def get_cnn_dsc_architecture(): | |
input_img = Input(shape=(None,None,3)) # adapt this if using `channels_first` image data format | |
x1 = Conv2D(64, (4, 4), activation='relu', padding='same')(input_img) | |
gf1,fg1,c1 = get_gated_connections(0.1,x1) | |
x = MaxPooling2D((2, 2), padding='same')(fg1) | |
x2 = Conv2D(64, (4, 4), activation='relu', padding='same')(x) | |
gf2,fg2,c2 = get_gated_connections(0.2,x2) | |
x = MaxPooling2D((2, 2), padding='same')(fg2) | |
x3 = Conv2D(128, (4, 4), activation='relu', padding='same')(x) | |
gf3,fg3,c3 = get_gated_connections(0.3,x3) | |
x = MaxPooling2D((2, 2), padding='same')(x3) | |
x4 = Conv2D(256, (4, 4), activation='relu', padding='same')(x) | |
gf4,fg4,c4 = get_gated_connections(0.4,x4) | |
x = MaxPooling2D((2, 2), padding='same')(x4) | |
x5 = Conv2D(512, (4, 4), activation='relu', padding='same')(x) | |
x = UpSampling2D((2, 2))(x5) | |
y1 = Conv2DTranspose(256, (4, 4), activation='relu', padding='same')(x) | |
jt4 = Add()([y1,c4]) | |
x = UpSampling2D((2, 2))(jt4) | |
y2 = Conv2DTranspose(128, (4, 4), activation='relu', padding='same')(x) | |
jt3 = Add()([y2,c3]) | |
x = UpSampling2D((2, 2))(jt3) | |
y3 = Conv2DTranspose(64, (4, 4), activation='relu', padding='same')(x) | |
jt2 = Add()([y3,c2]) | |
x = UpSampling2D((2, 2))(jt2) | |
jt1 = Add()([x,c1]) | |
y4 = Conv2DTranspose(64, (4, 4), activation='relu', padding='same')(jt1) | |
y5 = Conv2DTranspose(3, (4, 4), activation='relu', padding='same')(y4) | |
layers = y5 | |
sym_autoencoder = Model([input_img,gf1,gf2,gf3,gf4],layers) | |
sym_autoencoder.compile(optimizer='sgd', loss = 'mean_squared_error', metrics = ['accuracy','mean_squared_error']) | |
print sym_autoencoder.summary() | |
return sym_autoencoder | |
sym_autoencoder = get_cnn_dsc_architecture() | |
model_checkpoint = ModelCheckpoint('./models/gated_cnn_autoencoder.hdf5',monitor = 'loss', verbose = 1,save_best_only=True) | |
sym_autoencoder.fit(x_train_noisy, x_train, | |
epochs=200, | |
batch_size=20, | |
shuffle=True, | |
validation_data=(x_test_noisy, x_test), | |
callbacks=[TensorBoard(log_dir='/tmp/gated_cnn_autoencoder', | |
histogram_freq=0, | |
write_graph=True),model_checkpoint1]) | |
''' | |
Model Summary | |
Layer (type) Output Shape Param # Connected to | |
==================================================================================================== | |
input_9 (InputLayer) (None, None, None, 3) 0 | |
____________________________________________________________________________________________________ | |
conv2d_21 (Conv2D) (None, None, None, 64 3136 input_9[0][0] | |
____________________________________________________________________________________________________ | |
input_10 (InputLayer) (1,) 0 | |
____________________________________________________________________________________________________ | |
lambda_9 (Lambda) (None, None, None, 64 0 conv2d_21[0][0] | |
input_10[0][0] | |
____________________________________________________________________________________________________ | |
max_pooling2d_11 (MaxPooling2D) (None, None, None, 64 0 lambda_9[0][0] | |
____________________________________________________________________________________________________ | |
conv2d_22 (Conv2D) (None, None, None, 64 65600 max_pooling2d_11[0][0] | |
____________________________________________________________________________________________________ | |
input_11 (InputLayer) (1,) 0 | |
____________________________________________________________________________________________________ | |
lambda_11 (Lambda) (None, None, None, 64 0 conv2d_22[0][0] | |
input_11[0][0] | |
____________________________________________________________________________________________________ | |
max_pooling2d_12 (MaxPooling2D) (None, None, None, 64 0 lambda_11[0][0] | |
____________________________________________________________________________________________________ | |
conv2d_23 (Conv2D) (None, None, None, 12 131200 max_pooling2d_12[0][0] | |
____________________________________________________________________________________________________ | |
max_pooling2d_13 (MaxPooling2D) (None, None, None, 12 0 conv2d_23[0][0] | |
____________________________________________________________________________________________________ | |
conv2d_24 (Conv2D) (None, None, None, 25 524544 max_pooling2d_13[0][0] | |
____________________________________________________________________________________________________ | |
max_pooling2d_14 (MaxPooling2D) (None, None, None, 25 0 conv2d_24[0][0] | |
____________________________________________________________________________________________________ | |
conv2d_25 (Conv2D) (None, None, None, 51 2097664 max_pooling2d_14[0][0] | |
____________________________________________________________________________________________________ | |
input_13 (InputLayer) (1,) 0 | |
____________________________________________________________________________________________________ | |
up_sampling2d_11 (UpSampling2D) (None, None, None, 51 0 conv2d_25[0][0] | |
____________________________________________________________________________________________________ | |
lambda_15 (Lambda) (None, None, None, 25 0 conv2d_24[0][0] | |
input_13[0][0] | |
____________________________________________________________________________________________________ | |
conv2d_transpose_6 (Conv2DTransp (None, None, None, 25 2097408 up_sampling2d_11[0][0] | |
____________________________________________________________________________________________________ | |
lambda_16 (Lambda) (None, None, None, 25 0 conv2d_24[0][0] | |
lambda_15[0][0] | |
____________________________________________________________________________________________________ | |
add_5 (Add) (None, None, None, 25 0 conv2d_transpose_6[0][0] | |
lambda_16[0][0] | |
____________________________________________________________________________________________________ | |
input_12 (InputLayer) (1,) 0 | |
____________________________________________________________________________________________________ | |
up_sampling2d_12 (UpSampling2D) (None, None, None, 25 0 add_5[0][0] | |
____________________________________________________________________________________________________ | |
lambda_13 (Lambda) (None, None, None, 12 0 conv2d_23[0][0] | |
input_12[0][0] | |
____________________________________________________________________________________________________ | |
conv2d_transpose_7 (Conv2DTransp (None, None, None, 12 524416 up_sampling2d_12[0][0] | |
____________________________________________________________________________________________________ | |
lambda_14 (Lambda) (None, None, None, 12 0 conv2d_23[0][0] | |
lambda_13[0][0] | |
____________________________________________________________________________________________________ | |
add_6 (Add) (None, None, None, 12 0 conv2d_transpose_7[0][0] | |
lambda_14[0][0] | |
____________________________________________________________________________________________________ | |
up_sampling2d_13 (UpSampling2D) (None, None, None, 12 0 add_6[0][0] | |
____________________________________________________________________________________________________ | |
conv2d_transpose_8 (Conv2DTransp (None, None, None, 64 131136 up_sampling2d_13[0][0] | |
____________________________________________________________________________________________________ | |
lambda_12 (Lambda) (None, None, None, 64 0 conv2d_22[0][0] | |
lambda_11[0][0] | |
____________________________________________________________________________________________________ | |
add_7 (Add) (None, None, None, 64 0 conv2d_transpose_8[0][0] | |
lambda_12[0][0] | |
____________________________________________________________________________________________________ | |
up_sampling2d_14 (UpSampling2D) (None, None, None, 64 0 add_7[0][0] | |
____________________________________________________________________________________________________ | |
lambda_10 (Lambda) (None, None, None, 64 0 conv2d_21[0][0] | |
lambda_9[0][0] | |
____________________________________________________________________________________________________ | |
add_8 (Add) (None, None, None, 64 0 up_sampling2d_14[0][0] | |
lambda_10[0][0] | |
____________________________________________________________________________________________________ | |
conv2d_transpose_9 (Conv2DTransp (None, None, None, 64 65600 add_8[0][0] | |
____________________________________________________________________________________________________ | |
conv2d_transpose_10 (Conv2DTrans (None, None, None, 3) 3075 conv2d_transpose_9[0][0] | |
==================================================================================================== | |
Total params: 5,643,779 | |
Trainable params: 5,643,779 | |
Non-trainable params: 0 | |
____________________________________________________________________________________________________ | |
None | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment