Skip to content

Instantly share code, notes, and snippets.

@skaae
Created June 12, 2015 20:55
Show Gist options
  • Save skaae/e34d07a56920c522f6c2 to your computer and use it in GitHub Desktop.
Save skaae/e34d07a56920c522f6c2 to your computer and use it in GitHub Desktop.
class ElemwiseMergeLayer(MergeLayer):
"""
This layer performs an elementwise merge of its input layers.
It requires all input layers to have the same output shape.
Parameters
-----------
incomings : a list of :class:`Layer` instances or tuples
the layers feeding into this layer, or expected input shapes,
with all incoming shapes being equal
merge_function : callable
the merge function to use. Should take two arguments and return the
updated value. Some possible merge functions are T.mul and T.add
coeffs: list or scalar
A same-sized list of coefficients, or a single coefficient that
is to be applied to all instances. By default, these will not
be included in the learnable parameters of this layer.
Notes
-----------
Depending on your architecture, this can be used to avoid the more
costly :class:`ConcatLayer`. For example, instead of concatenating layers
before a :class:`DenseLayer`, insert separate :class:`DenseLayer` instances
of the same number of output units and add them up afterwards. (This avoids
the copy operations in concatenation, but splits up the dot product.)
"""
def __init__(self, incomings, merge_function, coeffs=1, **kwargs):
super(ElemwiseMergeLayer, self).__init__(incomings, **kwargs)
if isinstance(coeffs, list):
if len(coeffs) != len(incomings):
raise ValueError("Mismatch: got %d coeffs for %d incomings" %
(len(coeffs), len(incomings)))
else:
coeffs = [coeffs] * len(incomings)
self.coeffs = coeffs
self.merge_function = merge_function
def get_output_shape_for(self, input_shapes):
if any(shape != input_shapes[0] for shape in input_shapes):
raise ValueError("Mismatch: not all input shapes are the same")
return input_shapes[0]
def get_output_for(self, inputs, **kwargs):
output = None
for coeff, input in zip(self.coeffs, inputs):
if coeff != 1:
input *= coeff
if output is not None:
output = self.merge_function(output, input)
else:
output = input
return output
@lc82111
Copy link

lc82111 commented Jun 29, 2015

Is it performing an element-wise multiplication between inputs tensor?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment