Skip to content

Instantly share code, notes, and snippets.

@teaishealthy
Created June 2, 2025 17:10
Show Gist options
  • Save teaishealthy/c243e7441548a6bc9281d0cec1e68b59 to your computer and use it in GitHub Desktop.
Save teaishealthy/c243e7441548a6bc9281d0cec1e68b59 to your computer and use it in GitHub Desktop.
teaRPC - portable Redis* RPC
import asyncio
import json
import traceback
import uuid
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Literal,
Protocol,
TypedDict,
TypeVar,
cast,
)
import redis.asyncio as aioredis
type Flexible = str | int | float | bool | None | dict[str, Flexible] | list[Flexible]
R = TypeVar("R")
type Ack = Callable[[], Awaitable[None]]
type CoroT[**P, R] = Callable[P, Awaitable[R]]
type MaybeCoroT[**P, R] = CoroT[P, R] | Callable[P, R]
class RPCCall(TypedDict):
type: Literal["call"]
method: str
args: list[Flexible]
kwargs: dict[str, Flexible]
class RPCError(TypedDict):
type: Literal["error"]
error: dict[str, int | str] # Code and message
class RPCResult(TypedDict):
type: Literal["result"]
result: Flexible
class RPCAcknowledged(TypedDict):
type: Literal["acknowledged"]
class RPCIntrospectRequest(TypedDict):
type: Literal["introspect"]
class RPCIntrospectResponse(TypedDict):
type: Literal["introspect-response"]
methods: list[str]
class RPCMessageCall(TypedDict):
id: str
payload: RPCCall
class RPCMessageIntrospectRequest(TypedDict):
id: str
payload: RPCIntrospectRequest
class RPCMessageResult(TypedDict):
id: str
payload: RPCResult
class RPCMessageError(TypedDict):
id: str
payload: RPCError
class RPCMessageAcknowledged(TypedDict):
id: str
payload: RPCAcknowledged
class RPCMessageIntrospectResponse(TypedDict):
id: str
payload: RPCIntrospectResponse
type RPCMessage = (
RPCMessageCall
| RPCMessageIntrospectRequest
| RPCMessageResult
| RPCMessageError
| RPCMessageAcknowledged
| RPCMessageIntrospectResponse
)
class Callback(Protocol):
def __call__(self, ack: Ack, *args: Any) -> Awaitable[Flexible]: ...
# Client/server message filters for typing purposes
RPCClientMessage = RPCMessageCall | RPCMessageIntrospectRequest
RPCServerMessage = RPCMessageError | RPCMessageResult | RPCMessageAcknowledged | RPCMessageIntrospectResponse
async def maybe_coro(x: Coroutine[R, Any, Any] | R) -> R:
"""If x is a coroutine, await it; otherwise, return x."""
if asyncio.iscoroutine(x):
return await x
return x
class RPCClient:
def __init__(self, redis_url: str = "redis://localhost", *, logging: bool =False, ack_timeout: float = 1.0, total_timeout: float =5.0):
self.redis_url = redis_url
self.ack_timeout = ack_timeout
self.total_timeout = total_timeout
self.logging = logging
self.type_listeners: dict[str, Callable[[RPCMessage], Awaitable[None]]] = {}
self.method_listeners: dict[str, Callable[[RPCMessage], Awaitable[None]]] = {}
self.callback_listeners: dict[str, Callable[[RPCMessage], Awaitable[None]]] = {}
self.type_listeners["introspect"] = self._introspect_request_handler # type: ignore
self.type_listeners["call"] = self._call_request_handler # type: ignore
for t in ["acknowledged", "result", "error", "introspect-response"]:
self.type_listeners[t] = self._response_handler
async def connect(self):
self.producer = await aioredis.from_url(self.redis_url)
self.consumer = await aioredis.from_url(self.redis_url)
pubsub = self.consumer.pubsub()
await pubsub.subscribe("teaRPC")
async def reader():
async for message in pubsub.listen():
if message["type"] != "message":
continue
parsed = cast(RPCClientMessage, json.loads(message["data"]))
handler = self.type_listeners.get(parsed["payload"]["type"])
if handler:
await handler(parsed)
asyncio.create_task(reader())
async def dispose(self):
await self.producer.close()
await self.consumer.close()
async def _response_handler(self, message: RPCMessage):
cb = self.callback_listeners.get(message["id"])
if cb:
await maybe_coro(cb(message))
async def _introspect_request_handler(self, request: RPCMessageIntrospectRequest):
response: RPCMessageIntrospectResponse = {
"id": request["id"],
"payload": {
"type": "introspect-response",
"methods": list(self.method_listeners.keys()),
},
}
await self.producer.publish("teaRPC", json.dumps(response))
async def _call_request_handler(self, request: RPCMessageCall):
method = request["payload"]["method"]
listener = self.method_listeners.get(method)
if listener:
await maybe_coro(listener(request))
async def safe_call(self, method: str, *args: Flexible) -> dict[str, Any]:
id_ = uuid.uuid4().hex
request: RPCMessageCall = {
"id": id_,
"payload": {
"type": "call",
"method": method,
"args": list(args),
"kwargs": {},
},
}
fut = asyncio.get_event_loop().create_future()
acknowledged = False
def timeout_handler():
if not acknowledged:
self.callback_listeners.pop(id_, None)
fut.set_result({
"success": False,
"error": {"message": "acknowledgement timeout", "code": -1},
})
def total_timeout_handler():
self.callback_listeners.pop(id_, None)
if not fut.done():
fut.set_result({
"success": False,
"error": {"message": "total timeout", "code": -2},
})
async def cb(msg: RPCMessage):
nonlocal acknowledged
t = msg["payload"]["type"]
if t == "acknowledged":
acknowledged = True
elif t == "result":
msg = cast(RPCMessageResult, msg)
self.callback_listeners.pop(id_, None)
fut.set_result({"success": True, "result": msg["payload"]["result"]})
elif t == "error":
msg = cast(RPCMessageError, msg)
self.callback_listeners.pop(id_, None)
fut.set_result({"success": False, "error": msg["payload"]["error"]})
self.callback_listeners[id_] = cb
await self.producer.publish("teaRPC", json.dumps(request))
asyncio.get_event_loop().call_later(self.ack_timeout, timeout_handler)
asyncio.get_event_loop().call_later(self.total_timeout, total_timeout_handler)
return await fut
async def call(self, method: str, *args: Any) -> Any:
result = await self.safe_call(method, *args)
if not result["success"]:
raise Exception(result["error"]["message"])
return result["result"]
# decorator
def define(self, method: str) -> Callable[[Callback], Callback]:
def decorator(fn: Callback) -> Callback:
if not asyncio.iscoroutinefunction(fn):
raise ValueError(f"Function {fn.__qualname__} must be a coroutine function")
self._define(method, fn)
return fn
return decorator
def _define(self, method: str, fn: Callback):
async def handler(request: RPCMessage):
async def ack():
await self.producer.publish("teaRPC", json.dumps({
"id": request["id"],
"payload": {"type": "acknowledged"},
}))
try:
result = await fn(ack, *request["payload"].get("args", []))
response: RPCMessageResult = {
"id": request["id"],
"payload": {"type": "result", "result": result},
}
await self.producer.publish("teaRPC", json.dumps(response))
except Exception:
error: RPCMessageError = {
"id": request["id"],
"payload": {
"type": "error",
"error": {"message": traceback.format_exc(), "code": -1},
},
}
await self.producer.publish("teaRPC", json.dumps(error))
self.method_listeners[method] = handler
async def introspect(self) -> list[str]:
id_ = uuid.uuid4().hex
methods: set[str] = set()
fut = asyncio.get_event_loop().create_future()
async def cb(msg: RPCMessage):
if msg["payload"]["type"] == "introspect-response":
for m in msg["payload"]["methods"]:
methods.add(m)
self.callback_listeners[id_] = cb
asyncio.get_event_loop().call_later(
self.ack_timeout,
lambda: (fut.set_result(list(methods)), self.callback_listeners.pop(id_, None)),
)
await self.producer.publish("teaRPC", json.dumps({
"id": id_,
"payload": {"type": "introspect"},
}))
return await fut
import { Redis } from "ioredis";
/**
* A type representing a flexible data structure that can be one of the following:
* - `string`
* - `number`
* - `boolean`
* - `null`
* - An object with string keys and values of type `Flexible`
* - An array of `Flexible` elements
*
* Conveniently, this type is compatible with JSON data.
*/
export type Flexible =
| string
| number
| boolean
| null
| {
[key: string]: Flexible;
}
| Flexible[];
/**
* Represents a remote procedure call (RPC) made by a client to a server.
*/
export interface RPCCall {
type: "call";
method: string;
args: Flexible[];
kwargs: { [key: string]: Flexible };
}
/**
* Represents an error that occurred during an RPC call.
*/
export interface RPCError {
type: "error";
error: {
code: number;
message: string;
};
}
/**
* Represents the result of a successful RPC call.
*/
export interface RPCResult {
type: "result";
result: Flexible;
}
/**
* Represents an RPC acknowledgment message.
*
* This message type indicates that the callee has received the message
* and is currently processing it.
*/
export interface RPCAcknowledged {
type: "acknowledged";
}
/**
* Represents a request to introspect the available RPC methods.
*
* This message type is used by clients to query servers for a list of available methods.
*/
export interface RPCIntrospectRequest {
type: "introspect";
}
/**
* Represents the response to an RPC introspection request.
*/
export interface RPCIntrospectResponse {
type: "introspect-response";
methods: string[];
}
export interface RPCMessage<T> {
id: string;
payload: T;
}
export type RPCClientMessage = RPCMessage<RPCCall | RPCIntrospectRequest>;
export type RPCServerMessage = RPCMessage<
RPCError | RPCResult | RPCAcknowledged | RPCIntrospectResponse
>;
/**
* Represents the result of an RPC call.
*
* This type can either be a success or a failure.
*/
export type Result =
| {
success: false;
error: RPCError["error"];
}
| {
success: true;
result: any;
};
export class RPCClient {
private producer: Redis;
private consumer: Redis;
private ackTimeout = 1000;
private totalTimeout = 5000;
private logging: boolean;
// listeners bound to a type of message
private typeListeners: Map<string, (message: any) => void> = new Map();
// listeners bound to a method (for type 'call')
private methodListeners: Map<string, (message: any) => void> = new Map();
// listeners bound to an id (for responses)
private callbackListeners: Map<string, (message: any) => void> = new Map();
/**
* Constructs a new instance of the class.
*
* If no client is provided, a new Redis client with default options will be created.
*/
constructor(
{
client,
config: { logging, ackTimeout, totalTimeout },
}: {
client?: Redis;
config: {
logging?: boolean;
ackTimeout?: number;
totalTimeout?: number;
};
} = {
config: {},
},
) {
this.ackTimeout = ackTimeout || this.ackTimeout;
this.totalTimeout = totalTimeout || this.totalTimeout;
this.producer = client || new Redis();
this.consumer = this.producer.duplicate();
this.logging = logging || false;
if (logging) {
this.attachLogger(">", this.producer);
this.attachLogger("<", this.consumer);
}
this.typeListeners.set(
"introspect",
this.introspectRequestHandler.bind(this),
);
this.typeListeners.set("call", this.callRequestHandler.bind(this));
this.typeListeners.set("acknowledged", this.responseHandler.bind(this));
this.typeListeners.set("result", this.responseHandler.bind(this));
this.typeListeners.set("error", this.responseHandler.bind(this));
this.typeListeners.set(
"introspect-response",
this.responseHandler.bind(this),
);
}
private responseHandler(message: RPCServerMessage) {
const cb = this.callbackListeners.get(message.id);
if (cb) {
cb(message);
}
}
private async introspectRequestHandler(
request: RPCMessage<RPCIntrospectRequest>,
) {
const response: RPCMessage<RPCIntrospectResponse> = {
id: request.id,
payload: {
type: "introspect-response",
methods: Array.from(this.methodListeners.keys()),
},
};
await this.producer.publish("teaRPC", JSON.stringify(response));
}
private async callRequestHandler(request: RPCMessage<RPCCall>) {
const method = request.payload.method;
const listener = this.methodListeners.get(method);
if (listener) {
await listener(request);
}
// we don't know if there is a another SP that can handle this request
}
private attachLogger(direction: string, redis: Redis) {
redis.on("error", (error) => {
console.error(`${direction} ${error}`);
});
redis.on("connect", () => {
console.log(`${direction} connected`);
});
redis.on("close", () => {
console.log(`${direction} closed`);
});
redis.on("message", (channel, message) => {
if (channel === "teaRPC") console.log(`${direction} ${message}`);
});
}
/**
* Establishes a connection for the consumer and producer.
*
* This method subscribes the consumer to the "teaRPC" channe
* and waits for both the producer and consumer to be ready before resolving.
*/
async connect(): Promise<void> {
this.consumer.on("message", (channel, message) => {
const parsedMessage: RPCServerMessage | RPCClientMessage =
JSON.parse(message);
const listener = this.typeListeners.get(parsedMessage.payload.type);
if (listener) {
listener(parsedMessage);
}
});
this.consumer.subscribe("teaRPC");
await Promise.all([
new Promise((resolve) => {
this.producer.once("ready", resolve);
}),
new Promise((resolve) => {
this.consumer.once("ready", resolve);
}),
]);
}
/**
* Disposes of the resources used by the client.
*
* This method will asynchronously quit both the producer and consumer,
* ensuring that all resources are properly released.
*/
async dispose() {
await this.producer.quit();
await this.consumer.quit();
}
/**
* Makes a safe RPC call to the specified method with the provided arguments.
*
* The promise will resolve with an error if:
* - The acknowledgement timeout is reached before the call is acknowledged or completed.
* - The total timeout is reached before a result is received.
*
* The {@link Result} object will contain the following properties:
* - `success`: A boolean indicating whether the call was successful.
* - `result`: The result of the call if successful.
* - `error`: An error object containing the error message and code if the call was not successful.
*/
async safeCall(method: string, ...args: any[]): Promise<Result> {
const id = Math.random().toString(36).slice(2);
const request: RPCClientMessage = {
id: id,
payload: {
type: "call",
method,
args: args,
kwargs: {},
},
};
const promise = new Promise<Result>((resolve) => {
let acknowledged = false;
setTimeout(() => {
if (!acknowledged) {
resolve({
success: false,
error: {
message: "acknowledgement timeout",
code: -1,
},
});
}
}, this.ackTimeout);
const timeout = setTimeout(() => {
resolve({
success: false,
error: {
message: "total timeout",
code: -2,
},
});
}, this.totalTimeout);
const cb = async (rpcServerMessage: RPCServerMessage) => {
if (request.id === rpcServerMessage.id) {
if (rpcServerMessage.payload.type === "acknowledged") {
acknowledged = true;
} else if (rpcServerMessage.payload.type === "error") {
clearTimeout(timeout);
this.callbackListeners.delete(id);
resolve({
success: false,
error: rpcServerMessage.payload.error,
});
} else if (rpcServerMessage.payload.type === "result") {
clearTimeout(timeout);
this.callbackListeners.delete(id);
resolve({
success: true,
result: rpcServerMessage.payload.result,
});
}
}
};
this.callbackListeners.set(id, cb);
});
this.producer.publish("teaRPC", JSON.stringify(request));
return await promise;
}
/**
* Makes an asynchronous call to a specified method with the provided arguments.
* Like {@link safeCall}, but instead throws an error if the call is not successful.
*
* @throws An error if the call is not successful.
*/
async call(method: string, ...args: any[]): Promise<any> {
const result = await this.safeCall(method, ...args);
if (!result.success) {
throw new Error(result.error.message);
}
return result.result;
}
/**
* Defines a method that can be called remotely via RPC.
*
* The provided function `fn` is called with an object containing:
* - `ack`: A function that should be called to acknowledge the request.
* - `args`: An array of arguments passed to the method.
*
* The function should return a result or throw an error. The result or error will be sent back to the caller.
*
* If the function throws an error, it will be caught and an error response will be sent back to the caller.
* The error will be converted to a string as best as possible.
*/
async define(
method: string,
fn: (args: { ack: () => Promise<void>; args: any[] }) => any,
): Promise<void> {
this.methodListeners.set(method, async (request: RPCMessage<RPCCall>) => {
// if the request doesn't get acknowledged in ackTimeout ms, it's considered lost
const ack = async () => {
const response = {
id: request.id,
payload: {
type: "acknowledged",
},
};
await this.producer.publish("teaRPC", JSON.stringify(response));
};
let response;
try {
const result = await fn({
ack,
args: request.payload.args,
});
response = {
id: request.id,
payload: {
type: "result",
result,
},
};
} catch (error) {
// because for some reason in js you can throw anything
// we will try our best to convert it to a string
let message = "unknown error";
if (error instanceof Error) {
message = error.message;
} else if (typeof error === "string") {
message = error;
} else if (typeof error === "object") {
message = JSON.stringify(error);
}
response = {
id: request.id,
payload: {
type: "error",
error: {
message: message,
},
},
};
}
await this.producer.publish("teaRPC", JSON.stringify(response));
});
}
/**
* Sends an introspection request to the RPC server and returns a list of available methods.
*
* @returns {Promise<string[]>} A promise that resolves to an array of method names available on the RPC server.
*
* @throws {Error} If the introspection request fails or times out.
*
* @example
* const methods = await client.introspect();
* console.log(methods); // ['method1', 'method2', ...]
*/
async introspect(): Promise<string[]> {
const id = Math.random().toString(36).slice(2);
const promise = new Promise<string[]>((resolve) => {
const methods = new Set<string>();
const cb = async (rpcServerMessage: RPCServerMessage) => {
if (rpcServerMessage.payload.type === "introspect-response") {
for (const method of rpcServerMessage.payload.methods) {
methods.add(method);
}
}
};
setTimeout(() => {
resolve(Array.from(methods));
this.callbackListeners.delete(id);
}, this.ackTimeout);
this.callbackListeners.set(id, cb);
});
const request: RPCMessage<RPCIntrospectRequest> = {
id: id,
payload: {
type: "introspect",
},
};
await this.producer.publish("teaRPC", JSON.stringify(request));
return await promise;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment