Last active
June 25, 2018 07:16
-
-
Save agronholm/6bc337870fdc676cbac62befa9eb5988 to your computer and use it in GitHub Desktop.
Trio version of "async with threadpool():"
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import gc | |
| import inspect | |
| import threading | |
| import trio | |
| from trio import run_sync_in_worker_thread | |
| from trio.hazmat import ( | |
| spawn_system_task, wait_task_rescheduled, reschedule, current_task, Error, Value, Abort) | |
| def get_current_coroutine(): | |
| previous_frame = inspect.currentframe().f_back.f_back | |
| try: | |
| return next(obj for obj in gc.get_referrers(previous_frame.f_code) | |
| if inspect.iscoroutine(obj) and obj.cr_frame is previous_frame) | |
| except StopIteration: | |
| raise RuntimeError('Cannot find the current coroutine object') from None | |
| class _ThreadSwitcher: | |
| def __init__(self) -> None: | |
| self.task = current_task() | |
| self.cancelled = False | |
| self.exited = False | |
| def work_thread(self): | |
| print('sending to coroutine in thread', threading.get_ident()) | |
| try: | |
| self.coro.send(Value(None)) | |
| except BaseException as e: | |
| print('error in thread:', e) | |
| raise | |
| print('thread done') | |
| async def start_thread(self): | |
| print('starting thread') | |
| try: | |
| await run_sync_in_worker_thread(self.work_thread) | |
| # await run_sync_in_worker_thread(self.coro.send, Value(None)) | |
| except BaseException as e: | |
| print('error: ', e) | |
| outcome = Error(e) | |
| else: | |
| print('success') | |
| outcome = Value(None) | |
| if not self.cancelled: | |
| print('rescheduling with outcome:', outcome) | |
| reschedule(self.task, outcome) | |
| def abort_func(self, raise_cancel): | |
| print('aborted') | |
| self.cancelled = True | |
| return Abort.SUCCEEDED | |
| async def __aenter__(self): | |
| print('__aenter__, thread=', threading.get_ident()) | |
| self.coro = get_current_coroutine() | |
| spawn_system_task(self.start_thread) | |
| await wait_task_rescheduled(self.abort_func) | |
| def __aexit__(self, exc_type, exc_val, exc_tb): | |
| print('__aexit__, exc_type=', exc_type, 'thread=', threading.get_ident()) | |
| return self | |
| def __await__(self): | |
| print('__await__, thread=', threading.get_ident()) | |
| yield None | |
| def switch_to_worker_thread(): | |
| return _ThreadSwitcher() | |
| async def foo(): | |
| print('begin, running in thread', threading.get_ident()) | |
| async with switch_to_worker_thread(): | |
| print('running in worker thread', threading.get_ident()) | |
| print('end, running in thread', threading.get_ident()) | |
| trio.run(foo) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment