Created
May 4, 2023 06:20
-
-
Save maharjun/86764ce33230d684c884aa3b3d0a4a33 to your computer and use it in GitHub Desktop.
Function Disk Cache
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This module provides utilities for caching the results of time-consuming functions | |
to disk for faster retrieval in future runs. The main classes and functions are: | |
- hash_det: A function that generates a deterministic hash value for an object | |
(immune to changes in its id). | |
- function_takes_no_arguments: A helper function that checks if a given function | |
takes no arguments. | |
- CachedResultCallable: A class that implements a disk-cached version of a callable. | |
The result is cached in a specified location and retrieved from the disk cache in | |
future runs. This class can also handle dependencies to other caches, updating | |
the cache if the dependency cache files have been updated more recently. | |
- cache_result: A decorator that exposes the CachedResultCallable functionality, | |
allowing for a more readable syntax when using disk caching for functions. | |
Usage examples can be found in the documentation of the CachedResultCallable and | |
cache_result classes. | |
Note: CachedResultCallable can only be initialized with a callable that takes no | |
arguments. This is because its purpose is to retrieve the results of long | |
computations from the disk, and we don't expect multiple function calls with | |
different parameters. | |
""" | |
############################################################################### | |
# BSD 3-Clause License | |
# | |
# Copyright (c) 2023, maharjun | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions are met: | |
# | |
# 1. Redistributions of source code must retain the above copyright notice, this | |
# list of conditions and the following disclaimer. | |
# | |
# 2. Redistributions in binary form must reproduce the above copyright notice, | |
# this list of conditions and the following disclaimer in the documentation | |
# and/or other materials provided with the distribution. | |
# | |
# 3. Neither the name of the copyright holder nor the names of its | |
# contributors may be used to endorse or promote products derived from | |
# this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
############################################################################### | |
from __future__ import annotations | |
import os | |
from os.path import join as opj | |
import dill | |
from typing import Any, Iterable | |
from typing import Callable | |
import inspect | |
import logging | |
logger = logging.getLogger('utils.generic.cacheutils') | |
cached_callable_name_set = set() | |
def hash_det(object_to_hash, n_hex_digits=10): | |
""" | |
Gets a deterministic hash value for an object (immune to changes in its id). | |
This serializes the object via dill and returns an md5 hash computed on the | |
dill dump. | |
""" | |
from utils.generic.dillshim import dill | |
import hashlib | |
m = hashlib.md5() | |
m.update(dill.dumps(object_to_hash)) | |
return m.hexdigest()[:n_hex_digits] | |
def function_takes_no_arguments(func: Callable): | |
args_tuple = inspect.getargspec(func) | |
if args_tuple.keywords is not None: | |
return False | |
if args_tuple.varargs is not None: | |
return False | |
if args_tuple.defaults is not None: | |
return False | |
if len(args_tuple.args) > 1: | |
return False | |
elif not hasattr(args_tuple, '__self__') and len(args_tuple.args) > 0: | |
return False | |
return True | |
class CachedResultCallable: | |
""" | |
This class implements a disk cached version of a callable where the result is | |
cached in the specified location and retrieved from the disk cache in the | |
future. Additionally one can specify dependencies to other caches in which | |
case, the cache is updated if the dependency cache files have been updated more | |
recently. | |
Example | |
------- | |
The following example is a basic example:: | |
A = np.rand(100, 100) | |
B = np.rand(100, 100) | |
def long_computation(): | |
return np.sum(A @ B, axis=0) | |
cached_long_computation = CachedResultCallable(long_computation, | |
cache_dir='.', | |
cache_name='rand_multiply_sum', | |
key_to_hash=(A, B)) | |
D = long_computation() # takes a while the first time. If this code isrerun, | |
# simply reads result from disk cache | |
NOTE: CachedResultCallable can only be initialized with a callable that takes | |
no arguments. This is because it's purpose is to retrieve the results of long | |
computations from the disk. Here we don't expect multiple function calls with | |
different parameters. This is more applicable to cache closures that | |
encapsulate the logic of these involved computations | |
A more readable way of doing the same thing is to use the decorator | |
cache_result:: | |
A = np.rand(100, 100) | |
B = np.rand(100, 100) | |
@cache_result(cache_dir='.', cache_name='rand_multiply_sum', key_to_hash=(A, B)) | |
def long_computation(): | |
return np.sum(A @ B, axis=0) | |
D = long_computation() # takes a while the first time. If this code isrerun, | |
# simply reads result from disk cache | |
The various options, and the specifications of dependencies can be seen in the | |
documentation of __init__() | |
""" | |
def __init__(self, function: Callable, | |
cache_dir: str, cache_name: str, key_to_hash: Any = None, | |
perform_cache: bool = True, | |
dependency_callables: Iterable[CachedResultCallable] = []): | |
""" | |
Parameters | |
---------- | |
function: Callable[[], Any] | |
This must be a callable that takes no arguments. The return value of | |
this callable is cached | |
cache_dir: str | |
The directory where the cache file will be created | |
cache_name: str | |
The name of the cache file that will be created is <cache_name>.p | |
key_to_hash: Any | |
Can be any serializable object. This object is serialized and an md5 | |
hash is computed and appended to the filename. If unspecified, the file | |
name is entirely specified by cache_name | |
perform_cache: bool (default: True) | |
A boolean flag to indicate whether to perform caching at all. If True, | |
the cache is always recomputed. Note that even if t | |
dependency_callables: Iterable[CachedResultCallable] (default: []) | |
If this cache depends on any other cached callables they can be | |
specified here. In case any of those caches are updated more recently | |
than the current cache, the cache is invalidated and the next function | |
evaluation will recalculate the cache value. | |
""" | |
assert function_takes_no_arguments(function), \ | |
"CachedResultCallable only accepts functions that take no arguments" | |
self._function = function | |
if not isinstance(cache_dir, str): | |
raise TypeError("The cache_dir must be a string") | |
if not os.path.isdir(cache_dir): | |
raise ValueError("The cache_dir must point to a directory that already exists") | |
self._cache_dir = cache_dir | |
if not isinstance(cache_name, str): | |
raise TypeError("The cache_name must be a string") | |
self._cache_name = cache_name | |
if key_to_hash is not None: | |
try: | |
self._hash = hash_det(key_to_hash) | |
except dill.PicklingError: | |
raise TypeError("The 'key_to_hash' must be serializable using dill") | |
self._key_to_hash = key_to_hash | |
else: | |
self._key_to_hash = None | |
self._dependency_callables = list(dependency_callables) | |
# initializes self._consistent_at_init | |
self._perform_cache = perform_cache | |
self._is_consistent_at_init = self.cache_exists_and_is_consistent | |
# Verify unique path | |
if self.full_path in cached_callable_name_set: | |
raise ValueError(f'It appears that the cache path {self.full_path} is already in use in another CachedCallable') | |
cached_callable_name_set.add(self.full_path) | |
@property | |
def is_recomputed(self): | |
return not self._is_consistent_at_init | |
@property | |
def time_of_update(self): | |
if not hasattr(self, '_time_of_update'): | |
if self._perform_cache: | |
if os.path.isfile(self.full_path): | |
self._time_of_update = os.path.getmtime(self.full_path) | |
else: | |
self._time_of_update = None | |
else: | |
self._time_of_update = None | |
return self._time_of_update | |
@property | |
def full_name(self): | |
if self._key_to_hash is None: | |
return f"{self._cache_name}.p" | |
else: | |
return f"{self._cache_name}_{self._hash}.p" | |
@property | |
def full_path(self): | |
return opj(self._cache_dir, self.full_name) | |
@property | |
def cache_dir(self): | |
return self._cache_dir | |
@property | |
def cached_value(self): | |
if not hasattr(self, '_cached_value'): | |
raise AttributeError("The cached value hasn't been assigned / doesn't exist") | |
return self._cached_value | |
@cached_value.setter | |
def cached_value(self, cvalue): | |
if self.cache_exists_and_is_consistent: | |
raise AttributeError("Cannot reassign cached_value for cache that already contains a cached value") | |
self._cached_value = cvalue | |
@property | |
def cache_exists_and_is_consistent(self): | |
if self.time_of_update is None: | |
return False | |
if any(x.time_of_update is None or x.time_of_update > self.time_of_update | |
for x in self._dependency_callables): | |
return False | |
return True | |
def __enter__(self): | |
if self.cache_exists_and_is_consistent: | |
with open(self.full_path, 'rb') as fin: | |
self._cached_value = dill.load(fin) | |
return self | |
def __exit__(self, *args): | |
import warnings | |
if not self.cache_exists_and_is_consistent: | |
if hasattr(self, '_cached_value'): | |
with open(self.full_path, 'wb') as fout: | |
dill.dump(self._cached_value, fout, protocol=-1) | |
self._time_of_update = os.path.getmtime(self.full_path) # overwrite None value | |
del self._cached_value | |
else: | |
warnings.warn("No cache was created by the CachedResultCallable since the cached value wasn't assigned") | |
def __call__(self): | |
""" | |
Calls the callable and caches the result | |
""" | |
with self as C: | |
if not C.cache_exists_and_is_consistent: | |
C.cached_value = self._function() | |
return C.cached_value | |
class cache_result: | |
""" | |
Exposes the CachedResultCallable as a decorator where the function argument is | |
passed in with the call function | |
""" | |
def __init__(self, cache_dir: str, cache_name: str, key_to_hash: Any = None, | |
dependency_callables: Iterable[CachedResultCallable] = [], | |
perform_cache: bool = True): | |
self.cache_dir = cache_dir | |
self.cache_name = cache_name | |
self.key_to_hash = key_to_hash | |
self.dependency_callables = dependency_callables | |
self.perform_cache = perform_cache | |
def __call__(self, func: Callable) -> CachedResultCallable: | |
return CachedResultCallable(func, | |
cache_dir=self.cache_dir, | |
cache_name=self.cache_name, | |
key_to_hash=self.key_to_hash, | |
dependency_callables=self.dependency_callables, | |
perform_cache=self.perform_cache) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment