Created
September 12, 2017 09:12
-
-
Save snakers4/fb0dc2eb260635608bad05f001ccc1e0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import Model | |
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Activation, UpSampling2D, BatchNormalization,multiply | |
from keras.optimizers import RMSprop | |
from model.losses import bce_dice_loss, dice_loss, weighted_bce_dice_loss, weighted_dice_loss, dice_coeff | |
import params | |
orig_width = 1918 | |
orig_height = 1280 | |
def gate_unit(f_i, f_i1): | |
f_i = Conv2D(8, (3, 3), padding='same')(f_i) | |
f_i = BatchNormalization()(f_i) | |
f_i = Activation('relu')(f_i) | |
f_i1 = Conv2D(8, (3, 3), padding='same')(f_i1) | |
f_i1 = BatchNormalization()(f_i1) | |
f_i1 = Activation('relu')(f_i1) | |
f_i1 = UpSampling2D((2, 2))(f_i1) | |
output = multiply([f_i,f_i1]) | |
return output | |
def gated_refinement_unit(r_f, m_f, num_classes): | |
m_f = Conv2D(num_classes, (3, 3), padding='same')(m_f) | |
m_f = BatchNormalization()(m_f) | |
m_f = Activation('relu')(m_f) | |
output = concatenate([r_f, m_f], axis=3) | |
output = Conv2D(num_classes, (3, 3), padding='same')(output) | |
output = UpSampling2D((2, 2))(output) | |
return output | |
def g_frnet(input_shape=(orig_width+2, orig_height, 3), | |
num_classes=1): | |
inputs = Input(shape=input_shape) | |
# 1024 | |
f1 = Conv2D(8, (3, 3), padding='same')(inputs) | |
f1 = BatchNormalization()(f1) | |
f1 = Activation('relu')(f1) | |
f1 = Conv2D(8, (3, 3), padding='same')(f1) | |
f1 = BatchNormalization()(f1) | |
f1 = Activation('relu')(f1) | |
f1_pool = MaxPooling2D((2, 2), strides=(2, 2))(f1) | |
# 512 | |
f2 = Conv2D(16, (3, 3), padding='same')(f1_pool) | |
f2 = BatchNormalization()(f2) | |
f2 = Activation('relu')(f2) | |
f2 = Conv2D(16, (3, 3), padding='same')(f2) | |
f2 = BatchNormalization()(f2) | |
f2 = Activation('relu')(f2) | |
f2_pool = MaxPooling2D((2, 2), strides=(2, 2))(f2) | |
# 256 | |
f3 = Conv2D(32, (3, 3), padding='same')(f2_pool) | |
f3 = BatchNormalization()(f3) | |
f3 = Activation('relu')(f3) | |
f3 = Conv2D(32, (3, 3), padding='same')(f3) | |
f3 = BatchNormalization()(f3) | |
f3 = Activation('relu')(f3) | |
f3_pool = MaxPooling2D((2, 2), strides=(2, 2))(f3) | |
# 128 | |
f4 = Conv2D(64, (3, 3), padding='same')(f3_pool) | |
f4 = BatchNormalization()(f4) | |
f4 = Activation('relu')(f4) | |
f4 = Conv2D(64, (3, 3), padding='same')(f4) | |
f4 = BatchNormalization()(f4) | |
f4 = Activation('relu')(f4) | |
f4_pool = MaxPooling2D((2, 2), strides=(2, 2))(f4) | |
# 64 | |
f5 = Conv2D(128, (3, 3), padding='same')(f4_pool) | |
f5 = BatchNormalization()(f5) | |
f5 = Activation('relu')(f5) | |
f5 = Conv2D(128, (3, 3), padding='same')(f5) | |
f5 = BatchNormalization()(f5) | |
f5 = Activation('relu')(f5) | |
f5_pool = MaxPooling2D((2, 2), strides=(2, 2))(f5) | |
# 32 | |
f6 = Conv2D(256, (3, 3), padding='same')(f5_pool) | |
f6 = BatchNormalization()(f6) | |
f6 = Activation('relu')(f6) | |
f6 = Conv2D(256, (3, 3), padding='same')(f6) | |
f6 = BatchNormalization()(f6) | |
f6 = Activation('relu')(f6) | |
f6_pool = MaxPooling2D((2, 2), strides=(2, 2))(f6) | |
# 16 | |
f7 = Conv2D(512, (3, 3), padding='same')(f6_pool) | |
f7 = BatchNormalization()(f7) | |
f7 = Activation('relu')(f7) | |
f7 = Conv2D(512, (3, 3), padding='same')(f7) | |
f7 = BatchNormalization()(f7) | |
f7 = Activation('relu')(f7) | |
f7_pool = MaxPooling2D((2, 2), strides=(2, 2))(f7) | |
ru0 = Conv2D(num_classes, (1, 1), activation='sigmoid')(f7) | |
# 8 | |
g1 = gate_unit(f_i = f6_pool, f_i1 = f7_pool) | |
ru1 = gated_refinement_unit (r_f = ru0, m_f = g1, num_classes=num_classes) | |
g2 = gate_unit(f_i = f5_pool, f_i1 = f6_pool) | |
ru2 = gated_refinement_unit (r_f = ru1, m_f = g2, num_classes=num_classes) | |
g3 = gate_unit(f_i = f4_pool, f_i1 = f5_pool) | |
ru3 = gated_refinement_unit (r_f = ru2, m_f = g3, num_classes=num_classes) | |
g4 = gate_unit(f_i = f3_pool, f_i1 = f4_pool) | |
ru4 = gated_refinement_unit (r_f = ru3, m_f = g4, num_classes=num_classes) | |
g5 = gate_unit(f_i = f2_pool, f_i1 = f3_pool) | |
ru5 = gated_refinement_unit (r_f = ru4, m_f = g5, num_classes=num_classes) | |
g6 = gate_unit(f_i = f1_pool, f_i1 = f2_pool) | |
ru6 = gated_refinement_unit (r_f = ru5, m_f = g6, num_classes=num_classes) | |
model = Model(inputs=inputs, outputs=[ru6,ru5,ru4,ru3,ru2,ru1,ru0]) | |
model.compile(optimizer=RMSprop(lr=0.0001), loss=bce_dice_loss, metrics=[dice_coeff], loss_weights = [0.15,0.15,0.15,0.15,0.15,0.15,0.1]) | |
return model | |
def gate_unit_selu(f_i, f_i1): | |
# f_i = Conv2D(8, (3, 3), padding='same')(f_i) | |
# f_i = BatchNormalization()(f_i) | |
# f_i = Activation('selu')(f_i) | |
f_i = Conv2D(64, (3, 3), padding='same')(f_i) | |
f_i = BatchNormalization()(f_i) | |
f_i = Activation('selu')(f_i) | |
f_i = Conv2D(64, (3, 3), padding='same')(f_i) | |
f_i = BatchNormalization()(f_i) | |
f_i = Activation('selu')(f_i) | |
# f_i1 = Conv2D(8, (3, 3), padding='same')(f_i1) | |
# f_i1 = BatchNormalization()(f_i1) | |
# f_i1 = Activation('selu')(f_i1) | |
# f_i1 = UpSampling2D((2, 2))(f_i1) | |
f_i1 = Conv2D(64, (3, 3), padding='same')(f_i1) | |
f_i1 = BatchNormalization()(f_i1) | |
f_i1 = Activation('selu')(f_i1) | |
f_i1 = Conv2D(64, (3, 3), padding='same')(f_i1) | |
f_i1 = BatchNormalization()(f_i1) | |
f_i1 = Activation('selu')(f_i1) | |
f_i1 = UpSampling2D((2, 2))(f_i1) | |
output = multiply([f_i,f_i1]) | |
return output | |
def gated_refinement_unit_selu(r_f, m_f, num_classes): | |
m_f = Conv2D(num_classes, (3, 3), padding='same')(m_f) | |
m_f = BatchNormalization()(m_f) | |
m_f = Activation('selu')(m_f) | |
output = concatenate([r_f, m_f], axis=3) | |
output = Conv2D(num_classes, (3, 3), padding='same')(output) | |
output = UpSampling2D((2, 2))(output) | |
return output | |
def g_frnet_selu(input_shape=(orig_height,orig_width+2,3), | |
num_classes=1): | |
inputs = Input(shape=input_shape) | |
# 1024 | |
f1 = Conv2D(8, (3, 3), padding='same', name='f1_conv_1')(inputs) | |
f1 = BatchNormalization()(f1) | |
f1 = Activation('selu')(f1) | |
f1 = Conv2D(8, (3, 3), padding='same', name='f1_conv_2')(f1) | |
f1 = BatchNormalization()(f1) | |
f1 = Activation('selu')(f1) | |
f1_pool = MaxPooling2D((2, 2), strides=(2, 2))(f1) | |
# 512 | |
f2 = Conv2D(16, (3, 3), padding='same', name='f2_conv_1')(f1_pool) | |
f2 = BatchNormalization()(f2) | |
f2 = Activation('selu')(f2) | |
f2 = Conv2D(16, (3, 3), padding='same', name='f2_conv_2')(f2) | |
f2 = BatchNormalization()(f2) | |
f2 = Activation('selu')(f2) | |
f2_pool = MaxPooling2D((2, 2), strides=(2, 2))(f2) | |
# 256 | |
f3 = Conv2D(32, (3, 3), padding='same', name='f3_conv_1')(f2_pool) | |
f3 = BatchNormalization()(f3) | |
f3 = Activation('selu')(f3) | |
f3 = Conv2D(32, (3, 3), padding='same', name='f3_conv_2')(f3) | |
f3 = BatchNormalization()(f3) | |
f3 = Activation('selu')(f3) | |
f3_pool = MaxPooling2D((2, 2), strides=(2, 2))(f3) | |
# 128 | |
f4 = Conv2D(64, (3, 3), padding='same', name='f4_conv_1')(f3_pool) | |
f4 = BatchNormalization()(f4) | |
f4 = Activation('selu')(f4) | |
f4 = Conv2D(64, (3, 3), padding='same', name='f4_conv_2')(f4) | |
f4 = BatchNormalization()(f4) | |
f4 = Activation('selu')(f4) | |
f4_pool = MaxPooling2D((2, 2), strides=(2, 2))(f4) | |
# 64 | |
f5 = Conv2D(128, (3, 3), padding='same', name='f5_conv_1')(f4_pool) | |
f5 = BatchNormalization()(f5) | |
f5 = Activation('selu')(f5) | |
f5 = Conv2D(128, (3, 3), padding='same', name='f5_conv_2')(f5) | |
f5 = BatchNormalization()(f5) | |
f5 = Activation('selu')(f5) | |
f5_pool = MaxPooling2D((2, 2), strides=(2, 2))(f5) | |
# 32 | |
f6 = Conv2D(256, (3, 3), padding='same', name='f6_conv_1')(f5_pool) | |
f6 = BatchNormalization()(f6) | |
f6 = Activation('selu')(f6) | |
f6 = Conv2D(256, (3, 3), padding='same', name='f6_conv_2')(f6) | |
f6 = BatchNormalization()(f6) | |
f6 = Activation('selu')(f6) | |
f6_pool = MaxPooling2D((2, 2), strides=(2, 2))(f6) | |
# 16 | |
f7 = Conv2D(512, (3, 3), padding='same', name='f7_conv_1')(f6_pool) | |
f7 = BatchNormalization()(f7) | |
f7 = Activation('selu')(f7) | |
f7 = Conv2D(512, (3, 3), padding='same', name='f7_conv_2')(f7) | |
f7 = BatchNormalization()(f7) | |
f7 = Activation('selu')(f7) | |
f7_pool = MaxPooling2D((2, 2), strides=(2, 2))(f7) | |
ru0 = Conv2D(num_classes, (1, 1), activation='sigmoid', name='ru_0_conv')(f7) | |
# 8 | |
g1 = gate_unit_selu(f_i = f6_pool, f_i1 = f7_pool) | |
ru1 = gated_refinement_unit_selu (r_f = ru0, m_f = g1, num_classes=num_classes) | |
g2 = gate_unit_selu(f_i = f5_pool, f_i1 = f6_pool) | |
ru2 = gated_refinement_unit_selu (r_f = ru1, m_f = g2, num_classes=num_classes) | |
g3 = gate_unit_selu(f_i = f4_pool, f_i1 = f5_pool) | |
ru3 = gated_refinement_unit_selu (r_f = ru2, m_f = g3, num_classes=num_classes) | |
g4 = gate_unit_selu(f_i = f3_pool, f_i1 = f4_pool) | |
ru4 = gated_refinement_unit_selu (r_f = ru3, m_f = g4, num_classes=num_classes) | |
g5 = gate_unit_selu(f_i = f2_pool, f_i1 = f3_pool) | |
ru5 = gated_refinement_unit_selu (r_f = ru4, m_f = g5, num_classes=num_classes) | |
g6 = gate_unit_selu(f_i = f1_pool, f_i1 = f2_pool) | |
ru6 = gated_refinement_unit_selu (r_f = ru5, m_f = g6, num_classes=num_classes) | |
model = Model(inputs=inputs, outputs=[ru6,ru5,ru4,ru3,ru2,ru1,ru0]) | |
model.compile(optimizer=RMSprop(lr=0.0001), loss=bce_dice_loss, metrics=[dice_coeff], loss_weights = [0.15,0.15,0.15,0.15,0.15,0.15,0.1]) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment