Skip to content

Instantly share code, notes, and snippets.

@clemlesne
Last active June 27, 2025 17:23
Show Gist options
  • Save clemlesne/463a88bb87fe658c6df5e147a5e272dd to your computer and use it in GitHub Desktop.
Save clemlesne/463a88bb87fe658c6df5e147a5e272dd to your computer and use it in GitHub Desktop.
Stream a blob with the native BytesIO and AsyncIterable Python interfaces, from async Redis.
import asyncio
import io
from asyncio import AbstractEventLoop
from collections.abc import AsyncGenerator, Awaitable, Buffer
from contextlib import asynccontextmanager
from logging import getLogger
from os import environ as env
from threading import Thread
from typing import TypeVar, cast
from redis.asyncio import ConnectionPool, Redis
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff
from redis.exceptions import (
BusyLoadingError,
RedisError,
)
from redis.exceptions import (
ConnectionError as RedisConnectionError,
)
T = TypeVar("T")
logger = getLogger(__name__)
@asynccontextmanager
async def redis_stream(key: str) -> AsyncGenerator[io.BytesIO]:
"""
Get a stream from the cache.
"""
try:
yield RedisBytesIO(key.encode("utf-8"))
except RedisError as e:
logger.error(f"Error getting Redis stream: {e}")
raise e
class RedisBytesIO(io.BytesIO, AsyncIterable[bytes]):
_closed: bool = False
_length: int | None = None
_loop: AbstractEventLoop
_redis_cache: Redis | None = None
_thread: Thread
chunk_size: int
key: bytes
pos: int = 0
def __init__(
self,
key: bytes,
chunk_size: int = 131072, # 128 KiB
) -> None:
self._loop = asyncio.new_event_loop()
self.chunk_size = chunk_size
self.key = key
# Create a dedicated event loop in a separate thread
self._thread = Thread(
target=self._run_loop,
daemon=True,
)
self._thread.start()
def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()
def _run_sync(self, coro: Awaitable[T]) -> T:
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result()
async def _redis(self) -> Redis:
if self._redis_cache is None:
self._redis_cache = await Redis(
connection_pool=await _redis_connection_pool(),
).__aenter__()
return self._redis_cache
# BytesIO
async def length(self) -> int:
if self._length is None:
self._length = cast(int, await (await self._redis()).strlen(self.key))
return self._length
# BytesIO
def readable(self) -> bool:
return True
# BytesIO
def seekable(self) -> bool:
return True
# BytesIO
def writable(self) -> bool:
return False
# BytesIO
def read(self, n: int | None = -1) -> bytes:
if not n:
return b""
return self._run_sync(self._aread(n))
# AsyncIterable
def __aiter__(self):
return self
# AsyncIterable
async def __anext__(self) -> bytes:
chunk = await self._aread(self.chunk_size)
if not chunk:
raise StopAsyncIteration
return chunk
async def _aread(self, n: int = -1) -> bytes:
length = await self.length()
if self.pos >= length:
return b""
# Read until the end
if n < 0:
n = length - self.pos
# Calculate the end index for GETRANGE (inclusive)
end: int = self.pos + n - 1
chunk: bytes = await (await self._redis()).getrange(self.key, self.pos, end)
self.pos += len(chunk)
return chunk
# BytesIO
def seek(
self,
offset: int,
whence: int = io.SEEK_SET,
) -> int:
return self._run_sync(self._aseek(offset, whence))
async def _aseek(
self,
offset: int,
whence: int,
) -> int:
length = await self.length()
if whence == io.SEEK_SET:
new_pos: int = offset
elif whence == io.SEEK_CUR:
new_pos = self.pos + offset
elif whence == io.SEEK_END:
new_pos = length + offset
else:
raise ValueError("Invalid whence value")
if new_pos < 0:
raise ValueError("New position cannot be negative")
self.pos = new_pos
return self.pos
# BytesIO
def tell(self) -> int:
return self.pos
# BytesIO
def write(self, _: Buffer) -> int:
raise io.UnsupportedOperation("write")
# BytesIO
@property
def closed(self) -> bool:
return self._closed
# BytesIO
def close(self) -> None:
if self.closed:
return
# Close the Redis connection
self._run_sync(self._aclose())
# Stop the background event loop
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join()
# Mark as closed
self._closed = True
async def _aclose(self) -> None:
if self._redis_cache is not None:
await self._redis_cache.__aexit__(None, None, None)
self._redis_cache = None
async def _redis_connection_pool() -> ConnectionPool:
"""
Get the Redis connection pool.
"""
return ConnectionPool.from_url(
# Reliability
health_check_interval=10, # Check the health of the connection every 10 secs
retry_on_error=[BusyLoadingError, RedisConnectionError],
retry_on_timeout=True,
retry=Retry(backoff=ExponentialBackoff(), retries=3),
socket_connect_timeout=5, # Give the system sufficient time to connect even under higher CPU conditions
socket_timeout=1, # Respond quickly or abort, this is a cache
# Deployment
url=env["REDIS_URL"],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment