Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Last active December 10, 2022 05:25
Show Gist options
  • Save soulitzer/91fd975b60b90209aab9c7be634312dc to your computer and use it in GitHub Desktop.
Save soulitzer/91fd975b60b90209aab9c7be634312dc to your computer and use it in GitHub Desktop.
import torch
from typing import Callable, Any
import contextlib
from torch.utils._python_dispatch import TorchDispatchMode
from typing import Dict, Tuple, Optional, Set
import weakref
_cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
_original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
_tid_to_weakhandle: Dict[Tuple[int, int], weakref.ReferenceType] = dict()
_sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int]]] = dict()
_ctx_id = 0
_inside_ctx = False
def _get_tid(t) -> Tuple[int, int]:
return (id(t), t.data_ptr(), t._version)
def _get_sid(t) -> Tuple[int, int]:
return (t.data_ptr(), t._version)
class _Handle():
pass
class _swap_with_cloned(torch.autograd.graph.saved_tensors_hooks):
def __init__(self):
def pack_hook(t):
tid = _get_tid(t)
sid = _get_sid(t)
# Tensors saved for backward have an entry in _tid_to_weakhandle
handle: Optional[_Handle] = None
# Save aliasing information
if sid not in _sid_to_tid:
_sid_to_tid[sid] = set()
_sid_to_tid[sid].add(tid)
# NB: The same tensor (of the same version) can be saved multiple times
if tid not in _tid_to_weakhandle or _tid_to_weakhandle[tid]() is None:
handle = _Handle()
_tid_to_weakhandle[tid] = weakref.ref(handle)
_original[handle] = t
else:
# Store an additional handle
handle = _tid_to_weakhandle[tid]()
return _ctx_id, handle
def unpack_hook(tup):
ctx_id, handle = tup
assert ctx_id == _ctx_id, (
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
"in which the graph was originally recorded.")
if handle in _cloned:
res = _cloned[handle]
else:
res = _original[handle]
return res
super().__init__(pack_hook, unpack_hook)
class _CloneArgBeforeMutateMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# (only for in-place ops now, we may want to handle out= later)
if func.__name__.split('.')[0][-1] == "_":
# The first argument is assumed to be modified in-place
tid = _get_tid(args[0])
sid = _get_sid(args[0])
if sid in _sid_to_tid:
for tid in _sid_to_tid[sid]:
if tid not in _tid_to_weakhandle:
# It's never been saved
continue
handle = _tid_to_weakhandle[tid]()
if handle is None or handle in _cloned:
# It's been saved, but backward was run OR
# The same exactly tensor has been cloned already
continue
_cloned[handle] = _original[handle].clone()
del _original[handle]
else:
# this can happen with math views, I'm not sure why yet
assert not args[0]._is_view()
rs = func(*args, **kwargs)
return rs
@contextlib.contextmanager
def allow_mutation_on_saved_tensors():
"""Context manager under which mutating tensors saved for backward is allowed
Under this context manager, tensors saved for backward are cloned on mutation,
so the original version can still be used during backward. Normally, mutating a tensor
saved for backward will result in an error raised when it's used during backward.
"""
global _inside_ctx, _ctx_id
if str(torch.__version__) < '1.13':
from torch.utils._python_dispatch import push_torch_dispatch_mode
ctx = push_torch_dispatch_mode(_CloneArgBeforeMutateMode)
else:
ctx = _CloneArgBeforeMutateMode()
with _swap_with_cloned(), ctx:
try:
_ctx_id += 1
if _inside_ctx:
raise RuntimeError("allow_mutation_on_saved_tensors contexts cannot be nested")
_inside_ctx = True
yield
finally:
_cloned.clear()
_ctx_id += 1
_inside_ctx = False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment