Created
January 12, 2021 12:36
-
-
Save vincentsarago/12555830684ca23c2ac7a083ad35f461 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Cache Plugin.""" | |
import urllib | |
from typing import Optional | |
import aiocache | |
from starlette.responses import Response | |
from fastapi.dependencies.utils import is_coroutine_callable | |
class cached(aiocache.cached): | |
"""Custom Cached Decorator.""" | |
async def get_from_cache(self, key): | |
"""Custom methods which adds 'X-Cache' in headers.""" | |
try: | |
value = await self.cache.get(key) | |
if isinstance(value, Response): # If the value is a Response we update the headers. | |
value.headers["X-Cache"] = "HIT" | |
return value | |
except Exception: | |
aiocache.logger.exception("Couldn't retrieve %s, unexpected error", key) | |
async def decorator( | |
self, f, *args, cache_read=True, cache_write=True, aiocache_wait_for_write=True, **kwargs | |
): | |
"""Custom method which add compatibility for non-async method.""" | |
key = self.get_cache_key(f, args, kwargs) | |
if cache_read: | |
value = await self.get_from_cache(key) | |
if value is not None: | |
return value | |
# Here we check if the callable can be awaited or not | |
if is_coroutine_callable(f): | |
result = await f(*args, **kwargs) | |
else: | |
result = f(*args, **kwargs) | |
if cache_write: | |
if aiocache_wait_for_write: | |
await self.set_in_cache(key, result) | |
else: | |
asyncio.ensure_future(self.set_in_cache(key, result)) | |
return result | |
def setup_cache( | |
endpoint: Optional[str], | |
ttl: Optional[int], | |
serializer: str = "aiocache.serializers.PickleSerializer" | |
): | |
"""Setup aiocache.""" | |
config = { | |
'cache': "aiocache.SimpleMemoryCache", | |
'serializer': { | |
'class': serializer | |
} | |
} | |
if ttl: | |
config["ttl"] = ttl | |
if endpoint: | |
url = urllib.parse.urlparse(endpoint) | |
ulr_config = dict(urllib.parse.parse_qsl(url.query)) | |
config.update(ulr_config) | |
cache_class = aiocache.Cache.get_scheme_class(url.scheme) | |
config.update(cache_class.parse_uri_path(url.path)) | |
config["endpoint"] = url.hostname | |
config["port"] = str(url.port) | |
if url.password: | |
config["password"] = url.password | |
if cache_class == aiocache.Cache.REDIS: | |
config["cache"] = "aiocache.RedisCache" | |
elif cache_class == aiocache.Cache.MEMCACHED: | |
config["cache"] = "aiocache.MemcachedCache" | |
aiocache.caches.set_config({"default": config}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment