Last active
December 22, 2022 10:41
-
-
Save razor-x/e19d7d776cdf58d04af1e223b0757064 to your computer and use it in GitHub Desktop.
AppSync using Apollo Client with subscription support, and custom domain via Lambda@Edge
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import type { | |
NormalizedCacheObject, | |
PossibleTypesMap, | |
Resolvers, | |
TypePolicies | |
} from '@apollo/client' | |
import { | |
ApolloClient as Client, | |
InMemoryCache, | |
createHttpLink, | |
from, | |
split | |
} from '@apollo/client' | |
import { onError } from '@apollo/client/link/error' | |
import { setContext } from '@apollo/client/link/context' | |
import { getMainDefinition } from '@apollo/client/utilities' | |
import { CachePersistor } from 'apollo-cache-persist' | |
import type { AuthClient } from 'src/lib/auth-client' | |
import { createWebsocketLink } from './appsync' | |
type Config = { | |
authClient: AuthClient | |
origin: string | |
wsOrigin: string | |
path: string | |
apiKey: string | |
appsyncHost: string | |
possibleTypes: PossibleTypesMap | |
typePolicies: TypePolicies | |
resolvers: Resolvers | |
schemaVersion: string | |
schemaVersionKey: string | |
} | |
export type ApolloClient = Client<NormalizedCacheObject> | |
export const createApolloClient = ({ | |
authClient, | |
origin, | |
wsOrigin, | |
path, | |
apiKey, | |
appsyncHost, | |
possibleTypes, | |
typePolicies, | |
resolvers, | |
schemaVersion, | |
schemaVersionKey | |
}: Config) => { | |
const uri = [origin, path].join('') | |
const getAuthorization = () => authClient.getJwtToken().catch(() => null) | |
const httpLink = createHttpLink({ | |
uri | |
}) | |
const { wsLink, setWsAuthorization } = createWebsocketLink({ | |
origin: wsOrigin, | |
path, | |
host: appsyncHost, | |
getAuthorization | |
}) | |
const authLink = setContext(async (_, { headers }) => { | |
const authorization = await getAuthorization() | |
return { | |
headers: { | |
...headers, | |
...(authorization ? { authorization } : { 'x-api-key': apiKey }) | |
} | |
} | |
}) | |
const errorLink = onError(({ networkError }) => { | |
if ( | |
networkError && | |
'statusCode' in networkError && | |
networkError.statusCode === 401 | |
) { | |
authClient.signOut().catch(() => { | |
window.location.reload(true) | |
}) | |
} | |
}) | |
const link = split( | |
({ query }) => { | |
const definition = getMainDefinition(query) | |
return ( | |
definition.kind === 'OperationDefinition' && | |
definition.operation === 'subscription' | |
) | |
}, | |
wsLink, | |
from([authLink, errorLink, httpLink]) | |
) | |
const cache = new InMemoryCache({ | |
typePolicies, | |
possibleTypes | |
}) | |
const client = new Client<NormalizedCacheObject>({ | |
resolvers, | |
link, | |
cache | |
}) | |
const storage = window.localStorage as any | |
const persistor = new CachePersistor({ | |
cache, | |
storage | |
}) | |
const init = async () => { | |
const currentSchemaVersion = window.localStorage.getItem(schemaVersionKey) | |
if (currentSchemaVersion === schemaVersion) { | |
await persistor.restore() | |
} else { | |
await persistor.purge() | |
window.localStorage.setItem(schemaVersionKey, schemaVersion) | |
} | |
const authorization = await getAuthorization() | |
setWsAuthorization(authorization) | |
} | |
return { | |
client, | |
init | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import type { ClientOptions } from 'subscriptions-transport-ws' | |
import { SubscriptionClient } from 'subscriptions-transport-ws' | |
import { WebSocketLink } from '@apollo/client/link/ws' | |
import { print } from 'graphql/language/printer' | |
type Config = { | |
origin: string | |
path: string | |
host: string | |
getAuthorization: () => Promise<string | null> | |
} | |
type WebSocketLinkWrapper = { | |
wsLink: WebSocketLink | |
setWsAuthorization: (authorization: string | null) => void | |
} | |
const getUri = (origin: string, path: string) => { | |
const url = new URL(origin) | |
url.pathname = path | |
url.searchParams.set('ws', 'true') | |
return url.toString() | |
} | |
export const createWebsocketLink = ({ | |
origin, | |
path, | |
host, | |
getAuthorization | |
}: Config): WebSocketLinkWrapper => { | |
const uri = getUri(origin, path) | |
const middleware = { | |
applyMiddleware: async (options: any, next: any) => { | |
if (options.query) { | |
const authorization = await getAuthorization() | |
options.data = JSON.stringify({ | |
query: | |
typeof options.query === 'string' | |
? options.query | |
: print(options.query), | |
variables: options.variables | |
}) | |
options.extensions = { | |
authorization: { | |
authorization, | |
host | |
} | |
} | |
} | |
next() | |
} | |
} | |
let subscription: AWSSubscriptionClient | null = null | |
const setWsAuthorization = (authorization: string | null) => { | |
const url = new URL(uri) | |
if (authorization) url.searchParams.set('authorization', authorization) | |
if (subscription) subscription.setUrl(url.toString()) | |
} | |
const connectionCallback = async (message: any) => { | |
if (message) { | |
const { errors } = message | |
if (errors && errors.length > 0) { | |
const error = errors[0] | |
if (error) { | |
if (error.errorCode === 401) { | |
if (subscription) { | |
const authorization = await getAuthorization() | |
setWsAuthorization(authorization) | |
// Re-apply middleware to operation options since it could have | |
// an invalid token embedded in the options. | |
for (const key in Object.keys(subscription.operations)) { | |
if (key) { | |
const val = subscription.operations[key] | |
if (val) { | |
val.options = await subscription.applyMiddlewares( | |
val.options | |
) | |
} | |
} | |
} | |
// Force close after a 401. | |
// This will auto-reconnect | |
// if reconnect = true on the client options. | |
subscription.close(false, false) | |
} | |
} | |
} | |
} | |
} | |
} | |
subscription = new AWSSubscriptionClient(uri, { | |
lazy: true, | |
reconnect: true, | |
timeout: 5 * 60 * 1000, | |
connectionCallback | |
}) | |
const wsLink = new WebSocketLink(subscription) | |
// @ts-ignore | |
wsLink.subscriptionClient.use([middleware]) | |
return { | |
wsLink, | |
setWsAuthorization | |
} | |
} | |
class AWSSubscriptionClient extends SubscriptionClient { | |
constructor( | |
url: string, | |
options?: ClientOptions, | |
webSocketImpl?: any, | |
webSocketProtocols?: string | string[] | |
) { | |
super(url, options, webSocketImpl, webSocketProtocols) | |
// Since we are in TS and these functions are private | |
// we cannot directly override in this child class, | |
// so we use this trick (which is not safe) to override | |
// the parent functions. | |
// @ts-ignore | |
this.flushUnsentMessagesQueue = this.flush | |
// @ts-ignore | |
this.processReceivedData = this.process | |
} | |
public setUrl(url: string) { | |
// @ts-ignore | |
this.url = url | |
} | |
// Filter out duplicate messages before flushing queue. | |
private flush() { | |
const messages = this.getUnsentMessagesQueue() | |
const map = new Map() | |
for (const message of messages) { | |
const id = message.id | |
if (!map.has(id)) { | |
map.set(id, true) | |
// @ts-ignore | |
super.sendMessageRaw(message) | |
} | |
} | |
this.setUnsentMessagesQueue([]) | |
} | |
// Ignore start_ack message from AppSync: | |
// they are not a valid GQL message type. | |
private process(receivedData: string) { | |
let message: any = null | |
try { | |
message = JSON.parse(receivedData) | |
} catch (e) { | |
throw new Error(`Message must be JSON-parseable. Got: ${receivedData}`) | |
} | |
if (message.type === 'start_ack') { | |
const newQueue = this.getUnsentMessagesQueue().filter( | |
(el) => el.id !== message.id | |
) | |
this.setUnsentMessagesQueue(newQueue) | |
return | |
} | |
// @ts-ignore | |
super.processReceivedData(receivedData) | |
} | |
private getUnsentMessagesQueue(): any[] { | |
// @ts-ignore | |
return this.unsentMessagesQueue || [] | |
} | |
private setUnsentMessagesQueue(queue: any[]): void { | |
// @ts-ignore | |
this.unsentMessagesQueue = queue | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { realtimeSearchParam } from './realtime-search-param.mjs' | |
const authorizationSearchParam = 'authorization' | |
export const handler = (event, context, callback) => { | |
const request = event.Records[0].cf.request | |
const searchParams = new URLSearchParams(request.querystring) | |
if (!isRealtimeReq(searchParams)) return callback(null, request) | |
const { domainName } = request.origin.custom | |
request.origin.custom.domainName = toRealtimeDomain(domainName) | |
request.headers.host = [{ key: 'Host', value: domainName }] | |
request.querystring = getQuerystring(request, searchParams) | |
return callback(null, request) | |
} | |
const isRealtimeReq = (searchParams) => { | |
const param = searchParams.get(realtimeSearchParam) | |
return param && param.toString() === 'true' | |
} | |
const toRealtimeDomain = (domainName) => | |
domainName.replace('appsync-api', 'appsync-realtime-api') | |
const getQuerystring = (request, searchParams) => { | |
const hostHeader = request.headers.host | |
const authorization = searchParams.get(authorizationSearchParam) | |
const headerObj = {} | |
if (hostHeader && hostHeader[0]) headerObj.host = hostHeader[0].value | |
if (authorization) headerObj.authorization = authorization | |
const headerJson = JSON.stringify(headerObj) | |
const headerBase64 = Buffer.from(headerJson).toString('base64') | |
const payloadBase64 = Buffer.from('{}').toString('base64') | |
searchParams.set('header', headerBase64) | |
searchParams.set('payload', payloadBase64) | |
searchParams.delete(realtimeSearchParam) | |
searchParams.delete(authorizationSearchParam) | |
return searchParams.toString() | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { realtimeSearchParam } from './realtime-search-param.mjs' | |
const rootDomain = 'pureskill.gg' | |
export const handler = (event, context, callback) => { | |
const request = event.Records[0].cf.request | |
const originHeader = request.headers.origin | |
const method = request.method | |
const headers = getHeaders(originHeader, method) | |
const { response } = event.Records[0].cf | |
for (const k of Object.keys(corsHeaders())) delete response.headers[k] | |
for (const [k, v] of Object.entries(headers)) response.headers[k] = v | |
if (isInvalidOriginForRealtimeReq(request, originHeader)) { | |
response.status = 403 | |
response.statusDescription = 'Forbidden' | |
} | |
return callback(null, response) | |
} | |
const getHeaders = (originHeader, method) => { | |
const hasOrigin = originHeader && originHeader[0] | |
if (!hasOrigin) return commonHeaders | |
const origin = originHeader[0].value | |
if (!isValidOrigin(origin)) return commonHeaders | |
return { | |
...commonHeaders, | |
...corsHeaders(origin), | |
...(method === 'OPTIONS' ? cacheHeaders : {}) | |
} | |
} | |
const isInvalidOriginForRealtimeReq = (request, originHeader) => { | |
const searchParams = new URLSearchParams(request.querystring) | |
if (!isRealtimeReq(searchParams)) return false | |
const hasOrigin = originHeader && originHeader[0] | |
if (!hasOrigin) return true | |
const origin = originHeader[0].value | |
if (isValidOrigin(origin)) return false | |
return true | |
} | |
const isRealtimeReq = (searchParams) => { | |
const param = searchParams.get(realtimeSearchParam) | |
return param && param.toString() === 'true' | |
} | |
const isValidOrigin = (origin) => { | |
try { | |
const url = new URL(origin) | |
if (url.protocol !== 'https:') return false | |
if (url.hostname === rootDomain) return true | |
if (url.hostname.endsWith(`.${rootDomain}`)) return true | |
return false | |
} catch { | |
return false | |
} | |
} | |
const cacheHeaders = { | |
'cache-control': [ | |
{ | |
key: 'Cache-Control', | |
value: 'max-age=86400' | |
} | |
] | |
} | |
const corsHeaders = (origin) => ({ | |
'access-control-allow-origin': [ | |
{ | |
key: 'Access-Control-Allow-Origin', | |
value: origin | |
} | |
], | |
'access-control-allow-credentials': [ | |
{ | |
key: 'Access-Control-Allow-Credentials', | |
value: 'true' | |
} | |
], | |
'access-control-allow-methods': [ | |
{ | |
key: 'Access-Control-Allow-Methods', | |
value: ['DELETE', 'GET', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT'].join( | |
', ' | |
) | |
} | |
] | |
}) | |
const commonHeaders = { | |
vary: [ | |
{ | |
key: 'Vary', | |
value: 'Origin' | |
} | |
], | |
'access-control-max-age': [ | |
{ | |
key: 'Access-Control-Max-Age', | |
value: '86400' | |
} | |
], | |
'referrer-policy': [ | |
{ | |
key: 'Referrer-Policy', | |
value: 'same-origin' | |
} | |
], | |
'strict-transport-security': [ | |
{ | |
key: 'Strict-Transport-Security', | |
value: 'max-age=63072000; includeSubdomains; preload' | |
} | |
], | |
'x-content-type-options': [ | |
{ | |
key: 'X-Content-Type-Options', | |
value: 'nosniff' | |
} | |
], | |
'x-frame-options': [ | |
{ | |
key: 'X-Frame-Options', | |
value: 'DENY' | |
} | |
], | |
'x-xss-protection': [ | |
{ | |
key: 'X-XSS-Protection', | |
value: '1; mode=block' | |
} | |
] | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
export const realtimeSearchParam = 'ws' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Relevant versions:
The WebSocket integration is based off of https://gist.github.com/zachboyd/f5630736b0a5a9b627d61bfd25299c90 and the custom domain support is from awslabs/aws-mobile-appsync-sdk-js#517 (comment)