Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Last active October 14, 2024 17:59
Show Gist options
  • Save soulitzer/8041a239b42514a4514950496e5e31a3 to your computer and use it in GitHub Desktop.
Save soulitzer/8041a239b42514a4514950496e5e31a3 to your computer and use it in GitHub Desktop.
Priority Cache
from torch.utils.weak import WeakTensorKeyDictionary
import weakref
from dataclasses import dataclass
import dataclasses
from typing import *
import sys
@dataclass
class CacheEntry:
one: Optional[Union[torch.Tensor, weakref.ReferenceType]] = None
two: Optional[Union[torch.Tensor, weakref.ReferenceType]] = None
three: Optional[Union[torch.Tensor, weakref.ReferenceType]] = None
def _has_external_references(obj):
refcount = sys.getrefcount(obj)
# Expect refcount to be 5 when there are no external references.
# - temporary in "getrefcount"
# - "obj" in this scope
# - cache_entry's __dict__
# - caller's argument
# - (where is the last one coming from?)
BASELINE_REFCOUNT = 5
return refcount > BASELINE_REFCOUNT
def _finalizer(wkd, old_key, cache_entry):
new_key, new_canonical = None, None
for field in dataclasses.fields(cache_entry):
value = getattr(cache_entry, field.name)
if (
isinstance(value, torch.Tensor) and
_has_external_references(value)
):
new_key, new_canonical = field.name, value
break
if new_canonical:
setattr(cache_entry, new_key, weakref.ref(new_canonical))
wkd[new_canonical] = cache_entry
weakref.finalize(new_canonical, _finalizer, wkd, new_key, cache_entry)
setattr(cache_entry, old_key, None)
def _get_current_canonical(cache_entry) -> str:
for field in dataclasses.fields(cache_entry):
value = getattr(cache_entry, field.name)
if isinstance(value, weakref.ReferenceType):
return field.name, value
assert False, "Expected there to be a canonical!"
def viz_ownership(cache_entry):
curr_k, _ = _get_current_canonical(cache_entry)
out_str = []
for field in dataclasses.fields(cache_entry):
val = getattr(cache_entry, field.name)
if isinstance(val, torch.Tensor):
has_external = _has_external_references(val)
out_str.append(f"{field.name}[{has_external}]")
if len(out_str) == 0:
return f"{curr_k} -> ()"
else:
return f"{curr_k} -> {', '.join(out_str)}"
class PriorityCache():
# The PriorityCache is a cache that prioritizes certain entries being
# alive over others:
def __init__(self, cache_entry_cls):
self.cache_entry_cls = cache_entry_cls
self.cache_key_priority = {
field.name: i for i, field in enumerate(dataclasses.fields(cache_entry_cls()))
}
# Manages lifetime
self.wkd = WeakTensorKeyDictionary()
# Maps all offsets/lengths/cpu/cuda to weak CacheEntry
self.reverse = WeakTensorKeyDictionary()
def _maybe_swap_with_current(self, cache_entry, new_key, new_val):
curr_key, curr_val_ref = _get_current_canonical(cache_entry)
curr_val = curr_val_ref()
if self.cache_key_priority[new_key] < self.cache_key_priority[curr_key]:
setattr(cache_entry, new_key, weakref.ref(new_val))
self.wkd[new_val] = cache_entry
weakref.finalize(new_val, _finalizer, self.wkd, new_key, cache_entry)
setattr(cache_entry, curr_key, curr_val)
del self.wkd[curr_val]
def add(self, a, key, b):
assert key in self.cache_key_priority
cache_entry_ref = self.reverse.get(a)
assert cache_entry_ref
cache_entry = cache_entry_ref()
assert cache_entry
setattr(cache_entry, key, b)
self.reverse[b] = weakref.ref(cache_entry)
self._maybe_swap_with_current(cache_entry, key, b)
return cache_entry
def create(self, a, key):
assert key in self.cache_key_priority
cache_entry = self.cache_entry_cls(one=None, two=None, three=None)
self.wkd[a] = cache_entry
setattr(cache_entry, key, weakref.ref(a))
self.reverse[a] = weakref.ref(cache_entry)
weakref.finalize(a, _finalizer, self.wkd, key, cache_entry)
return cache_entry
def get(self, a, key):
assert key in self.cache_key_priority
# Bumps the reference count, now there IS an external reference
cache_entry_ref = self.reverse.get(a)
assert cache_entry_ref
cache_entry = cache_entry_ref()
assert cache_entry
ret = getattr(cache_entry, key)
if ret is not None:
self._maybe_swap_with_current(cache_entry, key, ret)
return ret
def test_no_leaks():
def scope():
a = torch.tensor(1.)
b = torch.tensor(1.)
c = torch.tensor(1.)
d = torch.tensor(1.)
a_ref = weakref.ref(a)
b_ref = weakref.ref(b)
c_ref = weakref.ref(c)
d_ref = weakref.ref(d)
pc = PriorityCache(CacheEntry)
x = pc.create(a, "one")
assert viz_ownership(x) == "one -> ()"
pc.add(a, "two", b)
# This string indicates the ownership direction, e.g., "a -> b, c"
# indicates a keeps b and c alive. The True/False inside "[]" indicates
# whether the "two" has any external references.
assert viz_ownership(x) == "one -> two[True]"
del a
# After the "a" which correspond to key "one" is deleted, the finalizer
# promotes "two" to be responsible for keeping the cache alive, since
# it is now the lowest priority entry.
assert viz_ownership(x) == "two -> ()"
pc.add(b, "one", c)
# After a cache entry for "one" is added, "one" now has the lowest
# priority again, and becomes responsible for keeping the cache alive
# rather than "two".
assert viz_ownership(x) == "one -> two[True]"
pc.add(b, "three", d)
assert viz_ownership(x) == "one -> two[True], three[True]"
del b
assert viz_ownership(x) == "one -> two[False], three[True]"
del c
# After deleting "one", "three" is chosen over "two" even though it has
# a higher priority because two does not have external references.
assert viz_ownership(x) == "three -> two[False]"
# After querying for "two", "two" is assumed to gain an external
# reference and replaces "three".
_ = pc.get(d, "two")
assert viz_ownership(x) == "two -> three[True]"
del _
# If "two" loses its external reference, the direction of ownership
# swaps back again.
assert viz_ownership(x) == "three -> two[False]"
return a_ref, b_ref, c_ref, d_ref
import gc
try:
gc.disable()
# No reference cycles
refs = scope()
assert all([ref() is None for ref in refs])
except:
gc.enable()
test_no_leaks()
def test_swap1():
a = torch.tensor(1.)
b = torch.tensor(1.)
c = torch.tensor(1.)
pc = PriorityCache(CacheEntry)
x = pc.create(a, "three")
assert viz_ownership(x) == "three -> ()"
pc.add(a, "two", b)
assert viz_ownership(x) == "two -> three[True]"
pc.add(a, "one", c)
assert viz_ownership(x) == "one -> two[True], three[True]"
del c
assert viz_ownership(x) == "two -> three[True]"
del b
assert viz_ownership(x) == "three -> ()"
test_swap1()
def test_swap2():
a = torch.tensor(1.)
b = torch.tensor(1.)
c = torch.tensor(1.)
pc = PriorityCache(CacheEntry)
x = pc.create(a, "three")
assert viz_ownership(x) == "three -> ()"
pc.add(a, "two", b)
assert viz_ownership(x) == "two -> three[True]"
pc.add(a, "one", c)
assert viz_ownership(x) == "one -> two[True], three[True]"
del b
assert viz_ownership(x) == "one -> two[False], three[True]"
del c
assert viz_ownership(x) == "three -> two[False]"
new_b = pc.get(a, "two")
assert viz_ownership(x) == "two -> three[True]"
del a
assert viz_ownership(x) == "two -> three[False]"
del new_b
assert isinstance(x.two, weakref.ReferenceType)
assert x.two() is None
assert isinstance(x.three, torch.Tensor)
assert len(pc.wkd) == 0
test_swap2()
def test_basic():
a = torch.tensor(1.)
b = torch.tensor(1.)
c = torch.tensor(1.)
pc = PriorityCache(CacheEntry)
x = pc.create(a, "one")
assert viz_ownership(x) == "one -> ()"
pc.add(a, "two", b)
assert viz_ownership(x) == "one -> two[True]"
pc.add(a, "three", c)
assert viz_ownership(x) == "one -> two[True], three[True]"
del c
assert viz_ownership(x) == "one -> two[True], three[False]"
del b
assert viz_ownership(x) == "one -> two[False], three[False]"
del a
assert isinstance(x.one, weakref.ReferenceType)
assert x.one() is None
assert isinstance(x.two, torch.Tensor)
assert isinstance(x.three, torch.Tensor)
assert len(pc.wkd) == 0
test_basic()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment