Last active
December 17, 2021 11:15
-
-
Save Tishka17/5ad10fa4641814798b9b58e47ef048b8 to your computer and use it in GitHub Desktop.
Cache protocol
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from logging import getLogger | |
| from typing import ClassVar, TypeVar, Optional, Protocol, Dict, Type | |
| logger = getLogger(__name__) | |
| class Cachable(Protocol): | |
| @classmethod | |
| def get_model_name(self) -> str: | |
| raise NotImplementedError | |
| def get_id(self) -> str: | |
| raise NotImplementedError | |
| def json(self) -> Dict: | |
| return vars(self) | |
| CacheType = TypeVar("CacheType", bound=Cachable, covariant=True) | |
| class DatasetBase(Cachable): | |
| id: str | |
| model_name: ClassVar[str] | |
| def get_id(self): | |
| return self.id | |
| @classmethod | |
| def get_model_name(cls) -> str: | |
| return cls.model_name | |
| @dataclass | |
| class DatasetHeader(DatasetBase): | |
| """Model for card of Dataset in Redis""" | |
| id: str | |
| name: str | |
| model_name: ClassVar[str] = "dataset_header" | |
| @dataclass | |
| class Dataset(DatasetBase): | |
| id: str | |
| data: str | |
| model_name: ClassVar[str] = "dataset" | |
| class RedisDBManager: | |
| def __init__(self): | |
| self.data = {} | |
| def _create_cache_id( | |
| self, cache_type: Type[Cachable], object_id: str, | |
| ) -> str: | |
| return f"{cache_type.get_model_name()}:{object_id}" | |
| def _get_cache_id( | |
| self, item: Cachable, cache_id: Optional[str] = None, | |
| ) -> str: | |
| if not cache_id: | |
| cache_id = item.get_id() | |
| return f"{item.get_model_name()}:{cache_id}" | |
| def save( | |
| self, item: Cachable, cache_id: Optional[str] = None, | |
| ) -> str: | |
| cache_id = self._get_cache_id(item, cache_id) | |
| self.data[cache_id] = item.json() | |
| logger.info("Cache in Redis: %s", cache_id) | |
| return cache_id | |
| def _load( | |
| self, cache_id: str, cache_type: Type[CacheType], | |
| ) -> CacheType: | |
| result = self.data.get(cache_id) | |
| if result: | |
| return cache_type(**result) | |
| logger.info("Cache miss in Redis: %s", cache_id) | |
| raise ValueError | |
| def load( | |
| self, object_id: str, cache_type: Type[CacheType], | |
| ) -> CacheType: | |
| cache_id = self._create_cache_id(cache_type, object_id) | |
| result = self._load(cache_id, cache_type) | |
| return result | |
| def test_1(): | |
| r = RedisDBManager() | |
| dh1 = DatasetHeader("1", "n1") | |
| id1 = r.save(dh1) | |
| assert id1 | |
| assert dh1 == r.load("1", DatasetHeader) | |
| def test_2(): | |
| r = RedisDBManager() | |
| d1 = Dataset("2", "d2") | |
| id1 = r.save(d1) | |
| assert id1 | |
| assert d1 == r.load("2", Dataset) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment