Last active
January 7, 2022 09:48
-
-
Save ItsDrike/4507190723a8887d2b0668c06a79cc94 to your computer and use it in GitHub Desktop.
Automatic caching for select methods of a python class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import cast, Callable | |
from functools import wraps | |
_MISSING = object() # Sentinel value to avoid conflicts with vars set to None | |
class AutoCacheMeta(type): | |
def __new__(cls: type[AutoCacheMeta], name: str, bases: tuple[type[object]], clsdict: dict[str, object], **kwargs): | |
allow_missing_cache = kwargs.pop("allow_missing_cache", False) | |
clsobj = super().__new__(cls, name, bases, clsdict, **kwargs) | |
cache = cls.get_cache(clsobj, allow_missing_cache) | |
cls.cache_methods(clsobj, cache) | |
return clsobj | |
@staticmethod | |
def get_cache(clsobj: object, allow_missing: bool) -> tuple[str]: | |
"""Ensure `cache` is present in kwargs and has the correct type.""" | |
cache = getattr(clsobj, "_cached", _MISSING) | |
# If cache isn't defined and it's allowed, set it to empty tuple | |
if allow_missing and cache is _MISSING: | |
cache = tuple() | |
if cache is _MISSING: | |
raise ValueError("AutoCacheMeta requires _cached class variable.") | |
if not isinstance(cache, tuple): | |
raise TypeError("_cached class variable must be a tuple of string method names.") | |
if not all(isinstance(el, str) for el in cache): | |
raise TypeError("_cached class variable can only contain strings (method names)") | |
return cast(tuple[str], cache) | |
@staticmethod | |
def is_descriptor(obj: object) -> bool: | |
return any(( | |
hasattr(obj, "__get__"), | |
hasattr(obj, "__set__"), | |
hasattr(obj, "__del__"), | |
)) | |
@classmethod | |
def _cache_descriptor(cls: type[AutoCacheMeta], clsobj: object, name: str) -> None: | |
"""Decorate descriptor's internal functions to allow for caching.""" | |
descriptor = getattr(clsobj, name) | |
if isinstance(descriptor, property): | |
getter_f = descriptor.fget | |
setter_f = descriptor.fset | |
deleter_f = descriptor.fdel | |
else: | |
getter_f = getattr(descriptor, "__get__", None) | |
setter_f = getattr(descriptor, "__set__", None) | |
deleter_f = getattr(descriptor, "__del__", None) | |
if getter_f is None: | |
raise ValueError(f"Triedt to cache getter function of '{name}' descriptor, but it wasn't defined.") | |
if setter_f is not None: | |
# TODO: We need to make a function that overrides setter which | |
# when called, resets the cache for the descriptor's getter | |
raise NotImplementedError("Setter functionality isn't yet supported with caching") | |
if deleter_f is not None: | |
# TODO: We need to make a function that overrides deleter which | |
# when called, removes the cache for the descriptor's geter | |
raise NotImplementedError("Deleter functionality isn't yet supported with caching") | |
# If we're dealing with a property, we need to make a new one, since it's | |
# descriptor functions are read-only. | |
if isinstance(descriptor, property): | |
new = property(fset=descriptor.fset, fdel=descriptor.fdel, fget=cls.cache(getter_f)) | |
return setattr(clsobj, name, new) | |
# We're caching a general descriptor, not a property | |
return setattr(descriptor, "__get__", cls.cache(getter_f)) | |
@classmethod | |
def cache_methods(cls: type[AutoCacheMeta], clsobj: object, cache: tuple[str]) -> None: | |
"""Decorate specified methods to cache with memoization decorator.""" | |
for name in cache: | |
attribute = getattr(clsobj, name, _MISSING) | |
if attribute is _MISSING: | |
raise AttributeError(f"Tried to cache non-existent attribute: '{name}'.") | |
if callable(attribute): | |
# Caching methods without self isn't possible without class-bound cache, | |
# we're only using instance-bound cache here though. | |
if isinstance(attribute, (staticmethod, classmethod)): | |
raise NotImplementedError("Can't cache static/class methods, they can't access the instance-bound cache.") | |
print(f"Found callable: {name}") | |
setattr(clsobj, name, cls.cache(attribute)) | |
continue | |
if cls.is_descriptor(attribute): | |
print(f"Found descriptor: {name}") | |
cls._cache_descriptor(clsobj, name) | |
continue | |
raise TypeError(f"Tried to cache non-callable attribute (can only cache methods/descriptors): '{name}'.") | |
@staticmethod | |
def cache(func: Callable) -> Callable: | |
"""Decorator for methods which should be cached.""" | |
kwd_mark = object() # Sentinel for separating args from kwargs | |
@wraps(func) | |
def wrapper(self, *args, **kwargs): | |
h = hash(self) | |
if not hasattr(self, f"_{__class__.__name__}__hash"): | |
print("Making hash") | |
self.__hash = h | |
self.__cache = {} | |
if self.__hash != h: | |
print("Hash changed! Resetting cache") | |
self.__hash = h | |
self.__cache = {} | |
if func not in self.__cache: | |
print("Populating func dict") | |
self.__cache[func] = {} | |
key = args + (kwd_mark,) + tuple(sorted(kwargs.items())) | |
if key not in self.__cache[func]: | |
print(f"Called func, not found in cache ({key})") | |
val = func(self, *args, **kwargs) | |
self.__cache[func][key] = val | |
print("From cache") | |
return self.__cache[func][key] | |
return wrapper | |
class AutoCacheMixin(metaclass=AutoCacheMeta, allow_missing_cache=True): | |
_cached: tuple[str] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note: This code snippet is licensed under the MIT license. (i.e. you can use this basically anywhere, you can sublicense it, etc. so long as you mention the original source). https://spdx.org/licenses/MIT.html