Created
April 26, 2025 04:00
-
-
Save Graeme22/a898a7afc88f3e26d8cf6ef4ad567a6b to your computer and use it in GitHub Desktop.
Fix redis-py's terrible types and add serialization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
from typing import Any, Set, Type, TypeVar | |
from redis.asyncio import Redis | |
from redis.typing import AbsExpiryT, ExpiryT, KeyT, ResponseT | |
T = TypeVar("T") | |
class TypedRedis(Redis): | |
async def set( | |
self, | |
name: KeyT, | |
value: Any, | |
ex: ExpiryT | None = None, | |
px: ExpiryT | None = None, | |
nx: bool = False, | |
xx: bool = False, | |
keepttl: bool = False, | |
get: bool = False, | |
exat: AbsExpiryT | None = None, | |
pxat: AbsExpiryT | None = None, | |
) -> ResponseT: | |
return await super().set( | |
name=name, | |
value=pickle.dumps(value), | |
ex=ex, | |
px=px, | |
nx=nx, | |
xx=xx, | |
keepttl=keepttl, | |
get=get, | |
exat=exat, | |
pxat=pxat, | |
) | |
async def get(self, name: KeyT, return_type: Type[T]) -> T | None: # type: ignore[override] | |
res = await super().get(name) | |
if res: | |
return pickle.loads(res) | |
async def hset(self, name: str, key: str, value: Any) -> int: # type: ignore[override] | |
return await super().hset(name=name, key=key, value=pickle.dumps(value)) # type: ignore | |
async def hget(self, name: str, key: str, return_type: Type[T]) -> T | None: # type: ignore[override] | |
res: bytes = await super().hget(name=name, key=key) # type: ignore | |
if res: | |
return pickle.loads(res) | |
async def hgetall(self, name: str, return_type: Type[T]) -> dict[str, T]: # type: ignore[override] | |
res = await super().hgetall(name=name) # type: ignore | |
if res: | |
return {k.decode(): pickle.loads(v) for k, v in res.items()} | |
return {} | |
async def hmget(self, name: str, keys: list[str], return_type: Type[T]) -> list[T]: # type: ignore[override] | |
res: list[bytes] = await super().hmget(name=name, keys=keys) # type: ignore | |
if res: | |
return [pickle.loads(v) for v in res] | |
return [] | |
async def hincrby(self, name: str, key: str, amount: int = 1) -> int: | |
return await super().hincrby(name, key, amount) # type: ignore | |
async def sadd(self, name: str, *values: Any) -> int: | |
serialized = [pickle.dumps(v) for v in values] | |
return await super().sadd(name, *serialized) # type: ignore | |
async def smembers(self, name: str, return_type: Type[T]) -> Set[T]: # type: ignore[override] | |
res = await super().smembers(name) # type: ignore | |
if res: | |
return set(pickle.loads(v) for v in res) | |
return set() | |
async def rpush(self, name: str, *values: Any) -> int: | |
serialized = [pickle.dumps(v) for v in values] | |
return await super().rpush(name, *serialized) # type: ignore | |
async def lindex(self, name: str, index: int, return_type: Type[T]) -> T | None: # type: ignore[override] | |
res: bytes | None = await super().lindex(name, index) # type: ignore | |
if res: | |
return pickle.loads(res) | |
async def lrange( # type: ignore[override] | |
self, name: str, start: int, end: int, return_type: Type[T] | |
) -> list[T]: | |
res = await super().lrange(name, start, end) # type: ignore | |
if res: | |
return [pickle.loads(v) for v in res] | |
return [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment