Created
January 24, 2023 19:34
-
-
Save albanD/804d5909295a1e71b5d726597dfbd605 to your computer and use it in GitHub Desktop.
Make PyTorch custom Function unpack input and output using pytree.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.autograd import Function | |
import torch.utils._pytree as pytree | |
# Basically wraps things in and out before passing it to the real function that the user defined. | |
def pytreeify(cls): | |
assert issubclass(cls, Function) | |
orig_fw = cls.forward | |
orig_bw = cls.backward | |
orig_apply = cls.apply | |
def new_apply(*inp): | |
flat_inp, struct = pytree.tree_flatten(inp) | |
out_struct_holder = [] | |
flat_out = orig_apply(struct, out_struct_holder, *flat_inp) | |
assert len(out_struct_holder) == 1 | |
return pytree.tree_unflatten(flat_out, out_struct_holder[0]) | |
def new_forward(ctx, struct, out_struct_holder, *flat_inp): | |
inp = pytree.tree_unflatten(flat_inp, struct) | |
out = orig_fw(ctx, *inp) | |
flat_out, out_struct = pytree.tree_flatten(out) | |
ctx._inp_struct = struct | |
ctx._out_struct = out_struct | |
out_struct_holder.append(out_struct) | |
return tuple(flat_out) | |
def new_backward(ctx, *flat_grad_outputs): | |
grad_outputs = pytree.tree_unflatten(flat_grad_outputs, ctx._out_struct) | |
if not isinstance(grad_outputs, tuple): | |
grad_outputs = (grad_outputs,) | |
grad_inputs = orig_bw(ctx, *grad_outputs) | |
flat_grad_inputs, grad_inputs_struct = pytree.tree_flatten(grad_inputs) | |
if grad_inputs_struct != ctx._inp_struct: | |
raise RuntimeError("The backward generated an arg structure that doesn't " | |
"match the forward's input.") | |
return (None, None) + tuple(flat_grad_inputs) | |
cls.apply = new_apply | |
cls.forward = new_forward | |
cls.backward = new_backward | |
return cls | |
# Very basic test where pytreeify is not needed | |
@pytreeify | |
class BasicTest(Function): | |
@staticmethod | |
def forward(ctx, inp1, inp2): | |
return inp1 + inp2 | |
@staticmethod | |
def backward(ctx, gO): | |
return gO, gO | |
a = torch.rand(10, requires_grad=True) | |
b = torch.rand(10, requires_grad=True) | |
out = BasicTest.apply(a, b) | |
out.sum().backward() | |
# A more advanced test | |
@pytreeify | |
class Test(Function): | |
@staticmethod | |
def forward(ctx, tuple_holder, dict_holder): | |
out0 = tuple_holder[0] + tuple_holder[1] + dict_holder["foo"] | |
out1 = (dict_holder["bar"] + tuple_holder[0],) | |
return (out0, out1) | |
@staticmethod | |
def backward(ctx, gO0, gO1): | |
gtuple0 = gO0 + gO1[0] | |
gtuple1 = gO0 | |
gdictfoo = gO0 | |
gdictbar = gO1[0] | |
return (gtuple0, gtuple1), {"foo": gdictfoo, "bar": gdictbar} | |
a = torch.rand(10, requires_grad=True) | |
b = torch.rand(10, requires_grad=True) | |
c = torch.rand(10, requires_grad=True) | |
d = torch.rand(10, requires_grad=True) | |
inp1 = (a, b) | |
inp2 = {"foo": c, "bar": d} | |
out = Test.apply(inp1, inp2) | |
print(out) | |
(out[0] + out[1][0]).sum().backward() | |
print(a.grad) | |
print(b.grad) | |
print(c.grad) | |
print(d.grad) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment