Skip to content

Instantly share code, notes, and snippets.

@tmwatchanan
Created September 30, 2019 06:46
Show Gist options
  • Select an option

  • Save tmwatchanan/7b5544e39b4c3d3e9368c2b73d577d33 to your computer and use it in GitHub Desktop.

Select an option

Save tmwatchanan/7b5544e39b4c3d3e9368c2b73d577d33 to your computer and use it in GitHub Desktop.
U-SegNet model architecture (keras)
def create_model(
pretrained_weights=None,
input_size=(),
num_classes=2,
learning_rate=1e-4,
batch_normalization=False,
is_summary=True,
):
# define params
kernel = 3
pool_size = (2, 2)
# create achitecture
inputs = Input(input_size)
conv_1 = Conv2D(24, (kernel, kernel), padding="same")(inputs)
conv_1 = BatchNormalization()(conv_1)
conv_1 = Activation("relu")(conv_1)
conv_2 = Conv2D(24, (kernel, kernel), padding="same")(conv_1)
conv_2 = BatchNormalization()(conv_2)
conv_2 = Activation("relu")(conv_2)
pool_1, mask_1 = MaxPoolingWithArgmax2D(pool_size)(conv_2)
conv_3 = Conv2D(48, (kernel, kernel), padding="same")(pool_1)
conv_3 = BatchNormalization()(conv_3)
conv_3 = Activation("relu")(conv_3)
conv_4 = Conv2D(48, (kernel, kernel), padding="same")(conv_3)
conv_4 = BatchNormalization()(conv_4)
conv_4 = Activation("relu")(conv_4)
pool_2, mask_2 = MaxPoolingWithArgmax2D(pool_size)(conv_4)
conv_5 = Conv2D(96, (kernel, kernel), padding="same")(pool_2)
conv_5 = BatchNormalization()(conv_5)
conv_5 = Activation("relu")(conv_5)
conv_6 = Conv2D(96, (kernel, kernel), padding="same")(conv_5)
conv_6 = BatchNormalization()(conv_6)
conv_6 = Activation("relu")(conv_6)
conv_7 = Conv2D(96, (kernel, kernel), padding="same")(conv_6)
conv_7 = BatchNormalization()(conv_7)
conv_7 = Activation("relu")(conv_7)
pool_3, mask_3 = MaxPoolingWithArgmax2D(pool_size)(conv_7)
conv_8 = Conv2D(128, (kernel, kernel), padding="same")(pool_3)
conv_8 = BatchNormalization()(conv_8)
conv_8 = Activation("relu")(conv_8)
conv_9 = Conv2D(128, (kernel, kernel), padding="same")(conv_8)
conv_9 = BatchNormalization()(conv_9)
conv_9 = Activation("relu")(conv_9)
conv_10 = Conv2D(128, (kernel, kernel), padding="same")(conv_9)
conv_10 = BatchNormalization()(conv_10)
conv_10 = Activation("relu")(conv_10)
pool_4, mask_4 = MaxPoolingWithArgmax2D(pool_size)(conv_10)
conv_11 = Conv2D(256, (kernel, kernel), padding="same")(pool_4)
conv_11 = BatchNormalization()(conv_11)
conv_11 = Activation("relu")(conv_11)
conv_12 = Conv2D(256, (kernel, kernel), padding="same")(conv_11)
conv_12 = BatchNormalization()(conv_12)
conv_12 = Activation("relu")(conv_12)
conv_13 = Conv2D(256, (kernel, kernel), padding="same")(conv_12)
conv_13 = BatchNormalization()(conv_13)
conv_13 = Activation("relu")(conv_13)
pool_5, mask_5 = MaxPoolingWithArgmax2D(pool_size)(conv_13)
# decoder
unpool_1 = MaxUnpooling2D(pool_size)([pool_5, mask_5])
conv_14 = Conv2D(256, (kernel, kernel), padding="same")(unpool_1)
conv_14 = BatchNormalization()(conv_14)
conv_14 = Activation("relu")(conv_14)
conv_15 = Conv2D(256, (kernel, kernel), padding="same")(conv_14)
conv_15 = BatchNormalization()(conv_15)
conv_15 = Activation("relu")(conv_15)
conv_16 = Conv2D(128, (kernel, kernel), padding="same")(conv_15)
conv_16 = BatchNormalization()(conv_16)
conv_16 = Activation("relu")(conv_16)
unpool_2 = MaxUnpooling2D(pool_size)([conv_16, mask_4])
conv_17 = Conv2D(128, (kernel, kernel), padding="same")(unpool_2)
conv_17 = BatchNormalization()(conv_17)
conv_17 = Activation("relu")(conv_17)
merge_17 = Concatenate(axis=3)([conv_10, conv_17])
conv_18 = Conv2D(128, (kernel, kernel), padding="same")(merge_17)
conv_18 = BatchNormalization()(conv_18)
conv_18 = Activation("relu")(conv_18)
conv_19 = Conv2D(96, (kernel, kernel), padding="same")(conv_18)
conv_19 = BatchNormalization()(conv_19)
conv_19 = Activation("relu")(conv_19)
unpool_3 = MaxUnpooling2D(pool_size)([conv_19, mask_3])
conv_20 = Conv2D(96, (kernel, kernel), padding="same")(unpool_3)
conv_20 = BatchNormalization()(conv_20)
conv_20 = Activation("relu")(conv_20)
merge_20 = Concatenate(axis=3)([conv_7, conv_20])
conv_21 = Conv2D(96, (kernel, kernel), padding="same")(merge_20)
conv_21 = BatchNormalization()(conv_21)
conv_21 = Activation("relu")(conv_21)
conv_22 = Conv2D(48, (kernel, kernel), padding="same")(conv_21)
conv_22 = BatchNormalization()(conv_22)
conv_22 = Activation("relu")(conv_22)
unpool_4 = MaxUnpooling2D(pool_size)([conv_22, mask_2])
conv_23 = Conv2D(48, (kernel, kernel), padding="same")(unpool_4)
conv_23 = BatchNormalization()(conv_23)
conv_23 = Activation("relu")(conv_23)
merge_23 = Concatenate(axis=3)([conv_4, conv_23])
conv_24 = Conv2D(24, (kernel, kernel), padding="same")(merge_23)
conv_24 = BatchNormalization()(conv_24)
conv_24 = Activation("relu")(conv_24)
unpool_5 = MaxUnpooling2D(pool_size)([conv_24, mask_1])
conv_25 = Conv2D(24, (kernel, kernel), padding="same")(unpool_5)
conv_25 = BatchNormalization()(conv_25)
conv_25 = Activation("relu")(conv_25)
merge_25 = Concatenate(axis=3)([conv_2, conv_25])
conv_26 = Conv2D(num_classes, (1, 1), padding="valid")(merge_25)
conv_26 = BatchNormalization()(conv_26)
# conv_26 = Reshape(
# (input_size[0]*input_size[1], num_classes),
# input_shape=(input_size[0], input_size[1], num_classes))(conv_26)
outputs = Activation("softmax")(conv_26)
model = Model(input = inputs, output = outputs)
model.compile(
optimizer=Adam(lr=learning_rate),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
if is_summary:
model.summary()
if pretrained_weights:
model.load_weights(pretrained_weights)
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment