Skip to content

Instantly share code, notes, and snippets.

@dirkgr
Last active September 12, 2018 20:46
Show Gist options
  • Save dirkgr/3b82d0d223510f5888268c0f5a9ecc65 to your computer and use it in GitHub Desktop.
Save dirkgr/3b82d0d223510f5888268c0f5a9ecc65 to your computer and use it in GitHub Desktop.
A Python function decorator that memoizes a function persistently, to disk
import dill
import lmdb
import functools
from typing import *
import io
import mmh3
def memoize(exceptions: Optional[List] = None, version = None):
if exceptions is None:
exceptions = []
exception_ids = [id(o) for o in exceptions]
class MemoizePickler(dill.Pickler):
def persistent_id(self, obj):
try:
return exception_ids.index(id(obj))
except ValueError:
return None
class MemoizeUnpickler(dill.Unpickler):
def persistent_load(self, pid):
return exceptions[pid]
def memoize_decorator(fn: Callable):
with io.BytesIO() as buffer:
pickler = MemoizePickler(buffer)
pickler.dump(fn)
if version is not None:
pickler.dump(version)
fn_hash = mmh3.hash(buffer.getvalue(), signed=False)
lmbd_env = lmdb.open(
"/tmp/memoize",
map_size=1024 * 1024 * 1024 * 1024,
metasync=False,
meminit=False,
max_dbs=0)
@functools.wraps(fn)
def inner(*args, **kwargs):
with io.BytesIO() as buffer:
pickler = MemoizePickler(buffer)
pickler.dump((args, kwargs))
combined_hash = mmh3.hash_bytes(buffer.getvalue(), seed=fn_hash)
# read from the db
with lmbd_env.begin(buffers=True) as read_txn:
r = read_txn.get(combined_hash, default=None)
if r is not None:
unpickler = MemoizeUnpickler(io.BytesIO(r))
r = unpickler.load()
return r
# if we didn't find anything, run the function and write to the db
if r is None:
r = fn(*args, **kwargs)
with io.BytesIO() as buffer:
pickler = MemoizePickler(buffer)
pickler.dump(r)
with lmbd_env.begin(write=True) as write_txn:
write_txn.put(combined_hash, buffer.getbuffer(), overwrite=True)
return r
return inner
return memoize_decorator
@dirkgr
Copy link
Author

dirkgr commented Sep 12, 2018

Requirements:

dill
mmh3
lmdb

Example:

from typing import *
from nltk.stem.porter import PorterStemmer

# get stemmer from nltk
stemmer = PorterStemmer()
words = [word for word in open('/usr/share/dict/words')]

@memoize()
def stem_all(words: Iterable[str]) -> List[str]:
  return [stemmer.stem(word) for word in words]    # some operation that takes a long time

stemmed_words = stem_all(words)    # the second time you run this, even across process restarts, it will use the cached version

Example with exception:

from typing import *
import Stemmer

# stemmer from the pystemmer package.
# The stemmer is written in C, so it cannot be serialized, and will fail during @memoize().
stem: Callable[[str], str] = Stemmer.Stemmer('english').stemWord
words = [word for word in open('/usr/share/dict/words')]

@memoize([stem])    # specify this exception to make it work anyways
def stem_all(words: Iterable[str]) -> List[str]:
  return [stem(word) for word in words]

stemmed_words = stem_all(words)

You can use exceptions when a function isn't serializable, or if an object is too big and you don't want to have it serialized. Note though that if an excepted object changes, it will not update the cache! If you know you changed an exception object and you want to force a cache refresh, you can specify a version: @memoize([exception1, exception2], version=1). The default version is None. You can use any serializable object as a version. For example, you could use some signature from the thing you are excepting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment