Last active
March 10, 2021 09:47
-
-
Save braingineer/b64ca35223c7782667984d34ddb7a7fa to your computer and use it in GitHub Desktop.
keras lambda layer supporting masking
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Lambda(Layer): | |
'''Used for evaluating an arbitrary Theano / TensorFlow expression | |
on the output of the previous layer. | |
# Examples | |
```python | |
# add a x -> x^2 layer | |
model.add(Lambda(lambda x: x ** 2)) | |
``` | |
```python | |
# add a layer that returns the concatenation | |
# of the positive part of the input and | |
# the opposite of the negative part | |
def antirectifier(x): | |
x -= K.mean(x, axis=1, keepdims=True) | |
x = K.l2_normalize(x, axis=1) | |
pos = K.relu(x) | |
neg = K.relu(-x) | |
return K.concatenate([pos, neg], axis=1) | |
def antirectifier_output_shape(input_shape): | |
shape = list(input_shape) | |
assert len(shape) == 2 # only valid for 2D tensors | |
shape[-1] *= 2 | |
return tuple(shape) | |
model.add(Lambda(antirectifier, output_shape=antirectifier_output_shape)) | |
``` | |
# Arguments | |
function: The function to be evaluated. | |
Takes one argument: the output of previous layer | |
output_shape: Expected output shape from function. | |
Could be a tuple or a function of the shape of the input | |
mask_function: A function which takes as input x and mask and returns a new mask | |
arguments: optional dictionary of keyword arguments to be passed | |
to the function. | |
# Input shape | |
Arbitrary. Use the keyword argument input_shape | |
(tuple of integers, does not include the samples axis) | |
when using this layer as the first layer in a model. | |
# Output shape | |
Specified by `output_shape` argument. | |
''' | |
def __init__(self, function, output_shape=None, mask_function=None, arguments={}, **kwargs): | |
self.function = function | |
self.arguments = arguments | |
if output_shape is None: | |
self._output_shape = None | |
elif type(output_shape) in {tuple, list}: | |
self._output_shape = tuple(output_shape) | |
else: | |
if not hasattr(output_shape, '__call__'): | |
raise Exception('In Lambda, `output_shape` ' | |
'must be a list, a tuple, or a function.') | |
self._output_shape = output_shape | |
if mask_function is None: | |
self._mask_function = None | |
self.supports_masking = False # can flag masking here or not. not sure which to do. | |
elif hasattr(mask_function, '__call__'): | |
self._mask_function = mask_function | |
self.supports_masking = True | |
else: | |
raise Exception("In Lambda, `mask_function` " | |
"must be a function that computes the new mask") | |
super(Lambda, self).__init__(**kwargs) | |
def get_output_shape_for(self, input_shape): | |
if self._output_shape is None: | |
# if TensorFlow, we can infer the output shape directly: | |
if K._BACKEND == 'tensorflow': | |
if type(input_shape) is list: | |
xs = [K.placeholder(shape=shape) for shape in input_shape] | |
x = self.call(xs) | |
else: | |
x = K.placeholder(shape=input_shape) | |
x = self.call(x) | |
if type(x) is list: | |
return [K.int_shape(x_elem) for x_elem in x] | |
else: | |
return K.int_shape(x) | |
# otherwise, we default to the input shape | |
return input_shape | |
elif type(self._output_shape) in {tuple, list}: | |
nb_samples = input_shape[0] if input_shape else None | |
return (nb_samples,) + tuple(self._output_shape) | |
else: | |
shape = self._output_shape(input_shape) | |
if type(shape) not in {list, tuple}: | |
raise Exception('output_shape function must return a tuple') | |
return tuple(shape) | |
def call(self, x, mask=None): | |
arguments = self.arguments | |
arg_spec = inspect.getargspec(self.function) | |
if 'mask' in arg_spec.args: | |
arguments['mask'] = mask | |
return self.function(x, **arguments) | |
def compute_mask(self, x, mask=None): | |
''' can either throw exception or just accept the mask here... not sure which to do''' | |
if self._mask_function is not None: | |
return self._mask_function(x, mask) | |
else: | |
return mask | |
def get_config(self): | |
py3 = sys.version_info[0] == 3 | |
if isinstance(self.function, python_types.LambdaType): | |
if py3: | |
function = marshal.dumps(self.function.__code__).decode('raw_unicode_escape') | |
else: | |
function = marshal.dumps(self.function.func_code).decode('raw_unicode_escape') | |
function_type = 'lambda' | |
else: | |
function = self.function.__name__ | |
function_type = 'function' | |
if isinstance(self._mask_function, python_types.LambdaType): | |
if py3: | |
mask_func = marshal.dumps(self._mask_function.__code__).decode('raw_unicode_escape') | |
else: | |
mask_func = marshal.dumps(self._mask_function.func_code).decode('raw_unicode_escape') | |
mask_func_type = 'lambda' | |
elif callable(self._mask_function): | |
mask_func = self._mask_function.__name__ | |
mask_func_type = 'function' | |
else: | |
mask_func = 'unknown' | |
mask_func_type = 'unknown' | |
if isinstance(self._output_shape, python_types.LambdaType): | |
if py3: | |
output_shape = marshal.dumps(self._output_shape.__code__) | |
else: | |
output_shape = marshal.dumps(self._output_shape.func_code) | |
output_shape_type = 'lambda' | |
elif callable(self._output_shape): | |
output_shape = self._output_shape.__name__ | |
output_shape_type = 'function' | |
else: | |
output_shape = self._output_shape | |
output_shape_type = 'raw' | |
config = {'function': function, | |
'function_type': function_type, | |
'mask_function': mask_func, | |
'mask_function_type': mask_func_type, | |
'output_shape': output_shape, | |
'output_shape_type': output_shape_type, | |
'arguments': self.arguments} | |
base_config = super(Lambda, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
@classmethod | |
def from_config(cls, config): | |
function_type = config.pop('function_type') | |
if function_type == 'function': | |
function = globals()[config['function']] | |
elif function_type == 'lambda': | |
function = marshal.loads(config['function'].encode('raw_unicode_escape')) | |
function = python_types.FunctionType(function, globals()) | |
else: | |
raise Exception('Unknown function type: ' + function_type) | |
mask_function_type = config.pop('mask_function_type') | |
if mask_function_type == 'function': | |
mask_function = globals()[config['mask_function']] | |
elif mask_function_type == 'lambda': | |
mask_function = marshal.loads(config['mask_function'].encode('raw_unicode_escape')) | |
mask_function = python_types.FunctionType(mask_function, globals()) | |
else: | |
raise Exception('Unknown function type: ' + mask_function_type) | |
output_shape_type = config.pop('output_shape_type') | |
if output_shape_type == 'function': | |
output_shape = globals()[config['output_shape']] | |
elif output_shape_type == 'lambda': | |
output_shape = marshal.loads(config['output_shape']) | |
output_shape = python_types.FunctionType(output_shape, globals()) | |
else: | |
output_shape = config['output_shape'] | |
config['function'] = function | |
config['output_shape'] = output_shape | |
config['mask_function'] = mask_function | |
return cls(**config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment