Last active
October 14, 2024 17:59
-
-
Save soulitzer/8041a239b42514a4514950496e5e31a3 to your computer and use it in GitHub Desktop.
Priority Cache
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.weak import WeakTensorKeyDictionary | |
import weakref | |
from dataclasses import dataclass | |
import dataclasses | |
from typing import * | |
import sys | |
@dataclass | |
class CacheEntry: | |
one: Optional[Union[torch.Tensor, weakref.ReferenceType]] = None | |
two: Optional[Union[torch.Tensor, weakref.ReferenceType]] = None | |
three: Optional[Union[torch.Tensor, weakref.ReferenceType]] = None | |
def _has_external_references(obj): | |
refcount = sys.getrefcount(obj) | |
# Expect refcount to be 5 when there are no external references. | |
# - temporary in "getrefcount" | |
# - "obj" in this scope | |
# - cache_entry's __dict__ | |
# - caller's argument | |
# - (where is the last one coming from?) | |
BASELINE_REFCOUNT = 5 | |
return refcount > BASELINE_REFCOUNT | |
def _finalizer(wkd, old_key, cache_entry): | |
new_key, new_canonical = None, None | |
for field in dataclasses.fields(cache_entry): | |
value = getattr(cache_entry, field.name) | |
if ( | |
isinstance(value, torch.Tensor) and | |
_has_external_references(value) | |
): | |
new_key, new_canonical = field.name, value | |
break | |
if new_canonical: | |
setattr(cache_entry, new_key, weakref.ref(new_canonical)) | |
wkd[new_canonical] = cache_entry | |
weakref.finalize(new_canonical, _finalizer, wkd, new_key, cache_entry) | |
setattr(cache_entry, old_key, None) | |
def _get_current_canonical(cache_entry) -> str: | |
for field in dataclasses.fields(cache_entry): | |
value = getattr(cache_entry, field.name) | |
if isinstance(value, weakref.ReferenceType): | |
return field.name, value | |
assert False, "Expected there to be a canonical!" | |
def viz_ownership(cache_entry): | |
curr_k, _ = _get_current_canonical(cache_entry) | |
out_str = [] | |
for field in dataclasses.fields(cache_entry): | |
val = getattr(cache_entry, field.name) | |
if isinstance(val, torch.Tensor): | |
has_external = _has_external_references(val) | |
out_str.append(f"{field.name}[{has_external}]") | |
if len(out_str) == 0: | |
return f"{curr_k} -> ()" | |
else: | |
return f"{curr_k} -> {', '.join(out_str)}" | |
class PriorityCache(): | |
# The PriorityCache is a cache that prioritizes certain entries being | |
# alive over others: | |
def __init__(self, cache_entry_cls): | |
self.cache_entry_cls = cache_entry_cls | |
self.cache_key_priority = { | |
field.name: i for i, field in enumerate(dataclasses.fields(cache_entry_cls())) | |
} | |
# Manages lifetime | |
self.wkd = WeakTensorKeyDictionary() | |
# Maps all offsets/lengths/cpu/cuda to weak CacheEntry | |
self.reverse = WeakTensorKeyDictionary() | |
def _maybe_swap_with_current(self, cache_entry, new_key, new_val): | |
curr_key, curr_val_ref = _get_current_canonical(cache_entry) | |
curr_val = curr_val_ref() | |
if self.cache_key_priority[new_key] < self.cache_key_priority[curr_key]: | |
setattr(cache_entry, new_key, weakref.ref(new_val)) | |
self.wkd[new_val] = cache_entry | |
weakref.finalize(new_val, _finalizer, self.wkd, new_key, cache_entry) | |
setattr(cache_entry, curr_key, curr_val) | |
del self.wkd[curr_val] | |
def add(self, a, key, b): | |
assert key in self.cache_key_priority | |
cache_entry_ref = self.reverse.get(a) | |
assert cache_entry_ref | |
cache_entry = cache_entry_ref() | |
assert cache_entry | |
setattr(cache_entry, key, b) | |
self.reverse[b] = weakref.ref(cache_entry) | |
self._maybe_swap_with_current(cache_entry, key, b) | |
return cache_entry | |
def create(self, a, key): | |
assert key in self.cache_key_priority | |
cache_entry = self.cache_entry_cls(one=None, two=None, three=None) | |
self.wkd[a] = cache_entry | |
setattr(cache_entry, key, weakref.ref(a)) | |
self.reverse[a] = weakref.ref(cache_entry) | |
weakref.finalize(a, _finalizer, self.wkd, key, cache_entry) | |
return cache_entry | |
def get(self, a, key): | |
assert key in self.cache_key_priority | |
# Bumps the reference count, now there IS an external reference | |
cache_entry_ref = self.reverse.get(a) | |
assert cache_entry_ref | |
cache_entry = cache_entry_ref() | |
assert cache_entry | |
ret = getattr(cache_entry, key) | |
if ret is not None: | |
self._maybe_swap_with_current(cache_entry, key, ret) | |
return ret | |
def test_no_leaks(): | |
def scope(): | |
a = torch.tensor(1.) | |
b = torch.tensor(1.) | |
c = torch.tensor(1.) | |
d = torch.tensor(1.) | |
a_ref = weakref.ref(a) | |
b_ref = weakref.ref(b) | |
c_ref = weakref.ref(c) | |
d_ref = weakref.ref(d) | |
pc = PriorityCache(CacheEntry) | |
x = pc.create(a, "one") | |
assert viz_ownership(x) == "one -> ()" | |
pc.add(a, "two", b) | |
# This string indicates the ownership direction, e.g., "a -> b, c" | |
# indicates a keeps b and c alive. The True/False inside "[]" indicates | |
# whether the "two" has any external references. | |
assert viz_ownership(x) == "one -> two[True]" | |
del a | |
# After the "a" which correspond to key "one" is deleted, the finalizer | |
# promotes "two" to be responsible for keeping the cache alive, since | |
# it is now the lowest priority entry. | |
assert viz_ownership(x) == "two -> ()" | |
pc.add(b, "one", c) | |
# After a cache entry for "one" is added, "one" now has the lowest | |
# priority again, and becomes responsible for keeping the cache alive | |
# rather than "two". | |
assert viz_ownership(x) == "one -> two[True]" | |
pc.add(b, "three", d) | |
assert viz_ownership(x) == "one -> two[True], three[True]" | |
del b | |
assert viz_ownership(x) == "one -> two[False], three[True]" | |
del c | |
# After deleting "one", "three" is chosen over "two" even though it has | |
# a higher priority because two does not have external references. | |
assert viz_ownership(x) == "three -> two[False]" | |
# After querying for "two", "two" is assumed to gain an external | |
# reference and replaces "three". | |
_ = pc.get(d, "two") | |
assert viz_ownership(x) == "two -> three[True]" | |
del _ | |
# If "two" loses its external reference, the direction of ownership | |
# swaps back again. | |
assert viz_ownership(x) == "three -> two[False]" | |
return a_ref, b_ref, c_ref, d_ref | |
import gc | |
try: | |
gc.disable() | |
# No reference cycles | |
refs = scope() | |
assert all([ref() is None for ref in refs]) | |
except: | |
gc.enable() | |
test_no_leaks() | |
def test_swap1(): | |
a = torch.tensor(1.) | |
b = torch.tensor(1.) | |
c = torch.tensor(1.) | |
pc = PriorityCache(CacheEntry) | |
x = pc.create(a, "three") | |
assert viz_ownership(x) == "three -> ()" | |
pc.add(a, "two", b) | |
assert viz_ownership(x) == "two -> three[True]" | |
pc.add(a, "one", c) | |
assert viz_ownership(x) == "one -> two[True], three[True]" | |
del c | |
assert viz_ownership(x) == "two -> three[True]" | |
del b | |
assert viz_ownership(x) == "three -> ()" | |
test_swap1() | |
def test_swap2(): | |
a = torch.tensor(1.) | |
b = torch.tensor(1.) | |
c = torch.tensor(1.) | |
pc = PriorityCache(CacheEntry) | |
x = pc.create(a, "three") | |
assert viz_ownership(x) == "three -> ()" | |
pc.add(a, "two", b) | |
assert viz_ownership(x) == "two -> three[True]" | |
pc.add(a, "one", c) | |
assert viz_ownership(x) == "one -> two[True], three[True]" | |
del b | |
assert viz_ownership(x) == "one -> two[False], three[True]" | |
del c | |
assert viz_ownership(x) == "three -> two[False]" | |
new_b = pc.get(a, "two") | |
assert viz_ownership(x) == "two -> three[True]" | |
del a | |
assert viz_ownership(x) == "two -> three[False]" | |
del new_b | |
assert isinstance(x.two, weakref.ReferenceType) | |
assert x.two() is None | |
assert isinstance(x.three, torch.Tensor) | |
assert len(pc.wkd) == 0 | |
test_swap2() | |
def test_basic(): | |
a = torch.tensor(1.) | |
b = torch.tensor(1.) | |
c = torch.tensor(1.) | |
pc = PriorityCache(CacheEntry) | |
x = pc.create(a, "one") | |
assert viz_ownership(x) == "one -> ()" | |
pc.add(a, "two", b) | |
assert viz_ownership(x) == "one -> two[True]" | |
pc.add(a, "three", c) | |
assert viz_ownership(x) == "one -> two[True], three[True]" | |
del c | |
assert viz_ownership(x) == "one -> two[True], three[False]" | |
del b | |
assert viz_ownership(x) == "one -> two[False], three[False]" | |
del a | |
assert isinstance(x.one, weakref.ReferenceType) | |
assert x.one() is None | |
assert isinstance(x.two, torch.Tensor) | |
assert isinstance(x.three, torch.Tensor) | |
assert len(pc.wkd) == 0 | |
test_basic() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment