Last active
December 10, 2022 05:25
-
-
Save soulitzer/91fd975b60b90209aab9c7be634312dc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from typing import Callable, Any | |
import contextlib | |
from torch.utils._python_dispatch import TorchDispatchMode | |
from typing import Dict, Tuple, Optional, Set | |
import weakref | |
_cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() | |
_original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() | |
_tid_to_weakhandle: Dict[Tuple[int, int], weakref.ReferenceType] = dict() | |
_sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int]]] = dict() | |
_ctx_id = 0 | |
_inside_ctx = False | |
def _get_tid(t) -> Tuple[int, int]: | |
return (id(t), t.data_ptr(), t._version) | |
def _get_sid(t) -> Tuple[int, int]: | |
return (t.data_ptr(), t._version) | |
class _Handle(): | |
pass | |
class _swap_with_cloned(torch.autograd.graph.saved_tensors_hooks): | |
def __init__(self): | |
def pack_hook(t): | |
tid = _get_tid(t) | |
sid = _get_sid(t) | |
# Tensors saved for backward have an entry in _tid_to_weakhandle | |
handle: Optional[_Handle] = None | |
# Save aliasing information | |
if sid not in _sid_to_tid: | |
_sid_to_tid[sid] = set() | |
_sid_to_tid[sid].add(tid) | |
# NB: The same tensor (of the same version) can be saved multiple times | |
if tid not in _tid_to_weakhandle or _tid_to_weakhandle[tid]() is None: | |
handle = _Handle() | |
_tid_to_weakhandle[tid] = weakref.ref(handle) | |
_original[handle] = t | |
else: | |
# Store an additional handle | |
handle = _tid_to_weakhandle[tid]() | |
return _ctx_id, handle | |
def unpack_hook(tup): | |
ctx_id, handle = tup | |
assert ctx_id == _ctx_id, ( | |
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" | |
"in which the graph was originally recorded.") | |
if handle in _cloned: | |
res = _cloned[handle] | |
else: | |
res = _original[handle] | |
return res | |
super().__init__(pack_hook, unpack_hook) | |
class _CloneArgBeforeMutateMode(TorchDispatchMode): | |
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
if kwargs is None: | |
kwargs = {} | |
# (only for in-place ops now, we may want to handle out= later) | |
if func.__name__.split('.')[0][-1] == "_": | |
# The first argument is assumed to be modified in-place | |
tid = _get_tid(args[0]) | |
sid = _get_sid(args[0]) | |
if sid in _sid_to_tid: | |
for tid in _sid_to_tid[sid]: | |
if tid not in _tid_to_weakhandle: | |
# It's never been saved | |
continue | |
handle = _tid_to_weakhandle[tid]() | |
if handle is None or handle in _cloned: | |
# It's been saved, but backward was run OR | |
# The same exactly tensor has been cloned already | |
continue | |
_cloned[handle] = _original[handle].clone() | |
del _original[handle] | |
else: | |
# this can happen with math views, I'm not sure why yet | |
assert not args[0]._is_view() | |
rs = func(*args, **kwargs) | |
return rs | |
@contextlib.contextmanager | |
def allow_mutation_on_saved_tensors(): | |
"""Context manager under which mutating tensors saved for backward is allowed | |
Under this context manager, tensors saved for backward are cloned on mutation, | |
so the original version can still be used during backward. Normally, mutating a tensor | |
saved for backward will result in an error raised when it's used during backward. | |
""" | |
global _inside_ctx, _ctx_id | |
if str(torch.__version__) < '1.13': | |
from torch.utils._python_dispatch import push_torch_dispatch_mode | |
ctx = push_torch_dispatch_mode(_CloneArgBeforeMutateMode) | |
else: | |
ctx = _CloneArgBeforeMutateMode() | |
with _swap_with_cloned(), ctx: | |
try: | |
_ctx_id += 1 | |
if _inside_ctx: | |
raise RuntimeError("allow_mutation_on_saved_tensors contexts cannot be nested") | |
_inside_ctx = True | |
yield | |
finally: | |
_cloned.clear() | |
_ctx_id += 1 | |
_inside_ctx = False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment