Last active
May 27, 2023 00:17
-
-
Save b0o/2ecbf884e671e2ada6ab2248a87e3191 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Ported from https://github.com/trpc/trpc/blob/main/packages/server/src/adapters/ws.ts | |
import type { ServeOptions, Server, ServerWebSocket } from 'bun' | |
import { | |
AnyRouter, | |
CombinedDataTransformer, | |
ProcedureType, | |
TRPCError, | |
callProcedure, | |
getTRPCErrorFromUnknown, | |
inferRouterContext, | |
} from '@trpc/server' | |
import { Unsubscribable, isObservable } from '@trpc/server/observable' | |
import { JSONRPC2, TRPCClientOutgoingMessage, TRPCResponseMessage } from '@trpc/server/rpc' | |
import { transformTRPCResponse } from '@trpc/server/shared' | |
import { FetchCreateContextOption } from '@trpc/server/adapters/fetch' | |
function getCauseFromUnknown(cause: unknown) { | |
if (cause instanceof Error) { | |
return cause | |
} | |
return undefined | |
} | |
function assertIsObject(obj: unknown): asserts obj is Record<string, unknown> { | |
if (typeof obj !== 'object' || Array.isArray(obj) || !obj) { | |
throw new Error('Not an object') | |
} | |
} | |
function assertIsProcedureType(obj: unknown): asserts obj is ProcedureType { | |
if (obj !== 'query' && obj !== 'subscription' && obj !== 'mutation') { | |
throw new Error('Invalid procedure type') | |
} | |
} | |
function assertIsRequestId(obj: unknown): asserts obj is number | string | null { | |
if (obj !== null && typeof obj === 'number' && isNaN(obj) && typeof obj !== 'string') { | |
throw new Error('Invalid request id') | |
} | |
} | |
function assertIsString(obj: unknown): asserts obj is string { | |
if (typeof obj !== 'string') { | |
throw new Error('Invalid string') | |
} | |
} | |
function assertIsJSONRPC2OrUndefined(obj: unknown): asserts obj is '2.0' | undefined { | |
if (typeof obj !== 'undefined' && obj !== '2.0') { | |
throw new Error('Must be JSONRPC 2.0') | |
} | |
} | |
export function parseMessage(obj: unknown, transformer: CombinedDataTransformer): TRPCClientOutgoingMessage { | |
assertIsObject(obj) | |
const { method, params, id, jsonrpc } = obj | |
assertIsRequestId(id) | |
assertIsJSONRPC2OrUndefined(jsonrpc) | |
if (method === 'subscription.stop') { | |
return { | |
id, | |
jsonrpc, | |
method, | |
} | |
} | |
assertIsProcedureType(method) | |
assertIsObject(params) | |
const { input: rawInput, path } = params | |
assertIsString(path) | |
const input = transformer.input.deserialize(rawInput) | |
return { | |
id, | |
jsonrpc, | |
method, | |
params: { | |
input, | |
path, | |
}, | |
} | |
} | |
interface BunWSSHandlerData { | |
id?: number | |
req: Request | |
} | |
type WSSMessage = string | ArrayBuffer | Uint8Array | |
function wssMessageToString(data: WSSMessage): string { | |
if (typeof data === 'string') { | |
return data | |
} | |
if (data instanceof ArrayBuffer) { | |
return Buffer.from(data).toString('utf-8') | |
} | |
return Buffer.from(data).toString('utf-8') | |
} | |
function parseWssMessage(data: WSSMessage): unknown { | |
return JSON.parse(wssMessageToString(data)) | |
} | |
export type BunWSSHandlerOptions<TData extends BunWSSHandlerData, TRouter extends AnyRouter> = { | |
router: TRouter | |
onOpen?: (this: BunWSSHandler<TData, TRouter>, client: ServerWebSocket<TData>) => void | Promise<void> | |
onMessage?: ( | |
this: BunWSSHandler<TData, TRouter>, | |
client: ServerWebSocket<TData>, | |
message: WSSMessage, | |
) => void | Promise<void> | |
onClose?: ( | |
this: BunWSSHandler<TData, TRouter>, | |
client: ServerWebSocket<TData>, | |
code: number, | |
reason: string, | |
) => void | Promise<void> | |
onDrain?: (this: BunWSSHandler<TData, TRouter>, client: ServerWebSocket<TData>) => void | Promise<void> | |
} & FetchCreateContextOption<TRouter> | |
export interface BunWSSClient<TData, TContext> { | |
ws: ServerWebSocket<TData> | |
subscriptions: Map<number | string, Unsubscribable> | |
ctx?: TContext | |
} | |
enum ReadyState { | |
CONNECTING = 0, | |
OPEN = 1, | |
CLOSING = 2, | |
CLOSED = 3, | |
} | |
export class BunWSSHandler<TData extends BunWSSHandlerData, TRouter extends AnyRouter> { | |
private opts: BunWSSHandlerOptions<TData, TRouter> | |
private router: TRouter | |
private clients: Map<number, BunWSSClient<TData, inferRouterContext<TRouter>>> | |
private nextId = 0 | |
constructor(opts: BunWSSHandlerOptions<TData, TRouter>) { | |
console.log('BunWSSHandler constructor', opts) | |
this.opts = opts | |
this.router = opts.router | |
this.clients = new Map() | |
} | |
get connectionCount() { | |
return this.clients.size | |
} | |
private handleError(ws: ServerWebSocket<TData>, error: unknown) { | |
const err = error instanceof Error ? error : new Error(`Unknown error: ${error}`) | |
let reason = 'Unknown error' | |
if (error instanceof Error) { | |
reason = error.message | |
} | |
let code = 1008 | |
if (typeof err === 'object' && 'code' in err && typeof err.code === 'number') { | |
code = err.code | |
} | |
ws.close(code, reason) | |
} | |
private respond(ws: ServerWebSocket<TData>, untransformedJSON: TRPCResponseMessage) { | |
ws.send(JSON.stringify(transformTRPCResponse(this.router, untransformedJSON))) | |
} | |
async handleRequest(ws: ServerWebSocket<TData>, msg: TRPCClientOutgoingMessage) { | |
if (ws.data.id === undefined) { | |
throw new TRPCError({ code: 'BAD_REQUEST', message: 'Missing client id' }) | |
} | |
const client = this.clients.get(ws.data.id) | |
if (!client) { | |
throw new TRPCError({ code: 'BAD_REQUEST', message: 'Unknown client id' }) | |
} | |
const { id, jsonrpc } = msg | |
/* istanbul ignore next -- @preserve */ | |
if (id === null) { | |
throw new TRPCError({ | |
code: 'BAD_REQUEST', | |
message: 'missing message id', | |
}) | |
} | |
const stopSubscription = ( | |
subscription: Unsubscribable, | |
{ id, jsonrpc }: { id: JSONRPC2.RequestId } & JSONRPC2.BaseEnvelope, | |
) => { | |
subscription.unsubscribe() | |
this.respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'stopped', | |
}, | |
}) | |
} | |
if (msg.method === 'subscription.stop') { | |
const sub = client.subscriptions.get(id) | |
if (sub) { | |
stopSubscription(sub, { id, jsonrpc }) | |
} | |
client.subscriptions.delete(id) | |
return | |
} | |
const { path, input } = msg.params | |
const type = msg.method | |
try { | |
// TODO | |
// await ws.ctxPromise // asserts context has been set | |
const result = await callProcedure({ | |
procedures: this.opts.router._def.procedures, | |
path, | |
rawInput: input, | |
ctx: client.ctx, | |
type, | |
}) | |
if (type === 'subscription') { | |
if (!isObservable(result)) { | |
throw new TRPCError({ | |
message: `Subscription ${path} did not return an observable`, | |
code: 'INTERNAL_SERVER_ERROR', | |
}) | |
} | |
} else { | |
// send the value as data if the method is not a subscription | |
this.respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'data', | |
data: result, | |
}, | |
}) | |
return | |
} | |
const observable = result | |
const sub = observable.subscribe({ | |
next: data => { | |
console.log('next', data) | |
this.respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'data', | |
data, | |
}, | |
}) | |
}, | |
error: err => { | |
const error = getTRPCErrorFromUnknown(err) | |
// this.opts.onError?.({ error, path, type, ctx: ws.ctx, req: this.opts.req, input }) | |
this.respond(ws, { | |
id, | |
jsonrpc, | |
error: this.opts.router.getErrorShape({ | |
error, | |
type, | |
path, | |
input, | |
ctx: client.ctx, | |
}), | |
}) | |
}, | |
complete: () => { | |
console.log('complete') | |
this.respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'stopped', | |
}, | |
}) | |
}, | |
}) | |
if (ws.readyState !== ReadyState.OPEN) { | |
// if the client got disconnected whilst initializing the subscription | |
// no need to send stopped message if the client is disconnected | |
sub.unsubscribe() | |
return | |
} | |
if (client.subscriptions.has(id)) { | |
// duplicate request ids for client | |
stopSubscription(sub, { id, jsonrpc }) | |
throw new TRPCError({ | |
message: `Duplicate id ${id}`, | |
code: 'BAD_REQUEST', | |
}) | |
} | |
client.subscriptions.set(id, sub) | |
console.log('started') | |
this.respond(ws, { | |
id, | |
jsonrpc, | |
result: { | |
type: 'started', | |
}, | |
}) | |
} catch (cause) /* istanbul ignore next -- @preserve */ { | |
// procedure threw an error | |
const error = getTRPCErrorFromUnknown(cause) | |
// this.opts.onError?.({ error, path, type, ctx: ws.ctx, req: this.opts.req, input }) | |
this.respond(ws, { | |
id, | |
jsonrpc, | |
error: this.opts.router.getErrorShape({ | |
error, | |
type, | |
path, | |
input, | |
ctx: client.ctx, | |
}), | |
}) | |
} | |
} | |
async open(ws: ServerWebSocket<TData>) { | |
ws.data.id = this.nextId++ | |
const client: BunWSSClient<TData, inferRouterContext<TRouter>> = { | |
ws, | |
subscriptions: new Map(), | |
} | |
this.clients.set(ws.data.id, client) | |
if (this.opts.onOpen) { | |
try { | |
await this.opts.onOpen.call(this, ws) | |
} catch (error) { | |
this.handleError(ws, error) | |
return | |
} | |
} | |
// TODO | |
const resHeaders = new Headers() | |
const ctxPromise = | |
this.opts.createContext?.({ | |
req: ws.data.req, | |
resHeaders, | |
}) ?? Promise.resolve() | |
try { | |
client.ctx = await ctxPromise | |
} catch (cause) { | |
const error = getTRPCErrorFromUnknown(cause) | |
// TODO | |
// this.opts.onError?.({ | |
// error, | |
// path: undefined, | |
// type: 'unknown', | |
// ctx: ws.ctx, | |
// req: this.opts.req, | |
// input: undefined, | |
// }) | |
this.respond(ws, { | |
id: null, | |
error: this.opts.router.getErrorShape({ | |
error, | |
type: 'unknown', | |
path: undefined, | |
input: undefined, | |
ctx: client.ctx, | |
}), | |
}) | |
// close in next tick | |
global.setImmediate(() => { | |
ws.close() | |
}) | |
} | |
} | |
async message(ws: ServerWebSocket<TData>, message: WSSMessage) { | |
console.log(`message (${ws.data.id})`, message) | |
if (this.opts.onMessage) { | |
try { | |
await this.opts.onMessage.call(this, ws, message) | |
} catch (error) { | |
this.handleError(ws, error) | |
return | |
} | |
} | |
try { | |
const msgJSON: unknown = parseWssMessage(message) | |
const msgs: unknown[] = Array.isArray(msgJSON) ? msgJSON : [msgJSON] | |
const promises = msgs | |
.map(raw => parseMessage(raw, this.opts.router._def._config.transformer)) | |
.map(msg => this.handleRequest(ws, msg)) | |
await Promise.all(promises) | |
} catch (cause) { | |
const error = new TRPCError({ | |
code: 'PARSE_ERROR', | |
cause: getCauseFromUnknown(cause), | |
}) | |
this.respond(ws, { | |
id: null, | |
error: this.opts.router.getErrorShape({ | |
error, | |
type: 'unknown', | |
path: undefined, | |
input: undefined, | |
ctx: undefined, | |
}), | |
}) | |
} | |
} | |
close(ws: ServerWebSocket<TData>, code: number, reason: string) { | |
console.log('close', code, reason) | |
if (this.opts.onClose) { | |
try { | |
this.opts.onClose.call(this, ws, code, reason) | |
} catch (error) { | |
// TODO: Is it safe to call close() here? | |
this.handleError(ws, error) | |
return | |
} | |
} | |
if (ws.data.id !== undefined) { | |
this.clients.delete(ws.data.id) | |
} | |
} | |
drain(ws: ServerWebSocket<TData>) { | |
console.log('drain') | |
if (this.opts.onDrain) { | |
try { | |
this.opts.onDrain.call(this, ws) | |
} catch (error) { | |
this.handleError(ws, error) | |
return | |
} | |
} | |
} | |
serve(opts: Omit<ServeOptions, 'fetch' | 'websocket'>) { | |
return Bun.serve({ | |
...opts, | |
fetch: (req: Request, server: Server) => { | |
const data: BunWSSHandlerData = { | |
req, | |
} | |
server.upgrade(req, { data }) | |
return undefined | |
}, | |
websocket: { | |
open: (ws: ServerWebSocket<TData>) => this.open(ws), | |
message: (ws: ServerWebSocket<TData>, message: WSSMessage) => this.message(ws, message), | |
close: (ws: ServerWebSocket<TData>, code: number, reason: string) => this.close(ws, code, reason), | |
drain: (ws: ServerWebSocket<TData>) => this.drain(ws), | |
}, | |
}) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment