Last active
February 2, 2025 04:52
-
-
Save ktnyt/ec145e6354ec428acd56466af8d9e8e1 to your computer and use it in GitHub Desktop.
Just a simple asyncio standard I/O (stdio) based JSON RPC implementation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import json | |
import logging | |
import os | |
import re | |
from collections.abc import AsyncIterator | |
from contextlib import asynccontextmanager | |
from typing import Any, TypedDict | |
class JsonRpcRequest(TypedDict): | |
jsonrpc: str | |
method: str | |
id: int | |
params: Any | |
class JsonRpcNotification(TypedDict): | |
jsonrpc: str | |
method: str | |
params: Any | |
class JsonRpcResponse(TypedDict): | |
jsonrpc: str | |
id: int | |
result: Any | |
class JsonRpcError(TypedDict): | |
code: int | |
message: str | |
data: Any | |
class JsonRpcErrorResponse(TypedDict): | |
jsonrpc: str | |
id: int | |
error: JsonRpcError | |
class JsonRpcException(Exception): | |
def __init__(self, error: JsonRpcError): | |
super().__init__(f"[{error['code']}] {error['message']}") | |
self.error = error | |
class JsonRpcHeader(TypedDict): | |
content_length: int | |
content_type: str | |
class Interrupted(Exception): | |
pass | |
class JsonRpcStdio: | |
def __init__( | |
self, | |
process: asyncio.subprocess.Process, | |
loop: asyncio.AbstractEventLoop, | |
*, | |
logger: logging.Logger | None = None, | |
) -> None: | |
self._process = process | |
self._loop = loop | |
self._stdout_task = self._loop.create_task(self._handle_stdout()) | |
self._stderr_task = self._loop.create_task(self._handle_stderr()) | |
self._requests: asyncio.Queue[JsonRpcRequest] = asyncio.Queue() | |
self._notifications: asyncio.Queue[JsonRpcNotification] = asyncio.Queue() | |
self.request_id = 0 | |
self.response_futures: dict[int, asyncio.Future[JsonRpcResponse]] = {} | |
self.logger = logger or logging.getLogger(__name__).getChild(self.__class__.__name__) | |
async def shutdown(self) -> None: | |
logger = self.logger.getChild("shutdown") | |
logger.debug("Initiate") | |
logger.debug(" Close stdin") | |
assert self._process.stdin is not None | |
self._process.stdin.close() | |
await asyncio.sleep(0) | |
logger.debug(" Interrupt stdout") | |
assert self._process.stdout is not None | |
self._process.stdout.set_exception(Interrupted()) | |
await asyncio.sleep(0) | |
logger.debug(" Interrupt stderr") | |
assert self._process.stderr is not None | |
self._process.stderr.set_exception(Interrupted()) | |
await asyncio.sleep(0) | |
logger.debug(" Wait for tasks") | |
await asyncio.gather(self._stdout_task, self._stderr_task) | |
logger.debug(" Terminate process") | |
self._process.terminate() | |
try: | |
logger.debug(" Wait for process") | |
await asyncio.wait_for(self._process.wait(), timeout=5) | |
except asyncio.TimeoutError: | |
logger.error(" Process did not terminate: killing") | |
self._process.kill() | |
logger.debug("Complete") | |
@classmethod | |
@asynccontextmanager | |
async def run(cls, cmd: str) -> AsyncIterator["JsonRpcStdio"]: | |
logger = logging.getLogger(__name__).getChild(cls.__name__).getChild("run") | |
logger.debug("Start") | |
logger.debug(f" Command: {cmd}") | |
process = await asyncio.subprocess.create_subprocess_shell( | |
cmd=cmd, | |
stdin=asyncio.subprocess.PIPE, | |
stdout=asyncio.subprocess.PIPE, | |
stderr=asyncio.subprocess.PIPE, | |
env=os.environ, | |
cwd=os.getcwd(), | |
) | |
loop = asyncio.get_event_loop() | |
logger.debug(" Create client") | |
client = JsonRpcStdio(process, loop, logger=logger) | |
logger.debug(" Yield client") | |
yield client | |
logger.debug(" Shutdown client") | |
await client.shutdown() | |
async def _handle_stdout(self) -> None: | |
logger = self.logger.getChild("stdout") | |
logger.debug("Start") | |
try: | |
buffer = b"" | |
while self._process.stdout and not self._process.stdout.at_eof(): | |
logger.debug(" Read line") | |
line = await self._process.stdout.readline() | |
logger.debug(f" Line: {line!r}") | |
buffer += line | |
match = re.match( | |
rb"^Content-Length: (\d+)\r\n(:?Content-Type: (.*)\r\n)?\r\n$", | |
buffer, | |
) | |
if match: | |
logger.debug(" Header matched") | |
content_length = int(match.group(1)) | |
logger.debug(f" Content length: {content_length}") | |
body = await self._process.stdout.read(content_length) | |
obj = json.loads(body.strip()) | |
has_method = "method" in obj | |
has_id = "id" in obj | |
has_error = "error" in obj | |
if has_method and has_id: | |
request = JsonRpcRequest( | |
jsonrpc="2.0", | |
method=obj["method"], | |
id=obj["id"], | |
params=obj["params"], | |
) | |
logger.debug(f" Request: {request}") | |
self._requests.put_nowait(request) | |
elif has_method: | |
notification = JsonRpcNotification( | |
jsonrpc="2.0", | |
method=obj["method"], | |
params=obj["params"], | |
) | |
logger.debug(f" Notification: {notification}") | |
self._notifications.put_nowait(notification) | |
elif has_id and has_error: | |
error_response = JsonRpcErrorResponse( | |
jsonrpc="2.0", | |
id=obj["id"], | |
error=obj["error"], | |
) | |
logger.debug(f" Error: {error_response}") | |
raise JsonRpcException(error_response["error"]) | |
elif has_id: | |
response = JsonRpcResponse( | |
jsonrpc="2.0", | |
id=obj["id"], | |
result=obj["result"], | |
) | |
logger.debug(f" Response: {response}") | |
self.response_futures[response["id"]].set_result(response) | |
buffer = b"" | |
except (BrokenPipeError, ConnectionResetError, Interrupted): | |
logger.debug("Received interrupt") | |
async def _handle_stderr(self) -> None: | |
logger = self.logger.getChild("stderr") | |
logger.debug("Start") | |
try: | |
while self._process.stderr and not self._process.stderr.at_eof(): | |
line = await self._process.stderr.readline() | |
logger.error(line) | |
except (BrokenPipeError, ConnectionResetError, Interrupted): | |
logger.debug("Received interrupt") | |
async def _send(self, payload: str) -> None: | |
logger = self.logger.getChild("send") | |
logger.debug("Start") | |
logger.debug(f" Payload: {payload}") | |
assert self._process.stdin is not None | |
content_length = len(payload.encode("utf-8")) | |
logger.debug(f" Content length: {content_length}") | |
self._process.stdin.write(f"Content-Length: {content_length}\r\n\r\n{payload}".encode("utf-8")) | |
logger.debug(" Drain") | |
await self._process.stdin.drain() | |
async def request(self, method: str, params: Any) -> JsonRpcResponse: | |
logger = self.logger.getChild("request") | |
logger.debug("Start") | |
logger.debug(f" Method: {method}") | |
logger.debug(f" Params: {params}") | |
request = JsonRpcRequest(jsonrpc="2.0", method=method, params=params, id=self.request_id) | |
request_str = json.dumps(request) | |
self.response_futures[self.request_id] = asyncio.Future() | |
logger.debug(" Send request") | |
await self._send(request_str) | |
logger.debug(" Wait for response") | |
response = await self.response_futures[self.request_id] | |
del self.response_futures[self.request_id] | |
logger.debug(f" Return response {response}") | |
return response | |
async def notify(self, method: str, params: Any) -> None: | |
logger = self.logger.getChild("notify") | |
logger.debug("Start") | |
logger.debug(f" Method: {method}") | |
logger.debug(f" Params: {params}") | |
notification = JsonRpcNotification(jsonrpc="2.0", method=method, params=params) | |
notification_str = json.dumps(notification) | |
logger.debug(" Send notification") | |
await self._send(notification_str) | |
async def respond(self, result: Any) -> None: | |
logger = self.logger.getChild("respond") | |
logger.debug("Start") | |
logger.debug(f" Result: {result}") | |
response = JsonRpcResponse(jsonrpc="2.0", id=self.request_id, result=result) | |
response_str = json.dumps(response) | |
logger.debug(" Send response") | |
await self._send(response_str) | |
async def get_request(self, *, timeout: float = 10.0) -> JsonRpcRequest: | |
logger = self.logger.getChild("get_request") | |
logger.debug("Start") | |
logger.debug(f" Timeout: {timeout}") | |
try: | |
request = await asyncio.wait_for(self._requests.get(), timeout=timeout) | |
logger.debug(f" Request: {request}") | |
return request | |
except asyncio.TimeoutError as e: | |
logger.error(f"Timeout: {e}") | |
raise e | |
async def get_notification(self, *, timeout: float = 10.0) -> JsonRpcNotification: | |
logger = self.logger.getChild("get_notification") | |
logger.debug("Start") | |
try: | |
notification = await asyncio.wait_for(self._notifications.get(), timeout=timeout) | |
logger.debug(f" Notification: {notification}") | |
return notification | |
except asyncio.TimeoutError as e: | |
logger.error(f"Timeout: {e}") | |
raise e | |
async def flush_notifications(self) -> list[JsonRpcNotification]: | |
logger = self.logger.getChild("flush_notifications") | |
logger.debug("Start") | |
notifications = [] | |
while not self._notifications.empty(): | |
notifications.append(await self._notifications.get()) | |
logger.debug(f" Notifications: {notifications}") | |
return notifications |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment