Skip to content

Instantly share code, notes, and snippets.

@lanfon72
Created March 6, 2021 14:15
Show Gist options
  • Save lanfon72/03be838ccdfcdef8f80434ac3a68e92b to your computer and use it in GitHub Desktop.
Save lanfon72/03be838ccdfcdef8f80434ac3a68e92b to your computer and use it in GitHub Desktop.
fix contextvar behaviors inside asyncio REPL
import ast
import asyncio
import code
import concurrent.futures
import inspect
import sys
import threading
import types
import warnings
import contextvars
from asyncio import futures, iscoroutinefunction
class AsyncIOInteractiveConsole(code.InteractiveConsole):
def __init__(self, locals, loop):
super().__init__(locals)
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT
self.loop = loop
def runcode(self, code):
future = concurrent.futures.Future()
def callback():
global repl_future
global repl_future_interrupted
global repl_context
repl_future = None
repl_future_interrupted = False
func = types.FunctionType(code, self.locals)
try:
if not iscoroutinefunction(func):
code_ = self.compile("repl_ctx.run(repl_f)", self.filename)
new_locals = dict(repl_ctx=repl_context, repl_f=func, **self.locals)
func = types.FunctionType(code_, new_locals)
coro = func()
except SystemExit:
raise
except KeyboardInterrupt as ex:
repl_future_interrupted = True
future.set_exception(ex)
return
except BaseException as ex:
future.set_exception(ex)
return
if not inspect.iscoroutine(coro):
future.set_result(coro)
return
try:
repl_future = self.loop.create_task(coro)
futures._chain_future(repl_future, future)
except BaseException as exc:
future.set_exception(exc)
loop.call_soon_threadsafe(callback, context=repl_context.copy())
try:
return future.result()
except SystemExit:
raise
except BaseException:
if repl_future_interrupted:
self.write("\nKeyboardInterrupt\n")
else:
self.showtraceback()
class REPLThread(threading.Thread):
def run(self):
try:
banner = (
f'asyncio REPL {sys.version} on {sys.platform}\n'
f'Use "await" directly instead of "asyncio.run()".\n'
f'Type "help", "copyright", "credits" or "license" '
f'for more information.\n'
f'{getattr(sys, "ps1", ">>> ")}import asyncio'
)
console.interact(
banner=banner,
exitmsg='exiting asyncio REPL...')
finally:
warnings.filterwarnings(
'ignore',
message=r'^coroutine .* was never awaited$',
category=RuntimeWarning)
loop.call_soon_threadsafe(loop.stop)
if __name__ == '__main__':
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
repl_locals = {'asyncio': asyncio}
for key in {'__name__', '__package__',
'__loader__', '__spec__',
'__builtins__', '__file__'}:
repl_locals[key] = locals()[key]
console = AsyncIOInteractiveConsole(repl_locals, loop)
repl_future = None
repl_future_interrupted = False
repl_context = contextvars.Context()
try:
import readline # NoQA
except ImportError:
pass
repl_thread = REPLThread()
repl_thread.daemon = True
repl_thread.start()
while True:
try:
loop.run_forever()
except KeyboardInterrupt:
if repl_future and not repl_future.done():
repl_future.cancel()
repl_future_interrupted = True
continue
else:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment