Created
September 16, 2011 17:09
-
-
Save dpo/1222577 to your computer and use it in GitHub Desktop.
Memoization of functions and methods with sha1 digest of numpy arrays.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import functools | |
import hashlib | |
class Memoized(object): | |
""" | |
Decorator class used to cache the most recent value of a function or method | |
based on the signature of its arguments. If any single argument changes, | |
the function or method is evaluated afresh. | |
""" | |
def __init__(self, callable): | |
self._callable = callable | |
self._callable_is_method = False | |
self.value = None # Cached value or derivative. | |
self._args_signatures = {} | |
return | |
def __get_signature(self, x): | |
# Return signature of argument. | |
# The signature is the value of the argument or the sha1 digest if the | |
# argument is a numpy array. | |
# Subclass to implement other digests. | |
if isinstance(x, np.ndarray): | |
_x = x.view(np.uint8) | |
return hashlib.sha1(_x).hexdigest() | |
return x | |
def __call__(self, *args, **kwargs): | |
# The callable will be called if any single argument is new or changed. | |
callable = self._callable | |
evaluate = False | |
# If we're memoizing a class method, the first argument will be 'self' | |
# and need not be memoized. | |
firstarg = 1 if self._callable_is_method else 0 | |
# Get signature of all arguments. | |
nargs = callable.func_code.co_argcount # Non-keyword arguments. | |
argnames = callable.func_code.co_varnames[firstarg:nargs] | |
argvals = args[firstarg:] | |
for (argname,argval) in zip(argnames,argvals) + kwargs.items(): | |
_arg_signature = self.__get_signature(argval) | |
try: | |
cached_arg_sig = self._args_signatures[argname] | |
if cached_arg_sig != _arg_signature: | |
self._args_signatures[argname] = _arg_signature | |
evaluate = True | |
except KeyError: | |
self._args_signatures[argname] = _arg_signature | |
evaluate = True | |
# If all arguments are unchanged, return cached value. | |
if evaluate: | |
self.value = callable(*args, **kwargs) | |
return self.value | |
def __get__(self, obj, objtype): | |
"Support instance methods." | |
self._callable_is_method = True | |
return functools.partial(self.__call__, obj) | |
def __repr__(self): | |
"Return the wrapped function or method's docstring." | |
return self.method.__doc__ | |
class Something(object): | |
def __init__(self): | |
self._private_val = 2 | |
self.ncalls = 0 | |
@Memoized | |
def f(self, x, k, *args, **kwargs): | |
"This is a documented method." | |
print 'private val = ', self._private_val # Check access to attributes. | |
self.ncalls += 1 | |
return np.random.random() | |
something = Something() | |
e = np.ones(5) | |
print 'First call:' | |
val = something.f(e, 1, thingy='some stuff', z=3, y=-1, extra_arg='not important') | |
print 'return value = ', val | |
print 'Second call (all same args):' | |
val = something.f(e, 1, thingy='some stuff', z=3, y=-1, extra_arg='not important') | |
print 'return value = ', val | |
print 'Third call (first arg is different):' | |
val = something.f(2*e, 1, thingy='some stuff', z=3, y=-1, extra_arg='not important') | |
print 'return value = ', val | |
print 'Number of evaluations: ', something.ncalls |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment