Created
February 16, 2023 21:51
-
-
Save outofmbufs/c0d7ca230d7db90afbaa985c714619d4 to your computer and use it in GitHub Desktop.
enhance python lru_cache with key capability so it will only use a subset of the arguments for cache lookup
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# MIT License | |
# | |
# Copyright (c) 2023 Neil Webber | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in | |
# all copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import functools | |
import collections.abc | |
# keyed_lru_cache extends functools.lru_cache, allowing caching/memoizing | |
# to depend only on a *subset* of the function's arguments. It does this by | |
# using a 'keyfunc' that will be called like this: | |
# | |
# cachekey = keyfunc(args, kwargs) | |
# | |
# Every time the wrapped/memoized function is called, the wrapper function | |
# first calls the keyfunc with all the arguments supplied (in packed form, | |
# i.e., args/kwargs as shown above). The returned cachekey is then the (only) | |
# value that will be seen by lru_cache for cache lookup prior to invoking the | |
# memoized function (if the cache misses). Thus the keyfunc has complete | |
# control over how arguments affect cache hits/misses. | |
# | |
# Here's an example: | |
# | |
# def bozo(ignored_arg, /, *, important_arg): | |
# print(ignored_arg) | |
# return important_arg | |
# | |
# If memoized with functools.lru_cache, both arguments would be part of the | |
# cache index and changing either argument would cause a cache miss. | |
# | |
# However if memoized like this: | |
# | |
# @keyed_lru_cache(keyfunc=lambda args, kwargs: kwargs['important_arg']) | |
# def bozo(ignored_arg, /, *, important_arg): | |
# print(ignored_arg) | |
# return important_arg | |
# | |
# Only the 'important_arg' will be used in cache lookup. Therefore: | |
# | |
# bozo("HELLO WORLD", important_arg=2) | |
# | |
# would print "HELLO WORLD", but a second call: | |
# | |
# bozo("WOAH", important_arg=2) | |
# | |
# would not print anything as the value would be found in the cache | |
# (based on the important_arg only). | |
# | |
# *************************************************************** | |
# CAVEAT CODER: This could certainly be fertile grounds for truly | |
# bizarre bugs if not used with care. | |
# *************************************************************** | |
# | |
# See kwargkeys() for a helper function for the common case of creating | |
# a keyfunc based on specific named arguments. Instead of this lambda: | |
# | |
# keyfunc=lambda args, kwargs: kwargs['important_arg']) | |
# | |
# kwargkeys can be used like this: | |
# | |
# keyfunc=kwargkeys('important_arg') | |
# | |
# to build a (better) keyfunc. See kwargkeys() for details. | |
# ---------- | |
# A __Smuggle is the (devious) way arguments are passed but ignored | |
# | |
# Every __Smuggle compares equal to every other __Smuggle regardless | |
# of the values of args/kwargs (!!) | |
# | |
# This is an internal class not intended for use elsewhere. | |
# | |
class __SmuggleArgs: | |
__slots__ = ['args', 'kwargs'] | |
def __init__(self, args, kwargs): | |
self.args = args | |
self.kwargs = kwargs | |
# all __SmuggleArgs are equal and hash the same | |
def __hash__(self): | |
return 42 | |
def __eq__(self, other): | |
return True | |
# Didn't *have to* provide this, but may as well | |
def __ne__(self, other): | |
return False | |
def keyed_lru_cache(*, keyfunc, **lrucache_kwargs): | |
"""Decorator like lru_cache, but caching on *some* arguments, not all. | |
keyfunc: a function that will be called every time the decorated | |
function is called; keyfunc is called like this: | |
key = keyfunc(args, kwargs) | |
where args/kwargs are the arguments (packed) given in | |
the call of the cached/memoized function. The returned | |
key is what will be used by lru_cache for cache lookups. | |
lrucache_kwargs: additional arguments that will be given to lru_cache | |
""" | |
def _deco(f): | |
# as far as lru_cache is concerned, the function only has | |
# two arguments: the key (which will be obtained by keyfunc() | |
# and the smuggled arguments which ALWAYS COMPARE THE SAME ALWAYS | |
# (so never cause a cache miss; only key can do that). | |
# | |
# The two functions _f() and _smuggler() are essentially the bookends | |
# around this whole subterfuge. _smuggler() calls the keyfunc() | |
# to create the key, then passes that and the _Smuggle'd args | |
# to _f(), which will in fact only be invoked if lru_cache misses. | |
# _f(), in turn, ignores the key argument and unpacks the smuggled | |
# arguments to pass to the real (original) function. | |
# | |
# NOTE: Code relies on the local closure (e.g., to get f into _f, | |
# though conceptually that could have been part of the smuggle) | |
@functools.lru_cache(**lrucache_kwargs) | |
def _f(key, smarg): | |
return f(*smarg.args, **smarg.kwargs) | |
@functools.wraps(f) | |
def _smuggler(*args, **kwargs): | |
return _f(keyfunc(args, kwargs), __SmuggleArgs(args, kwargs)) | |
# propagate any "public" (no '_' at start) attrs lru_cache adds to _f | |
# the intent is to pick up cache_clear, cache_info, etc | |
for attr in dir(_f): | |
if attr[0] != '_' and not hasattr(_smuggler, attr): | |
setattr(_smuggler, attr, getattr(_f, attr)) | |
return _smuggler | |
return _deco | |
def kwargkeys(kwargname0, *kwaN, missingkeys={}): | |
"""For keyed_lru_cache(): build a keyfunc for one or more given keys. | |
ARGUMENTS: | |
one or more positional arguments: keyword argument names | |
missingkeys: see below | |
Any keyword arguments not present in a memoized function call cause a | |
miss by default (a unique value is used in the key each time for each | |
missing key). To change that, provide missingkeys=None (or ANY value). | |
That (fixed) value will then be used in every key for every missing arg. | |
To specify a per-key missing value (perhaps to match the defaults of | |
the memoized function), provide missingkeys=dict(blah=foo, ...) where | |
each entry in the missingkeys dictionary is a keyword and the value is | |
the value to use when that keyword is missing. The special key '*' | |
may be used to specify a default value. If after all this no match | |
is found, a unique value will be used for each not-found missing key. | |
NOTE (and hopefully UNDERSTAND) that this has nothing to do with the | |
value of those arguments seen by the memoized/wrapped function; this is | |
solely about how to treat defaulted key arguments when computing the key. | |
""" | |
kn = (kwargname0, *kwaN) | |
if not isinstance(missingkeys, collections.abc.Mapping): | |
missingkeys = {'*': missingkeys} | |
def keyf(args, kwargs): | |
# the key is a tuple of the values of the specific keyword args | |
# - if kwarg[k] is not present then use missingkeys[k] | |
# - if missingkeys[k] is not present then use missingkeys['*'] | |
# - if missingkeys['*'] is not present use a unique object() | |
return tuple( | |
kwargs[k] if k in kwargs | |
else (missingkeys[k] if k in missingkeys | |
else missingkeys.get('*', object())) | |
for k in kn) | |
return keyf | |
if __name__ == "__main__": | |
import unittest | |
fooby_cachesize = 5 | |
def fooxcachefill(fx, fxv): | |
testvargs, notcached = fxv() | |
# Reset the cache to a "known" state, or at a minimum, ensure | |
# that none of the testvectors values are in the cache... | |
# COULD use cache_clear but prefer to do it "manually" this way. | |
for tv in notcached: | |
fx(tv[0], **tv[1]) | |
# now fill cache with testvectors | |
for tv in testvargs: | |
fx(tv[0], **tv[1]) | |
return testvargs, notcached | |
@keyed_lru_cache(keyfunc=kwargkeys('kwarg'), maxsize=fooby_cachesize) | |
def fooby(arg0, *, kwarg): | |
return arg0 + kwarg | |
def foobyvecs(): | |
# construct test vectors for the fooby variations. | |
# XXX This "knows" (as do many other parts of the tests) that there | |
# is only one positional argument in all the test functions. | |
# toomany is constructed to be twice the cache size | |
toomany = [((i*3) + 1, {'kwarg': (i*3) + 2}, (i*6) + 3) | |
for i in range(fooby_cachesize*2)] | |
# and therefore the first half is the testvectors and the | |
# second half are values known to not match the testvectors | |
return toomany[:fooby_cachesize], toomany[fooby_cachesize:] | |
fooby.filler = lambda: fooxcachefill(fooby, foobyvecs) | |
# same but: multiple keyword arguments that are cache keys and not | |
@keyed_lru_cache(keyfunc=kwargkeys('kw0', 'kw1'), maxsize=fooby_cachesize) | |
def fooby2(arg0, *, kw0, kw1, kwnot): | |
return arg0 + kw0 + kw1 + kwnot | |
def fooby2vecs(): | |
# construct test vectors for the fooby variations. | |
# XXX This "knows" (as do many other parts of the tests) that there | |
# is only one positional argument in all the test functions. | |
# toomany is constructed to be twice the cache size | |
toomany = [((i*3) + 1, | |
{'kw0': (i*3) + 2, 'kw1': i, 'kwnot': 0}, (i*7) + 3) | |
for i in range(fooby_cachesize*2)] | |
# and therefore the first half is the testvectors and the | |
# second half are values known to not match the testvectors | |
return toomany[:fooby_cachesize], toomany[fooby_cachesize:] | |
fooby2.filler = lambda: fooxcachefill(fooby2, fooby2vecs) | |
testfuncs = [fooby, fooby2] | |
class TestMethods(unittest.TestCase): | |
def test_commentexample(self): | |
# taken (adjusted slightly) from the module comments example | |
@keyed_lru_cache( | |
keyfunc=lambda args, kwargs: kwargs['important_arg']) | |
def bozo(ignored_arg, /, *, important_arg): | |
# the example uses 'print(ignored_arg)' -- adjusted to: | |
self.assertEqual(ignored_arg, "HELLO WORLD") | |
return important_arg | |
self.assertEqual(2, bozo("HELLO WORLD", important_arg=2)) | |
self.assertEqual(2, bozo("WOAH", important_arg=2)) | |
self.assertEqual('xx', bozo("HELLO WORLD", important_arg='xx')) | |
def testtrivial(self): | |
# this basically would work even if the cache didn't exist, | |
# so it tests that the cache hasn't bolluxed anything up | |
for f in testfuncs: | |
with self.subTest(f=f): | |
testvectors, nc = f.filler() | |
for tv in testvectors: | |
self.assertEqual(f(tv[0], **tv[1]), tv[2]) | |
def test_ignored_positionals(self): | |
# none of the fooby functions use the positional arg | |
# as part of the cache key; this tests that the positional | |
# argument is properly ignored | |
for f in testfuncs: | |
with self.subTest(f=f): | |
testvectors, nc = f.filler() | |
# different arg0 values should not affect the return value | |
# because all the "correct" results are cached and arg0 | |
# is not part of the key | |
for i in range(20): | |
for tv in testvectors: | |
self.assertEqual(f(i, **tv[1]), tv[2]) | |
def test_cacheflow(self): | |
# this is really just a test of lru_cache - fill the cache, | |
# add 1 extra thing, then all the previously cached entries | |
# should age out in order and new entries should be made. | |
# This tests that. | |
for f in testfuncs: | |
with self.subTest(f=f): | |
testvectors, nc = f.filler() | |
# poison the cache with something that isn't in it | |
nc0 = nc[0] | |
self.assertEqual(f(nc0[0], **nc0[1]), nc0[2]) | |
# now the new calls should first cache the i=0 case | |
adjtv = [(0, tv[1], tv[2] - tv[0]) for tv in testvectors] | |
for i in range(20): | |
for atv in adjtv: | |
self.assertEqual(f(i, **atv[1]), atv[2]) | |
def test_cacheclear(self): | |
for f in testfuncs: | |
with self.subTest(f=f): | |
testvectors, nc = f.filler() | |
tv = testvectors[0] | |
self.assertEqual(f(tv[0], **tv[1]), tv[2]) | |
# but since it is cached, this should also be the case: | |
xtv = tv[2] + 1000000000000 | |
self.assertEqual(f(xtv, **tv[1]), tv[2]) | |
# now clear the cache and try again; should NOT be equal | |
f.cache_clear() | |
self.assertNotEqual(f(xtv, **tv[1]), tv[2]) | |
def test_kkmiss(self): | |
# test the simple missingkeys functionality of kwargkeys | |
missval = 17 | |
@keyed_lru_cache(keyfunc=kwargkeys('kwarg', missingkeys=missval)) | |
def m17(arg0, *, kwarg=missval): | |
return arg0 + kwarg | |
x = m17(0) | |
self.assertEqual(m17(8, kwarg=missval), x) | |
@keyed_lru_cache(keyfunc=kwargkeys('kwarg')) | |
def m17miss(arg0, *, kwarg=missval): | |
return arg0 + kwarg | |
x = m17miss(0) | |
self.assertNotEqual(m17miss(8, kwarg=missval), x) | |
def test_kkmissmulti(self): | |
# test the more elaborate missingkeys functionality of kwargkeys | |
dz = 333 | |
dstar = 100 | |
defk = {'kw1': 1, 'kw2': 2, '*': dstar} | |
@keyed_lru_cache( | |
keyfunc=kwargkeys('kw1', 'kw2', 'z', missingkeys=defk)) | |
def foo(arg0, *, kw1=1, kw2=2, z=dz, zz=0): | |
return arg0 + kw1 + kw2 + z + zz | |
# (args, kwargs, result, (hits, misses, currsize)) | |
# NOTE WELL: These tests demonstrate some of the chaos that | |
# can ensue from ill-advised use of caching/defaults | |
testvecs = (((0,), dict(kw1=0, kw2=0), dz, (0, 1, 1)), | |
# because dz != dstar, this should be a miss | |
((0,), dict(kw1=0, kw2=0, z=dz), dz, (0, 2, 2)), | |
# this should be a hit because (again) dz != dstar | |
# but note that since it is a hit, the return is dz! | |
((0,), dict(kw1=0, kw2=0, z=dstar), dz, (1, 2, 2)), | |
# also a hit because zz is not a key | |
((0,), dict(kw1=0, kw2=0, zz=dstar), dz, (2, 2, 2)), | |
# also a hit bcs zz not a key | |
((0,), dict(kw1=0, kw2=0, zz=1234), dz, (3, 2, 2)), | |
) | |
for tv in testvecs: | |
with self.subTest(tv=tv): | |
args, kwargs, result, cs = tv | |
x = foo(*args, **kwargs) | |
self.assertEqual(x, result) | |
ci = foo.cache_info() | |
self.assertEqual((ci.hits, ci.misses, ci.currsize), cs) | |
def test_kkunique(self): | |
# test that something unique is used for each missing key | |
dz = 1 | |
@keyed_lru_cache(keyfunc=kwargkeys('kw0')) | |
def foo(arg0, *, z=dz): | |
return arg0 + z | |
# (args, kwargs, result, (hits, misses, currsize)) | |
testvecs = (((0,), dict(), dz, (0, 1, 1)), | |
# again should miss again | |
((0,), dict(), dz, (0, 2, 2)), | |
) | |
for tv in testvecs: | |
with self.subTest(tv=tv): | |
args, kwargs, result, cs = tv | |
x = foo(*args, **kwargs) | |
self.assertEqual(x, result) | |
ci = foo.cache_info() | |
self.assertEqual((ci.hits, ci.misses, ci.currsize), cs) | |
def test_multikey(self): | |
# verify that in a multi-arg key, each component matters | |
@keyed_lru_cache(keyfunc=kwargkeys('kw0', 'kw1')) | |
def foo(arg0, *, kw0, kw1): | |
return arg0 + kw0 + kw1 | |
# (args, kwargs, result, (hits, misses, currsize)) | |
testvecs = (((0,), dict(kw0=0, kw1=0), 0, (0, 1, 1)), | |
# same thing again should bump hits | |
((0,), dict(kw0=0, kw1=0), 0, (1, 1, 1)), | |
# varying kw0 should be a miss and a new answer | |
((0,), dict(kw0=1, kw1=0), 1, (1, 2, 2)), | |
# varying kw1 should be a miss and a new answer | |
((0,), dict(kw0=1, kw1=2), 3, (1, 3, 3))) | |
for tv in testvecs: | |
with self.subTest(tv=tv): | |
args, kwargs, result, cs = tv | |
x = foo(*args, **kwargs) | |
self.assertEqual(x, result) | |
ci = foo.cache_info() | |
self.assertEqual((ci.hits, ci.misses, ci.currsize), cs) | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment