Last active
May 30, 2025 07:28
-
-
Save frostming/1585f1e3f40bd7627df3d9399bca778c to your computer and use it in GitHub Desktop.
Go context in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import time | |
from typing import Any, Awaitable, Callable, TypeVar | |
Callback = TypeVar("Callback", bound=Callable[[], Any]) | |
class Context: | |
def _cancel(self, msg: str | None = None) -> None: | |
"""Cancel the current task.""" | |
raise NotImplementedError("This method should be implemented by subclasses.") | |
def done(self) -> asyncio.Future[None]: | |
"""Return a Future that resolves when the current task is done.""" | |
raise NotImplementedError("This method should be implemented by subclasses.") | |
def value(self, key: Any) -> Any: | |
"""Get a value from the context.""" | |
raise NotImplementedError("This method should be implemented by subclasses.") | |
class _ContextWithParent(Context): | |
def __init__(self, parent: Context) -> None: | |
self._parent = parent | |
def _cancel(self, msg: str | None = None) -> None: | |
"""Cancel the current task and its parent.""" | |
self._parent._cancel(msg) | |
def done(self) -> asyncio.Future[None]: | |
"""Return a Future that resolves when the current task is done.""" | |
return self._parent.done() | |
def value(self, key: Any) -> Any: | |
"""Get a value from the context.""" | |
return self._parent.value(key) | |
class Background(Context): | |
def __init__(self) -> None: | |
loop = asyncio.get_event_loop() | |
self._future = loop.create_future() | |
def _cancel(self, msg: str | None = None) -> None: | |
"""Cancel the current task.""" | |
self._future.cancel(msg) | |
def done(self) -> asyncio.Future[None]: | |
"""Return a Future that resolves when the current task is done.""" | |
return self._future | |
def value(self, key: Any) -> Any: | |
return None | |
class _ValueContext(_ContextWithParent): | |
def __init__(self, parent: Context, key: Any, value: Any) -> None: | |
super().__init__(parent) | |
self._key = key | |
self._value = value | |
def value(self, key: Any) -> Any: | |
"""Get a value from the context.""" | |
if key == self._key: | |
return self._value | |
return super().value(key) | |
def with_cancel(parent: Context) -> tuple[Context, Callable[[], None]]: | |
"""Create a context that can cancel the parent context.""" | |
ctx = _ContextWithParent(parent) | |
return ctx, lambda: ctx._cancel() | |
def with_value(parent: Context, key: Any, value: Any) -> Context: | |
"""Create a context that holds a value.""" | |
return _ValueContext(parent, key, value) | |
def with_timeout(parent: Context, timeout: float) -> tuple[Context, Callable[[], None]]: | |
"""Create a context that cancels after a timeout.""" | |
loop = asyncio.get_event_loop() | |
loop.call_later(timeout, parent._cancel, "Timeout reached") | |
return _ContextWithParent(parent), lambda: parent._cancel() | |
class Selector: | |
"""A port of the Go select statement for asyncio futures. | |
Example usage: | |
```python | |
selector = Selector() | |
ctx, cancel = context.with_cancel(context.Background()) | |
@selector.on(ctx.done()) | |
def on_done(): | |
print("Task is done") | |
await selector.select() | |
""" | |
def __init__(self) -> None: | |
self._waited = False | |
self._callbacks: dict[asyncio.Future[Any], Callable[[], Any]] = {} | |
def on(self, future: Awaitable[Any]) -> Callable[[Callback], Callback]: | |
"""Register a callback to be called when the future is done.""" | |
def decorator(callback: Callback) -> Callback: | |
self._callbacks[asyncio.ensure_future(future)] = callback | |
return callback | |
return decorator | |
async def select(self) -> None: | |
"""Wait for any registered future to complete and call its callback.""" | |
if self._waited: | |
raise RuntimeError("Selector has already been waited on.") | |
self._waited = True | |
if not self._callbacks: | |
return | |
done, _ = await asyncio.wait( | |
self._callbacks.keys(), return_when=asyncio.FIRST_COMPLETED | |
) | |
for future in done: | |
callback = self._callbacks.pop(future, None) | |
if callback is not None: | |
callback() | |
async def test(): | |
"""Test function to demonstrate usage of the Selector.""" | |
selector = Selector() | |
ctx, cancel = with_timeout(Background(), 2) | |
start_time = time.perf_counter() | |
@selector.on(ctx.done()) | |
def on_done(): | |
print("Task is done", time.perf_counter() - start_time) | |
@selector.on(asyncio.sleep(3)) | |
def on_sleep(): | |
print("Sleep completed", time.perf_counter() - start_time) | |
await selector.select() | |
cancel() | |
if __name__ == "__main__": | |
asyncio.run(test()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment