Created
March 20, 2016 16:51
-
-
Save dlebech/c16a34f735c0c4e9b604 to your computer and use it in GitHub Desktop.
Python LRU cache that works with coroutines (asyncio)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Global LRU caching utility. For that little bit of extra speed. | |
The caching utility provides a single wrapper function that can be used to | |
provide a bit of extra speed for some often used function. The cache is an LRU | |
cache including a key timeout. | |
Usage:: | |
import cache | |
@cache.memoize | |
def myfun(x, y): | |
return x + y | |
Also support asyncio coroutines:: | |
@cache.memoize | |
async def myfun(x, y): | |
x_result = await fetch_x(x) | |
return x_result + y | |
The cache can be manually cleared with `myfun.cache.clear()` | |
""" | |
import asyncio | |
from functools import wraps | |
from lru import LRUCacheDict | |
__all__ = ['memoize'] | |
def _wrap_coroutine_storage(cache_dict, key, future): | |
async def wrapper(): | |
val = await future | |
cache_dict[key] = val | |
return val | |
return wrapper() | |
def _wrap_value_in_coroutine(val): | |
async def wrapper(): | |
return val | |
return wrapper() | |
def memoize(f): | |
"""An in-memory cache wrapper that can be used on any function, including | |
coroutines. | |
""" | |
__cache = LRUCacheDict(max_size=256, expiration=60) | |
@wraps(f) | |
def wrapper(*args, **kwargs): | |
# Simple key generation. Notice that there are no guarantees that the | |
# key will be the same when using dict arguments. | |
key = f.__module__ + '#' + f.__name__ + '#' + repr((args, kwargs)) | |
try: | |
val = __cache[key] | |
if asyncio.iscoroutinefunction(f): | |
return _wrap_value_in_coroutine(val) | |
return val | |
except KeyError: | |
val = f(*args, **kwargs) | |
if asyncio.iscoroutine(val): | |
# If the value returned by the function is a coroutine, wrap | |
# the future in a new coroutine that stores the actual result | |
# in the cache. | |
return _wrap_coroutine_storage(__cache, key, val) | |
# Otherwise just store and return the value directly | |
__cache[key] = val | |
return val | |
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Tests the caching module.""" | |
import asyncio | |
import unittest | |
import cache | |
called = 0 | |
@cache.memoize | |
def wrapped(): | |
global called | |
called += 1 | |
return 10 | |
class MemoizeClass(object): | |
cls_called = 0 | |
cls_async_called = 0 | |
@classmethod | |
@cache.memoize | |
def my_class_fun(cls): | |
cls.cls_called += 1 | |
return 20 | |
@classmethod | |
@cache.memoize | |
async def my_async_classmethod(cls): | |
cls.cls_async_called += 1 | |
return 40 | |
def __init__(self): | |
self.called = 0 | |
@cache.memoize | |
def my_fun(self): | |
self.called += 1 | |
return 30 | |
@cache.memoize | |
async def my_async_fun(self): | |
self.called += 1 | |
return 50 | |
class TestMemoize(unittest.TestCase): | |
def setUp(self): | |
self.loop = asyncio.new_event_loop() | |
def test_memoize_fun(self): | |
"""It should work for a module level method""" | |
self.assertEqual(called, 0) | |
val = wrapped() | |
self.assertEqual(val, 10) | |
self.assertEqual(called, 1) | |
val = wrapped() | |
self.assertEqual(val, 10) | |
self.assertEqual(called, 1) | |
def test_memoize_class_method(self): | |
"""It should work for a classmethod""" | |
self.assertEqual(MemoizeClass.cls_called, 0) | |
val = MemoizeClass.my_class_fun() | |
self.assertEqual(val, 20) | |
self.assertEqual(MemoizeClass.cls_called, 1) | |
val = MemoizeClass.my_class_fun() | |
self.assertEqual(val, 20) | |
self.assertEqual(MemoizeClass.cls_called, 1) | |
def test_memoize_instance_method(self): | |
"""It should work for an instance method""" | |
mc = MemoizeClass() | |
self.assertEqual(mc.called, 0) | |
val = mc.my_fun() | |
self.assertEqual(val, 30) | |
self.assertEqual(mc.called, 1) | |
val = mc.my_fun() | |
self.assertEqual(val, 30) | |
self.assertEqual(mc.called, 1) | |
def test_memoize_async_classmethod(self): | |
"""It should work with an async coroutine as classmethod.""" | |
self.assertEqual(MemoizeClass.cls_async_called, 0) | |
async def go(): | |
val_fut1 = await MemoizeClass.my_async_classmethod() | |
val_fut2 = await MemoizeClass.my_async_classmethod() | |
self.assertEqual(val_fut1, 40) | |
self.assertEqual(val_fut2, 40) | |
self.loop.run_until_complete(go()) | |
self.assertEqual(MemoizeClass.cls_async_called, 1) | |
def test_memoize_async(self): | |
"""It should work with an async coroutine instance method.""" | |
mc = MemoizeClass() | |
self.assertEqual(mc.called, 0) | |
async def go(): | |
val_fut1 = await mc.my_async_fun() | |
val_fut2 = await mc.my_async_fun() | |
self.assertEqual(val_fut1, 50) | |
self.assertEqual(val_fut2, 50) | |
self.loop.run_until_complete(go()) | |
self.assertEqual(mc.called, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment