Created
September 30, 2019 06:49
-
-
Save tmwatchanan/1dc7fcb56a95c121683412eaea81210a to your computer and use it in GitHub Desktop.
SegNet model architecture (keras)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def create_model( | |
| pretrained_weights=None, | |
| input_size=(), | |
| num_classes=2, | |
| learning_rate=1e-4, | |
| batch_normalization=False, | |
| is_summary=True, | |
| ): | |
| # define params | |
| kernel = 3 | |
| pool_size = (2, 2) | |
| # create achitecture | |
| inputs = Input(input_size) | |
| conv_1 = Conv2D(64, (kernel, kernel), padding="same")(inputs) | |
| conv_1 = BatchNormalization()(conv_1) | |
| conv_1 = Activation("relu")(conv_1) | |
| conv_2 = Conv2D(64, (kernel, kernel), padding="same")(conv_1) | |
| conv_2 = BatchNormalization()(conv_2) | |
| conv_2 = Activation("relu")(conv_2) | |
| pool_1, mask_1 = MaxPoolingWithArgmax2D(pool_size)(conv_2) | |
| conv_3 = Conv2D(128, (kernel, kernel), padding="same")(pool_1) | |
| conv_3 = BatchNormalization()(conv_3) | |
| conv_3 = Activation("relu")(conv_3) | |
| conv_4 = Conv2D(128, (kernel, kernel), padding="same")(conv_3) | |
| conv_4 = BatchNormalization()(conv_4) | |
| conv_4 = Activation("relu")(conv_4) | |
| pool_2, mask_2 = MaxPoolingWithArgmax2D(pool_size)(conv_4) | |
| conv_5 = Conv2D(256, (kernel, kernel), padding="same")(pool_2) | |
| conv_5 = BatchNormalization()(conv_5) | |
| conv_5 = Activation("relu")(conv_5) | |
| conv_6 = Conv2D(256, (kernel, kernel), padding="same")(conv_5) | |
| conv_6 = BatchNormalization()(conv_6) | |
| conv_6 = Activation("relu")(conv_6) | |
| conv_7 = Conv2D(256, (kernel, kernel), padding="same")(conv_6) | |
| conv_7 = BatchNormalization()(conv_7) | |
| conv_7 = Activation("relu")(conv_7) | |
| pool_3, mask_3 = MaxPoolingWithArgmax2D(pool_size)(conv_7) | |
| conv_8 = Conv2D(512, (kernel, kernel), padding="same")(pool_3) | |
| conv_8 = BatchNormalization()(conv_8) | |
| conv_8 = Activation("relu")(conv_8) | |
| conv_9 = Conv2D(512, (kernel, kernel), padding="same")(conv_8) | |
| conv_9 = BatchNormalization()(conv_9) | |
| conv_9 = Activation("relu")(conv_9) | |
| conv_10 = Conv2D(512, (kernel, kernel), padding="same")(conv_9) | |
| conv_10 = BatchNormalization()(conv_10) | |
| conv_10 = Activation("relu")(conv_10) | |
| pool_4, mask_4 = MaxPoolingWithArgmax2D(pool_size)(conv_10) | |
| conv_11 = Conv2D(512, (kernel, kernel), padding="same")(pool_4) | |
| conv_11 = BatchNormalization()(conv_11) | |
| conv_11 = Activation("relu")(conv_11) | |
| conv_12 = Conv2D(512, (kernel, kernel), padding="same")(conv_11) | |
| conv_12 = BatchNormalization()(conv_12) | |
| conv_12 = Activation("relu")(conv_12) | |
| conv_13 = Conv2D(512, (kernel, kernel), padding="same")(conv_12) | |
| conv_13 = BatchNormalization()(conv_13) | |
| conv_13 = Activation("relu")(conv_13) | |
| pool_5, mask_5 = MaxPoolingWithArgmax2D(pool_size)(conv_13) | |
| # decoder | |
| unpool_1 = MaxUnpooling2D(pool_size)([pool_5, mask_5]) | |
| conv_14 = Conv2D(512, (kernel, kernel), padding="same")(unpool_1) | |
| conv_14 = BatchNormalization()(conv_14) | |
| conv_14 = Activation("relu")(conv_14) | |
| conv_15 = Conv2D(512, (kernel, kernel), padding="same")(conv_14) | |
| conv_15 = BatchNormalization()(conv_15) | |
| conv_15 = Activation("relu")(conv_15) | |
| conv_16 = Conv2D(512, (kernel, kernel), padding="same")(conv_15) | |
| conv_16 = BatchNormalization()(conv_16) | |
| conv_16 = Activation("relu")(conv_16) | |
| unpool_2 = MaxUnpooling2D(pool_size)([conv_16, mask_4]) | |
| conv_17 = Conv2D(512, (kernel, kernel), padding="same")(unpool_2) | |
| conv_17 = BatchNormalization()(conv_17) | |
| conv_17 = Activation("relu")(conv_17) | |
| conv_18 = Conv2D(512, (kernel, kernel), padding="same")(conv_17) | |
| conv_18 = BatchNormalization()(conv_18) | |
| conv_18 = Activation("relu")(conv_18) | |
| conv_19 = Conv2D(256, (kernel, kernel), padding="same")(conv_18) | |
| conv_19 = BatchNormalization()(conv_19) | |
| conv_19 = Activation("relu")(conv_19) | |
| unpool_3 = MaxUnpooling2D(pool_size)([conv_19, mask_3]) | |
| conv_20 = Conv2D(256, (kernel, kernel), padding="same")(unpool_3) | |
| conv_20 = BatchNormalization()(conv_20) | |
| conv_20 = Activation("relu")(conv_20) | |
| conv_21 = Conv2D(256, (kernel, kernel), padding="same")(conv_20) | |
| conv_21 = BatchNormalization()(conv_21) | |
| conv_21 = Activation("relu")(conv_21) | |
| conv_22 = Conv2D(128, (kernel, kernel), padding="same")(conv_21) | |
| conv_22 = BatchNormalization()(conv_22) | |
| conv_22 = Activation("relu")(conv_22) | |
| unpool_4 = MaxUnpooling2D(pool_size)([conv_22, mask_2]) | |
| conv_23 = Conv2D(128, (kernel, kernel), padding="same")(unpool_4) | |
| conv_23 = BatchNormalization()(conv_23) | |
| conv_23 = Activation("relu")(conv_23) | |
| conv_24 = Conv2D(64, (kernel, kernel), padding="same")(conv_23) | |
| conv_24 = BatchNormalization()(conv_24) | |
| conv_24 = Activation("relu")(conv_24) | |
| unpool_5 = MaxUnpooling2D(pool_size)([conv_24, mask_1]) | |
| conv_25 = Conv2D(64, (kernel, kernel), padding="same")(unpool_5) | |
| conv_25 = BatchNormalization()(conv_25) | |
| conv_25 = Activation("relu")(conv_25) | |
| conv_26 = Conv2D(num_classes, (1, 1), padding="valid")(conv_25) | |
| conv_26 = BatchNormalization()(conv_26) | |
| # conv_26 = Reshape( | |
| # (input_size[0]*input_size[1], num_classes), | |
| # input_shape=(input_size[0], input_size[1], num_classes))(conv_26) | |
| outputs = Activation("softmax")(conv_26) | |
| model = Model(input = inputs, output = outputs) | |
| model.compile( | |
| optimizer=Adam(lr=learning_rate), | |
| loss="categorical_crossentropy", | |
| metrics=["accuracy"], | |
| ) | |
| if is_summary: | |
| model.summary() | |
| if pretrained_weights: | |
| model.load_weights(pretrained_weights) | |
| return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment