Skip to content

Instantly share code, notes, and snippets.

@ktnyt
Last active February 2, 2025 04:52
Show Gist options
  • Save ktnyt/ec145e6354ec428acd56466af8d9e8e1 to your computer and use it in GitHub Desktop.
Save ktnyt/ec145e6354ec428acd56466af8d9e8e1 to your computer and use it in GitHub Desktop.
Just a simple asyncio standard I/O (stdio) based JSON RPC implementation.
import asyncio
import json
import logging
import os
import re
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, TypedDict
class JsonRpcRequest(TypedDict):
jsonrpc: str
method: str
id: int
params: Any
class JsonRpcNotification(TypedDict):
jsonrpc: str
method: str
params: Any
class JsonRpcResponse(TypedDict):
jsonrpc: str
id: int
result: Any
class JsonRpcError(TypedDict):
code: int
message: str
data: Any
class JsonRpcErrorResponse(TypedDict):
jsonrpc: str
id: int
error: JsonRpcError
class JsonRpcException(Exception):
def __init__(self, error: JsonRpcError):
super().__init__(f"[{error['code']}] {error['message']}")
self.error = error
class JsonRpcHeader(TypedDict):
content_length: int
content_type: str
class Interrupted(Exception):
pass
class JsonRpcStdio:
def __init__(
self,
process: asyncio.subprocess.Process,
loop: asyncio.AbstractEventLoop,
*,
logger: logging.Logger | None = None,
) -> None:
self._process = process
self._loop = loop
self._stdout_task = self._loop.create_task(self._handle_stdout())
self._stderr_task = self._loop.create_task(self._handle_stderr())
self._requests: asyncio.Queue[JsonRpcRequest] = asyncio.Queue()
self._notifications: asyncio.Queue[JsonRpcNotification] = asyncio.Queue()
self.request_id = 0
self.response_futures: dict[int, asyncio.Future[JsonRpcResponse]] = {}
self.logger = logger or logging.getLogger(__name__).getChild(self.__class__.__name__)
async def shutdown(self) -> None:
logger = self.logger.getChild("shutdown")
logger.debug("Initiate")
logger.debug(" Close stdin")
assert self._process.stdin is not None
self._process.stdin.close()
await asyncio.sleep(0)
logger.debug(" Interrupt stdout")
assert self._process.stdout is not None
self._process.stdout.set_exception(Interrupted())
await asyncio.sleep(0)
logger.debug(" Interrupt stderr")
assert self._process.stderr is not None
self._process.stderr.set_exception(Interrupted())
await asyncio.sleep(0)
logger.debug(" Wait for tasks")
await asyncio.gather(self._stdout_task, self._stderr_task)
logger.debug(" Terminate process")
self._process.terminate()
try:
logger.debug(" Wait for process")
await asyncio.wait_for(self._process.wait(), timeout=5)
except asyncio.TimeoutError:
logger.error(" Process did not terminate: killing")
self._process.kill()
logger.debug("Complete")
@classmethod
@asynccontextmanager
async def run(cls, cmd: str) -> AsyncIterator["JsonRpcStdio"]:
logger = logging.getLogger(__name__).getChild(cls.__name__).getChild("run")
logger.debug("Start")
logger.debug(f" Command: {cmd}")
process = await asyncio.subprocess.create_subprocess_shell(
cmd=cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=os.environ,
cwd=os.getcwd(),
)
loop = asyncio.get_event_loop()
logger.debug(" Create client")
client = JsonRpcStdio(process, loop, logger=logger)
logger.debug(" Yield client")
yield client
logger.debug(" Shutdown client")
await client.shutdown()
async def _handle_stdout(self) -> None:
logger = self.logger.getChild("stdout")
logger.debug("Start")
try:
buffer = b""
while self._process.stdout and not self._process.stdout.at_eof():
logger.debug(" Read line")
line = await self._process.stdout.readline()
logger.debug(f" Line: {line!r}")
buffer += line
match = re.match(
rb"^Content-Length: (\d+)\r\n(:?Content-Type: (.*)\r\n)?\r\n$",
buffer,
)
if match:
logger.debug(" Header matched")
content_length = int(match.group(1))
logger.debug(f" Content length: {content_length}")
body = await self._process.stdout.read(content_length)
obj = json.loads(body.strip())
has_method = "method" in obj
has_id = "id" in obj
has_error = "error" in obj
if has_method and has_id:
request = JsonRpcRequest(
jsonrpc="2.0",
method=obj["method"],
id=obj["id"],
params=obj["params"],
)
logger.debug(f" Request: {request}")
self._requests.put_nowait(request)
elif has_method:
notification = JsonRpcNotification(
jsonrpc="2.0",
method=obj["method"],
params=obj["params"],
)
logger.debug(f" Notification: {notification}")
self._notifications.put_nowait(notification)
elif has_id and has_error:
error_response = JsonRpcErrorResponse(
jsonrpc="2.0",
id=obj["id"],
error=obj["error"],
)
logger.debug(f" Error: {error_response}")
raise JsonRpcException(error_response["error"])
elif has_id:
response = JsonRpcResponse(
jsonrpc="2.0",
id=obj["id"],
result=obj["result"],
)
logger.debug(f" Response: {response}")
self.response_futures[response["id"]].set_result(response)
buffer = b""
except (BrokenPipeError, ConnectionResetError, Interrupted):
logger.debug("Received interrupt")
async def _handle_stderr(self) -> None:
logger = self.logger.getChild("stderr")
logger.debug("Start")
try:
while self._process.stderr and not self._process.stderr.at_eof():
line = await self._process.stderr.readline()
logger.error(line)
except (BrokenPipeError, ConnectionResetError, Interrupted):
logger.debug("Received interrupt")
async def _send(self, payload: str) -> None:
logger = self.logger.getChild("send")
logger.debug("Start")
logger.debug(f" Payload: {payload}")
assert self._process.stdin is not None
content_length = len(payload.encode("utf-8"))
logger.debug(f" Content length: {content_length}")
self._process.stdin.write(f"Content-Length: {content_length}\r\n\r\n{payload}".encode("utf-8"))
logger.debug(" Drain")
await self._process.stdin.drain()
async def request(self, method: str, params: Any) -> JsonRpcResponse:
logger = self.logger.getChild("request")
logger.debug("Start")
logger.debug(f" Method: {method}")
logger.debug(f" Params: {params}")
request = JsonRpcRequest(jsonrpc="2.0", method=method, params=params, id=self.request_id)
request_str = json.dumps(request)
self.response_futures[self.request_id] = asyncio.Future()
logger.debug(" Send request")
await self._send(request_str)
logger.debug(" Wait for response")
response = await self.response_futures[self.request_id]
del self.response_futures[self.request_id]
logger.debug(f" Return response {response}")
return response
async def notify(self, method: str, params: Any) -> None:
logger = self.logger.getChild("notify")
logger.debug("Start")
logger.debug(f" Method: {method}")
logger.debug(f" Params: {params}")
notification = JsonRpcNotification(jsonrpc="2.0", method=method, params=params)
notification_str = json.dumps(notification)
logger.debug(" Send notification")
await self._send(notification_str)
async def respond(self, result: Any) -> None:
logger = self.logger.getChild("respond")
logger.debug("Start")
logger.debug(f" Result: {result}")
response = JsonRpcResponse(jsonrpc="2.0", id=self.request_id, result=result)
response_str = json.dumps(response)
logger.debug(" Send response")
await self._send(response_str)
async def get_request(self, *, timeout: float = 10.0) -> JsonRpcRequest:
logger = self.logger.getChild("get_request")
logger.debug("Start")
logger.debug(f" Timeout: {timeout}")
try:
request = await asyncio.wait_for(self._requests.get(), timeout=timeout)
logger.debug(f" Request: {request}")
return request
except asyncio.TimeoutError as e:
logger.error(f"Timeout: {e}")
raise e
async def get_notification(self, *, timeout: float = 10.0) -> JsonRpcNotification:
logger = self.logger.getChild("get_notification")
logger.debug("Start")
try:
notification = await asyncio.wait_for(self._notifications.get(), timeout=timeout)
logger.debug(f" Notification: {notification}")
return notification
except asyncio.TimeoutError as e:
logger.error(f"Timeout: {e}")
raise e
async def flush_notifications(self) -> list[JsonRpcNotification]:
logger = self.logger.getChild("flush_notifications")
logger.debug("Start")
notifications = []
while not self._notifications.empty():
notifications.append(await self._notifications.get())
logger.debug(f" Notifications: {notifications}")
return notifications
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment