Created
January 16, 2019 02:45
-
-
Save jimmy15923/a5496bf796799fc30a835a8c4743547e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import keras | |
import keras.backend as K | |
import keras.layers as KL | |
import keras.engine as KE | |
import keras.models as KM | |
from keras.engine import Layer, InputSpec | |
from keras import initializers, regularizers, constraints | |
from keras import backend as K | |
from keras.utils.generic_utils import get_custom_objects | |
from mrcnn import utils | |
# Requires TensorFlow 1.3+ and Keras 2.0.8+. | |
from distutils.version import LooseVersion | |
assert LooseVersion(tf.__version__) >= LooseVersion("1.3") | |
assert LooseVersion(keras.__version__) >= LooseVersion('2.0.8') | |
class BatchNorm(KL.BatchNormalization): | |
"""Extends the Keras BatchNormalization class to allow a central place | |
to make changes if needed. | |
Batch normalization has a negative effect on training if batches are small | |
so this layer is often frozen (via setting in Config class) and functions | |
as linear layer. | |
""" | |
def call(self, inputs, training=None): | |
""" | |
Note about training values: | |
None: Train BN layers. This is the normal mode | |
False: Freeze BN layers. Good when batch size is small | |
True: (don't use). Set layer in training mode even when inferencing | |
""" | |
return super(self.__class__, self).call(inputs, training=training) | |
class GroupNorm(Layer): | |
"""Group normalization layer | |
Group Normalization divides the channels into groups and computes within each group | |
the mean and variance for normalization. GN's computation is independent of batch sizes, | |
and its accuracy is stable in a wide range of batch sizes | |
# Arguments | |
groups: Integer, the number of groups for Group Normalization. | |
axis: Integer, the axis that should be normalized | |
(typically the features axis). | |
For instance, after a `Conv2D` layer with | |
`data_format="channels_first"`, | |
set `axis=1` in `BatchNormalization`. | |
epsilon: Small float added to variance to avoid dividing by zero. | |
center: If True, add offset of `beta` to normalized tensor. | |
If False, `beta` is ignored. | |
scale: If True, multiply by `gamma`. | |
If False, `gamma` is not used. | |
When the next layer is linear (also e.g. `nn.relu`), | |
this can be disabled since the scaling | |
will be done by the next layer. | |
beta_initializer: Initializer for the beta weight. | |
gamma_initializer: Initializer for the gamma weight. | |
beta_regularizer: Optional regularizer for the beta weight. | |
gamma_regularizer: Optional regularizer for the gamma weight. | |
beta_constraint: Optional constraint for the beta weight. | |
gamma_constraint: Optional constraint for the gamma weight. | |
# Input shape | |
Arbitrary. Use the keyword argument `input_shape` | |
(tuple of integers, does not include the samples axis) | |
when using this layer as the first layer in a model. | |
# Output shape | |
Same shape as input. | |
# References | |
- [Group Normalization](https://arxiv.org/abs/1803.08494) | |
""" | |
def __init__(self, | |
groups=32, | |
axis=-1, | |
epsilon=1e-5, | |
center=True, | |
scale=True, | |
beta_initializer='zeros', | |
gamma_initializer='ones', | |
beta_regularizer=None, | |
gamma_regularizer=None, | |
beta_constraint=None, | |
gamma_constraint=None, | |
**kwargs): | |
super(GroupNorm, self).__init__(**kwargs) | |
self.supports_masking = True | |
self.groups = groups | |
self.axis = axis | |
self.epsilon = epsilon | |
self.center = center | |
self.scale = scale | |
self.beta_initializer = initializers.get(beta_initializer) | |
self.gamma_initializer = initializers.get(gamma_initializer) | |
self.beta_regularizer = regularizers.get(beta_regularizer) | |
self.gamma_regularizer = regularizers.get(gamma_regularizer) | |
self.beta_constraint = constraints.get(beta_constraint) | |
self.gamma_constraint = constraints.get(gamma_constraint) | |
def build(self, input_shape): | |
dim = input_shape[self.axis] | |
if dim is None: | |
raise ValueError('Axis ' + str(self.axis) + ' of ' | |
'input tensor should have a defined dimension ' | |
'but the layer received an input with shape ' + | |
str(input_shape) + '.') | |
if dim < self.groups: | |
raise ValueError('Number of groups (' + str(self.groups) + ') cannot be ' | |
'more than the number of channels (' + | |
str(dim) + ').') | |
if dim % self.groups != 0: | |
raise ValueError('Number of groups (' + str(self.groups) + ') must be a ' | |
'multiple of the number of channels (' + | |
str(dim) + ').') | |
self.input_spec = InputSpec(ndim=len(input_shape), | |
axes={self.axis: dim}) | |
shape = (dim,) | |
if self.scale: | |
self.gamma = self.add_weight(shape=shape, | |
name='gamma', | |
initializer=self.gamma_initializer, | |
regularizer=self.gamma_regularizer, | |
constraint=self.gamma_constraint) | |
else: | |
self.gamma = None | |
if self.center: | |
self.beta = self.add_weight(shape=shape, | |
name='beta', | |
initializer=self.beta_initializer, | |
regularizer=self.beta_regularizer, | |
constraint=self.beta_constraint) | |
else: | |
self.beta = None | |
self.built = True | |
def call(self, inputs, **kwargs): | |
input_shape = K.int_shape(inputs) | |
tensor_input_shape = K.shape(inputs) | |
# Prepare broadcasting shape. | |
reduction_axes = list(range(len(input_shape))) | |
del reduction_axes[self.axis] | |
broadcast_shape = [1] * len(input_shape) | |
broadcast_shape[self.axis] = input_shape[self.axis] // self.groups | |
broadcast_shape.insert(1, self.groups) | |
reshape_group_shape = K.shape(inputs) | |
group_axes = [reshape_group_shape[i] for i in range(len(input_shape))] | |
group_axes[self.axis] = input_shape[self.axis] // self.groups | |
group_axes.insert(1, self.groups) | |
# reshape inputs to new group shape | |
group_shape = [group_axes[0], self.groups] + group_axes[2:] | |
group_shape = K.stack(group_shape) | |
inputs = K.reshape(inputs, group_shape) | |
group_reduction_axes = list(range(len(group_axes))) | |
group_reduction_axes = group_reduction_axes[2:] | |
mean = K.mean(inputs, axis=group_reduction_axes, keepdims=True) | |
variance = K.var(inputs, axis=group_reduction_axes, keepdims=True) | |
inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon)) | |
# prepare broadcast shape | |
inputs = K.reshape(inputs, group_shape) | |
outputs = inputs | |
# In this case we must explicitly broadcast all parameters. | |
if self.scale: | |
broadcast_gamma = K.reshape(self.gamma, broadcast_shape) | |
outputs = outputs * broadcast_gamma | |
if self.center: | |
broadcast_beta = K.reshape(self.beta, broadcast_shape) | |
outputs = outputs + broadcast_beta | |
outputs = K.reshape(outputs, tensor_input_shape) | |
return outputs | |
def get_config(self): | |
config = { | |
'groups': self.groups, | |
'axis': self.axis, | |
'epsilon': self.epsilon, | |
'center': self.center, | |
'scale': self.scale, | |
'beta_initializer': initializers.serialize(self.beta_initializer), | |
'gamma_initializer': initializers.serialize(self.gamma_initializer), | |
'beta_regularizer': regularizers.serialize(self.beta_regularizer), | |
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), | |
'beta_constraint': constraints.serialize(self.beta_constraint), | |
'gamma_constraint': constraints.serialize(self.gamma_constraint) | |
} | |
base_config = super(GroupNorm, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def compute_output_shape(self, input_shape): | |
return input_shape | |
class SwitchNorm(Layer): | |
"""Switchable Normalization layer | |
Switch Normalization performs Instance Normalization, Layer Normalization and Batch | |
Normalization using its parameters, and then weighs them using learned parameters to | |
allow different levels of interaction of the 3 normalization schemes for each layer. | |
Only supports the moving average variant from the paper, since the `batch average` | |
scheme requires dynamic graph execution to compute the mean and variance of several | |
batches at runtime. | |
# Arguments | |
axis: Integer, the axis that should be normalized | |
(typically the features axis). | |
For instance, after a `Conv2D` layer with | |
`data_format="channels_first"`, | |
set `axis=1` in `BatchNormalization`. | |
momentum: Momentum for the moving mean and the moving variance. The original | |
implementation suggests a default momentum of `0.997`, however it is highly | |
unstable and training can fail after a few epochs. To stabilise training, use | |
lower values of momentum such as `0.99` or `0.98`. | |
epsilon: Small float added to variance to avoid dividing by zero. | |
final_gamma: Bool value to determine if this layer is the final | |
normalization layer for the residual block. Overrides the initialization | |
of the scaling weights to be `zeros`. Only used for Residual Networks, | |
to make the forward/backward signal initially propagated through an | |
identity shortcut. | |
center: If True, add offset of `beta` to normalized tensor. | |
If False, `beta` is ignored. | |
scale: If True, multiply by `gamma`. | |
If False, `gamma` is not used. | |
When the next layer is linear (also e.g. `nn.relu`), | |
this can be disabled since the scaling | |
will be done by the next layer. | |
beta_initializer: Initializer for the beta weight. | |
gamma_initializer: Initializer for the gamma weight. | |
mean_weights_initializer: Initializer for the mean weights. | |
variance_weights_initializer: Initializer for the variance weights. | |
moving_mean_initializer: Initializer for the moving mean. | |
moving_variance_initializer: Initializer for the moving variance. | |
beta_regularizer: Optional regularizer for the beta weight. | |
gamma_regularizer: Optional regularizer for the gamma weight. | |
mean_weights_regularizer: Optional regularizer for the mean weights. | |
variance_weights_regularizer: Optional regularizer for the variance weights. | |
beta_constraint: Optional constraint for the beta weight. | |
gamma_constraint: Optional constraint for the gamma weight. | |
mean_weights_constraints: Optional constraint for the mean weights. | |
variance_weights_constraints: Optional constraint for the variance weights. | |
# Input shape | |
Arbitrary. Use the keyword argument `input_shape` | |
(tuple of integers, does not include the samples axis) | |
when using this layer as the first layer in a model. | |
# Output shape | |
Same shape as input. | |
# References | |
- [Differentiable Learning-to-Normalize via Switchable Normalization](https://arxiv.org/abs/1806.10779) | |
""" | |
def __init__(self, | |
axis=-1, | |
momentum=0.99, | |
epsilon=1e-3, | |
final_gamma=False, | |
center=True, | |
scale=True, | |
beta_initializer='zeros', | |
gamma_initializer='ones', | |
mean_weights_initializer='ones', | |
variance_weights_initializer='ones', | |
moving_mean_initializer='ones', | |
moving_variance_initializer='zeros', | |
beta_regularizer=None, | |
gamma_regularizer=None, | |
mean_weights_regularizer=None, | |
variance_weights_regularizer=None, | |
beta_constraint=None, | |
gamma_constraint=None, | |
mean_weights_constraints=None, | |
variance_weights_constraints=None, | |
**kwargs): | |
super(SwitchNorm, self).__init__(**kwargs) | |
self.supports_masking = True | |
self.axis = axis | |
self.momentum = momentum | |
self.epsilon = epsilon | |
self.center = center | |
self.scale = scale | |
self.beta_initializer = initializers.get(beta_initializer) | |
if final_gamma: | |
self.gamma_initializer = initializers.get('zeros') | |
else: | |
self.gamma_initializer = initializers.get(gamma_initializer) | |
self.mean_weights_initializer = initializers.get(mean_weights_initializer) | |
self.variance_weights_initializer = initializers.get(variance_weights_initializer) | |
self.moving_mean_initializer = initializers.get(moving_mean_initializer) | |
self.moving_variance_initializer = initializers.get(moving_variance_initializer) | |
self.beta_regularizer = regularizers.get(beta_regularizer) | |
self.gamma_regularizer = regularizers.get(gamma_regularizer) | |
self.mean_weights_regularizer = regularizers.get(mean_weights_regularizer) | |
self.variance_weights_regularizer = regularizers.get(variance_weights_regularizer) | |
self.beta_constraint = constraints.get(beta_constraint) | |
self.gamma_constraint = constraints.get(gamma_constraint) | |
self.mean_weights_constraints = constraints.get(mean_weights_constraints) | |
self.variance_weights_constraints = constraints.get(variance_weights_constraints) | |
def build(self, input_shape): | |
dim = input_shape[self.axis] | |
if dim is None: | |
raise ValueError('Axis ' + str(self.axis) + ' of ' | |
'input tensor should have a defined dimension ' | |
'but the layer received an input with shape ' + | |
str(input_shape) + '.') | |
self.input_spec = InputSpec(ndim=len(input_shape), | |
axes={self.axis: dim}) | |
shape = (dim,) | |
if self.scale: | |
self.gamma = self.add_weight( | |
shape=shape, | |
name='gamma', | |
initializer=self.gamma_initializer, | |
regularizer=self.gamma_regularizer, | |
constraint=self.gamma_constraint) | |
else: | |
self.gamma = None | |
if self.center: | |
self.beta = self.add_weight( | |
shape=shape, | |
name='beta', | |
initializer=self.beta_initializer, | |
regularizer=self.beta_regularizer, | |
constraint=self.beta_constraint) | |
else: | |
self.beta = None | |
self.moving_mean = self.add_weight( | |
shape=shape, | |
name='moving_mean', | |
initializer=self.moving_mean_initializer, | |
trainable=False) | |
self.moving_variance = self.add_weight( | |
shape=shape, | |
name='moving_variance', | |
initializer=self.moving_variance_initializer, | |
trainable=False) | |
self.mean_weights = self.add_weight( | |
shape=(3,), | |
name='mean_weights', | |
initializer=self.mean_weights_initializer, | |
regularizer=self.mean_weights_regularizer, | |
constraint=self.mean_weights_constraints) | |
self.variance_weights = self.add_weight( | |
shape=(3,), | |
name='variance_weights', | |
initializer=self.variance_weights_initializer, | |
regularizer=self.variance_weights_regularizer, | |
constraint=self.variance_weights_constraints) | |
self.built = True | |
def call(self, inputs, training=None): | |
input_shape = K.int_shape(inputs) | |
# Prepare broadcasting shape. | |
reduction_axes = list(range(len(input_shape))) | |
del reduction_axes[self.axis] | |
if self.axis != 0: | |
del reduction_axes[0] | |
broadcast_shape = [1] * len(input_shape) | |
broadcast_shape[self.axis] = input_shape[self.axis] | |
mean_instance = K.mean(inputs, reduction_axes, keepdims=True) | |
variance_instance = K.var(inputs, reduction_axes, keepdims=True) | |
mean_layer = K.mean(mean_instance, self.axis, keepdims=True) | |
temp = variance_instance + K.square(mean_instance) | |
variance_layer = K.mean(temp, self.axis, keepdims=True) - K.square(mean_layer) | |
def training_phase(): | |
mean_batch = K.mean(mean_instance, axis=0, keepdims=True) | |
variance_batch = K.mean(temp, axis=0, keepdims=True) - K.square(mean_batch) | |
mean_batch_reshaped = K.flatten(mean_batch) | |
variance_batch_reshaped = K.flatten(variance_batch) | |
if K.backend() != 'cntk': | |
sample_size = K.prod([K.shape(inputs)[axis] | |
for axis in reduction_axes]) | |
sample_size = K.cast(sample_size, dtype=K.dtype(inputs)) | |
# sample variance - unbiased estimator of population variance | |
variance_batch_reshaped *= sample_size / (sample_size - (1.0 + self.epsilon)) | |
self.add_update([K.moving_average_update(self.moving_mean, | |
mean_batch_reshaped, | |
self.momentum), | |
K.moving_average_update(self.moving_variance, | |
variance_batch_reshaped, | |
self.momentum)], | |
inputs) | |
return normalize_func(mean_batch, variance_batch) | |
def inference_phase(): | |
mean_batch = self.moving_mean | |
variance_batch = self.moving_variance | |
return normalize_func(mean_batch, variance_batch) | |
def normalize_func(mean_batch, variance_batch): | |
mean_batch = K.reshape(mean_batch, broadcast_shape) | |
variance_batch = K.reshape(variance_batch, broadcast_shape) | |
mean_weights = K.softmax(self.mean_weights, axis=0) | |
variance_weights = K.softmax(self.variance_weights, axis=0) | |
mean = (mean_weights[0] * mean_instance + | |
mean_weights[1] * mean_layer + | |
mean_weights[2] * mean_batch) | |
variance = (variance_weights[0] * variance_instance + | |
variance_weights[1] * variance_layer + | |
variance_weights[2] * variance_batch) | |
outputs = (inputs - mean) / (K.sqrt(variance + self.epsilon)) | |
if self.scale: | |
broadcast_gamma = K.reshape(self.gamma, broadcast_shape) | |
outputs = outputs * broadcast_gamma | |
if self.center: | |
broadcast_beta = K.reshape(self.beta, broadcast_shape) | |
outputs = outputs + broadcast_beta | |
return outputs | |
if training in {0, False}: | |
return inference_phase() | |
return K.in_train_phase(training_phase, | |
inference_phase, | |
training=training) | |
def get_config(self): | |
config = { | |
'axis': self.axis, | |
'epsilon': self.epsilon, | |
'momentum': self.momentum, | |
'center': self.center, | |
'scale': self.scale, | |
'beta_initializer': initializers.serialize(self.beta_initializer), | |
'gamma_initializer': initializers.serialize(self.gamma_initializer), | |
'mean_weights_initializer': initializers.serialize(self.mean_weights_initializer), | |
'variance_weights_initializer': initializers.serialize(self.variance_weights_initializer), | |
'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer), | |
'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer), | |
'beta_regularizer': regularizers.serialize(self.beta_regularizer), | |
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), | |
'mean_weights_regularizer': regularizers.serialize(self.mean_weights_regularizer), | |
'variance_weights_regularizer': regularizers.serialize(self.variance_weights_regularizer), | |
'beta_constraint': constraints.serialize(self.beta_constraint), | |
'gamma_constraint': constraints.serialize(self.gamma_constraint), | |
'mean_weights_constraints': constraints.serialize(self.mean_weights_constraints), | |
'variance_weights_constraints': constraints.serialize(self.variance_weights_constraints), | |
} | |
base_config = super(SwitchNormalization, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def compute_output_shape(self, input_shape): | |
return input_shape | |
get_custom_objects().update({'GroupNorm': GroupNorm}) | |
get_custom_objects().update({'SwitchNorm': SwitchNorm}) | |
def compute_backbone_shapes(config, image_shape): | |
"""Computes the width and height of each stage of the backbone network. | |
Returns: | |
[N, (height, width)]. Where N is the number of stages | |
""" | |
if callable(config.BACKBONE): | |
return config.COMPUTE_BACKBONE_SHAPE(image_shape) | |
# Currently supports ResNet only | |
# assert config.BACKBONE in ["resnet50", "resnet101", "densenet121", "densenet169"] | |
return np.array( | |
[[int(math.ceil(image_shape[0] / stride)), | |
int(math.ceil(image_shape[1] / stride))] | |
for stride in config.BACKBONE_STRIDES]) | |
def normalize_layer(tensor, name, norm_use='sn', train_bn=False): | |
"""Setup desired normalization layer. | |
# Arguments | |
tensor: input tensor. | |
layer_name: norm_use will be prefix, e.g. "sn_norm1" | |
norm_use: "sn":Switchable normalization, "bn":Batch normalization, "gn": group normalization | |
train_bn: if norm_use="bn" and batch size is large enough, train_bn can be True | |
# Returns | |
Output tensor for the block. | |
""" | |
if norm_use == "gn": | |
x = GroupNorm(name='gn_'+ name)(tensor) | |
elif norm_use == "sn": | |
x = SwitchNorm(name='sn_'+ name)(tensor) | |
else: | |
x = BatchNorm(name='bn_'+ name)(tensor, training=train_bn) | |
return x | |
############################################################ | |
# Resnet Graph | |
############################################################ | |
# Code adopted from: | |
# https://github.com/fchollet/deep-learning-models/blob/master/resnet50.py | |
def identity_block(input_tensor, kernel_size, filters, stage, block, | |
use_bias=True, norm_use="sn", train_bn=True): | |
"""The identity_block is the block that has no conv layer at shortcut | |
# Arguments | |
input_tensor: input tensor | |
kernel_size: defualt 3, the kernel size of middle conv layer at main path | |
filters: list of integers, the nb_filters of 3 conv layer at main path | |
stage: integer, current stage label, used for generating layer names | |
block: 'a','b'..., current block label, used for generating layer names | |
use_bias: Boolean. To use or not use a bias in conv layers. | |
train_bn: Boolean. Train or freeze Batch Norm layres | |
""" | |
nb_filter1, nb_filter2, nb_filter3 = filters | |
conv_name_base = 'res' + str(stage) + block + '_branch' | |
norm_name_base = str(stage) + block + "_branch" | |
x = KL.Conv2D(nb_filter1, (1, 1), name=conv_name_base + '2a', | |
use_bias=use_bias)(input_tensor) | |
x = normalize_layer(x, name=norm_name_base + "2a", norm_use=norm_use, train_bn=train_bn) | |
x = KL.Activation('relu')(x) | |
x = KL.Conv2D(nb_filter2, (kernel_size, kernel_size), padding='same', | |
name=conv_name_base + '2b', use_bias=use_bias)(x) | |
x = normalize_layer(x, name=norm_name_base + "2b", norm_use=norm_use, train_bn=train_bn) | |
x = KL.Activation('relu')(x) | |
x = KL.Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c', | |
use_bias=use_bias)(x) | |
x = normalize_layer(x, name=norm_name_base + "2c", norm_use=norm_use, train_bn=train_bn) | |
x = KL.Add()([x, input_tensor]) | |
x = KL.Activation('relu', name='res' + str(stage) + block + '_out')(x) | |
return x | |
def conv_block(input_tensor, kernel_size, filters, stage, block, | |
strides=(2, 2), use_bias=True, norm_use='sn', train_bn=True): | |
"""conv_block is the block that has a conv layer at shortcut | |
# Arguments | |
input_tensor: input tensor | |
kernel_size: defualt 3, the kernel size of middle conv layer at main path | |
filters: list of integers, the nb_filters of 3 conv layer at main path | |
stage: integer, current stage label, used for generating layer names | |
block: 'a','b'..., current block label, used for generating layer names | |
norm_use: default 'sn', which is STOA normlization by sensetime. select the normlization layer to use: "sn", 'gn', 'bn' | |
use_bias: Boolean. To use or not use a bias in conv layers. | |
train_bn: Boolean. Train or freeze Batch Norm layres | |
Note that from stage 3, the first conv layer at main path is with subsample=(2,2) | |
And the shortcut should have subsample=(2,2) as well | |
""" | |
nb_filter1, nb_filter2, nb_filter3 = filters | |
conv_name_base = "res" + str(stage) + block + "_branch" | |
norm_name_base = str(stage) + block + "_branch" | |
x = KL.Conv2D(nb_filter1, (1, 1), strides=strides, | |
name=conv_name_base + "2a", use_bias=use_bias)(input_tensor) | |
x = normalize_layer(x, name=norm_name_base + "2a", norm_use=norm_use, train_bn=train_bn) | |
x = KL.Activation('relu')(x) | |
x = KL.Conv2D(nb_filter2, (kernel_size, kernel_size), padding='same', | |
name=conv_name_base + '2b', use_bias=use_bias)(x) | |
x = normalize_layer(x, name=norm_name_base + "2b", norm_use=norm_use, train_bn=train_bn) | |
x = KL.Activation('relu')(x) | |
x = KL.Conv2D(nb_filter3, (1, 1), name=conv_name_base + | |
'2c', use_bias=use_bias)(x) | |
x = normalize_layer(x, name=norm_name_base + "2c", norm_use=norm_use, train_bn=train_bn) | |
shortcut = KL.Conv2D(nb_filter3, (1, 1), strides=strides, | |
name=conv_name_base + '1', use_bias=use_bias)(input_tensor) | |
shortcut = normalize_layer(shortcut, name=norm_name_base + '1', norm_use=norm_use, train_bn=train_bn) | |
x = KL.Add()([x, shortcut]) | |
x = KL.Activation('relu', name='res' + str(stage) + block + '_out')(x) | |
return x | |
def resnet_graph(input_image, architecture, stage5=False, norm_use="sn", train_bn=True): | |
"""Build a ResNet graph. | |
architecture: Can be resnet50 or resnet101 | |
stage5: Boolean. If False, stage5 of the network is not created | |
norm_use: Defualt:sn. "sn": Switchable Normalization, "bn": Batch Normalization, "gn": Group Normalization | |
train_bn: Boolean. Train or freeze Batch Norm layres | |
""" | |
assert architecture in ["resnet50", "resnet101"] | |
# Stage 1 | |
x = KL.ZeroPadding2D((3, 3))(input_image) | |
x = KL.Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=True)(x) | |
x = normalize_layer(x, name="conv1", norm_use=norm_use, train_bn=train_bn) | |
x = KL.Activation('relu')(x) | |
C1 = x = KL.MaxPooling2D((3, 3), strides=(2, 2), padding="same")(x) | |
# Stage 2 | |
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), norm_use=norm_use, train_bn=train_bn) | |
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', norm_use=norm_use, train_bn=train_bn) | |
C2 = x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', norm_use=norm_use, train_bn=train_bn) | |
# Stage 3 | |
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', norm_use=norm_use, train_bn=train_bn) | |
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', norm_use=norm_use, train_bn=train_bn) | |
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', norm_use=norm_use, train_bn=train_bn) | |
C3 = x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', norm_use=norm_use, train_bn=train_bn) | |
# Stage 4 | |
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', norm_use=norm_use, train_bn=train_bn) | |
block_count = {"resnet50": 5, "resnet101": 22}[architecture] | |
for i in range(block_count): | |
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=chr(98 + i), norm_use=norm_use, train_bn=train_bn) | |
C4 = x | |
# Stage 5 | |
if stage5: | |
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', norm_use=norm_use, train_bn=train_bn) | |
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', norm_use=norm_use, train_bn=train_bn) | |
C5 = x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', norm_use=norm_use, train_bn=train_bn) | |
else: | |
C5 = None | |
return [C1, C2, C3, C4, C5] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment