Last active
May 24, 2019 15:45
-
-
Save wohlert/e7242e708b076a3108c7602d99f31e09 to your computer and use it in GitHub Desktop.
A metaprogramming approach with minimal overhead to perform Jacobian computation for arbitrary networks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Registers jacobian for basic PyTorch layers | |
""" | |
import torch | |
from torch import nn | |
from .metric import register_jacobian | |
@register_jacobian(nn.Linear) | |
def _linear_jacobian(module): | |
# Jacobian is simply the weight matrix transposed | |
w = module.weight | |
def jacobian(_x, dx): | |
eq = "bi, ij -> bij" if dx.dim() == w.dim() else "bij, jk -> bik" | |
dx = torch.einsum(eq, dx, w) | |
return dx | |
return jacobian | |
@register_jacobian(nn.Tanh) | |
def _tanh_jacobian(_module): | |
# dtanh(x)/dx | |
dtanh = lambda x: (1 - torch.tanh(x)**2) | |
def jacobian(x, dx): | |
# Send x back through derivative of activation | |
dy = dtanh(x) | |
dy = dy.unsqueeze(1) if dx.dim() != dy.dim() else dy | |
# Accumulate gradient | |
return dx * dy | |
return jacobian |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
# When importing the Jacobians are automatically registered | |
from .metric import model_jacobian | |
model = nn.Sequential( | |
nn.Linear(3, 32), | |
nn.Tanh(), | |
nn.Linear(32, 32), | |
nn.Tanh(), | |
nn.Linear(32, 2) | |
) | |
# Easily compute | |
x = torch.randn(64, 3) | |
jac = model_jacobian(model, x) | |
# Define another model | |
model2 = nn.Sequential( | |
nn.Linear(3, 32), | |
nn.Linear(32, 2) | |
) | |
# No need to change anything or write model specific code | |
jac = model_jacobian(model2, x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
import torch | |
_JACOBIAN_REGISTRY = {} | |
_JACOBIAN_MEMOIZE = {} | |
def register_jacobian(module_type): | |
""" | |
Adds a Jacobian method for the specific module type | |
to the lookup. | |
:param module_type: type of module (must inherit nn.Module) | |
:return: a function that decorates the Jacobian | |
""" | |
if not isinstance(module_type, type) and issubclass(module_type, nn.Module): | |
raise TypeError('Expected type_p to be a Module subclass but got {}'.format(module_type)) | |
def decorator(fun): | |
_JACOBIAN_REGISTRY[module_type] = fun | |
_JACOBIAN_MEMOIZE.clear() # reset since lookup order may have changed | |
return fun | |
return decorator | |
def jacobian(module, input, grad=None): | |
""" | |
Computes the Jacobian of a module (layer) given | |
an input and optionally a gradient of previous layer. | |
The Jacobian of the module must be preregistered through | |
the method `register_jacobian`. | |
:param module: layer to compute for | |
:param input: input to compute Jacobian wrt. | |
:param grad: | |
:return: J_input | |
""" | |
module_type = type(module) | |
# Try to find it in cache first | |
try: | |
fun = _JACOBIAN_MEMOIZE[module_type] | |
except KeyError: | |
# Find a function in registry that matches module | |
matches = [key for key in _JACOBIAN_REGISTRY if module_type is key] | |
if not matches: | |
return NotImplemented | |
# Jacobian found, save in cache and return | |
fun = _JACOBIAN_REGISTRY[module_type] | |
_JACOBIAN_MEMOIZE[module_type] = fun | |
if fun is NotImplemented: | |
raise NotImplementedError | |
return fun(module)(input, grad) | |
def model_jacobian(model, x): | |
""" | |
Computes the Jacobian for a PyTorch module for which | |
forward computation is straight forward, e.g. nn.Sequential. | |
All layers have registered Jacobians. | |
:param model: nn.Module | |
:param x: input | |
:return: J_x | |
""" | |
# Compute forward pass and save intermediate states | |
forwards = [] | |
for module in model: | |
forwards.append(x) | |
x = module(x) | |
# Perform backward by accumulating Jacobian recursively | |
# initial condition J_0 = 1 | |
dx = torch.ones_like(x) | |
for x, module in zip(reversed(forwards), reversed(model)): | |
dx = jacobian(module, x, dx) | |
return dx |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment