Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Created April 9, 2025 22:10
Show Gist options
  • Save soulitzer/7472bc97b909474e77617f3bd474406c to your computer and use it in GitHub Desktop.
Save soulitzer/7472bc97b909474e77617f3bd474406c to your computer and use it in GitHub Desktop.
graph-based AC
import torch
import functools
import contextlib
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakTensorKeyDictionary
from torch.utils.checkpoint import CheckpointPolicy, _policy_from_bool
from collections import namedtuple
import weakref
_NodeOutput = namedtuple('NodeOutput', ['node', 'idx'])
# Subclass instead of using namedtuple directly so that it's a pytree leaf
class NodeOutput(_NodeOutput):
pass
def realize(node_output):
return node_output.node.run(node_output.idx)
def increment_use_counts(node_output):
node_output.node.increment_use_counts(node_output.idx)
class _List(list):
pass
# Note [AC Node use-count tracking to clear cached tensors sooner]
#
# Consider the graph,
#
# A <- B <- C <- D (saved)
# \
# E (saved)
#
# The arrows indicate the direction of ownership. The flow of data is from left
# to right, e.g. A is used to compute B, etc.
#
# When D is unpacked and realized during backward, it's dependencies, C, B, A
# are also recursively realized. To avoid recomputing A and B again when
# E is unpacked, some intermediate result, e.g. B's output, should be cached.
# Conversely A's output is not needed if B is already cached.
#
# Note that when D is unpacked A is not freed because A is still kept alive by E
# Instead of waiting for A to be automatically cleared when E is eventually
# unpacked, we track the number of users of A (1 in this case) and clear it
# as soon as B is computed.
#
# We do this by matching every realize call with a increment_use_count call and
# making sure increment_use_count() recurses the same way as realize().
class Node:
def __init__(self, func=None, args=None, outs: tuple = None):
self.func = func
self.args = args
self.out = None
self.nb_users = dict() # out_idx -> nb_users
if outs is not None:
self.out = _List([x.detach() if isinstance(x, torch.Tensor) else x for x in outs])
def run(self, idx):
if self.out is None:
new_args = tree_map_only(NodeOutput, realize, self.args)
raw_out = self.func(*new_args)
self.out = list(raw_out) if isinstance(raw_out, tuple) else [raw_out]
out = self.out[idx]
self.nb_users[idx] -= 1
if self.nb_users[idx] == 0:
self.out[idx] = None
return out
def increment_use_counts(self, idx):
if self.out is None and len(self.nb_users) == 0:
tree_map_only(NodeOutput, increment_use_counts, self.args)
self.nb_users[idx] = self.nb_users.get(idx, 0) + 1
def get_node_output(node_outputs, t):
if t not in node_outputs:
# If the tensor was created in the checkpoint region, then it would've
# been saved to node_outputs. If it's not there, then it's an input
node_outputs[t] = NodeOutput(Node(None, None, (t,)), 0)
return node_outputs[t]
class Context():
def __init__(self, nodes):
self.nodes = nodes
class TrackingMode(TorchDispatchMode):
def __init__(self, node_outputs):
self.node_outputs = node_outputs
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
out = func(*args, **kwargs)
# Non-tensor are always kept alive by the node
wrapped_args = tree_map_only(
torch.Tensor,
functools.partial(get_node_output, self.node_outputs),
args
)
# TODO: nodes shouldn't be public API right?
ctx = Context(self.node_outputs)
should_save = is_save_policy(current_global_policy(ctx, out, func, *args, **kwargs))
out_tuple = tuple(out) if isinstance(out, (list, tuple)) else (out,)
node = (
Node(func, None, out_tuple) if should_save else
Node(func, wrapped_args, None)
)
for idx, t in enumerate(out_tuple):
if isinstance(t, torch.Tensor):
self.node_outputs[t] = NodeOutput(node, idx)
return out
class CheckpointHook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, node_outputs):
def pack_hook(raw_tensor):
node_output = get_node_output(node_outputs, raw_tensor)
increment_use_counts(node_output)
return node_output
def unpack_hook(node_output):
return realize(node_output)
super().__init__(pack_hook, unpack_hook)
def recompute_all(ctx, op, *args, **kwargs):
return CheckpointPolicy.PREFER_RECOMPUTE
def must_save_all(ctx, op, *args, **kwargs):
return CheckpointPolicy.MUST_SAVE
def simple_auto_ac(ctx, out, op, *args, **kwargs):
# Heuristic:
# - I want to periodically save some tensor to cap peak memory
# - Prioritize saving expensive tensors
# - I want to save more tensors later to when the peak memory is higher
# - Keep track of everything I've saved so far
# - Provide a budget for how much memory I can use
pass
policy_stack = []
@contextlib.contextmanager
def push_policy(policy_fn):
if policy_fn == "recompute_all":
policy_fn = recompute_all
elif policy_fn == "must_save_all":
policy_fn = must_save_all
try:
policy_stack.append(policy_fn)
yield
finally:
policy_stack.pop()
def is_must_policy(policy):
return policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.MUST_RECOMPUTE)
def is_save_policy(policy):
return policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE)
def current_global_policy(ctx, out, op, *args, **kwargs):
# Applies policy functions in policy_stack from outermost to innermost.
# - Default policy is PREFER_SAVE.
# - A MUST_SAVE policy overrides any PREFER_SAVE policies seen so far.
# - If a policy function returns a bool, convert it to a policy via _policy_from_bool.
# - If two MUST_SAVE policies conflict, raise an error.
# - Return the final resolved policy.
current_policy = CheckpointPolicy.PREFER_SAVE
for policy_fn in policy_stack:
policy = policy_fn(ctx, out, op, *args, **kwargs)
if isinstance(policy, bool):
policy = _policy_from_bool(policy)
elif current_policy != policy:
if is_must_policy(policy) and is_must_policy(current_policy):
raise RuntimeError(
"Conflicting policies found in the policy stack. "
"Please ensure that the policy stack is consistent."
)
if is_must_policy(current_policy):
continue
current_policy = policy
return current_policy
is_checkpoint_enabled = False
@contextlib.contextmanager
def save_recompute_policy(policy_fn="recompute_all"):
global is_checkpoint_enabled
if is_checkpoint_enabled or policy_fn == "must_save_all":
try:
with push_policy(policy_fn):
yield
finally:
pass
else:
try:
is_checkpoint_enabled = True
node_outputs = WeakTensorKeyDictionary()
with push_policy(policy_fn), CheckpointHook(node_outputs), TrackingMode(node_outputs):
yield
finally:
is_checkpoint_enabled = False
node_outputs.clear()
with save_recompute_policy(policy_fn="recompute_all"):
with save_recompute_policy(policy_fn="must_save_all"):
out = (a.sin().cos() * 10 * 10 * 10 * 10).sin().cos()
out.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment