Created
August 5, 2021 22:12
-
-
Save adriangb/8b55ca2d5f29f31acedce9a97f8be0fb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextvars import Context, ContextVar, copy_context | |
from typing import Any | |
def _set_cvar(cvar: ContextVar, val: Any): | |
cvar.set(val) | |
class CaptureContext: | |
"""Capture changes to the Context within the block. | |
Call .sync() to capture before exiting the block. | |
""" | |
def __init__(self) -> None: | |
self.context = Context() | |
def __enter__(self) -> "CaptureContext": | |
self._outer = copy_context() | |
return self | |
def sync(self): | |
final = copy_context() | |
for cvar in final: | |
if cvar not in self._outer: | |
# new contextvar set | |
self.context.run(_set_cvar, cvar, final.get(cvar)) | |
else: | |
final_val = final.get(cvar) | |
if self._outer.get(cvar) != final_val: | |
# value changed | |
self.context.run(_set_cvar, cvar, final_val) | |
def __exit__(self, *args: Any): | |
self.sync() | |
def restore_context(context: Context) -> None: | |
"""Restore `context` to the current Context""" | |
for cvar in context.keys(): | |
try: | |
cvar.set(context.get(cvar)) | |
except LookupError: | |
cvar.set(context.get(cvar)) | |
# Usage | |
import asyncio | |
import contextvars | |
ctxvar = contextvars.ContextVar("ctx") | |
async def lifespan(cap): | |
with cap: | |
ctx.set("spam") | |
async def endpoint(): | |
assert ctxvar.get() == "spam" | |
async def main(): | |
cap = CaptureContext() | |
await asyncio.create_task(lifespan(cap)) | |
restore_context(cap.context) | |
await asyncio.create_task(endpoint) # run in copy of lifespan's context |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment