Created
April 9, 2025 22:10
-
-
Save soulitzer/7472bc97b909474e77617f3bd474406c to your computer and use it in GitHub Desktop.
graph-based AC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import functools | |
import contextlib | |
from torch.utils._python_dispatch import TorchDispatchMode | |
from torch.utils._pytree import tree_map_only | |
from torch.utils.weak import WeakTensorKeyDictionary | |
from torch.utils.checkpoint import CheckpointPolicy, _policy_from_bool | |
from collections import namedtuple | |
import weakref | |
_NodeOutput = namedtuple('NodeOutput', ['node', 'idx']) | |
# Subclass instead of using namedtuple directly so that it's a pytree leaf | |
class NodeOutput(_NodeOutput): | |
pass | |
def realize(node_output): | |
return node_output.node.run(node_output.idx) | |
def increment_use_counts(node_output): | |
node_output.node.increment_use_counts(node_output.idx) | |
class _List(list): | |
pass | |
# Note [AC Node use-count tracking to clear cached tensors sooner] | |
# | |
# Consider the graph, | |
# | |
# A <- B <- C <- D (saved) | |
# \ | |
# E (saved) | |
# | |
# The arrows indicate the direction of ownership. The flow of data is from left | |
# to right, e.g. A is used to compute B, etc. | |
# | |
# When D is unpacked and realized during backward, it's dependencies, C, B, A | |
# are also recursively realized. To avoid recomputing A and B again when | |
# E is unpacked, some intermediate result, e.g. B's output, should be cached. | |
# Conversely A's output is not needed if B is already cached. | |
# | |
# Note that when D is unpacked A is not freed because A is still kept alive by E | |
# Instead of waiting for A to be automatically cleared when E is eventually | |
# unpacked, we track the number of users of A (1 in this case) and clear it | |
# as soon as B is computed. | |
# | |
# We do this by matching every realize call with a increment_use_count call and | |
# making sure increment_use_count() recurses the same way as realize(). | |
class Node: | |
def __init__(self, func=None, args=None, outs: tuple = None): | |
self.func = func | |
self.args = args | |
self.out = None | |
self.nb_users = dict() # out_idx -> nb_users | |
if outs is not None: | |
self.out = _List([x.detach() if isinstance(x, torch.Tensor) else x for x in outs]) | |
def run(self, idx): | |
if self.out is None: | |
new_args = tree_map_only(NodeOutput, realize, self.args) | |
raw_out = self.func(*new_args) | |
self.out = list(raw_out) if isinstance(raw_out, tuple) else [raw_out] | |
out = self.out[idx] | |
self.nb_users[idx] -= 1 | |
if self.nb_users[idx] == 0: | |
self.out[idx] = None | |
return out | |
def increment_use_counts(self, idx): | |
if self.out is None and len(self.nb_users) == 0: | |
tree_map_only(NodeOutput, increment_use_counts, self.args) | |
self.nb_users[idx] = self.nb_users.get(idx, 0) + 1 | |
def get_node_output(node_outputs, t): | |
if t not in node_outputs: | |
# If the tensor was created in the checkpoint region, then it would've | |
# been saved to node_outputs. If it's not there, then it's an input | |
node_outputs[t] = NodeOutput(Node(None, None, (t,)), 0) | |
return node_outputs[t] | |
class Context(): | |
def __init__(self, nodes): | |
self.nodes = nodes | |
class TrackingMode(TorchDispatchMode): | |
def __init__(self, node_outputs): | |
self.node_outputs = node_outputs | |
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
kwargs = {} if kwargs is None else kwargs | |
out = func(*args, **kwargs) | |
# Non-tensor are always kept alive by the node | |
wrapped_args = tree_map_only( | |
torch.Tensor, | |
functools.partial(get_node_output, self.node_outputs), | |
args | |
) | |
# TODO: nodes shouldn't be public API right? | |
ctx = Context(self.node_outputs) | |
should_save = is_save_policy(current_global_policy(ctx, out, func, *args, **kwargs)) | |
out_tuple = tuple(out) if isinstance(out, (list, tuple)) else (out,) | |
node = ( | |
Node(func, None, out_tuple) if should_save else | |
Node(func, wrapped_args, None) | |
) | |
for idx, t in enumerate(out_tuple): | |
if isinstance(t, torch.Tensor): | |
self.node_outputs[t] = NodeOutput(node, idx) | |
return out | |
class CheckpointHook(torch.autograd.graph.saved_tensors_hooks): | |
def __init__(self, node_outputs): | |
def pack_hook(raw_tensor): | |
node_output = get_node_output(node_outputs, raw_tensor) | |
increment_use_counts(node_output) | |
return node_output | |
def unpack_hook(node_output): | |
return realize(node_output) | |
super().__init__(pack_hook, unpack_hook) | |
def recompute_all(ctx, op, *args, **kwargs): | |
return CheckpointPolicy.PREFER_RECOMPUTE | |
def must_save_all(ctx, op, *args, **kwargs): | |
return CheckpointPolicy.MUST_SAVE | |
def simple_auto_ac(ctx, out, op, *args, **kwargs): | |
# Heuristic: | |
# - I want to periodically save some tensor to cap peak memory | |
# - Prioritize saving expensive tensors | |
# - I want to save more tensors later to when the peak memory is higher | |
# - Keep track of everything I've saved so far | |
# - Provide a budget for how much memory I can use | |
pass | |
policy_stack = [] | |
@contextlib.contextmanager | |
def push_policy(policy_fn): | |
if policy_fn == "recompute_all": | |
policy_fn = recompute_all | |
elif policy_fn == "must_save_all": | |
policy_fn = must_save_all | |
try: | |
policy_stack.append(policy_fn) | |
yield | |
finally: | |
policy_stack.pop() | |
def is_must_policy(policy): | |
return policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.MUST_RECOMPUTE) | |
def is_save_policy(policy): | |
return policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) | |
def current_global_policy(ctx, out, op, *args, **kwargs): | |
# Applies policy functions in policy_stack from outermost to innermost. | |
# - Default policy is PREFER_SAVE. | |
# - A MUST_SAVE policy overrides any PREFER_SAVE policies seen so far. | |
# - If a policy function returns a bool, convert it to a policy via _policy_from_bool. | |
# - If two MUST_SAVE policies conflict, raise an error. | |
# - Return the final resolved policy. | |
current_policy = CheckpointPolicy.PREFER_SAVE | |
for policy_fn in policy_stack: | |
policy = policy_fn(ctx, out, op, *args, **kwargs) | |
if isinstance(policy, bool): | |
policy = _policy_from_bool(policy) | |
elif current_policy != policy: | |
if is_must_policy(policy) and is_must_policy(current_policy): | |
raise RuntimeError( | |
"Conflicting policies found in the policy stack. " | |
"Please ensure that the policy stack is consistent." | |
) | |
if is_must_policy(current_policy): | |
continue | |
current_policy = policy | |
return current_policy | |
is_checkpoint_enabled = False | |
@contextlib.contextmanager | |
def save_recompute_policy(policy_fn="recompute_all"): | |
global is_checkpoint_enabled | |
if is_checkpoint_enabled or policy_fn == "must_save_all": | |
try: | |
with push_policy(policy_fn): | |
yield | |
finally: | |
pass | |
else: | |
try: | |
is_checkpoint_enabled = True | |
node_outputs = WeakTensorKeyDictionary() | |
with push_policy(policy_fn), CheckpointHook(node_outputs), TrackingMode(node_outputs): | |
yield | |
finally: | |
is_checkpoint_enabled = False | |
node_outputs.clear() | |
with save_recompute_policy(policy_fn="recompute_all"): | |
with save_recompute_policy(policy_fn="must_save_all"): | |
out = (a.sin().cos() * 10 * 10 * 10 * 10).sin().cos() | |
out.backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment