Skip to content

Instantly share code, notes, and snippets.

@outofmbufs
Created February 16, 2023 21:51
Show Gist options
  • Save outofmbufs/c0d7ca230d7db90afbaa985c714619d4 to your computer and use it in GitHub Desktop.
Save outofmbufs/c0d7ca230d7db90afbaa985c714619d4 to your computer and use it in GitHub Desktop.
enhance python lru_cache with key capability so it will only use a subset of the arguments for cache lookup
# MIT License
#
# Copyright (c) 2023 Neil Webber
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import functools
import collections.abc
# keyed_lru_cache extends functools.lru_cache, allowing caching/memoizing
# to depend only on a *subset* of the function's arguments. It does this by
# using a 'keyfunc' that will be called like this:
#
# cachekey = keyfunc(args, kwargs)
#
# Every time the wrapped/memoized function is called, the wrapper function
# first calls the keyfunc with all the arguments supplied (in packed form,
# i.e., args/kwargs as shown above). The returned cachekey is then the (only)
# value that will be seen by lru_cache for cache lookup prior to invoking the
# memoized function (if the cache misses). Thus the keyfunc has complete
# control over how arguments affect cache hits/misses.
#
# Here's an example:
#
# def bozo(ignored_arg, /, *, important_arg):
# print(ignored_arg)
# return important_arg
#
# If memoized with functools.lru_cache, both arguments would be part of the
# cache index and changing either argument would cause a cache miss.
#
# However if memoized like this:
#
# @keyed_lru_cache(keyfunc=lambda args, kwargs: kwargs['important_arg'])
# def bozo(ignored_arg, /, *, important_arg):
# print(ignored_arg)
# return important_arg
#
# Only the 'important_arg' will be used in cache lookup. Therefore:
#
# bozo("HELLO WORLD", important_arg=2)
#
# would print "HELLO WORLD", but a second call:
#
# bozo("WOAH", important_arg=2)
#
# would not print anything as the value would be found in the cache
# (based on the important_arg only).
#
# ***************************************************************
# CAVEAT CODER: This could certainly be fertile grounds for truly
# bizarre bugs if not used with care.
# ***************************************************************
#
# See kwargkeys() for a helper function for the common case of creating
# a keyfunc based on specific named arguments. Instead of this lambda:
#
# keyfunc=lambda args, kwargs: kwargs['important_arg'])
#
# kwargkeys can be used like this:
#
# keyfunc=kwargkeys('important_arg')
#
# to build a (better) keyfunc. See kwargkeys() for details.
# ----------
# A __Smuggle is the (devious) way arguments are passed but ignored
#
# Every __Smuggle compares equal to every other __Smuggle regardless
# of the values of args/kwargs (!!)
#
# This is an internal class not intended for use elsewhere.
#
class __SmuggleArgs:
__slots__ = ['args', 'kwargs']
def __init__(self, args, kwargs):
self.args = args
self.kwargs = kwargs
# all __SmuggleArgs are equal and hash the same
def __hash__(self):
return 42
def __eq__(self, other):
return True
# Didn't *have to* provide this, but may as well
def __ne__(self, other):
return False
def keyed_lru_cache(*, keyfunc, **lrucache_kwargs):
"""Decorator like lru_cache, but caching on *some* arguments, not all.
keyfunc: a function that will be called every time the decorated
function is called; keyfunc is called like this:
key = keyfunc(args, kwargs)
where args/kwargs are the arguments (packed) given in
the call of the cached/memoized function. The returned
key is what will be used by lru_cache for cache lookups.
lrucache_kwargs: additional arguments that will be given to lru_cache
"""
def _deco(f):
# as far as lru_cache is concerned, the function only has
# two arguments: the key (which will be obtained by keyfunc()
# and the smuggled arguments which ALWAYS COMPARE THE SAME ALWAYS
# (so never cause a cache miss; only key can do that).
#
# The two functions _f() and _smuggler() are essentially the bookends
# around this whole subterfuge. _smuggler() calls the keyfunc()
# to create the key, then passes that and the _Smuggle'd args
# to _f(), which will in fact only be invoked if lru_cache misses.
# _f(), in turn, ignores the key argument and unpacks the smuggled
# arguments to pass to the real (original) function.
#
# NOTE: Code relies on the local closure (e.g., to get f into _f,
# though conceptually that could have been part of the smuggle)
@functools.lru_cache(**lrucache_kwargs)
def _f(key, smarg):
return f(*smarg.args, **smarg.kwargs)
@functools.wraps(f)
def _smuggler(*args, **kwargs):
return _f(keyfunc(args, kwargs), __SmuggleArgs(args, kwargs))
# propagate any "public" (no '_' at start) attrs lru_cache adds to _f
# the intent is to pick up cache_clear, cache_info, etc
for attr in dir(_f):
if attr[0] != '_' and not hasattr(_smuggler, attr):
setattr(_smuggler, attr, getattr(_f, attr))
return _smuggler
return _deco
def kwargkeys(kwargname0, *kwaN, missingkeys={}):
"""For keyed_lru_cache(): build a keyfunc for one or more given keys.
ARGUMENTS:
one or more positional arguments: keyword argument names
missingkeys: see below
Any keyword arguments not present in a memoized function call cause a
miss by default (a unique value is used in the key each time for each
missing key). To change that, provide missingkeys=None (or ANY value).
That (fixed) value will then be used in every key for every missing arg.
To specify a per-key missing value (perhaps to match the defaults of
the memoized function), provide missingkeys=dict(blah=foo, ...) where
each entry in the missingkeys dictionary is a keyword and the value is
the value to use when that keyword is missing. The special key '*'
may be used to specify a default value. If after all this no match
is found, a unique value will be used for each not-found missing key.
NOTE (and hopefully UNDERSTAND) that this has nothing to do with the
value of those arguments seen by the memoized/wrapped function; this is
solely about how to treat defaulted key arguments when computing the key.
"""
kn = (kwargname0, *kwaN)
if not isinstance(missingkeys, collections.abc.Mapping):
missingkeys = {'*': missingkeys}
def keyf(args, kwargs):
# the key is a tuple of the values of the specific keyword args
# - if kwarg[k] is not present then use missingkeys[k]
# - if missingkeys[k] is not present then use missingkeys['*']
# - if missingkeys['*'] is not present use a unique object()
return tuple(
kwargs[k] if k in kwargs
else (missingkeys[k] if k in missingkeys
else missingkeys.get('*', object()))
for k in kn)
return keyf
if __name__ == "__main__":
import unittest
fooby_cachesize = 5
def fooxcachefill(fx, fxv):
testvargs, notcached = fxv()
# Reset the cache to a "known" state, or at a minimum, ensure
# that none of the testvectors values are in the cache...
# COULD use cache_clear but prefer to do it "manually" this way.
for tv in notcached:
fx(tv[0], **tv[1])
# now fill cache with testvectors
for tv in testvargs:
fx(tv[0], **tv[1])
return testvargs, notcached
@keyed_lru_cache(keyfunc=kwargkeys('kwarg'), maxsize=fooby_cachesize)
def fooby(arg0, *, kwarg):
return arg0 + kwarg
def foobyvecs():
# construct test vectors for the fooby variations.
# XXX This "knows" (as do many other parts of the tests) that there
# is only one positional argument in all the test functions.
# toomany is constructed to be twice the cache size
toomany = [((i*3) + 1, {'kwarg': (i*3) + 2}, (i*6) + 3)
for i in range(fooby_cachesize*2)]
# and therefore the first half is the testvectors and the
# second half are values known to not match the testvectors
return toomany[:fooby_cachesize], toomany[fooby_cachesize:]
fooby.filler = lambda: fooxcachefill(fooby, foobyvecs)
# same but: multiple keyword arguments that are cache keys and not
@keyed_lru_cache(keyfunc=kwargkeys('kw0', 'kw1'), maxsize=fooby_cachesize)
def fooby2(arg0, *, kw0, kw1, kwnot):
return arg0 + kw0 + kw1 + kwnot
def fooby2vecs():
# construct test vectors for the fooby variations.
# XXX This "knows" (as do many other parts of the tests) that there
# is only one positional argument in all the test functions.
# toomany is constructed to be twice the cache size
toomany = [((i*3) + 1,
{'kw0': (i*3) + 2, 'kw1': i, 'kwnot': 0}, (i*7) + 3)
for i in range(fooby_cachesize*2)]
# and therefore the first half is the testvectors and the
# second half are values known to not match the testvectors
return toomany[:fooby_cachesize], toomany[fooby_cachesize:]
fooby2.filler = lambda: fooxcachefill(fooby2, fooby2vecs)
testfuncs = [fooby, fooby2]
class TestMethods(unittest.TestCase):
def test_commentexample(self):
# taken (adjusted slightly) from the module comments example
@keyed_lru_cache(
keyfunc=lambda args, kwargs: kwargs['important_arg'])
def bozo(ignored_arg, /, *, important_arg):
# the example uses 'print(ignored_arg)' -- adjusted to:
self.assertEqual(ignored_arg, "HELLO WORLD")
return important_arg
self.assertEqual(2, bozo("HELLO WORLD", important_arg=2))
self.assertEqual(2, bozo("WOAH", important_arg=2))
self.assertEqual('xx', bozo("HELLO WORLD", important_arg='xx'))
def testtrivial(self):
# this basically would work even if the cache didn't exist,
# so it tests that the cache hasn't bolluxed anything up
for f in testfuncs:
with self.subTest(f=f):
testvectors, nc = f.filler()
for tv in testvectors:
self.assertEqual(f(tv[0], **tv[1]), tv[2])
def test_ignored_positionals(self):
# none of the fooby functions use the positional arg
# as part of the cache key; this tests that the positional
# argument is properly ignored
for f in testfuncs:
with self.subTest(f=f):
testvectors, nc = f.filler()
# different arg0 values should not affect the return value
# because all the "correct" results are cached and arg0
# is not part of the key
for i in range(20):
for tv in testvectors:
self.assertEqual(f(i, **tv[1]), tv[2])
def test_cacheflow(self):
# this is really just a test of lru_cache - fill the cache,
# add 1 extra thing, then all the previously cached entries
# should age out in order and new entries should be made.
# This tests that.
for f in testfuncs:
with self.subTest(f=f):
testvectors, nc = f.filler()
# poison the cache with something that isn't in it
nc0 = nc[0]
self.assertEqual(f(nc0[0], **nc0[1]), nc0[2])
# now the new calls should first cache the i=0 case
adjtv = [(0, tv[1], tv[2] - tv[0]) for tv in testvectors]
for i in range(20):
for atv in adjtv:
self.assertEqual(f(i, **atv[1]), atv[2])
def test_cacheclear(self):
for f in testfuncs:
with self.subTest(f=f):
testvectors, nc = f.filler()
tv = testvectors[0]
self.assertEqual(f(tv[0], **tv[1]), tv[2])
# but since it is cached, this should also be the case:
xtv = tv[2] + 1000000000000
self.assertEqual(f(xtv, **tv[1]), tv[2])
# now clear the cache and try again; should NOT be equal
f.cache_clear()
self.assertNotEqual(f(xtv, **tv[1]), tv[2])
def test_kkmiss(self):
# test the simple missingkeys functionality of kwargkeys
missval = 17
@keyed_lru_cache(keyfunc=kwargkeys('kwarg', missingkeys=missval))
def m17(arg0, *, kwarg=missval):
return arg0 + kwarg
x = m17(0)
self.assertEqual(m17(8, kwarg=missval), x)
@keyed_lru_cache(keyfunc=kwargkeys('kwarg'))
def m17miss(arg0, *, kwarg=missval):
return arg0 + kwarg
x = m17miss(0)
self.assertNotEqual(m17miss(8, kwarg=missval), x)
def test_kkmissmulti(self):
# test the more elaborate missingkeys functionality of kwargkeys
dz = 333
dstar = 100
defk = {'kw1': 1, 'kw2': 2, '*': dstar}
@keyed_lru_cache(
keyfunc=kwargkeys('kw1', 'kw2', 'z', missingkeys=defk))
def foo(arg0, *, kw1=1, kw2=2, z=dz, zz=0):
return arg0 + kw1 + kw2 + z + zz
# (args, kwargs, result, (hits, misses, currsize))
# NOTE WELL: These tests demonstrate some of the chaos that
# can ensue from ill-advised use of caching/defaults
testvecs = (((0,), dict(kw1=0, kw2=0), dz, (0, 1, 1)),
# because dz != dstar, this should be a miss
((0,), dict(kw1=0, kw2=0, z=dz), dz, (0, 2, 2)),
# this should be a hit because (again) dz != dstar
# but note that since it is a hit, the return is dz!
((0,), dict(kw1=0, kw2=0, z=dstar), dz, (1, 2, 2)),
# also a hit because zz is not a key
((0,), dict(kw1=0, kw2=0, zz=dstar), dz, (2, 2, 2)),
# also a hit bcs zz not a key
((0,), dict(kw1=0, kw2=0, zz=1234), dz, (3, 2, 2)),
)
for tv in testvecs:
with self.subTest(tv=tv):
args, kwargs, result, cs = tv
x = foo(*args, **kwargs)
self.assertEqual(x, result)
ci = foo.cache_info()
self.assertEqual((ci.hits, ci.misses, ci.currsize), cs)
def test_kkunique(self):
# test that something unique is used for each missing key
dz = 1
@keyed_lru_cache(keyfunc=kwargkeys('kw0'))
def foo(arg0, *, z=dz):
return arg0 + z
# (args, kwargs, result, (hits, misses, currsize))
testvecs = (((0,), dict(), dz, (0, 1, 1)),
# again should miss again
((0,), dict(), dz, (0, 2, 2)),
)
for tv in testvecs:
with self.subTest(tv=tv):
args, kwargs, result, cs = tv
x = foo(*args, **kwargs)
self.assertEqual(x, result)
ci = foo.cache_info()
self.assertEqual((ci.hits, ci.misses, ci.currsize), cs)
def test_multikey(self):
# verify that in a multi-arg key, each component matters
@keyed_lru_cache(keyfunc=kwargkeys('kw0', 'kw1'))
def foo(arg0, *, kw0, kw1):
return arg0 + kw0 + kw1
# (args, kwargs, result, (hits, misses, currsize))
testvecs = (((0,), dict(kw0=0, kw1=0), 0, (0, 1, 1)),
# same thing again should bump hits
((0,), dict(kw0=0, kw1=0), 0, (1, 1, 1)),
# varying kw0 should be a miss and a new answer
((0,), dict(kw0=1, kw1=0), 1, (1, 2, 2)),
# varying kw1 should be a miss and a new answer
((0,), dict(kw0=1, kw1=2), 3, (1, 3, 3)))
for tv in testvecs:
with self.subTest(tv=tv):
args, kwargs, result, cs = tv
x = foo(*args, **kwargs)
self.assertEqual(x, result)
ci = foo.cache_info()
self.assertEqual((ci.hits, ci.misses, ci.currsize), cs)
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment