Last active
June 22, 2018 20:11
-
-
Save kkew3/e63aad07ba4139bc66a9f49075848a76 to your computer and use it in GitHub Desktop.
Cache function call (function name, positional arguments, keyword arguments) using pickle. The overhead could be huge, so don't try to use this decorator unless the decorated function call takes at least ten minutes to return.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tempfile | |
| import pickle | |
| import os | |
| from functools import wraps | |
| def _encode_fcall(f, *args, **kwargs): | |
| """ | |
| Encode a function invocation to a dictionary (function call descriptor). | |
| Structure of the function call descriptor: | |
| fd | |
| |- m => __module__ | |
| |- n => __name__ | |
| |- a => args | |
| |- k => kwargs | |
| |- _hash | |
| | |- m => hash(__module__) | |
| | |- n => hash(__name__) | |
| | |- a => hash(args) | |
| \ \- k => hash(kwargs) | |
| :return: the function call descriptor | |
| """ | |
| fd = dict() | |
| fd['m'] = f.__module__ | |
| fd['n'] = f.__name__ | |
| fd['a'] = pickle.dumps(args) | |
| fd['k'] = pickle.dumps(sorted(kwargs.items(), key=lambda x: x[0])) | |
| eval_keys = 'mnak' # assert set([x for x in eval_keys]) == set(fd.keys()) | |
| # for fast comparison | |
| fd['_hash'] = dict() | |
| for k in eval_keys: | |
| fd['_hash'][k] = hash(fd[k]) | |
| return fd | |
| class CacheNotFoundError(BaseException): pass | |
| def _query_fcall(fd, filename): | |
| """ | |
| Query return value given function call descriptor (as returned by | |
| `_encode_fcall`) and the directory where to perform the search. | |
| Cache file structure: | |
| cache | |
| |- h => fd | |
| \- v => return_value? | |
| :param fd: the function call descriptor | |
| :param filename: potential cache file | |
| :return: the return value | |
| :raise: CacheNotFoundError, if not found | |
| """ | |
| eval_keys = 'mnak' | |
| def _dictkey(x): | |
| return x[0] | |
| try: | |
| with open(filename, 'rb') as infile: | |
| cache = pickle.load(infile) | |
| other_fd = cache['h'] | |
| assert sorted(fd.keys()) == sorted(other_fd.keys()) | |
| assert sorted(fd['_hash'].items(), key=_dictkey) == \ | |
| sorted(other_fd['_hash'].items(), key=_dictkey) | |
| for k in eval_keys: | |
| assert fd[k] == other_fd[k] | |
| return cache['v'] | |
| except BaseException: | |
| raise CacheNotFoundError() | |
| def pickle_cached(todir='.', protocol=2): | |
| """ | |
| Cache the returned tuple in file. The file basename is related to the | |
| invocation and parameter list, and the file dirname is specified by `todir`. | |
| Usage: | |
| @pickle_cached(todir='.cache') | |
| def very_time_consuming_function(x, y): | |
| pass | |
| :param todir: where to put the cache pickle file; the directory must exist | |
| :param protocol: the protocol of the pickle file; for python2 compatible | |
| concern it's recommended to use protocol 2 | |
| """ | |
| def _pickle_cached(f): | |
| @wraps(f) | |
| def wrapper(*args, **kwargs): | |
| fd = _encode_fcall(f, *args, **kwargs) | |
| for filename in os.listdir(todir): | |
| try: | |
| val = _query_fcall(fd, filename) | |
| return val | |
| except CacheNotFoundError: | |
| pass | |
| # if reaching here, all queries must have failed | |
| val = f(*args, **kwargs) | |
| with tempfile.NamedTemporaryFile(mode='wb', suffix='.cache.pkl', | |
| dir=todir, delete=False) as cachef: | |
| jo = dict(h=fd, v=val) | |
| pickle.dump(jo, cachef, protocol=protocol) | |
| return val | |
| return wrapper | |
| return _pickle_cached |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment