Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Created July 6, 2023 23:03
Show Gist options
  • Save soulitzer/1db4d29f1223c7a24ffd8715be9f230f to your computer and use it in GitHub Desktop.
Save soulitzer/1db4d29f1223c7a24ffd8715be9f230f to your computer and use it in GitHub Desktop.
output_nr issue when requires_grad=False
from torch.library import Library
test_ns = "abc"
lib = Library(test_ns, "FRAGMENT")
lib.define("foo(Tensor(a!) a, Tensor(b!) b) -> (Tensor(a!), Tensor(b!))")
def get_op(name):
return getattr(getattr(torch.ops, test_ns), name).default
op = get_op("foo")
def foo_impl(a, b):
return a, b
lib.impl("foo", foo_impl, "CPU")
x = torch.randn(3, requires_grad=False)
y = torch.randn(3, requires_grad=False)
z, w = op(x, y)
print(z.grad_fn, y.grad_fn)
print(z.output_nr, w.output_nr) # 0, 0
out = (z * 2 + w * 3).sum()
dz, dw = torch.autograd.grad(out, (z, w))
print(dz, dw) # tensor([5., 5., 5.]) tensor([5., 5., 5.])
# expect tensor([2., 2., 2.]) tensor([3., 3., 3.])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment