Skip to content

Instantly share code, notes, and snippets.

@albanD
Created January 24, 2023 19:34
Show Gist options
  • Save albanD/804d5909295a1e71b5d726597dfbd605 to your computer and use it in GitHub Desktop.
Save albanD/804d5909295a1e71b5d726597dfbd605 to your computer and use it in GitHub Desktop.
Make PyTorch custom Function unpack input and output using pytree.
import torch
from torch.autograd import Function
import torch.utils._pytree as pytree
# Basically wraps things in and out before passing it to the real function that the user defined.
def pytreeify(cls):
assert issubclass(cls, Function)
orig_fw = cls.forward
orig_bw = cls.backward
orig_apply = cls.apply
def new_apply(*inp):
flat_inp, struct = pytree.tree_flatten(inp)
out_struct_holder = []
flat_out = orig_apply(struct, out_struct_holder, *flat_inp)
assert len(out_struct_holder) == 1
return pytree.tree_unflatten(flat_out, out_struct_holder[0])
def new_forward(ctx, struct, out_struct_holder, *flat_inp):
inp = pytree.tree_unflatten(flat_inp, struct)
out = orig_fw(ctx, *inp)
flat_out, out_struct = pytree.tree_flatten(out)
ctx._inp_struct = struct
ctx._out_struct = out_struct
out_struct_holder.append(out_struct)
return tuple(flat_out)
def new_backward(ctx, *flat_grad_outputs):
grad_outputs = pytree.tree_unflatten(flat_grad_outputs, ctx._out_struct)
if not isinstance(grad_outputs, tuple):
grad_outputs = (grad_outputs,)
grad_inputs = orig_bw(ctx, *grad_outputs)
flat_grad_inputs, grad_inputs_struct = pytree.tree_flatten(grad_inputs)
if grad_inputs_struct != ctx._inp_struct:
raise RuntimeError("The backward generated an arg structure that doesn't "
"match the forward's input.")
return (None, None) + tuple(flat_grad_inputs)
cls.apply = new_apply
cls.forward = new_forward
cls.backward = new_backward
return cls
# Very basic test where pytreeify is not needed
@pytreeify
class BasicTest(Function):
@staticmethod
def forward(ctx, inp1, inp2):
return inp1 + inp2
@staticmethod
def backward(ctx, gO):
return gO, gO
a = torch.rand(10, requires_grad=True)
b = torch.rand(10, requires_grad=True)
out = BasicTest.apply(a, b)
out.sum().backward()
# A more advanced test
@pytreeify
class Test(Function):
@staticmethod
def forward(ctx, tuple_holder, dict_holder):
out0 = tuple_holder[0] + tuple_holder[1] + dict_holder["foo"]
out1 = (dict_holder["bar"] + tuple_holder[0],)
return (out0, out1)
@staticmethod
def backward(ctx, gO0, gO1):
gtuple0 = gO0 + gO1[0]
gtuple1 = gO0
gdictfoo = gO0
gdictbar = gO1[0]
return (gtuple0, gtuple1), {"foo": gdictfoo, "bar": gdictbar}
a = torch.rand(10, requires_grad=True)
b = torch.rand(10, requires_grad=True)
c = torch.rand(10, requires_grad=True)
d = torch.rand(10, requires_grad=True)
inp1 = (a, b)
inp2 = {"foo": c, "bar": d}
out = Test.apply(inp1, inp2)
print(out)
(out[0] + out[1][0]).sum().backward()
print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment