Skip to content

Instantly share code, notes, and snippets.

@KubaO
Last active October 7, 2024 02:35
Show Gist options
  • Save KubaO/7aa1f571c60e4bccdf683e339dbff9f4 to your computer and use it in GitHub Desktop.
Save KubaO/7aa1f571c60e4bccdf683e339dbff9f4 to your computer and use it in GitHub Desktop.
Python 3.12 asyncio ProactorEventLoop integration with Windows message pump / event loop and tkinter event loop
"""
An event loop policy that integrates Windows message loop with IOCP-based async processing.
Usage:
asyncio.set_event_loop_policy(winmsgasyncio.MsgProactorEventLoopPolicy())
To integerate with tkinter, use tk_sync_mainloop(root) or await tk_async_mainloop(root).
"""
# SPDX-License-Identifier: MIT
__all__ = (
'MsgProactorEventLoopPolicy',
'MsgProactorEventLoop',
'MsgIocpProactor',
'tk_sync_mainloop',
'tk_async_mainloop',
)
import asyncio
import asyncio.events
import ctypes
import logging
import math
# import sys
import traceback
from ctypes.wintypes import HANDLE, MSG, BOOL, LPMSG, HWND, UINT, LPARAM, DWORD, LPHANDLE
from typing import Callable
from _overlapped import INVALID_HANDLE_VALUE, GetQueuedCompletionStatus
from _winapi import INFINITE, CloseHandle
WAIT_OBJECT_0 = 0
WAIT_IO_COMPLETION = 192
WAIT_TIMEOUT = 258
WAIT_FAILED = 4294967295
MWMO_ALERTABLE = 2
MWMO_INPUTAVAILABLE = 4
PM_REMOVE = 1
QS_ALLINPUT = 1279
WS_OVERLAPPED = 0
WM_QUIT = 18
LRESULT = LPARAM
def _winfn(lib, name, *args):
fun = getattr(lib, name)
globals()[name] = ctypes.WINFUNCTYPE(*args)(fun)
_user32 = ctypes.WinDLL("USER32")
_winfn(_user32, 'PeekMessageW', BOOL, LPMSG, HWND, UINT, UINT, UINT)
_winfn(_user32, 'TranslateMessage', BOOL, LPMSG)
_winfn(_user32, 'DispatchMessageW', LRESULT, LPMSG)
_winfn(_user32, 'MsgWaitForMultipleObjectsEx', DWORD,
DWORD, LPHANDLE, DWORD, DWORD, DWORD)
logger = logging.getLogger(__name__)
def debug_stack(limit=None):
stack = traceback.extract_stack(limit=limit)
for item in traceback.StackSummary.from_list(stack).format():
logger.debug(item)
class MsgIocpProactor(asyncio.IocpProactor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._msg = MSG()
self._handles = (HANDLE * 1)()
self._handles[0] = self._iocp
self._dumpstack: bool = True
self._counter = 0
self.pre_message_hook: Callable[[MSG], bool] | None = None
self.post_message_hook: Callable[[MSG], None] | None = None
def _poll(self, timeout=None):
logger.debug("_poll enter")
if self._dumpstack:
debug_stack(limit=6)
self._dumpstack = False
if timeout is None:
ms = INFINITE
elif timeout < 0:
raise ValueError("negative timeout")
else:
# MsgWaitForMultipleObjectsEx() has a resolution of 1 millisecond,
# round away from zero to wait *at least* timeout seconds.
ms = math.ceil(timeout * 1e3)
if ms >= INFINITE:
raise ValueError("timeout too big")
while True:
status = None
rv = MsgWaitForMultipleObjectsEx(
1, self._handles, ms, QS_ALLINPUT, MWMO_ALERTABLE | MWMO_INPUTAVAILABLE)
logger.debug(f"* {self._counter}")
self._counter += 1
if rv == WAIT_OBJECT_0 or rv == WAIT_IO_COMPLETION:
logger.debug(f"MsgWaitForMultipleObjectsEx IOC {rv}")
status = GetQueuedCompletionStatus(self._iocp, ms)
elif rv == WAIT_OBJECT_0 + 1:
while rc := PeekMessageW(self._msg, 0, 0, 0, PM_REMOVE):
if not self.pre_message_hook or not self.pre_message_hook(
self._msg):
logger.debug(f"MSG=0x{self._msg.message:x}")
TranslateMessage(self._msg)
DispatchMessageW(self._msg)
if self.post_message_hook:
self.post_message_hook(self._msg)
if self._msg.message == WM_QUIT:
logger.debug("WM_QUIT!")
print(id(asyncio.get_event_loop()))
asyncio.get_event_loop().stop()
elif rv == WAIT_TIMEOUT:
logger.debug("MsgWaitForMultipleObjectsEx WAIT_TIMEOUT")
elif rv == WAIT_FAILED:
logger.debug("MsgWaitForMultipleObjectsEx WAIT_FAILED")
else:
logger.debug(f"MsgWaitForMultipleObjectsEx returned {rv}")
if status is None:
break
err, transferred, key, address = status
try:
f, ov, obj, callback = self._cache.pop(address)
except KeyError:
if self._loop.get_debug():
self._loop.call_exception_handler({
'message': ('GetQueuedCompletionStatus() returned an '
'unexpected event'),
'status': ('err=%s transferred=%s key=%#x address=%#x'
% (err, transferred, key, address)),
})
# key is either zero, or it is used to return a pipe
# handle which should be closed to avoid a leak.
if key not in (0, INVALID_HANDLE_VALUE):
CloseHandle(key)
continue
if obj in self._stopped_serving:
f.cancel()
# Don't call the callback if _register() already read the result or
# if the overlapped has been cancelled
elif not f.done():
try:
value = callback(transferred, key, ov)
except OSError as e:
f.set_exception(e)
self._results.append(f)
else:
f.set_result(value)
self._results.append(f)
finally:
f = None
# Remove unregistered futures
for ov in self._unregistered:
self._cache.pop(ov.address, None)
self._unregistered.clear()
logger.debug("_poll exit")
class MsgProactorEventLoop(asyncio.ProactorEventLoop):
def __init__(self, proactor=None):
if proactor is None:
proactor = MsgIocpProactor()
super().__init__(proactor)
self.Tk_GetNumMainWindows = None
self._tk_initialized = False
def _tk_init(self):
if not self._tk_initialized:
tkdll = ctypes.WinDLL("tk86t.dll")
self.Tk_GetNumMainWindows = tkdll.Tk_GetNumMainWindows
self._tk_initialized = True
def tk_sync_mainloop(self, root):
logger.info("tk_sync_mainloop")
self._tk_init()
def message_hook(msg) -> bool:
while root.tk.dooneevent(_tkinter.DONT_WAIT):
pass
if self.Tk_GetNumMainWindows() == 0:
self.stop()
return False
post_hook = self._proactor.post_message_hook
self._proactor.post_message_hook = message_hook
self.run_forever()
self._proactor.post_message_hook = post_hook
async def tk_async_mainloop(self, root):
logger.info("tk_async_mainloop")
self._tk_init()
future = self.create_future()
def post_message_hook(msg) -> bool:
while root.tk.dooneevent(_tkinter.DONT_WAIT):
pass
if self.Tk_GetNumMainWindows() == 0:
future.set_result(True)
return True
post_hook = self._proactor.post_message_hook
def release_hook(msg):
self._proactor.post_message_hook = post_hook
self._proactor.post_message_hook = post_message_hook
future.add_done_callback(release_hook)
return await future
class MsgProactorEventLoopPolicy(
asyncio.events.BaseDefaultEventLoopPolicy):
_loop_factory = MsgProactorEventLoop
def tk_sync_mainloop(root, close=True):
loop = asyncio.new_event_loop()
loop.tk_sync_mainloop(root)
if close:
loop.close()
async def tk_async_mainloop(root):
await asyncio.get_running_loop().tk_async_mainloop(root)
if __name__ == '__main__':
from tkinter import Tk, Label, Button, _tkinter
import logging
import random
logging.basicConfig(level=logging.INFO)
asyncio.set_event_loop_policy(MsgProactorEventLoopPolicy())
root = Tk()
text = "This is Tcl/Tk %s" % root.globalgetvar('tk_patchLevel')
text += "\nThis should be a cedilla: \xe7"
label = Label(root, text=text)
label.pack()
test = Button(root, text="Click me!",
command=lambda root=root: root.test.configure(
text="[%s]" % root.test['text']))
test.pack()
root.test = test
quit = Button(root, text="QUIT", command=root.destroy)
quit.pack()
root.iconify()
root.update()
root.deiconify()
if bool(random.getrandbits(1)):
tk_sync_mainloop(root)
else:
asyncio.run(tk_async_mainloop(root))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment