Skip to content

Instantly share code, notes, and snippets.

@sandeepnmenon
Last active March 13, 2019 20:21
Show Gist options
  • Save sandeepnmenon/023495c69a877980535a7bf2da1389fa to your computer and use it in GitHub Desktop.
Save sandeepnmenon/023495c69a877980535a7bf2da1389fa to your computer and use it in GitHub Desktop.
Gated autoencoder model
def get_gated_connections(gatePercentageFactor,inputLayer):
gateFactor = Input(tensor = K.variable([gatePercentageFactor]))
fractionG = Lambda(lambda x: x[0]*x[1])([inputLayer,gateFactor])
complement = Lambda(lambda x: x[0] - x[1])([inputLayer,fractionG])
return gateFactor,fractionG,complement
#x is conv layer
#y is de-conv layer
#gf is gating factor
#fg is fractional input from gate
#c is complement ie remaining fraction from the gate
#jt joining tensor of convolution layer and previous de-conv layer
def get_cnn_dsc_architecture():
input_img = Input(shape=(None,None,3)) # adapt this if using `channels_first` image data format
x1 = Conv2D(64, (4, 4), activation='relu', padding='same')(input_img)
gf1,fg1,c1 = get_gated_connections(0.1,x1)
x = MaxPooling2D((2, 2), padding='same')(fg1)
x2 = Conv2D(64, (4, 4), activation='relu', padding='same')(x)
gf2,fg2,c2 = get_gated_connections(0.2,x2)
x = MaxPooling2D((2, 2), padding='same')(fg2)
x3 = Conv2D(128, (4, 4), activation='relu', padding='same')(x)
gf3,fg3,c3 = get_gated_connections(0.3,x3)
x = MaxPooling2D((2, 2), padding='same')(x3)
x4 = Conv2D(256, (4, 4), activation='relu', padding='same')(x)
gf4,fg4,c4 = get_gated_connections(0.4,x4)
x = MaxPooling2D((2, 2), padding='same')(x4)
x5 = Conv2D(512, (4, 4), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x5)
y1 = Conv2DTranspose(256, (4, 4), activation='relu', padding='same')(x)
jt4 = Add()([y1,c4])
x = UpSampling2D((2, 2))(jt4)
y2 = Conv2DTranspose(128, (4, 4), activation='relu', padding='same')(x)
jt3 = Add()([y2,c3])
x = UpSampling2D((2, 2))(jt3)
y3 = Conv2DTranspose(64, (4, 4), activation='relu', padding='same')(x)
jt2 = Add()([y3,c2])
x = UpSampling2D((2, 2))(jt2)
jt1 = Add()([x,c1])
y4 = Conv2DTranspose(64, (4, 4), activation='relu', padding='same')(jt1)
y5 = Conv2DTranspose(3, (4, 4), activation='relu', padding='same')(y4)
layers = y5
sym_autoencoder = Model([input_img,gf1,gf2,gf3,gf4],layers)
sym_autoencoder.compile(optimizer='sgd', loss = 'mean_squared_error', metrics = ['accuracy','mean_squared_error'])
print sym_autoencoder.summary()
return sym_autoencoder
sym_autoencoder = get_cnn_dsc_architecture()
model_checkpoint = ModelCheckpoint('./models/gated_cnn_autoencoder.hdf5',monitor = 'loss', verbose = 1,save_best_only=True)
sym_autoencoder.fit(x_train_noisy, x_train,
epochs=200,
batch_size=20,
shuffle=True,
validation_data=(x_test_noisy, x_test),
callbacks=[TensorBoard(log_dir='/tmp/gated_cnn_autoencoder',
histogram_freq=0,
write_graph=True),model_checkpoint1])
'''
Model Summary
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_9 (InputLayer) (None, None, None, 3) 0
____________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, None, None, 64 3136 input_9[0][0]
____________________________________________________________________________________________________
input_10 (InputLayer) (1,) 0
____________________________________________________________________________________________________
lambda_9 (Lambda) (None, None, None, 64 0 conv2d_21[0][0]
input_10[0][0]
____________________________________________________________________________________________________
max_pooling2d_11 (MaxPooling2D) (None, None, None, 64 0 lambda_9[0][0]
____________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, None, None, 64 65600 max_pooling2d_11[0][0]
____________________________________________________________________________________________________
input_11 (InputLayer) (1,) 0
____________________________________________________________________________________________________
lambda_11 (Lambda) (None, None, None, 64 0 conv2d_22[0][0]
input_11[0][0]
____________________________________________________________________________________________________
max_pooling2d_12 (MaxPooling2D) (None, None, None, 64 0 lambda_11[0][0]
____________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, None, None, 12 131200 max_pooling2d_12[0][0]
____________________________________________________________________________________________________
max_pooling2d_13 (MaxPooling2D) (None, None, None, 12 0 conv2d_23[0][0]
____________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, None, None, 25 524544 max_pooling2d_13[0][0]
____________________________________________________________________________________________________
max_pooling2d_14 (MaxPooling2D) (None, None, None, 25 0 conv2d_24[0][0]
____________________________________________________________________________________________________
conv2d_25 (Conv2D) (None, None, None, 51 2097664 max_pooling2d_14[0][0]
____________________________________________________________________________________________________
input_13 (InputLayer) (1,) 0
____________________________________________________________________________________________________
up_sampling2d_11 (UpSampling2D) (None, None, None, 51 0 conv2d_25[0][0]
____________________________________________________________________________________________________
lambda_15 (Lambda) (None, None, None, 25 0 conv2d_24[0][0]
input_13[0][0]
____________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTransp (None, None, None, 25 2097408 up_sampling2d_11[0][0]
____________________________________________________________________________________________________
lambda_16 (Lambda) (None, None, None, 25 0 conv2d_24[0][0]
lambda_15[0][0]
____________________________________________________________________________________________________
add_5 (Add) (None, None, None, 25 0 conv2d_transpose_6[0][0]
lambda_16[0][0]
____________________________________________________________________________________________________
input_12 (InputLayer) (1,) 0
____________________________________________________________________________________________________
up_sampling2d_12 (UpSampling2D) (None, None, None, 25 0 add_5[0][0]
____________________________________________________________________________________________________
lambda_13 (Lambda) (None, None, None, 12 0 conv2d_23[0][0]
input_12[0][0]
____________________________________________________________________________________________________
conv2d_transpose_7 (Conv2DTransp (None, None, None, 12 524416 up_sampling2d_12[0][0]
____________________________________________________________________________________________________
lambda_14 (Lambda) (None, None, None, 12 0 conv2d_23[0][0]
lambda_13[0][0]
____________________________________________________________________________________________________
add_6 (Add) (None, None, None, 12 0 conv2d_transpose_7[0][0]
lambda_14[0][0]
____________________________________________________________________________________________________
up_sampling2d_13 (UpSampling2D) (None, None, None, 12 0 add_6[0][0]
____________________________________________________________________________________________________
conv2d_transpose_8 (Conv2DTransp (None, None, None, 64 131136 up_sampling2d_13[0][0]
____________________________________________________________________________________________________
lambda_12 (Lambda) (None, None, None, 64 0 conv2d_22[0][0]
lambda_11[0][0]
____________________________________________________________________________________________________
add_7 (Add) (None, None, None, 64 0 conv2d_transpose_8[0][0]
lambda_12[0][0]
____________________________________________________________________________________________________
up_sampling2d_14 (UpSampling2D) (None, None, None, 64 0 add_7[0][0]
____________________________________________________________________________________________________
lambda_10 (Lambda) (None, None, None, 64 0 conv2d_21[0][0]
lambda_9[0][0]
____________________________________________________________________________________________________
add_8 (Add) (None, None, None, 64 0 up_sampling2d_14[0][0]
lambda_10[0][0]
____________________________________________________________________________________________________
conv2d_transpose_9 (Conv2DTransp (None, None, None, 64 65600 add_8[0][0]
____________________________________________________________________________________________________
conv2d_transpose_10 (Conv2DTrans (None, None, None, 3) 3075 conv2d_transpose_9[0][0]
====================================================================================================
Total params: 5,643,779
Trainable params: 5,643,779
Non-trainable params: 0
____________________________________________________________________________________________________
None
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment