Skip to content

Instantly share code, notes, and snippets.

@snakers4
Created September 12, 2017 09:12
Show Gist options
  • Save snakers4/fb0dc2eb260635608bad05f001ccc1e0 to your computer and use it in GitHub Desktop.
Save snakers4/fb0dc2eb260635608bad05f001ccc1e0 to your computer and use it in GitHub Desktop.
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D, BatchNormalization,multiply
from keras.optimizers import RMSprop
from model.losses import bce_dice_loss, dice_loss, weighted_bce_dice_loss, weighted_dice_loss, dice_coeff
import params
orig_width = 1918
orig_height = 1280
def gate_unit(f_i, f_i1):
f_i = Conv2D(8, (3, 3), padding='same')(f_i)
f_i = BatchNormalization()(f_i)
f_i = Activation('relu')(f_i)
f_i1 = Conv2D(8, (3, 3), padding='same')(f_i1)
f_i1 = BatchNormalization()(f_i1)
f_i1 = Activation('relu')(f_i1)
f_i1 = UpSampling2D((2, 2))(f_i1)
output = multiply([f_i,f_i1])
return output
def gated_refinement_unit(r_f, m_f, num_classes):
m_f = Conv2D(num_classes, (3, 3), padding='same')(m_f)
m_f = BatchNormalization()(m_f)
m_f = Activation('relu')(m_f)
output = concatenate([r_f, m_f], axis=3)
output = Conv2D(num_classes, (3, 3), padding='same')(output)
output = UpSampling2D((2, 2))(output)
return output
def g_frnet(input_shape=(orig_width+2, orig_height, 3),
num_classes=1):
inputs = Input(shape=input_shape)
# 1024
f1 = Conv2D(8, (3, 3), padding='same')(inputs)
f1 = BatchNormalization()(f1)
f1 = Activation('relu')(f1)
f1 = Conv2D(8, (3, 3), padding='same')(f1)
f1 = BatchNormalization()(f1)
f1 = Activation('relu')(f1)
f1_pool = MaxPooling2D((2, 2), strides=(2, 2))(f1)
# 512
f2 = Conv2D(16, (3, 3), padding='same')(f1_pool)
f2 = BatchNormalization()(f2)
f2 = Activation('relu')(f2)
f2 = Conv2D(16, (3, 3), padding='same')(f2)
f2 = BatchNormalization()(f2)
f2 = Activation('relu')(f2)
f2_pool = MaxPooling2D((2, 2), strides=(2, 2))(f2)
# 256
f3 = Conv2D(32, (3, 3), padding='same')(f2_pool)
f3 = BatchNormalization()(f3)
f3 = Activation('relu')(f3)
f3 = Conv2D(32, (3, 3), padding='same')(f3)
f3 = BatchNormalization()(f3)
f3 = Activation('relu')(f3)
f3_pool = MaxPooling2D((2, 2), strides=(2, 2))(f3)
# 128
f4 = Conv2D(64, (3, 3), padding='same')(f3_pool)
f4 = BatchNormalization()(f4)
f4 = Activation('relu')(f4)
f4 = Conv2D(64, (3, 3), padding='same')(f4)
f4 = BatchNormalization()(f4)
f4 = Activation('relu')(f4)
f4_pool = MaxPooling2D((2, 2), strides=(2, 2))(f4)
# 64
f5 = Conv2D(128, (3, 3), padding='same')(f4_pool)
f5 = BatchNormalization()(f5)
f5 = Activation('relu')(f5)
f5 = Conv2D(128, (3, 3), padding='same')(f5)
f5 = BatchNormalization()(f5)
f5 = Activation('relu')(f5)
f5_pool = MaxPooling2D((2, 2), strides=(2, 2))(f5)
# 32
f6 = Conv2D(256, (3, 3), padding='same')(f5_pool)
f6 = BatchNormalization()(f6)
f6 = Activation('relu')(f6)
f6 = Conv2D(256, (3, 3), padding='same')(f6)
f6 = BatchNormalization()(f6)
f6 = Activation('relu')(f6)
f6_pool = MaxPooling2D((2, 2), strides=(2, 2))(f6)
# 16
f7 = Conv2D(512, (3, 3), padding='same')(f6_pool)
f7 = BatchNormalization()(f7)
f7 = Activation('relu')(f7)
f7 = Conv2D(512, (3, 3), padding='same')(f7)
f7 = BatchNormalization()(f7)
f7 = Activation('relu')(f7)
f7_pool = MaxPooling2D((2, 2), strides=(2, 2))(f7)
ru0 = Conv2D(num_classes, (1, 1), activation='sigmoid')(f7)
# 8
g1 = gate_unit(f_i = f6_pool, f_i1 = f7_pool)
ru1 = gated_refinement_unit (r_f = ru0, m_f = g1, num_classes=num_classes)
g2 = gate_unit(f_i = f5_pool, f_i1 = f6_pool)
ru2 = gated_refinement_unit (r_f = ru1, m_f = g2, num_classes=num_classes)
g3 = gate_unit(f_i = f4_pool, f_i1 = f5_pool)
ru3 = gated_refinement_unit (r_f = ru2, m_f = g3, num_classes=num_classes)
g4 = gate_unit(f_i = f3_pool, f_i1 = f4_pool)
ru4 = gated_refinement_unit (r_f = ru3, m_f = g4, num_classes=num_classes)
g5 = gate_unit(f_i = f2_pool, f_i1 = f3_pool)
ru5 = gated_refinement_unit (r_f = ru4, m_f = g5, num_classes=num_classes)
g6 = gate_unit(f_i = f1_pool, f_i1 = f2_pool)
ru6 = gated_refinement_unit (r_f = ru5, m_f = g6, num_classes=num_classes)
model = Model(inputs=inputs, outputs=[ru6,ru5,ru4,ru3,ru2,ru1,ru0])
model.compile(optimizer=RMSprop(lr=0.0001), loss=bce_dice_loss, metrics=[dice_coeff], loss_weights = [0.15,0.15,0.15,0.15,0.15,0.15,0.1])
return model
def gate_unit_selu(f_i, f_i1):
# f_i = Conv2D(8, (3, 3), padding='same')(f_i)
# f_i = BatchNormalization()(f_i)
# f_i = Activation('selu')(f_i)
f_i = Conv2D(64, (3, 3), padding='same')(f_i)
f_i = BatchNormalization()(f_i)
f_i = Activation('selu')(f_i)
f_i = Conv2D(64, (3, 3), padding='same')(f_i)
f_i = BatchNormalization()(f_i)
f_i = Activation('selu')(f_i)
# f_i1 = Conv2D(8, (3, 3), padding='same')(f_i1)
# f_i1 = BatchNormalization()(f_i1)
# f_i1 = Activation('selu')(f_i1)
# f_i1 = UpSampling2D((2, 2))(f_i1)
f_i1 = Conv2D(64, (3, 3), padding='same')(f_i1)
f_i1 = BatchNormalization()(f_i1)
f_i1 = Activation('selu')(f_i1)
f_i1 = Conv2D(64, (3, 3), padding='same')(f_i1)
f_i1 = BatchNormalization()(f_i1)
f_i1 = Activation('selu')(f_i1)
f_i1 = UpSampling2D((2, 2))(f_i1)
output = multiply([f_i,f_i1])
return output
def gated_refinement_unit_selu(r_f, m_f, num_classes):
m_f = Conv2D(num_classes, (3, 3), padding='same')(m_f)
m_f = BatchNormalization()(m_f)
m_f = Activation('selu')(m_f)
output = concatenate([r_f, m_f], axis=3)
output = Conv2D(num_classes, (3, 3), padding='same')(output)
output = UpSampling2D((2, 2))(output)
return output
def g_frnet_selu(input_shape=(orig_height,orig_width+2,3),
num_classes=1):
inputs = Input(shape=input_shape)
# 1024
f1 = Conv2D(8, (3, 3), padding='same', name='f1_conv_1')(inputs)
f1 = BatchNormalization()(f1)
f1 = Activation('selu')(f1)
f1 = Conv2D(8, (3, 3), padding='same', name='f1_conv_2')(f1)
f1 = BatchNormalization()(f1)
f1 = Activation('selu')(f1)
f1_pool = MaxPooling2D((2, 2), strides=(2, 2))(f1)
# 512
f2 = Conv2D(16, (3, 3), padding='same', name='f2_conv_1')(f1_pool)
f2 = BatchNormalization()(f2)
f2 = Activation('selu')(f2)
f2 = Conv2D(16, (3, 3), padding='same', name='f2_conv_2')(f2)
f2 = BatchNormalization()(f2)
f2 = Activation('selu')(f2)
f2_pool = MaxPooling2D((2, 2), strides=(2, 2))(f2)
# 256
f3 = Conv2D(32, (3, 3), padding='same', name='f3_conv_1')(f2_pool)
f3 = BatchNormalization()(f3)
f3 = Activation('selu')(f3)
f3 = Conv2D(32, (3, 3), padding='same', name='f3_conv_2')(f3)
f3 = BatchNormalization()(f3)
f3 = Activation('selu')(f3)
f3_pool = MaxPooling2D((2, 2), strides=(2, 2))(f3)
# 128
f4 = Conv2D(64, (3, 3), padding='same', name='f4_conv_1')(f3_pool)
f4 = BatchNormalization()(f4)
f4 = Activation('selu')(f4)
f4 = Conv2D(64, (3, 3), padding='same', name='f4_conv_2')(f4)
f4 = BatchNormalization()(f4)
f4 = Activation('selu')(f4)
f4_pool = MaxPooling2D((2, 2), strides=(2, 2))(f4)
# 64
f5 = Conv2D(128, (3, 3), padding='same', name='f5_conv_1')(f4_pool)
f5 = BatchNormalization()(f5)
f5 = Activation('selu')(f5)
f5 = Conv2D(128, (3, 3), padding='same', name='f5_conv_2')(f5)
f5 = BatchNormalization()(f5)
f5 = Activation('selu')(f5)
f5_pool = MaxPooling2D((2, 2), strides=(2, 2))(f5)
# 32
f6 = Conv2D(256, (3, 3), padding='same', name='f6_conv_1')(f5_pool)
f6 = BatchNormalization()(f6)
f6 = Activation('selu')(f6)
f6 = Conv2D(256, (3, 3), padding='same', name='f6_conv_2')(f6)
f6 = BatchNormalization()(f6)
f6 = Activation('selu')(f6)
f6_pool = MaxPooling2D((2, 2), strides=(2, 2))(f6)
# 16
f7 = Conv2D(512, (3, 3), padding='same', name='f7_conv_1')(f6_pool)
f7 = BatchNormalization()(f7)
f7 = Activation('selu')(f7)
f7 = Conv2D(512, (3, 3), padding='same', name='f7_conv_2')(f7)
f7 = BatchNormalization()(f7)
f7 = Activation('selu')(f7)
f7_pool = MaxPooling2D((2, 2), strides=(2, 2))(f7)
ru0 = Conv2D(num_classes, (1, 1), activation='sigmoid', name='ru_0_conv')(f7)
# 8
g1 = gate_unit_selu(f_i = f6_pool, f_i1 = f7_pool)
ru1 = gated_refinement_unit_selu (r_f = ru0, m_f = g1, num_classes=num_classes)
g2 = gate_unit_selu(f_i = f5_pool, f_i1 = f6_pool)
ru2 = gated_refinement_unit_selu (r_f = ru1, m_f = g2, num_classes=num_classes)
g3 = gate_unit_selu(f_i = f4_pool, f_i1 = f5_pool)
ru3 = gated_refinement_unit_selu (r_f = ru2, m_f = g3, num_classes=num_classes)
g4 = gate_unit_selu(f_i = f3_pool, f_i1 = f4_pool)
ru4 = gated_refinement_unit_selu (r_f = ru3, m_f = g4, num_classes=num_classes)
g5 = gate_unit_selu(f_i = f2_pool, f_i1 = f3_pool)
ru5 = gated_refinement_unit_selu (r_f = ru4, m_f = g5, num_classes=num_classes)
g6 = gate_unit_selu(f_i = f1_pool, f_i1 = f2_pool)
ru6 = gated_refinement_unit_selu (r_f = ru5, m_f = g6, num_classes=num_classes)
model = Model(inputs=inputs, outputs=[ru6,ru5,ru4,ru3,ru2,ru1,ru0])
model.compile(optimizer=RMSprop(lr=0.0001), loss=bce_dice_loss, metrics=[dice_coeff], loss_weights = [0.15,0.15,0.15,0.15,0.15,0.15,0.1])
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment