Last active
May 28, 2023 08:56
-
-
Save gunhaxxor/1e8d2593697e174e418d128e2319e4e4 to your computer and use it in GitHub Desktop.
TRPC adapter for uWebsockets.js. The ws-adapter is a modified version of the original ws-adapter in the TRPC repo. Also added some utility functions to make it easier to setup subscriptions that don't send data to the triggering client.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { | |
createTRPCProxyClient, | |
createWSClient, | |
wsLink, | |
} from '@trpc/client'; | |
import { Unsubscribable } from '@trpc/server/observable'; | |
import AbortController from 'abort-controller'; | |
import fetch from 'node-fetch'; | |
import ws from 'ws'; | |
import type { AppRouter } from './server'; | |
// polyfill fetch & websocket | |
const globalAny = global as any; | |
globalAny.AbortController = AbortController; | |
globalAny.fetch = fetch; | |
globalAny.WebSocket = ws; | |
const randomInt = Math.trunc(Math.random()*1000) | |
const wsClient = createWSClient({ | |
url: `ws://localhost:2022?user-${randomInt}`, | |
}); | |
const trpc = createTRPCProxyClient<AppRouter>({ | |
links: [ | |
wsLink({ | |
client: wsClient, | |
}) | |
], | |
}); | |
async function main() { | |
const myToken = await trpc.room.getMyToken.query(); | |
console.log('MY TOKEN IS:', myToken); | |
const subShouldNotTrigger = await new Promise<Unsubscribable>(resolve => { | |
const sub = trpc.room.onRoomUpdate.subscribe({excludeSelf: true}, { | |
onData: (data) => console.log('received subscribed roomState:', data), | |
onStarted() { | |
resolve(sub); | |
}, | |
}); | |
}) | |
const createdRoom = await trpc.room.createAndJoinRoom.mutate('coolRoom'); | |
console.log('created room: ', createdRoom); | |
// subShouldNotTrigger.unsubscribe(); | |
await trpc.room.updateMyPosition.mutate({ | |
x: 1, | |
y: 2, | |
z: 3 | |
}); | |
const subShouldTrigger = await new Promise<Unsubscribable>(resolve => { | |
const sub = trpc.room.onRoomUpdate.subscribe({excludeSelf: false}, { | |
onData: (data) => console.log('received subscribed roomState:', data), | |
onStarted() { | |
resolve(sub); | |
}, | |
}); | |
}) | |
const createdRoom2 = await trpc.room.createAndJoinRoom.mutate('boringRoom'); | |
console.log('created room 2:', createdRoom2); | |
subShouldTrigger.unsubscribe(); | |
wsClient.close(); | |
} | |
main(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { initTRPC, TRPCError } from '@trpc/server'; | |
import { applyWSHandler } from './ws-adapter'; | |
import { z } from 'zod'; | |
import uWebSockets from 'uWebSockets.js'; | |
import { TypedEmitter } from "tiny-typed-emitter"; | |
import { attachFilteredEmitter, FilteredEvents } from "./trpc-utils"; | |
type UData = { | |
token: string | |
} | |
type ClientEvents = FilteredEvents<{ | |
'roomState': (room: RoomStateMessage) => void; | |
'kickedFromRoom': (roomId: string) => void; | |
}, UData['token']>; | |
const t = initTRPC.context<UData>().create(); | |
const publicProcedure = t.procedure; | |
const router = t.router; | |
const clientInfo = z.object({ | |
id: z.string(), | |
role: z.union([z.literal('admin'), z.literal('user'), z.literal('guest')]), | |
position: z.optional(z.tuple([z.number(), z.number(), z.number()])), | |
currentRoom: z.optional(z.string()), | |
clientEmitter: z.custom<TypedEmitter<ClientEvents>>(d => d instanceof TypedEmitter) | |
}) | |
type ClientInfo = z.infer<typeof clientInfo> | |
const clientInfoMessage = clientInfo.pick({ | |
id: true, | |
role: true, | |
position: true, | |
currentRoom: true | |
}) | |
type ClientInfoMessage = z.infer<typeof clientInfoMessage> | |
function getClientInfoMessage(clientInfo: ClientInfo): ClientInfoMessage { | |
return clientInfoMessage.parse(clientInfo); | |
} | |
const roomState = z.object({ | |
roomId: z.string(), | |
clients: z.object({}).catchall(clientInfo) | |
}) | |
type RoomState = z.infer<typeof roomState> | |
const roomStateMessage = roomState.extend({ | |
clients: z.object({}).catchall(clientInfoMessage) | |
}); | |
type RoomStateMessage = z.infer<typeof roomStateMessage> | |
function getRoomStateMessage(roomState: RoomState): RoomStateMessage { | |
return roomStateMessage.parse(roomState); | |
} | |
const connectedClients: Map<string, ClientInfo> = new Map(); | |
const rooms: Map<string, RoomState> = new Map(); | |
function addUserToRoom(userId: string, roomId: string){ | |
const room = rooms.get(roomId) | |
if(!room) | |
throw Error('no room with that id found') | |
const client = connectedClients.get(userId); | |
if(!client) | |
throw Error('no client with that id found') | |
room.clients[userId] = client; | |
client.currentRoom = room.roomId; | |
return room; | |
} | |
function broadcastRoomState(room: RoomState, triggeringClient: string){ | |
const roomMessage = getRoomStateMessage(room); | |
// console.log('broadcasting room:', roomMessage); | |
for(const client of Object.values(room.clients)){ | |
if(!client.clientEmitter) | |
continue; | |
client.clientEmitter.emit('roomState', roomMessage, client.id); | |
// client.clientEmitter.emit('testEvent', client, client.id, client.id); | |
} | |
} | |
function getMe(userId: string){ | |
const me = connectedClients.get(userId); | |
if(!me) | |
throw new TRPCError({code: 'NOT_FOUND', message: 'didnt self among backedn clients'}); | |
return me; | |
} | |
const roomRouter = router({ | |
getMyToken: publicProcedure | |
.query(({ ctx }) => { | |
return ctx.token | |
}), | |
updateMyPosition: publicProcedure | |
.input(z.object({ | |
x: z.number(), | |
y: z.number(), | |
z: z.number(), | |
})) | |
.mutation(({input, ctx}) => { | |
const me = getMe(ctx.token); | |
me.position = [input.x, input.y, input.z]; | |
}), | |
getMyRoom: publicProcedure | |
.query(({ ctx }) => { | |
const me = getMe(ctx.token); | |
const err = new TRPCError({code: 'NOT_FOUND', message: 'you are not in a room'}); | |
if(!me.currentRoom) | |
throw err; | |
const room = rooms.get(me.currentRoom); | |
if(!room) | |
throw err; | |
return getRoomStateMessage(room); | |
}), | |
createAndJoinRoom: publicProcedure | |
.input(z.string()) | |
.mutation(({input, ctx})=> { | |
rooms.set(input, { | |
roomId: input, | |
clients: {} | |
}); | |
const room = addUserToRoom(ctx.token, input); | |
broadcastRoomState(room, ctx.token) | |
return getRoomStateMessage(room); | |
}), | |
joinRoom: publicProcedure | |
.input(z.string()) | |
.mutation(({input: roomName, ctx})=> { | |
const room = addUserToRoom(ctx.token, roomName); | |
broadcastRoomState(room, ctx.token); | |
return getRoomStateMessage(room); | |
}), | |
onRoomUpdate: publicProcedure.input(z.object({excludeSelf: z.boolean()})).subscription(({input: {excludeSelf}, ctx}) => { | |
// console.log('subscription request received:', ctx); | |
const me = getMe(ctx.token); | |
const filter = excludeSelf? me.id : undefined; | |
return attachFilteredEmitter(me.clientEmitter, 'roomState', filter); | |
}), | |
}); | |
// Merge routers together | |
const appRouter = router({ | |
room: roomRouter, | |
}); | |
export type AppRouter = typeof appRouter; | |
// ws server | |
const { onSocketMessage, onSocketOpen, onSocketClose} = applyWSHandler<AppRouter, UData>({ | |
router: appRouter, | |
}); | |
const app = uWebSockets.App().ws<UData>('/*', { | |
upgrade: (res, req, ctx) => { | |
// console.log('upgrade request received:', req); | |
// console.log('ws ctx:', ctx); | |
const token = req.getQuery(); | |
res.upgrade<UData>( | |
{ | |
token | |
}, | |
/* Spell these correctly */ | |
req.getHeader('sec-websocket-key'), | |
req.getHeader('sec-websocket-protocol'), | |
req.getHeader('sec-websocket-extensions'), | |
ctx | |
); | |
}, | |
open: (ws) => { | |
const uData = Object.assign({}, ws.getUserData()); | |
onSocketOpen(ws, uData); | |
const e: ClientInfo['clientEmitter'] = new TypedEmitter(); | |
const newClient: ClientInfo = { | |
id: uData.token, | |
role: 'user', | |
clientEmitter: e, | |
} | |
// newClient.clientEmitter.on('roomState', (roomState, filter) => console.log(newClient.id, ': my emitter was triggered:', filter)); | |
connectedClients.set(uData.token, newClient); | |
}, | |
message: (ws, msg) => { | |
// console.log('message received: ', msg); | |
const msgStr = Buffer.from(msg).toString(); | |
// console.log('stringified msg:', msgStr); | |
onSocketMessage(ws, msgStr); | |
}, | |
close: (ws, code, msg) => { | |
const msgStr = Buffer.from(msg).toString(); | |
// console.log('ws was closed:', code, msgStr); | |
onSocketClose(ws, msgStr); | |
connectedClients.delete(ws.getUserData().token); | |
} | |
}) | |
app.listen(2022, (ls) => { | |
console.log('listening on port 2022', ls); | |
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import {observable} from '@trpc/server/observable'; | |
import { ListenerSignature, TypedEmitter } from 'tiny-typed-emitter'; | |
//Internal utility types | |
type EmitterCallback<E extends ListenerSignature<E>, K extends keyof E> = E[K]; | |
type EventArgument<E extends ListenerSignature<E>, K extends keyof E> = Parameters<EmitterCallback<E,K>>[0] | |
type AddFilterParam<FuncType extends (...args: any) => any, FilterType> = (...args: [...parameters: Parameters<FuncType>, filter: FilterType]) => ReturnType<FuncType>; | |
type AddFilterToEvents<IEvents extends ListenerSignature<IEvents>, FilterType> = { | |
[K in keyof IEvents]: AddFilterParam<IEvents[K], FilterType> | |
} | |
export type FilteredEvents<E extends {[K in keyof E]: (p: any) => void}, FilterType> = AddFilterToEvents<E, FilterType> | |
// export function attachEmitter<E extends ListenerSignature<E>, K extends keyof E>(emitter: TypedEmitter<E>, event: K){ | |
// return observable<EventArgument<E, typeof event>>(emit => { | |
// const onEvent = (data: EventArgument<E,typeof event>): void => { | |
// console.log('emitter triggered'); | |
// emit.next(data); | |
// } | |
// emitter.on(event, onEvent as E[typeof event]); | |
// return () => { | |
// emitter.off(event, onEvent as E[typeof event]); | |
// } | |
// }) | |
// } | |
export function attachFilteredEmitter<E extends ListenerSignature<E>, K extends keyof E, FilterType>(emitter: TypedEmitter<E>, event: K, filter: FilterType){ | |
return observable<EventArgument<E, typeof event>>(emit => { | |
const onEvent = (data: EventArgument<E,typeof event>, triggerId: FilterType): void => { | |
if(triggerId === filter){ | |
// console.log('skipping because emitter is filtered'); | |
return | |
} | |
// console.log('emitter triggered'); | |
emit.next(data); | |
} | |
emitter.on(event, onEvent as E[typeof event]); | |
return () => { | |
emitter.off(event, onEvent as E[typeof event]); | |
} | |
}) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { | |
AnyRouter, | |
ProcedureType, | |
callProcedure, | |
TRPCError, | |
} from '@trpc/server'; | |
import { OnErrorFunction } from '@trpc/server/dist/internals/types'; | |
import { Unsubscribable, isObservable } from '@trpc/server/observable'; | |
import { | |
JSONRPC2, | |
TRPCClientOutgoingMessage, | |
TRPCReconnectNotification, | |
TRPCResponse, | |
TRPCResponseMessage, | |
} from '@trpc/server/rpc'; | |
import { CombinedDataTransformer } from '@trpc/server/dist/transformer'; | |
function parseMessage( | |
obj: unknown, | |
transformer: CombinedDataTransformer, | |
): TRPCClientOutgoingMessage { | |
assertIsObject(obj); | |
const { method, params, id, jsonrpc } = obj; | |
assertIsRequestId(id); | |
assertIsJSONRPC2OrUndefined(jsonrpc); | |
if (method === 'subscription.stop') { | |
return { | |
id, | |
jsonrpc, | |
method, | |
}; | |
} | |
assertIsProcedureType(method); | |
assertIsObject(params); | |
const { input: rawInput, path } = params; | |
assertIsString(path); | |
const input = transformer.input.deserialize(rawInput); | |
return { | |
id, | |
jsonrpc, | |
method, | |
params: { | |
input, | |
path, | |
}, | |
}; | |
} | |
type BasicSendFunction = (message: string) => void; | |
interface MinimalWSInterface { | |
send: BasicSendFunction | |
} | |
type OnErrorWithoutRequest = (opts: Omit<Parameters<OnErrorFunction<AnyRouter, undefined>>[0], 'req'>) => void | |
export interface WSHandlerOptions<TRouter extends AnyRouter> { | |
onError?: OnErrorWithoutRequest | |
router: TRouter | |
} | |
export function applyWSHandler<TRouter extends AnyRouter, Ctx>(opts: WSHandlerOptions<TRouter>) { | |
const { router, } = opts; | |
const { transformer } = router._def._config; | |
const websockets: Map<MinimalWSInterface, { subscriptions: Map<number | string, Unsubscribable>, ctx: Ctx }> = new Map(); | |
const onSocketOpen = (ws: MinimalWSInterface, ctx: Ctx) => { | |
// console.log('ws-adapter: ws opened'); | |
websockets.set(ws, { subscriptions: new Map(), ctx }); | |
} | |
const onSocketMessage = async (ws: MinimalWSInterface, stringifiedMessage: string) => { | |
// console.log('ws-adapter: msg received'); | |
// console.dir(stringifiedMessage); | |
try { | |
const msgJSON: unknown = JSON.parse(stringifiedMessage); | |
const msgs: unknown[] = Array.isArray(msgJSON) ? msgJSON : [msgJSON]; | |
const promises = msgs | |
.map((raw) => parseMessage(raw, transformer)) | |
.map(msg => handleRequest(ws, msg, ws.send)); | |
await Promise.all(promises); | |
} catch (cause) { | |
const error = new TRPCError({ | |
code: 'PARSE_ERROR', | |
cause: getCauseFromUnknown(cause), | |
}); | |
respond(ws, { | |
id: null, | |
error: router.getErrorShape({ | |
error, | |
type: 'unknown', | |
path: undefined, | |
input: undefined, | |
ctx: undefined, | |
}), | |
}); | |
} | |
} | |
const onSocketClose = (ws: MinimalWSInterface, msg: string) => { | |
// console.log('ws-adapter: ws closed'); | |
const wsData = websockets.get(ws); | |
if(!wsData) | |
return | |
const { subscriptions } = wsData | |
for(const sub of subscriptions.values()){ | |
sub.unsubscribe(); | |
} | |
subscriptions.clear(); | |
websockets.delete(ws); | |
} | |
function respond(ws: MinimalWSInterface, untransformedJSON: TRPCResponseMessage) { | |
const response = JSON.stringify(transformTRPCResponse(router, untransformedJSON)) | |
// console.log('created response:', response); | |
// return response | |
ws.send(response); | |
}; | |
function stopSubscription( | |
ws: MinimalWSInterface, | |
subscription: Unsubscribable, | |
{ id, jsonrpc }: { id: JSONRPC2.RequestId } & JSONRPC2.BaseEnvelope, | |
) { | |
subscription.unsubscribe(); | |
respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'stopped', | |
}, | |
}); | |
} | |
const handleRequest = async (ws: MinimalWSInterface, msg: TRPCClientOutgoingMessage, send: BasicSendFunction) => { | |
if(!ws){ | |
throw new Error('handler was called with undefined websocket instance') | |
} | |
const { id, jsonrpc } = msg; | |
if (id === null) { | |
throw new TRPCError({ | |
code: 'BAD_REQUEST', | |
message: '`id` is required', | |
}); | |
} | |
const wsData = websockets.get(ws); | |
if (!wsData){ | |
throw new TRPCError({ | |
code: 'INTERNAL_SERVER_ERROR', | |
message: 'websocket instance not found in the adapter/handler' | |
}) | |
} | |
const { ctx, subscriptions } = wsData | |
if (msg.method === 'subscription.stop') { | |
const sub = subscriptions?.get(id); | |
if (sub) { | |
stopSubscription(ws, sub, { id, jsonrpc }); | |
} | |
subscriptions.delete(id); | |
return; | |
} | |
const { path, input } = msg.params; | |
const type = msg.method; | |
try { | |
const result = await callProcedure({ | |
procedures: router._def.procedures, | |
path, | |
rawInput: input, | |
ctx, | |
type, | |
}); | |
if (type === 'subscription') { | |
if (!isObservable(result)) { | |
throw new TRPCError({ | |
message: `Subscription ${path} did not return an observable`, | |
code: 'INTERNAL_SERVER_ERROR', | |
}); | |
} | |
} else { | |
// send the value as data if the method is not a subscription | |
respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'data', | |
data: result, | |
}, | |
}); | |
return; | |
} | |
const observable = result; | |
const sub = observable.subscribe({ | |
next(data) { | |
respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'data', | |
data, | |
}, | |
}); | |
}, | |
error(err) { | |
const error = getTRPCErrorFromUnknown(err); | |
// if there was an error callback provided we call it here | |
opts.onError?.({ error, path, type, ctx, input }); | |
respond(ws, { | |
id, | |
jsonrpc, | |
error: router.getErrorShape({ | |
error, | |
type, | |
path, | |
input, | |
ctx, | |
}), | |
}); | |
}, | |
complete() { | |
respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'stopped', | |
}, | |
}); | |
}, | |
}); | |
if (subscriptions.has(id)) { | |
// duplicate request ids for client | |
stopSubscription(ws, sub, { id, jsonrpc }); | |
throw new TRPCError({ | |
message: `Duplicate id ${id}`, | |
code: 'BAD_REQUEST', | |
}); | |
} | |
subscriptions.set(id, sub); | |
respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'started', | |
}, | |
}); | |
} catch (cause) { | |
// procedure threw an error | |
const error = getTRPCErrorFromUnknown(cause); | |
opts.onError?.({ error, path, type, ctx, input }); | |
respond(ws, { | |
id, | |
jsonrpc, | |
error: router.getErrorShape({ | |
error, | |
type, | |
path, | |
input, | |
ctx, | |
}), | |
}); | |
} | |
} | |
return { | |
onSocketOpen, | |
onSocketMessage, | |
onSocketClose, | |
broadcastReconnectNotification: () => { | |
const response: TRPCReconnectNotification = { | |
id: null, | |
method: 'reconnect', | |
} | |
const data = JSON.stringify(response); | |
for(const client of websockets.keys()){ | |
client.send(data); | |
} | |
} | |
} | |
} | |
function assertIsObject(obj: unknown): asserts obj is Record<string, unknown> { | |
if (typeof obj !== 'object' || Array.isArray(obj) || !obj) { | |
throw new Error('Not an object'); | |
} | |
} | |
function assertIsProcedureType(obj: unknown): asserts obj is ProcedureType { | |
if (obj !== 'query' && obj !== 'subscription' && obj !== 'mutation') { | |
throw new Error('Invalid procedure type'); | |
} | |
} | |
function assertIsRequestId( | |
obj: unknown, | |
): asserts obj is number | string | null { | |
if ( | |
obj !== null && | |
typeof obj === 'number' && | |
isNaN(obj) && | |
typeof obj !== 'string' | |
) { | |
throw new Error('Invalid request id'); | |
} | |
} | |
function assertIsString(obj: unknown): asserts obj is string { | |
if (typeof obj !== 'string') { | |
throw new Error('Invalid string'); | |
} | |
} | |
function assertIsJSONRPC2OrUndefined( | |
obj: unknown, | |
): asserts obj is '2.0' | undefined { | |
if (typeof obj !== 'undefined' && obj !== '2.0') { | |
throw new Error('Must be JSONRPC 2.0'); | |
} | |
} | |
function getMessageFromUnkownError( | |
err: unknown, | |
fallback: string, | |
): string { | |
if (typeof err === 'string') { | |
return err; | |
} | |
if (err instanceof Error && typeof err.message === 'string') { | |
return err.message; | |
} | |
return fallback; | |
} | |
function getErrorFromUnknown(cause: unknown): Error { | |
if (cause instanceof Error) { | |
return cause; | |
} | |
const message = getMessageFromUnkownError(cause, 'Unknown error'); | |
return new Error(message); | |
} | |
function getTRPCErrorFromUnknown(cause: unknown): TRPCError { | |
const error = getErrorFromUnknown(cause); | |
// this should ideally be an `instanceof TRPCError` but for some reason that isn't working | |
// ref https://github.com/trpc/trpc/issues/331 | |
if (error.name === 'TRPCError') { | |
return cause as TRPCError; | |
} | |
const trpcError = new TRPCError({ | |
code: 'INTERNAL_SERVER_ERROR', | |
cause: error, | |
message: error.message, | |
}); | |
// Inherit stack from error | |
trpcError.stack = error.stack; | |
return trpcError; | |
} | |
function getCauseFromUnknown(cause: unknown) { | |
if (cause instanceof Error) { | |
return cause; | |
} | |
return undefined; | |
} | |
function transformTRPCResponseItem< | |
TResponseItem extends TRPCResponse | TRPCResponseMessage, | |
>(router: AnyRouter, item: TResponseItem): TResponseItem { | |
if ('error' in item) { | |
return { | |
...item, | |
error: router._def._config.transformer.output.serialize(item.error), | |
}; | |
} | |
if ('data' in item.result) { | |
return { | |
...item, | |
result: { | |
...item.result, | |
data: router._def._config.transformer.output.serialize( | |
item.result.data, | |
), | |
}, | |
}; | |
} | |
return item; | |
} | |
/** | |
* Takes a unserialized `TRPCResponse` and serializes it with the router's transformers | |
**/ | |
function transformTRPCResponse< | |
TResponse extends | |
| TRPCResponse | |
| TRPCResponse[] | |
| TRPCResponseMessage | |
| TRPCResponseMessage[], | |
>(router: AnyRouter, itemOrItems: TResponse) { | |
return Array.isArray(itemOrItems) | |
? itemOrItems.map((item) => transformTRPCResponseItem(router, item)) | |
: transformTRPCResponseItem(router, itemOrItems); | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment