Skip to content

Instantly share code, notes, and snippets.

@frostming
Last active May 30, 2025 07:28
Show Gist options
  • Save frostming/1585f1e3f40bd7627df3d9399bca778c to your computer and use it in GitHub Desktop.
Save frostming/1585f1e3f40bd7627df3d9399bca778c to your computer and use it in GitHub Desktop.
Go context in Python
import asyncio
import time
from typing import Any, Awaitable, Callable, TypeVar
Callback = TypeVar("Callback", bound=Callable[[], Any])
class Context:
def _cancel(self, msg: str | None = None) -> None:
"""Cancel the current task."""
raise NotImplementedError("This method should be implemented by subclasses.")
def done(self) -> asyncio.Future[None]:
"""Return a Future that resolves when the current task is done."""
raise NotImplementedError("This method should be implemented by subclasses.")
def value(self, key: Any) -> Any:
"""Get a value from the context."""
raise NotImplementedError("This method should be implemented by subclasses.")
class _ContextWithParent(Context):
def __init__(self, parent: Context) -> None:
self._parent = parent
def _cancel(self, msg: str | None = None) -> None:
"""Cancel the current task and its parent."""
self._parent._cancel(msg)
def done(self) -> asyncio.Future[None]:
"""Return a Future that resolves when the current task is done."""
return self._parent.done()
def value(self, key: Any) -> Any:
"""Get a value from the context."""
return self._parent.value(key)
class Background(Context):
def __init__(self) -> None:
loop = asyncio.get_event_loop()
self._future = loop.create_future()
def _cancel(self, msg: str | None = None) -> None:
"""Cancel the current task."""
self._future.cancel(msg)
def done(self) -> asyncio.Future[None]:
"""Return a Future that resolves when the current task is done."""
return self._future
def value(self, key: Any) -> Any:
return None
class _ValueContext(_ContextWithParent):
def __init__(self, parent: Context, key: Any, value: Any) -> None:
super().__init__(parent)
self._key = key
self._value = value
def value(self, key: Any) -> Any:
"""Get a value from the context."""
if key == self._key:
return self._value
return super().value(key)
def with_cancel(parent: Context) -> tuple[Context, Callable[[], None]]:
"""Create a context that can cancel the parent context."""
ctx = _ContextWithParent(parent)
return ctx, lambda: ctx._cancel()
def with_value(parent: Context, key: Any, value: Any) -> Context:
"""Create a context that holds a value."""
return _ValueContext(parent, key, value)
def with_timeout(parent: Context, timeout: float) -> tuple[Context, Callable[[], None]]:
"""Create a context that cancels after a timeout."""
loop = asyncio.get_event_loop()
loop.call_later(timeout, parent._cancel, "Timeout reached")
return _ContextWithParent(parent), lambda: parent._cancel()
class Selector:
"""A port of the Go select statement for asyncio futures.
Example usage:
```python
selector = Selector()
ctx, cancel = context.with_cancel(context.Background())
@selector.on(ctx.done())
def on_done():
print("Task is done")
await selector.select()
"""
def __init__(self) -> None:
self._waited = False
self._callbacks: dict[asyncio.Future[Any], Callable[[], Any]] = {}
def on(self, future: Awaitable[Any]) -> Callable[[Callback], Callback]:
"""Register a callback to be called when the future is done."""
def decorator(callback: Callback) -> Callback:
self._callbacks[asyncio.ensure_future(future)] = callback
return callback
return decorator
async def select(self) -> None:
"""Wait for any registered future to complete and call its callback."""
if self._waited:
raise RuntimeError("Selector has already been waited on.")
self._waited = True
if not self._callbacks:
return
done, _ = await asyncio.wait(
self._callbacks.keys(), return_when=asyncio.FIRST_COMPLETED
)
for future in done:
callback = self._callbacks.pop(future, None)
if callback is not None:
callback()
async def test():
"""Test function to demonstrate usage of the Selector."""
selector = Selector()
ctx, cancel = with_timeout(Background(), 2)
start_time = time.perf_counter()
@selector.on(ctx.done())
def on_done():
print("Task is done", time.perf_counter() - start_time)
@selector.on(asyncio.sleep(3))
def on_sleep():
print("Sleep completed", time.perf_counter() - start_time)
await selector.select()
cancel()
if __name__ == "__main__":
asyncio.run(test())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment