Created
June 12, 2015 20:55
-
-
Save skaae/e34d07a56920c522f6c2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ElemwiseMergeLayer(MergeLayer): | |
""" | |
This layer performs an elementwise merge of its input layers. | |
It requires all input layers to have the same output shape. | |
Parameters | |
----------- | |
incomings : a list of :class:`Layer` instances or tuples | |
the layers feeding into this layer, or expected input shapes, | |
with all incoming shapes being equal | |
merge_function : callable | |
the merge function to use. Should take two arguments and return the | |
updated value. Some possible merge functions are T.mul and T.add | |
coeffs: list or scalar | |
A same-sized list of coefficients, or a single coefficient that | |
is to be applied to all instances. By default, these will not | |
be included in the learnable parameters of this layer. | |
Notes | |
----------- | |
Depending on your architecture, this can be used to avoid the more | |
costly :class:`ConcatLayer`. For example, instead of concatenating layers | |
before a :class:`DenseLayer`, insert separate :class:`DenseLayer` instances | |
of the same number of output units and add them up afterwards. (This avoids | |
the copy operations in concatenation, but splits up the dot product.) | |
""" | |
def __init__(self, incomings, merge_function, coeffs=1, **kwargs): | |
super(ElemwiseMergeLayer, self).__init__(incomings, **kwargs) | |
if isinstance(coeffs, list): | |
if len(coeffs) != len(incomings): | |
raise ValueError("Mismatch: got %d coeffs for %d incomings" % | |
(len(coeffs), len(incomings))) | |
else: | |
coeffs = [coeffs] * len(incomings) | |
self.coeffs = coeffs | |
self.merge_function = merge_function | |
def get_output_shape_for(self, input_shapes): | |
if any(shape != input_shapes[0] for shape in input_shapes): | |
raise ValueError("Mismatch: not all input shapes are the same") | |
return input_shapes[0] | |
def get_output_for(self, inputs, **kwargs): | |
output = None | |
for coeff, input in zip(self.coeffs, inputs): | |
if coeff != 1: | |
input *= coeff | |
if output is not None: | |
output = self.merge_function(output, input) | |
else: | |
output = input | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is it performing an element-wise multiplication between inputs tensor?