Skip to content

Instantly share code, notes, and snippets.

@jacksmith15
Last active January 10, 2023 00:47
Show Gist options
  • Save jacksmith15/f9596e8b7c237e8c10e040d5da17a0ef to your computer and use it in GitHub Desktop.
Save jacksmith15/f9596e8b7c237e8c10e040d5da17a0ef to your computer and use it in GitHub Desktop.
Redis wrapper leveraging pydantic models for a clean (and very strict) interface. Requires data structures be consistent for each database.
from typing import AsyncIterator
from pydantic import BaseModel
from radish import RedisInterface, Database
class User(BaseModel):
id: int
name: str
class Tweet(BaseModel):
id: int
text: str
author: User
class Radish(RedisInterface):
users = Database(1, User, key="id")
tweets = Database(2, Tweet, key="id")
async def get_user_tweets(user_id: int) -> AsyncIterator[Tweet]:
"""Get tweets for a given user.
This is contrived - in reality you would just filter tweets on their
author id.
"""
async with Radish(
address="redis://localhost:6379",
password="XXX",
) as redis:
user: User = await redis.users.get(user_id)
async for tweet in redis.tweets.filter(author=user):
yield tweet
from abc import abstractmethod
import asyncio
from functools import partial
from operator import attrgetter
from typing import Any, Callable, cast, Dict, Generic, Tuple, Type, TypeVar, Union
from aioredis import create_redis_pool, Redis
from pydantic import BaseModel
from typing_extensions import Protocol
class SupportsStr(Protocol):
"""An ABC with one abstract method __str__."""
@abstractmethod
def __str__(self) -> str:
pass
Model = TypeVar("Model", bound=BaseModel)
class RadishError(Exception):
"""Exception to raise from redis interface."""
class RadishKeyError(RadishError):
"""Exception to raise when requested key does not exist."""
class FilterFactory(Generic[Model]):
def __init__(self, target_model: Type[Model]):
self.target_model = target_model
def __call__(self, **filter_kwargs) -> Callable[[Model], bool]:
bad_kwargs = set(filter_kwargs) - set(self.target_model.__fields__)
if bad_kwargs:
raise RadishError(
f"Invalid filter fields for {self.target_model}: {bad_kwargs}."
)
def filter_func(instance: Model):
for attr, value in filter_kwargs.items():
if not getattr(instance, attr) == value:
return False
return True
return filter_func
class _DatabaseDescriptor(Generic[Model]):
def __init__(
self,
database_id: int,
model: Type[Model],
key: Union[str, Callable[[Model], SupportsStr]],
):
self.database_id = database_id
self.model = model
self._key_func: Callable[[Model], SupportsStr] = attrgetter(key) if isinstance(
key, str
) else key
self.filter_factory = FilterFactory(self.model)
def __call__(
self, connection_factory: Callable = create_redis_pool, **connection_kwargs: Any
) -> "_Database[Model]":
return _Database(self, connection_factory, **connection_kwargs)
def get_key(self, instance: Union[Model, SupportsStr]) -> str:
if isinstance(instance, self.model):
return str(self._key_func(instance))
if isinstance(instance, bytes):
return instance.decode("utf-8")
return str(instance)
def deserialize(self, data: bytes) -> Model:
return self.model.parse_raw(data)
def serialize(self, instance: Model) -> bytes:
return instance.json().encode("utf-8")
NOT_PASSED = object()
class _Database(Generic[Model]):
def __init__(
self,
descriptor: _DatabaseDescriptor[Model],
connection_factory: Callable = create_redis_pool,
**connection_kwargs: Any,
):
self.descriptor = descriptor
self._connection_factory = partial(
connection_factory, **connection_kwargs, db=self.descriptor.database_id
)
self._connection = None
async def __aenter__(self):
if self._connection:
raise RadishError("Already connected to redis!")
self._connection = await self._connection_factory()
return self
async def __aexit__(self, _exception_type, _exception_value, _traceback):
self._connection.close()
await self._connection.wait_closed()
self._connection = None
@property
def connection(self) -> Redis:
if not self._connection:
raise RadishError("Connection to redis has not been initialised.")
return self._connection
async def close(self) -> None:
if self._connection:
self._connection.close()
await self._connection.wait_closed()
async def save(self, instance: Model, allow_update: bool = True, expire: int = None):
if not allow_update:
existing = await self.connection.exists(self.descriptor.get_key(instance))
if existing:
raise RadishError(f"Record for {repr(instance)} already exists")
return await self.connection.set(
self.descriptor.get_key(instance),
self.descriptor.serialize(instance),
expire=expire,
)
async def get(self, instance: Union[Model, SupportsStr], default=NOT_PASSED) -> Model:
key: str = self.descriptor.get_key(instance)
value = await self.connection.get(key)
if value is None and default is NOT_PASSED:
raise RadishKeyError(f"Key {repr(key)} does not exist.")
return self.descriptor.deserialize(value or default)
async def delete(self, instance: Union[Model, SupportsStr]) -> None:
key: str = self.descriptor.get_key(instance)
exists = bool(await self.connection.delete(str(key)))
if not exists:
raise RadishKeyError(f"Key {repr(key)} does not exist.")
async def expire(self, instance: Union[Model, SupportsStr]) -> None:
key: str = self.descriptor.get_key(instance)
exists = bool(await self.connection.expire(str(key)))
if not exists:
raise RadishKeyError(f"Key {repr(key)} does not exist.")
async def __aiter__(self):
async for key in self.connection.iscan():
yield await self.get(key)
async def filter(self, **filter_kwargs: Any):
filter_func = self.descriptor.filter_factory(**filter_kwargs)
async for instance in self:
if filter_func(instance):
yield instance
class RadishMeta(type):
_meta: Dict[str, Any]
def __new__(
mcs, name: str, bases: Tuple[Type], classdict: Dict[str, Any],
):
databases = {
attr: value
for attr, value in classdict.items()
if isinstance(value, _DatabaseDescriptor)
}
for attr in databases:
del classdict[attr]
if "_meta" in classdict:
raise RadishError(
"'_meta' is a reserved class property for `RadishInterface` classes"
)
cls: RadishMeta = cast(
RadishMeta, type.__new__(mcs, name, bases, classdict)
)
cls._meta = {"databases": databases}
return cls
class RedisInterface(metaclass=RadishMeta):
def __init__(self, connection_factory: Callable = create_redis_pool, **redis_settings):
for attr, database_meta in type(self)._meta["databases"].items():
setattr(self, attr, database_meta(connection_factory, **redis_settings))
async def __aenter__(self: "RedisInterfaceT") -> "RedisInterfaceT":
await asyncio.gather(
*[getattr(self, attr).__aenter__() for attr in type(self)._meta["databases"]]
)
return self
async def __aexit__(self, _exception_type, _exception_value, _traceback):
await asyncio.gather(
*[
getattr(self, attr).__aexit__(_exception_type, _exception_value, _traceback)
for attr in type(self)._meta["databases"]
]
)
RedisInterfaceT = TypeVar("RedisInterfaceT", bound=RedisInterface)
def Database(
database_id: int, model: Type[Model], key: Union[str, Callable[[Model], SupportsStr]]
) -> _Database[Model]:
"""This ensures instance type annotations are correct when descriptors are set on the class."""
return cast(_Database, _DatabaseDescriptor(database_id, model, key))
if __name__ == "__main__":
class User(BaseModel):
id: int
name: str
class Radish(RedisInterface):
users = Database(15, User, key="id")
redis_interface = Radish(
address="redis://localhost:6379",
password="XXX",
)
async def cleanup() -> None:
async with redis_interface as redis:
async for user in redis.users:
await redis.users.delete(user)
async def test_simple_redis() -> None:
async with redis_interface as redis:
user: User = User(id=1, name="bob")
await redis.users.save(user)
retrieved_user = await redis.users.get(user)
assert retrieved_user == user
async for user in redis.users:
assert isinstance(user, User)
async def test_filterable_redis() -> None:
async with redis_interface as redis:
users = [
User(id=1, name="bob"),
User(id=2, name="fred"),
]
for user in users:
await redis.users.save(user)
results = [user async for user in redis.users.filter(name="bob")]
assert len(results) == 1
assert results[0] == users[0]
async def run_all():
await cleanup()
await test_simple_redis()
await cleanup()
await test_filterable_redis()
await cleanup()
asyncio.run(run_all())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment