Skip to content

Instantly share code, notes, and snippets.

@rcolomina
Forked from sandeepnmenon/cnn.py
Created March 11, 2020 11:14
Show Gist options
  • Save rcolomina/dbf5069c694f377181b5bbbe08e037dd to your computer and use it in GitHub Desktop.
Save rcolomina/dbf5069c694f377181b5bbbe08e037dd to your computer and use it in GitHub Desktop.
Keras CNN with skip connections and gates
def get_cnn_architecture(weights_path=None):
input_img = Input(shape=(64,64,3)) # adapt this if using `channels_first` image data format
x1 = Conv2D(64, (3, 3), activation='relu', padding='same')(input_img)
gateFactor = Input(tensor = K.variable([0.3]))
fractionG = Multiply()([x1,gateFactor])
complement = Lambda(lambda x: x[0] - x[1])([x1,fractionG])
x = MaxPooling2D((2, 2), padding='same')(fractionG)
x2 = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x2)
x3 = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x3)
x4 = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x4)
x5 = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x5)
y1 = Conv2DTranspose(256, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(y1)
y2 = Conv2DTranspose(128, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(y2)
y3 = Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(y3)
joinedTensor = Add()([x,complement])
y4 = Conv2DTranspose(64, (3, 3), activation='relu', padding='same')(joinedTensor)
y5 = Conv2DTranspose(3, (3, 3), activation='relu', padding='same')(y4)
layers = y5
model = Model(input_img,layers)
print model.summary()
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment