Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Last active April 29, 2019 05:00
Show Gist options
  • Save wanchaol/c50875569d4266aff20133b6d0f387c3 to your computer and use it in GitHub Desktop.
Save wanchaol/c50875569d4266aff20133b6d0f387c3 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
import torch.jit as jit
from torch import Tensor
@jit.script
def test_mm_back(input1, input2, normalized_shape):
# type: (Tensor, Tensor, List[int]) -> Tensor
return F.layer_norm(torch.mm(input1, input2), normalized_shape)
input1 = torch.randn(3, 10, requires_grad=True)
input2 = torch.randn(10, 80, requires_grad=True)
output = test_mm_back(input1, input2, (80,))
def simple_backward_setup(output, seed=None):
assert isinstance(output, torch.Tensor)
if seed:
torch.manual_seed(seed)
grad_output = torch.randn_like(output)
return output, grad_output
def simple_backward(output, grad_output):
return output.backward(grad_output)
backward_input = simple_backward_setup(output)
simple_backward(*backward_input)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment