Skip to content

Instantly share code, notes, and snippets.

@airhorns
Created July 5, 2020 23:54
Show Gist options
  • Save airhorns/1e97831d1bbc3782a93bdb909ac38872 to your computer and use it in GitHub Desktop.
Save airhorns/1e97831d1bbc3782a93bdb909ac38872 to your computer and use it in GitHub Desktop.
declare let window: any;
const _global = typeof global !== "undefined" ? global : typeof window !== "undefined" ? window : {};
const NativeWebSocket = _global.WebSocket || _global.MozWebSocket;
import Backoff from "backo2";
import { default as EventEmitterType, EventEmitter, ListenerFn } from "eventemitter3";
import { ExecutionResult } from "graphql/execution/execute";
import { DocumentNode } from "graphql/language/ast";
import { print } from "graphql/language/printer";
import { getOperationAST } from "graphql/utilities/getOperationAST";
import { WS_TIMEOUT } from "subscriptions-transport-ws/dist/defaults";
import MessageTypes from "subscriptions-transport-ws/dist/message-types";
import { GRAPHQL_WS } from "subscriptions-transport-ws/dist/protocol";
import isObject from "subscriptions-transport-ws/dist/utils/is-object";
import isString from "subscriptions-transport-ws/dist/utils/is-string";
import $$observable from "symbol-observable";
export interface Observer<T> {
next?: (value: T) => void;
error?: (error: Error) => void;
complete?: () => void;
}
export interface Observable<T> {
subscribe(
observer: Observer<T>
): {
unsubscribe: () => void;
};
}
export interface OperationOptions {
query?: string | DocumentNode;
variables?: Record<string, any>;
operationName?: string;
[key: string]: any;
}
export type FormatedError = Error & {
originalError?: any;
};
export interface Operation {
options: OperationOptions;
handler: (error: Error[] | null, result?: any) => void;
}
export interface Operations {
[id: string]: Operation;
}
export interface Middleware {
applyMiddleware(options: OperationOptions, next: () => void): void;
}
export type ConnectionParams = {
[paramName: string]: any;
};
export type ConnectionParamsOptions = ConnectionParams | (() => ConnectionParams) | Promise<ConnectionParams>;
export interface ClientOptions {
connectionParams?: ConnectionParamsOptions;
timeout?: number;
reconnect?: boolean;
reconnectionAttempts?: number;
connectionCallback?: (error: Error[], result?: any) => void;
lazy?: boolean;
inactivityTimeout?: number;
}
/**
* Subscription client
*/
export class SubscriptionClient {
public client: any;
public operations: Operations;
private url: string;
private nextOperationId: number;
private connectionParamsGenerator: () => Promise<ConnectionParams>;
private wsTimeout: number;
private unsentMessagesQueue: Array<any>; // queued messages while websocket is opening.
private reconnect: boolean;
private reconnecting: boolean;
private reconnectionAttempts: number;
private backoff: any;
private connectionCallback: any;
private eventEmitter: EventEmitterType;
private lazy: boolean;
private inactivityTimeout: number;
private inactivityTimeoutId: any;
private closedByUser: boolean;
private wsImpl: any;
private wsProtocols: string | string[];
private wasKeepAliveReceived: boolean | undefined;
private tryReconnectTimeoutId: any;
private checkConnectionIntervalId: any;
private maxConnectTimeoutId: any;
private middlewares: Middleware[];
private maxConnectTimeGenerator: any;
constructor(url: string, options?: ClientOptions, webSocketImpl?: any, webSocketProtocols?: string | string[]) {
const {
connectionCallback = undefined,
connectionParams = {},
timeout = WS_TIMEOUT,
reconnect = false,
reconnectionAttempts = Infinity,
lazy = false,
inactivityTimeout = 0,
} = options || {};
this.wsImpl = webSocketImpl || NativeWebSocket;
if (!this.wsImpl) {
throw new Error("Unable to find native implementation, or alternative implementation for WebSocket!");
}
this.wsProtocols = webSocketProtocols || GRAPHQL_WS;
this.connectionCallback = connectionCallback;
this.url = url;
this.operations = {};
this.nextOperationId = 0;
this.wsTimeout = timeout;
this.unsentMessagesQueue = [];
this.reconnect = reconnect;
this.reconnecting = false;
this.reconnectionAttempts = reconnectionAttempts;
this.lazy = !!lazy;
this.inactivityTimeout = inactivityTimeout;
this.closedByUser = false;
this.backoff = new Backoff({ jitter: 0.5 });
this.eventEmitter = new EventEmitter();
this.middlewares = [];
this.client = null;
this.maxConnectTimeGenerator = this.createMaxConnectTimeGenerator();
this.connectionParamsGenerator = this.getConnectionParamsGenerator(connectionParams);
if (!this.lazy) {
this.connect();
}
}
public get status() {
if (this.client === null) {
return this.wsImpl.CLOSED;
}
return this.client.readyState;
}
public close(isForced = true, closedByUser = true) {
this.clearInactivityTimeout();
if (this.client !== null) {
this.closedByUser = closedByUser;
if (isForced) {
this.clearCheckConnectionInterval();
this.clearMaxConnectTimeout();
this.clearTryReconnectTimeout();
this.unsubscribeAll();
this.sendMessage(undefined, MessageTypes.GQL_CONNECTION_TERMINATE, null);
}
this.client.close();
this.client = null;
this.eventEmitter.emit("disconnected");
if (!isForced) {
this.tryReconnect();
}
}
}
public request(request: OperationOptions): Observable<ExecutionResult> {
const getObserver = this.getObserver.bind(this);
const executeOperation = this.executeOperation.bind(this);
const unsubscribe = this.unsubscribe.bind(this);
let operation: string;
this.clearInactivityTimeout();
return {
[$$observable]() {
return this;
},
subscribe(
observerOrNext: Observer<ExecutionResult> | ((v: ExecutionResult) => void),
onError?: (error: Error) => void,
onComplete?: () => void
) {
const observer = getObserver(observerOrNext, onError, onComplete);
operation = executeOperation(request, (error: Error[] | null, result: any) => {
if (error === null && result === null) {
if (observer.complete) {
observer.complete();
}
} else if (error) {
if (observer.error) {
observer.error(error[0]);
}
} else {
if (observer.next) {
observer.next(result);
}
}
});
return {
unsubscribe: () => {
if (operation) {
unsubscribe(operation);
}
},
};
},
};
}
public on(eventName: string, callback: ListenerFn, context?: any): () => void {
const handler = this.eventEmitter.on(eventName, callback, context);
return () => {
handler.off(eventName, callback, context);
};
}
public onConnected(callback: ListenerFn, context?: any): () => void {
return this.on("connected", callback, context);
}
public onConnecting(callback: ListenerFn, context?: any): () => void {
return this.on("connecting", callback, context);
}
public onDisconnected(callback: ListenerFn, context?: any): () => void {
return this.on("disconnected", callback, context);
}
public onReconnected(callback: ListenerFn, context?: any): () => void {
return this.on("reconnected", callback, context);
}
public onReconnecting(callback: ListenerFn, context?: any): () => void {
return this.on("reconnecting", callback, context);
}
public onError(callback: ListenerFn, context?: any): () => void {
return this.on("error", callback, context);
}
public unsubscribeAll() {
Object.keys(this.operations).forEach((subId) => {
this.unsubscribe(subId);
});
}
public applyMiddlewares(options: OperationOptions): Promise<OperationOptions> {
return new Promise((resolve, reject) => {
const queue = (funcs: Middleware[], scope: any) => {
const next = (error?: any) => {
if (error) {
reject(error);
} else {
if (funcs.length > 0) {
const f = funcs.shift();
if (f) {
f.applyMiddleware.apply(scope, [options, next]);
}
} else {
resolve(options);
}
}
};
next();
};
queue([...this.middlewares], this);
});
}
public use(middlewares: Middleware[]): SubscriptionClient {
middlewares.map((middleware) => {
if (typeof middleware.applyMiddleware === "function") {
this.middlewares.push(middleware);
} else {
throw new Error("Middleware must implement the applyMiddleware function.");
}
});
return this;
}
private getConnectionParamsGenerator(connectionParams: ConnectionParamsOptions) {
return (): Promise<ConnectionParams> =>
new Promise((resolve, reject) => {
if (typeof connectionParams === "function") {
try {
return resolve(connectionParams());
} catch (error) {
return reject(error);
}
}
resolve(connectionParams);
});
}
private executeOperation(options: OperationOptions, handler: (error: Error[] | null, result?: any) => void): string {
if (this.client === null) {
this.connect();
}
const opId = this.generateOperationId();
this.operations[opId] = { options: options, handler };
this.applyMiddlewares(options)
.then((processedOptions) => {
this.checkOperationOptions(processedOptions, handler);
if (this.operations[opId]) {
this.operations[opId] = { options: processedOptions, handler };
this.sendMessage(opId, MessageTypes.GQL_START, processedOptions);
}
})
.catch((error) => {
this.unsubscribe(opId);
handler(this.formatErrors(error));
});
return opId;
}
private getObserver<T>(observerOrNext: Observer<T> | ((v: T) => void), error?: (e: Error) => void, complete?: () => void) {
if (typeof observerOrNext === "function") {
return {
next: (v: T) => observerOrNext(v),
error: (e: Error) => error && error(e),
complete: () => complete && complete(),
};
}
return observerOrNext;
}
private createMaxConnectTimeGenerator() {
const minValue = 1000;
const maxValue = this.wsTimeout;
return new Backoff({
min: minValue,
max: maxValue,
factor: 1.2,
});
}
private clearCheckConnectionInterval() {
if (this.checkConnectionIntervalId) {
clearInterval(this.checkConnectionIntervalId);
this.checkConnectionIntervalId = null;
}
}
private clearMaxConnectTimeout() {
if (this.maxConnectTimeoutId) {
clearTimeout(this.maxConnectTimeoutId);
this.maxConnectTimeoutId = null;
}
}
private clearTryReconnectTimeout() {
if (this.tryReconnectTimeoutId) {
clearTimeout(this.tryReconnectTimeoutId);
this.tryReconnectTimeoutId = null;
}
}
private clearInactivityTimeout() {
if (this.inactivityTimeoutId) {
clearTimeout(this.inactivityTimeoutId);
this.inactivityTimeoutId = null;
}
}
private setInactivityTimeout() {
if (this.inactivityTimeout > 0 && Object.keys(this.operations).length === 0) {
this.inactivityTimeoutId = setTimeout(() => {
if (Object.keys(this.operations).length === 0) {
this.close();
}
}, this.inactivityTimeout);
}
}
private checkOperationOptions(options: OperationOptions, handler: (error: Error[], result?: any) => void) {
const { query, variables, operationName } = options;
if (!query) {
throw new Error("Must provide a query.");
}
if (!handler) {
throw new Error("Must provide an handler.");
}
if (
(!isString(query) && !getOperationAST(query, operationName)) ||
(operationName && !isString(operationName)) ||
(variables && !isObject(variables))
) {
throw new Error(
"Incorrect option types. query must be a string or a document," +
"`operationName` must be a string, and `variables` must be an object."
);
}
}
private buildMessage(id: string | undefined, type: string, payload: any) {
const payloadToReturn =
payload && payload.query
? {
...payload,
query: typeof payload.query === "string" ? payload.query : print(payload.query),
}
: payload;
return {
id,
type,
payload: payloadToReturn,
};
}
// ensure we have an array of errors
private formatErrors(errors: any): FormatedError[] {
if (Array.isArray(errors)) {
return errors;
}
// TODO we should not pass ValidationError to callback in the future.
// ValidationError
if (errors && errors.errors) {
return this.formatErrors(errors.errors);
}
if (errors && errors.message) {
return [errors];
}
return [
{
name: "FormatedError",
message: "Unknown error",
originalError: errors,
},
];
}
private sendMessage(id: string | undefined, type: string, payload: any) {
this.sendMessageRaw(this.buildMessage(id, type, payload));
}
// send message, or queue it if connection is not open
private sendMessageRaw(message: Record<string, any>) {
switch (this.status) {
case this.wsImpl.OPEN:
const serializedMessage: string = JSON.stringify(message);
try {
JSON.parse(serializedMessage);
} catch (e) {
this.eventEmitter.emit("error", new Error(`Message must be JSON-serializable. Got: ${message}`));
}
console.warn("sending to socket", serializedMessage);
this.client.send(serializedMessage);
break;
case this.wsImpl.CONNECTING:
console.warn("deferring message send", message);
this.unsentMessagesQueue.push(message);
break;
default:
console.warn("websocket state wrong", message);
if (!this.reconnecting) {
this.eventEmitter.emit(
"error",
new Error(
"A message was not sent because socket is not connected, is closing or " +
"is already closed. Message was: " +
JSON.stringify(message)
)
);
}
}
}
private generateOperationId(): string {
return String(++this.nextOperationId);
}
private tryReconnect() {
if (!this.reconnect || this.backoff.attempts >= this.reconnectionAttempts) {
return;
}
if (!this.reconnecting) {
Object.keys(this.operations).forEach((key) => {
this.unsentMessagesQueue.push(this.buildMessage(key, MessageTypes.GQL_START, this.operations[key].options));
});
this.reconnecting = true;
}
this.clearTryReconnectTimeout();
const delay = this.backoff.duration();
this.tryReconnectTimeoutId = setTimeout(() => {
this.connect();
}, delay);
}
private flushUnsentMessagesQueue() {
this.unsentMessagesQueue.forEach((message) => {
this.sendMessageRaw(message);
});
this.unsentMessagesQueue = [];
}
private checkConnection() {
if (this.wasKeepAliveReceived) {
this.wasKeepAliveReceived = false;
return;
}
if (!this.reconnecting) {
this.close(false, true);
}
}
private checkMaxConnectTimeout() {
this.clearMaxConnectTimeout();
// Max timeout trying to connect
this.maxConnectTimeoutId = setTimeout(() => {
if (this.status !== this.wsImpl.OPEN) {
this.reconnecting = true;
this.close(false, true);
}
}, this.maxConnectTimeGenerator.duration());
}
private connect() {
this.client = new this.wsImpl(this.url, this.wsProtocols);
this.checkMaxConnectTimeout();
this.client.onopen = async () => {
if (this.status === this.wsImpl.OPEN) {
this.clearMaxConnectTimeout();
this.closedByUser = false;
this.eventEmitter.emit(this.reconnecting ? "reconnecting" : "connecting");
try {
const connectionParams: ConnectionParams = await this.connectionParamsGenerator();
this.sendMessage(undefined, MessageTypes.GQL_CONNECTION_INIT, connectionParams);
} catch (error) {
this.sendMessage(undefined, MessageTypes.GQL_CONNECTION_ERROR, error);
}
this.flushUnsentMessagesQueue();
}
};
this.client.onclose = () => {
if (!this.closedByUser) {
this.close(false, false);
}
};
this.client.onerror = (err: Error) => {
// Capture and ignore errors to prevent unhandled exceptions, wait for
// onclose to fire before attempting a reconnect.
this.eventEmitter.emit("error", err);
};
this.client.onmessage = ({ data }: { data: any }) => {
this.processReceivedData(data);
};
}
private processReceivedData(receivedData: any) {
let parsedMessage: any;
let opId: string;
try {
parsedMessage = JSON.parse(receivedData);
opId = parsedMessage.id;
} catch (e) {
throw new Error(`Message must be JSON-parseable. Got: ${receivedData}`);
}
console.warn("recievedData", receivedData);
if (
// eslint-disable-next-line lodash/prefer-includes
[MessageTypes.GQL_DATA, MessageTypes.GQL_COMPLETE, MessageTypes.GQL_ERROR].indexOf(parsedMessage.type) !== -1 &&
!this.operations[opId]
) {
this.unsubscribe(opId);
return;
}
switch (parsedMessage.type) {
case MessageTypes.GQL_CONNECTION_ERROR:
if (this.connectionCallback) {
this.connectionCallback(parsedMessage.payload);
}
break;
case MessageTypes.GQL_CONNECTION_ACK:
this.eventEmitter.emit(this.reconnecting ? "reconnected" : "connected");
this.reconnecting = false;
this.backoff.reset();
this.maxConnectTimeGenerator.reset();
if (this.connectionCallback) {
this.connectionCallback();
}
break;
case MessageTypes.GQL_COMPLETE:
this.operations[opId].handler(null, null);
delete this.operations[opId];
break;
case MessageTypes.GQL_ERROR:
this.operations[opId].handler(this.formatErrors(parsedMessage.payload), null);
delete this.operations[opId];
break;
case MessageTypes.GQL_DATA:
const parsedPayload = !parsedMessage.payload.errors
? parsedMessage.payload
: { ...parsedMessage.payload, errors: this.formatErrors(parsedMessage.payload.errors) };
this.operations[opId].handler(null, parsedPayload);
break;
case MessageTypes.GQL_CONNECTION_KEEP_ALIVE:
const firstKA = typeof this.wasKeepAliveReceived === "undefined";
this.wasKeepAliveReceived = true;
if (firstKA) {
this.checkConnection();
}
if (this.checkConnectionIntervalId) {
clearInterval(this.checkConnectionIntervalId);
this.checkConnection();
}
this.checkConnectionIntervalId = setInterval(this.checkConnection.bind(this), this.wsTimeout);
break;
default:
throw new Error("Invalid message type!");
}
}
private unsubscribe(opId: string) {
if (this.operations[opId]) {
delete this.operations[opId];
this.setInactivityTimeout();
this.sendMessage(opId, MessageTypes.GQL_STOP, undefined);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment