Skip to content

Instantly share code, notes, and snippets.

@zbessinger
Last active September 1, 2016 21:19
Show Gist options
  • Save zbessinger/439354f4979ca892c438353be4bdd199 to your computer and use it in GitHub Desktop.
Save zbessinger/439354f4979ca892c438353be4bdd199 to your computer and use it in GitHub Desktop.
from tensorflow.contrib import slim
from tensorflow.contrib.slim.nets import resnet_v2
def _res_net_tfslim(self, x, n_outputs, reg=None, activation_fn=tf.nn.relu, reuse=False, is_training=True):
with slim.arg_scope(self.resnet_arg_scope(is_training)):
net, end_points = resnet_v2.resnet_v2_50(x, num_classes=n_outputs, reuse=reuse)
net = tf.squeeze(net) # Comes out as 4D by default
# Final linear layer
net = fc_layer('fully_connected', net, n_outputs, activation_fn=None, reg=reg, reuse=reuse,
is_training=is_training)
return net
def fc_layer(name, net, num_outputs, reg=None, reuse=None, is_training=True, activation_fn=tf.nn.relu):
return tf.contrib.layers.fully_connected(net,
num_outputs=num_outputs,
normalizer_fn=tf.contrib.layers.batch_norm,
normalizer_params={'is_training': is_training,
'updates_collections': None},
weights_initializer=tf.contrib.layers.variance_scaling_initializer(),
weights_regularizer=reg,
activation_fn=activation_fn,
scope=name,
reuse=reuse)
def resnet_arg_scope(self,
is_training=True,
weight_decay=0.0001,
batch_norm_decay=0.997,
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
updates_collections=None):
"""Defines the default ResNet arg scope.
TODO(gpapan): The batch-normalization related default values above are
appropriate for use in conjunction with the reference ResNet models
released at https://github.com/KaimingHe/deep-residual-networks. When
training ResNets from scratch, they might need to be tuned.
Args:
is_training: Whether or not we are training the parameters in the batch
normalization layers of the model.
weight_decay: The weight decay to use for regularizing the model.
batch_norm_decay: The moving average decay when estimating layer activation
statistics in batch normalization.
batch_norm_epsilon: Small constant to prevent division by zero when
normalizing activations by their variance in batch normalization.
batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
activations in the batch normalization layer.
Returns:
An `arg_scope` to use for the resnet models.
"""
batch_norm_params = {
'is_training': is_training,
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon,
'scale': batch_norm_scale,
'updates_collections': updates_collections,
}
with slim.arg_scope(
[slim.conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
# The following implies padding='SAME' for pool1, which makes feature
# alignment easier for dense prediction tasks. This is also used in
# https://github.com/facebook/fb.resnet.torch. However the accompanying
# code of 'Deep Residual Learning for Image Recognition' uses
# padding='VALID' for pool1. You can switch to that choice by setting
# slim.arg_scope([slim.max_pool2d], padding='VALID').
with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
return arg_sc
@jacobsn
Copy link

jacobsn commented Sep 1, 2016

I think there is a potential bug in your use of tf.squeeze. If you have only one entity passing through the network it will fail. You should either use one of the flatten layers (my preferred solution) or explicitly define which dimensions you will squeeze.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment