Last active
September 1, 2016 21:19
-
-
Save zbessinger/439354f4979ca892c438353be4bdd199 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from tensorflow.contrib import slim | |
| from tensorflow.contrib.slim.nets import resnet_v2 | |
| def _res_net_tfslim(self, x, n_outputs, reg=None, activation_fn=tf.nn.relu, reuse=False, is_training=True): | |
| with slim.arg_scope(self.resnet_arg_scope(is_training)): | |
| net, end_points = resnet_v2.resnet_v2_50(x, num_classes=n_outputs, reuse=reuse) | |
| net = tf.squeeze(net) # Comes out as 4D by default | |
| # Final linear layer | |
| net = fc_layer('fully_connected', net, n_outputs, activation_fn=None, reg=reg, reuse=reuse, | |
| is_training=is_training) | |
| return net | |
| def fc_layer(name, net, num_outputs, reg=None, reuse=None, is_training=True, activation_fn=tf.nn.relu): | |
| return tf.contrib.layers.fully_connected(net, | |
| num_outputs=num_outputs, | |
| normalizer_fn=tf.contrib.layers.batch_norm, | |
| normalizer_params={'is_training': is_training, | |
| 'updates_collections': None}, | |
| weights_initializer=tf.contrib.layers.variance_scaling_initializer(), | |
| weights_regularizer=reg, | |
| activation_fn=activation_fn, | |
| scope=name, | |
| reuse=reuse) | |
| def resnet_arg_scope(self, | |
| is_training=True, | |
| weight_decay=0.0001, | |
| batch_norm_decay=0.997, | |
| batch_norm_epsilon=1e-5, | |
| batch_norm_scale=True, | |
| updates_collections=None): | |
| """Defines the default ResNet arg scope. | |
| TODO(gpapan): The batch-normalization related default values above are | |
| appropriate for use in conjunction with the reference ResNet models | |
| released at https://github.com/KaimingHe/deep-residual-networks. When | |
| training ResNets from scratch, they might need to be tuned. | |
| Args: | |
| is_training: Whether or not we are training the parameters in the batch | |
| normalization layers of the model. | |
| weight_decay: The weight decay to use for regularizing the model. | |
| batch_norm_decay: The moving average decay when estimating layer activation | |
| statistics in batch normalization. | |
| batch_norm_epsilon: Small constant to prevent division by zero when | |
| normalizing activations by their variance in batch normalization. | |
| batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the | |
| activations in the batch normalization layer. | |
| Returns: | |
| An `arg_scope` to use for the resnet models. | |
| """ | |
| batch_norm_params = { | |
| 'is_training': is_training, | |
| 'decay': batch_norm_decay, | |
| 'epsilon': batch_norm_epsilon, | |
| 'scale': batch_norm_scale, | |
| 'updates_collections': updates_collections, | |
| } | |
| with slim.arg_scope( | |
| [slim.conv2d], | |
| weights_regularizer=slim.l2_regularizer(weight_decay), | |
| weights_initializer=slim.variance_scaling_initializer(), | |
| activation_fn=tf.nn.relu, | |
| normalizer_fn=slim.batch_norm, | |
| normalizer_params=batch_norm_params): | |
| with slim.arg_scope([slim.batch_norm], **batch_norm_params): | |
| # The following implies padding='SAME' for pool1, which makes feature | |
| # alignment easier for dense prediction tasks. This is also used in | |
| # https://github.com/facebook/fb.resnet.torch. However the accompanying | |
| # code of 'Deep Residual Learning for Image Recognition' uses | |
| # padding='VALID' for pool1. You can switch to that choice by setting | |
| # slim.arg_scope([slim.max_pool2d], padding='VALID'). | |
| with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: | |
| return arg_sc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think there is a potential bug in your use of tf.squeeze. If you have only one entity passing through the network it will fail. You should either use one of the flatten layers (my preferred solution) or explicitly define which dimensions you will squeeze.