Skip to content

Instantly share code, notes, and snippets.

@lambdaofgod
Created May 27, 2020 21:06
Show Gist options
  • Save lambdaofgod/b7c2f5a628945923a8dd841beecb9a68 to your computer and use it in GitHub Desktop.
Save lambdaofgod/b7c2f5a628945923a8dd841beecb9a68 to your computer and use it in GitHub Desktop.
def build_segmentation_model(
input_shape,
n_classes,
base_block_size=BASE_BLOCK_SIZE,
base_dropout_rate=BASE_DROPOUT_RATE,
activation=ACTIVATION
):
# Build U-Net segmentation_model
inputs = layers.Input(input_shape)
s = layers.Lambda(lambda x: x - 0.5) (inputs)
c1 = layers.Conv2D(base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same')(s)
c1 = layers.BatchNormalization()(c1)
c1 = layers.Dropout(base_dropout_rate) (c1)
c1 = layers.Conv2D(base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c1)
c1 = layers.BatchNormalization()(c1)
p1 = layers.MaxPooling2D((2, 2)) (c1)
c2 = layers.Conv2D(2 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (p1)
c2 = layers.BatchNormalization()(c2)
c2 = layers.Dropout(base_dropout_rate) (c2)
c2 = layers.Conv2D(2 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c2)
c2 = layers.BatchNormalization()(c2)
p2 = layers.MaxPooling2D((2, 2)) (c2)
c3 = layers.Conv2D(4 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (p2)
c3 = layers.BatchNormalization()(c3)
c3 = layers.Dropout(2 * base_dropout_rate) (c3)
c3 = layers.Conv2D(4 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c3)
c3 = layers.BatchNormalization()(c3)
p3 = layers.MaxPooling2D((2, 2)) (c3)
c4 = layers.Conv2D(8 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (p3)
c4 = layers.BatchNormalization()(c4)
c4 = layers.Dropout(2 * base_dropout_rate) (c4)
c4 = layers.Conv2D(8 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c4)
c4 = layers.BatchNormalization()(c4)
p4 = layers.MaxPooling2D(pool_size=(2, 2)) (c4)
c5 = layers.Conv2D(16 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (p4)
c5 = layers.BatchNormalization()(c5)
c5 = layers.Dropout(0.3) (c5)
c5 = layers.Conv2D(16 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c5)
c5 = layers.BatchNormalization()(c5)
u6 = layers.Conv2DTranspose(4 * base_block_size, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = layers.concatenate([u6, c4])
c6 = layers.Conv2D(8 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (u6)
c6 = layers.BatchNormalization()(c6)
c6 = layers.Dropout(2 * base_dropout_rate) (c6)
c6 = layers.Conv2D(8 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c6)
c6 = layers.BatchNormalization()(c6)
u7 = layers.Conv2DTranspose(2 * base_block_size, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = layers.concatenate([u7, c3])
c7 = layers.Conv2D(4 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (u7)
c7 = layers.BatchNormalization()(c7)
c7 = layers.Dropout(2 * base_dropout_rate) (c7)
c7 = layers.Conv2D(4 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c7)
c7 = layers.BatchNormalization()(c7)
u8 = layers.Conv2DTranspose(base_block_size, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = layers.concatenate([u8, c2])
c8 = layers.Conv2D(2 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (u8)
c8 = layers.BatchNormalization()(c8)
c8 = layers.Dropout(base_dropout_rate) (c8)
c8 = layers.Conv2D(2 * base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c8)
c8 = layers.BatchNormalization()(c8)
u9 = layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = layers.concatenate([u9, c1], axis=3)
c9 = layers.Conv2D(base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (u9)
c9 = layers.BatchNormalization()(c9)
c9 = layers.Dropout(base_dropout_rate) (c9)
c9 = layers.Conv2D(base_block_size, (3, 3), activation=activation, kernel_initializer='he_normal', padding='same') (c9)
c9 = layers.BatchNormalization()(c9)
out = layers.Conv2D(n_classes, (1, 1)) (c9)
return models.Model(inputs=[inputs], outputs=[out])
def setup_segmentation_model(
input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
n_classes=N_CLASSES,
loss=losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam',
metrics=['accuracy']
):
segmentation_model = build_segmentation_model(input_shape, n_classes)
segmentation_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
return segmentation_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment