Last active
June 27, 2025 17:23
-
-
Save clemlesne/463a88bb87fe658c6df5e147a5e272dd to your computer and use it in GitHub Desktop.
Stream a blob with the native BytesIO and AsyncIterable Python interfaces, from async Redis.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import io | |
from asyncio import AbstractEventLoop | |
from collections.abc import AsyncGenerator, Awaitable, Buffer | |
from contextlib import asynccontextmanager | |
from logging import getLogger | |
from os import environ as env | |
from threading import Thread | |
from typing import TypeVar, cast | |
from redis.asyncio import ConnectionPool, Redis | |
from redis.asyncio.retry import Retry | |
from redis.backoff import ExponentialBackoff | |
from redis.exceptions import ( | |
BusyLoadingError, | |
RedisError, | |
) | |
from redis.exceptions import ( | |
ConnectionError as RedisConnectionError, | |
) | |
T = TypeVar("T") | |
logger = getLogger(__name__) | |
@asynccontextmanager | |
async def redis_stream(key: str) -> AsyncGenerator[io.BytesIO]: | |
""" | |
Get a stream from the cache. | |
""" | |
try: | |
yield RedisBytesIO(key.encode("utf-8")) | |
except RedisError as e: | |
logger.error(f"Error getting Redis stream: {e}") | |
raise e | |
class RedisBytesIO(io.BytesIO, AsyncIterable[bytes]): | |
_closed: bool = False | |
_length: int | None = None | |
_loop: AbstractEventLoop | |
_redis_cache: Redis | None = None | |
_thread: Thread | |
chunk_size: int | |
key: bytes | |
pos: int = 0 | |
def __init__( | |
self, | |
key: bytes, | |
chunk_size: int = 131072, # 128 KiB | |
) -> None: | |
self._loop = asyncio.new_event_loop() | |
self.chunk_size = chunk_size | |
self.key = key | |
# Create a dedicated event loop in a separate thread | |
self._thread = Thread( | |
target=self._run_loop, | |
daemon=True, | |
) | |
self._thread.start() | |
def _run_loop(self) -> None: | |
asyncio.set_event_loop(self._loop) | |
self._loop.run_forever() | |
def _run_sync(self, coro: Awaitable[T]) -> T: | |
future = asyncio.run_coroutine_threadsafe(coro, self._loop) | |
return future.result() | |
async def _redis(self) -> Redis: | |
if self._redis_cache is None: | |
self._redis_cache = await Redis( | |
connection_pool=await _redis_connection_pool(), | |
).__aenter__() | |
return self._redis_cache | |
# BytesIO | |
async def length(self) -> int: | |
if self._length is None: | |
self._length = cast(int, await (await self._redis()).strlen(self.key)) | |
return self._length | |
# BytesIO | |
def readable(self) -> bool: | |
return True | |
# BytesIO | |
def seekable(self) -> bool: | |
return True | |
# BytesIO | |
def writable(self) -> bool: | |
return False | |
# BytesIO | |
def read(self, n: int | None = -1) -> bytes: | |
if not n: | |
return b"" | |
return self._run_sync(self._aread(n)) | |
# AsyncIterable | |
def __aiter__(self): | |
return self | |
# AsyncIterable | |
async def __anext__(self) -> bytes: | |
chunk = await self._aread(self.chunk_size) | |
if not chunk: | |
raise StopAsyncIteration | |
return chunk | |
async def _aread(self, n: int = -1) -> bytes: | |
length = await self.length() | |
if self.pos >= length: | |
return b"" | |
# Read until the end | |
if n < 0: | |
n = length - self.pos | |
# Calculate the end index for GETRANGE (inclusive) | |
end: int = self.pos + n - 1 | |
chunk: bytes = await (await self._redis()).getrange(self.key, self.pos, end) | |
self.pos += len(chunk) | |
return chunk | |
# BytesIO | |
def seek( | |
self, | |
offset: int, | |
whence: int = io.SEEK_SET, | |
) -> int: | |
return self._run_sync(self._aseek(offset, whence)) | |
async def _aseek( | |
self, | |
offset: int, | |
whence: int, | |
) -> int: | |
length = await self.length() | |
if whence == io.SEEK_SET: | |
new_pos: int = offset | |
elif whence == io.SEEK_CUR: | |
new_pos = self.pos + offset | |
elif whence == io.SEEK_END: | |
new_pos = length + offset | |
else: | |
raise ValueError("Invalid whence value") | |
if new_pos < 0: | |
raise ValueError("New position cannot be negative") | |
self.pos = new_pos | |
return self.pos | |
# BytesIO | |
def tell(self) -> int: | |
return self.pos | |
# BytesIO | |
def write(self, _: Buffer) -> int: | |
raise io.UnsupportedOperation("write") | |
# BytesIO | |
@property | |
def closed(self) -> bool: | |
return self._closed | |
# BytesIO | |
def close(self) -> None: | |
if self.closed: | |
return | |
# Close the Redis connection | |
self._run_sync(self._aclose()) | |
# Stop the background event loop | |
self._loop.call_soon_threadsafe(self._loop.stop) | |
self._thread.join() | |
# Mark as closed | |
self._closed = True | |
async def _aclose(self) -> None: | |
if self._redis_cache is not None: | |
await self._redis_cache.__aexit__(None, None, None) | |
self._redis_cache = None | |
async def _redis_connection_pool() -> ConnectionPool: | |
""" | |
Get the Redis connection pool. | |
""" | |
return ConnectionPool.from_url( | |
# Reliability | |
health_check_interval=10, # Check the health of the connection every 10 secs | |
retry_on_error=[BusyLoadingError, RedisConnectionError], | |
retry_on_timeout=True, | |
retry=Retry(backoff=ExponentialBackoff(), retries=3), | |
socket_connect_timeout=5, # Give the system sufficient time to connect even under higher CPU conditions | |
socket_timeout=1, # Respond quickly or abort, this is a cache | |
# Deployment | |
url=env["REDIS_URL"], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment