Skip to content

Instantly share code, notes, and snippets.

@adriangb
Created August 5, 2021 22:12
Show Gist options
  • Save adriangb/8b55ca2d5f29f31acedce9a97f8be0fb to your computer and use it in GitHub Desktop.
Save adriangb/8b55ca2d5f29f31acedce9a97f8be0fb to your computer and use it in GitHub Desktop.
from contextvars import Context, ContextVar, copy_context
from typing import Any
def _set_cvar(cvar: ContextVar, val: Any):
cvar.set(val)
class CaptureContext:
"""Capture changes to the Context within the block.
Call .sync() to capture before exiting the block.
"""
def __init__(self) -> None:
self.context = Context()
def __enter__(self) -> "CaptureContext":
self._outer = copy_context()
return self
def sync(self):
final = copy_context()
for cvar in final:
if cvar not in self._outer:
# new contextvar set
self.context.run(_set_cvar, cvar, final.get(cvar))
else:
final_val = final.get(cvar)
if self._outer.get(cvar) != final_val:
# value changed
self.context.run(_set_cvar, cvar, final_val)
def __exit__(self, *args: Any):
self.sync()
def restore_context(context: Context) -> None:
"""Restore `context` to the current Context"""
for cvar in context.keys():
try:
cvar.set(context.get(cvar))
except LookupError:
cvar.set(context.get(cvar))
# Usage
import asyncio
import contextvars
ctxvar = contextvars.ContextVar("ctx")
async def lifespan(cap):
with cap:
ctx.set("spam")
async def endpoint():
assert ctxvar.get() == "spam"
async def main():
cap = CaptureContext()
await asyncio.create_task(lifespan(cap))
restore_context(cap.context)
await asyncio.create_task(endpoint) # run in copy of lifespan's context
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment