Skip to content

Instantly share code, notes, and snippets.

@kkew3
Last active June 22, 2018 20:11
Show Gist options
  • Select an option

  • Save kkew3/e63aad07ba4139bc66a9f49075848a76 to your computer and use it in GitHub Desktop.

Select an option

Save kkew3/e63aad07ba4139bc66a9f49075848a76 to your computer and use it in GitHub Desktop.
Cache function call (function name, positional arguments, keyword arguments) using pickle. The overhead could be huge, so don't try to use this decorator unless the decorated function call takes at least ten minutes to return.
import tempfile
import pickle
import os
from functools import wraps
def _encode_fcall(f, *args, **kwargs):
"""
Encode a function invocation to a dictionary (function call descriptor).
Structure of the function call descriptor:
fd
|- m => __module__
|- n => __name__
|- a => args
|- k => kwargs
|- _hash
| |- m => hash(__module__)
| |- n => hash(__name__)
| |- a => hash(args)
\ \- k => hash(kwargs)
:return: the function call descriptor
"""
fd = dict()
fd['m'] = f.__module__
fd['n'] = f.__name__
fd['a'] = pickle.dumps(args)
fd['k'] = pickle.dumps(sorted(kwargs.items(), key=lambda x: x[0]))
eval_keys = 'mnak' # assert set([x for x in eval_keys]) == set(fd.keys())
# for fast comparison
fd['_hash'] = dict()
for k in eval_keys:
fd['_hash'][k] = hash(fd[k])
return fd
class CacheNotFoundError(BaseException): pass
def _query_fcall(fd, filename):
"""
Query return value given function call descriptor (as returned by
`_encode_fcall`) and the directory where to perform the search.
Cache file structure:
cache
|- h => fd
\- v => return_value?
:param fd: the function call descriptor
:param filename: potential cache file
:return: the return value
:raise: CacheNotFoundError, if not found
"""
eval_keys = 'mnak'
def _dictkey(x):
return x[0]
try:
with open(filename, 'rb') as infile:
cache = pickle.load(infile)
other_fd = cache['h']
assert sorted(fd.keys()) == sorted(other_fd.keys())
assert sorted(fd['_hash'].items(), key=_dictkey) == \
sorted(other_fd['_hash'].items(), key=_dictkey)
for k in eval_keys:
assert fd[k] == other_fd[k]
return cache['v']
except BaseException:
raise CacheNotFoundError()
def pickle_cached(todir='.', protocol=2):
"""
Cache the returned tuple in file. The file basename is related to the
invocation and parameter list, and the file dirname is specified by `todir`.
Usage:
@pickle_cached(todir='.cache')
def very_time_consuming_function(x, y):
pass
:param todir: where to put the cache pickle file; the directory must exist
:param protocol: the protocol of the pickle file; for python2 compatible
concern it's recommended to use protocol 2
"""
def _pickle_cached(f):
@wraps(f)
def wrapper(*args, **kwargs):
fd = _encode_fcall(f, *args, **kwargs)
for filename in os.listdir(todir):
try:
val = _query_fcall(fd, filename)
return val
except CacheNotFoundError:
pass
# if reaching here, all queries must have failed
val = f(*args, **kwargs)
with tempfile.NamedTemporaryFile(mode='wb', suffix='.cache.pkl',
dir=todir, delete=False) as cachef:
jo = dict(h=fd, v=val)
pickle.dump(jo, cachef, protocol=protocol)
return val
return wrapper
return _pickle_cached
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment