Last active
January 19, 2024 13:17
-
-
Save mogwai/e9437649335ab75338bdb35eef1e9e88 to your computer and use it in GitHub Desktop.
A simple method to cache functions to some folder basesd on the arguments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hashlib | |
import numpy as np | |
from torch import Tensor | |
from typing import Union | |
import os | |
import pickle | |
from functools import wraps | |
def sha256(b: Union[float, list, Tensor, str, bytes, np.ndarray]): | |
if isinstance(b, (int, list, float)): | |
b = str(b) | |
if isinstance(b, Tensor): | |
b = b.cpu().numpy() | |
if isinstance(b, np.ndarray): | |
b = b.tostring() | |
if type(b) == str: | |
b = b.encode() | |
if type(b) == bytes: | |
return hashlib.sha256(b).hexdigest() | |
else: | |
raise Exception("Not implemented a method to handle {0}".format(type(b))) | |
def hash_arguments(args, kwargs): | |
arguments = list(args) + list(kwargs.keys()) + list(kwargs.values()) | |
return "".join([sha256(b) for b in arguments]) | |
def cache(location=".cache") -> callable: | |
os.makedirs(location, exist_ok=True) | |
def inner_function(f): | |
@wraps(f) | |
def wrapper(*args, **kwargs): | |
s = hash_arguments(args, kwargs) | |
key = f.__name__ + s | |
# Hash the args correctly | |
fname = sha256(key) | |
fname = os.path.join(location, fname) | |
if os.path.exists(fname): | |
with open(fname, "rb") as fl: | |
return pickle.load(fl) | |
ret = f(*args, **kwargs) | |
with open(fname, "wb") as fl: | |
pickle.dump(ret, fl) | |
return ret | |
return wrapper | |
return inner_function | |
@cache() | |
def my_function(arg1: str): | |
return arg1 + "cached!" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment