Created
January 30, 2019 15:23
-
-
Save Razz21/c5d3d69e0234b5fbb8069bec3bf554e7 to your computer and use it in GitHub Desktop.
Tied weights version for tf.python.keras. Dense layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.python.framework import tensor_shape | |
from tensorflow.python.eager import context | |
from tensorflow.python.framework import common_shapes | |
from tensorflow.python.framework import ops | |
from tensorflow.python.framework import tensor_shape | |
from tensorflow.python.keras import activations | |
from tensorflow.python.keras import backend as K | |
from tensorflow.python.keras import constraints | |
from tensorflow.python.keras import initializers | |
from tensorflow.python.keras import regularizers | |
from tensorflow.python.keras.engine.base_layer import InputSpec | |
from tensorflow.python.keras.engine.base_layer import Layer | |
from tensorflow.python.keras.utils import conv_utils | |
from tensorflow.python.keras.utils import generic_utils | |
from tensorflow.python.keras.utils import tf_utils | |
from tensorflow.python.ops import array_ops | |
from tensorflow.python.ops import gen_math_ops | |
from tensorflow.python.ops import math_ops | |
from tensorflow.python.ops import nn | |
from tensorflow.python.ops import nn_ops | |
from tensorflow.python.ops import standard_ops | |
from tensorflow.python.util.tf_export import tf_export | |
import tensorflow as tf | |
class DenseTied(Layer): | |
#TODO update docstring | |
"""Just your regular densely-connected NN layer. | |
`Dense` implements the operation: | |
`output = activation(dot(input, kernel) + bias)` | |
where `activation` is the element-wise activation function | |
passed as the `activation` argument, `kernel` is a weights matrix | |
created by the layer, and `bias` is a bias vector created by the layer | |
(only applicable if `use_bias` is `True`). | |
Note: if the input to the layer has a rank greater than 2, then | |
it is flattened prior to the initial dot product with `kernel`. | |
Example: | |
```python | |
# as first layer in a sequential model: | |
model = Sequential() | |
model.add(Dense(32, input_shape=(16,))) | |
# now the model will take as input arrays of shape (*, 16) | |
# and output arrays of shape (*, 32) | |
# after the first layer, you don't need to specify | |
# the size of the input anymore: | |
model.add(Dense(32)) | |
``` | |
Arguments: | |
units: Positive integer, dimensionality of the output space. | |
activation: Activation function to use. | |
If you don't specify anything, no activation is applied | |
(ie. "linear" activation: `a(x) = x`). | |
use_bias: Boolean, whether the layer uses a bias vector. | |
kernel_initializer: Initializer for the `kernel` weights matrix. | |
bias_initializer: Initializer for the bias vector. | |
kernel_regularizer: Regularizer function applied to | |
the `kernel` weights matrix. | |
bias_regularizer: Regularizer function applied to the bias vector. | |
activity_regularizer: Regularizer function applied to | |
the output of the layer (its "activation").. | |
kernel_constraint: Constraint function applied to | |
the `kernel` weights matrix. | |
bias_constraint: Constraint function applied to the bias vector. | |
tied_to: tf layer name or layer variable to tie | |
Input shape: | |
nD tensor with shape: `(batch_size, ..., input_dim)`. | |
The most common situation would be | |
a 2D input with shape `(batch_size, input_dim)`. | |
Output shape: | |
nD tensor with shape: `(batch_size, ..., units)`. | |
For instance, for a 2D input with shape `(batch_size, input_dim)`, | |
the output would have shape `(batch_size, units)`. | |
""" | |
def __init__(self, | |
units, | |
activation=None, | |
use_bias=True, | |
# kernel_initializer='glorot_uniform', | |
bias_initializer='zeros', | |
# kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
# kernel_constraint=None, | |
bias_constraint=None, | |
tied_to=None, | |
**kwargs): | |
if 'input_shape' not in kwargs and 'input_dim' in kwargs: | |
kwargs['input_shape'] = (kwargs.pop('input_dim'),) | |
super(DenseTied, self).__init__( | |
activity_regularizer=regularizers.get(activity_regularizer), **kwargs) | |
self.tied_to = tied_to | |
self.units = int(units) | |
self.activation = activations.get(activation) | |
"""transposed weights are variables and don't use any regularizators or initizlizators""" | |
# self.kernel_initializer = None | |
# self.kernel_constraint = None | |
# self.kernel_regularizer = None | |
"""biases are still initialized and regularized""" | |
self.use_bias = use_bias | |
self.bias_initializer = initializers.get(bias_initializer) | |
self.bias_regularizer = regularizers.get(bias_regularizer) | |
self.bias_constraint = constraints.get(bias_constraint) | |
self.supports_masking = True | |
self.input_spec = InputSpec(min_ndim=2) | |
def build(self, input_shape): | |
input_shape = tensor_shape.TensorShape(input_shape) | |
if input_shape[-1].value is None: | |
raise ValueError('The last dimension of the inputs to `Dense` ' | |
'should be defined. Found `None`.') | |
self.input_spec = InputSpec(min_ndim=2, | |
axes={-1: input_shape[-1].value}) | |
"""Get and transpose tied weights | |
Caution: <weights> method returns array of arrays with kernels and biases and use only kernels here""" | |
if isinstance(self.tied_to, str): | |
# if <tied_to> is str i.e. tf layer name | |
try: | |
weights = model.get_layer("{}".format(self.tied_to)).weights[0] | |
except: | |
weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "{}".format(self.tied_to))[0] | |
self.transposed_weights = tf.transpose(weights, name='{}_kernel_transpose'.format(self.tied_to)) | |
else: | |
# if <tied_to> is layer variable | |
weights = self.tied_to.weights[0] | |
self.transposed_weights = tf.transpose(weights, name='{}_kernel_transpose'.format(self.tied_to.name)) | |
if self.use_bias: | |
self.bias = self.add_weight( | |
'bias', | |
shape=[self.units, ], | |
initializer=self.bias_initializer, | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint, | |
dtype=self.dtype, | |
trainable=True) | |
else: | |
self.bias = None | |
self.built = True | |
def call(self, inputs): | |
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) | |
rank = common_shapes.rank(inputs) | |
if rank > 2: | |
# Broadcasting is required for the inputs. | |
outputs = standard_ops.tensordot(inputs, self.transposed_weights, [[rank - 1], [0]]) | |
# Reshape the output back to the original ndim of the input. | |
if not context.executing_eagerly(): | |
shape = inputs.get_shape().as_list() | |
output_shape = shape[:-1] + [self.units] | |
outputs.set_shape(output_shape) | |
else: | |
outputs = gen_math_ops.mat_mul(inputs, self.transposed_weights) | |
if self.use_bias: | |
outputs = nn.bias_add(outputs, self.bias) | |
if self.activation is not None: | |
return self.activation(outputs) # pylint: disable=not-callable | |
return outputs | |
def compute_output_shape(self, input_shape): | |
input_shape = tensor_shape.TensorShape(input_shape) | |
input_shape = input_shape.with_rank_at_least(2) | |
if input_shape[-1].value is None: | |
raise ValueError( | |
'The innermost dimension of input_shape must be defined, but saw: %s' | |
% input_shape) | |
return input_shape[:-1].concatenate(self.units) | |
def get_config(self): | |
config = { | |
'units': self.units, | |
'activation': activations.serialize(self.activation), | |
'use_bias': self.use_bias, | |
# 'kernel_initializer': initializers.serialize(self.kernel_initializer), | |
'bias_initializer': initializers.serialize(self.bias_initializer), | |
# 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), | |
'bias_regularizer': regularizers.serialize(self.bias_regularizer), | |
'activity_regularizer': | |
regularizers.serialize(self.activity_regularizer), | |
# 'kernel_constraint': constraints.serialize(self.kernel_constraint), | |
'bias_constraint': constraints.serialize(self.bias_constraint) | |
} | |
base_config = super(DenseTied, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Tying weights is commonly used in Auto-encoders
Example - symmetrical Autoencoder with weights of the encoder and the decoder tied :
Only trainable parameters of decoder part (layers dense3 and dense4) are biases