Skip to content

Instantly share code, notes, and snippets.

View nmerrillq's full-sized avatar

Nate Merrill nmerrillq

View GitHub Profile
@albanD
albanD / pytreeify.py
Created January 24, 2023 19:34
Make PyTorch custom Function unpack input and output using pytree.
import torch
from torch.autograd import Function
import torch.utils._pytree as pytree
# Basically wraps things in and out before passing it to the real function that the user defined.
def pytreeify(cls):
assert issubclass(cls, Function)
orig_fw = cls.forward
orig_bw = cls.backward