Skip to content

Instantly share code, notes, and snippets.

@razor-x
Last active December 22, 2022 10:41
Show Gist options
  • Save razor-x/e19d7d776cdf58d04af1e223b0757064 to your computer and use it in GitHub Desktop.
Save razor-x/e19d7d776cdf58d04af1e223b0757064 to your computer and use it in GitHub Desktop.
AppSync using Apollo Client with subscription support, and custom domain via Lambda@Edge
import type {
NormalizedCacheObject,
PossibleTypesMap,
Resolvers,
TypePolicies
} from '@apollo/client'
import {
ApolloClient as Client,
InMemoryCache,
createHttpLink,
from,
split
} from '@apollo/client'
import { onError } from '@apollo/client/link/error'
import { setContext } from '@apollo/client/link/context'
import { getMainDefinition } from '@apollo/client/utilities'
import { CachePersistor } from 'apollo-cache-persist'
import type { AuthClient } from 'src/lib/auth-client'
import { createWebsocketLink } from './appsync'
type Config = {
authClient: AuthClient
origin: string
wsOrigin: string
path: string
apiKey: string
appsyncHost: string
possibleTypes: PossibleTypesMap
typePolicies: TypePolicies
resolvers: Resolvers
schemaVersion: string
schemaVersionKey: string
}
export type ApolloClient = Client<NormalizedCacheObject>
export const createApolloClient = ({
authClient,
origin,
wsOrigin,
path,
apiKey,
appsyncHost,
possibleTypes,
typePolicies,
resolvers,
schemaVersion,
schemaVersionKey
}: Config) => {
const uri = [origin, path].join('')
const getAuthorization = () => authClient.getJwtToken().catch(() => null)
const httpLink = createHttpLink({
uri
})
const { wsLink, setWsAuthorization } = createWebsocketLink({
origin: wsOrigin,
path,
host: appsyncHost,
getAuthorization
})
const authLink = setContext(async (_, { headers }) => {
const authorization = await getAuthorization()
return {
headers: {
...headers,
...(authorization ? { authorization } : { 'x-api-key': apiKey })
}
}
})
const errorLink = onError(({ networkError }) => {
if (
networkError &&
'statusCode' in networkError &&
networkError.statusCode === 401
) {
authClient.signOut().catch(() => {
window.location.reload(true)
})
}
})
const link = split(
({ query }) => {
const definition = getMainDefinition(query)
return (
definition.kind === 'OperationDefinition' &&
definition.operation === 'subscription'
)
},
wsLink,
from([authLink, errorLink, httpLink])
)
const cache = new InMemoryCache({
typePolicies,
possibleTypes
})
const client = new Client<NormalizedCacheObject>({
resolvers,
link,
cache
})
const storage = window.localStorage as any
const persistor = new CachePersistor({
cache,
storage
})
const init = async () => {
const currentSchemaVersion = window.localStorage.getItem(schemaVersionKey)
if (currentSchemaVersion === schemaVersion) {
await persistor.restore()
} else {
await persistor.purge()
window.localStorage.setItem(schemaVersionKey, schemaVersion)
}
const authorization = await getAuthorization()
setWsAuthorization(authorization)
}
return {
client,
init
}
}
import type { ClientOptions } from 'subscriptions-transport-ws'
import { SubscriptionClient } from 'subscriptions-transport-ws'
import { WebSocketLink } from '@apollo/client/link/ws'
import { print } from 'graphql/language/printer'
type Config = {
origin: string
path: string
host: string
getAuthorization: () => Promise<string | null>
}
type WebSocketLinkWrapper = {
wsLink: WebSocketLink
setWsAuthorization: (authorization: string | null) => void
}
const getUri = (origin: string, path: string) => {
const url = new URL(origin)
url.pathname = path
url.searchParams.set('ws', 'true')
return url.toString()
}
export const createWebsocketLink = ({
origin,
path,
host,
getAuthorization
}: Config): WebSocketLinkWrapper => {
const uri = getUri(origin, path)
const middleware = {
applyMiddleware: async (options: any, next: any) => {
if (options.query) {
const authorization = await getAuthorization()
options.data = JSON.stringify({
query:
typeof options.query === 'string'
? options.query
: print(options.query),
variables: options.variables
})
options.extensions = {
authorization: {
authorization,
host
}
}
}
next()
}
}
let subscription: AWSSubscriptionClient | null = null
const setWsAuthorization = (authorization: string | null) => {
const url = new URL(uri)
if (authorization) url.searchParams.set('authorization', authorization)
if (subscription) subscription.setUrl(url.toString())
}
const connectionCallback = async (message: any) => {
if (message) {
const { errors } = message
if (errors && errors.length > 0) {
const error = errors[0]
if (error) {
if (error.errorCode === 401) {
if (subscription) {
const authorization = await getAuthorization()
setWsAuthorization(authorization)
// Re-apply middleware to operation options since it could have
// an invalid token embedded in the options.
for (const key in Object.keys(subscription.operations)) {
if (key) {
const val = subscription.operations[key]
if (val) {
val.options = await subscription.applyMiddlewares(
val.options
)
}
}
}
// Force close after a 401.
// This will auto-reconnect
// if reconnect = true on the client options.
subscription.close(false, false)
}
}
}
}
}
}
subscription = new AWSSubscriptionClient(uri, {
lazy: true,
reconnect: true,
timeout: 5 * 60 * 1000,
connectionCallback
})
const wsLink = new WebSocketLink(subscription)
// @ts-ignore
wsLink.subscriptionClient.use([middleware])
return {
wsLink,
setWsAuthorization
}
}
class AWSSubscriptionClient extends SubscriptionClient {
constructor(
url: string,
options?: ClientOptions,
webSocketImpl?: any,
webSocketProtocols?: string | string[]
) {
super(url, options, webSocketImpl, webSocketProtocols)
// Since we are in TS and these functions are private
// we cannot directly override in this child class,
// so we use this trick (which is not safe) to override
// the parent functions.
// @ts-ignore
this.flushUnsentMessagesQueue = this.flush
// @ts-ignore
this.processReceivedData = this.process
}
public setUrl(url: string) {
// @ts-ignore
this.url = url
}
// Filter out duplicate messages before flushing queue.
private flush() {
const messages = this.getUnsentMessagesQueue()
const map = new Map()
for (const message of messages) {
const id = message.id
if (!map.has(id)) {
map.set(id, true)
// @ts-ignore
super.sendMessageRaw(message)
}
}
this.setUnsentMessagesQueue([])
}
// Ignore start_ack message from AppSync:
// they are not a valid GQL message type.
private process(receivedData: string) {
let message: any = null
try {
message = JSON.parse(receivedData)
} catch (e) {
throw new Error(`Message must be JSON-parseable. Got: ${receivedData}`)
}
if (message.type === 'start_ack') {
const newQueue = this.getUnsentMessagesQueue().filter(
(el) => el.id !== message.id
)
this.setUnsentMessagesQueue(newQueue)
return
}
// @ts-ignore
super.processReceivedData(receivedData)
}
private getUnsentMessagesQueue(): any[] {
// @ts-ignore
return this.unsentMessagesQueue || []
}
private setUnsentMessagesQueue(queue: any[]): void {
// @ts-ignore
this.unsentMessagesQueue = queue
}
}
import { realtimeSearchParam } from './realtime-search-param.mjs'
const authorizationSearchParam = 'authorization'
export const handler = (event, context, callback) => {
const request = event.Records[0].cf.request
const searchParams = new URLSearchParams(request.querystring)
if (!isRealtimeReq(searchParams)) return callback(null, request)
const { domainName } = request.origin.custom
request.origin.custom.domainName = toRealtimeDomain(domainName)
request.headers.host = [{ key: 'Host', value: domainName }]
request.querystring = getQuerystring(request, searchParams)
return callback(null, request)
}
const isRealtimeReq = (searchParams) => {
const param = searchParams.get(realtimeSearchParam)
return param && param.toString() === 'true'
}
const toRealtimeDomain = (domainName) =>
domainName.replace('appsync-api', 'appsync-realtime-api')
const getQuerystring = (request, searchParams) => {
const hostHeader = request.headers.host
const authorization = searchParams.get(authorizationSearchParam)
const headerObj = {}
if (hostHeader && hostHeader[0]) headerObj.host = hostHeader[0].value
if (authorization) headerObj.authorization = authorization
const headerJson = JSON.stringify(headerObj)
const headerBase64 = Buffer.from(headerJson).toString('base64')
const payloadBase64 = Buffer.from('{}').toString('base64')
searchParams.set('header', headerBase64)
searchParams.set('payload', payloadBase64)
searchParams.delete(realtimeSearchParam)
searchParams.delete(authorizationSearchParam)
return searchParams.toString()
}
import { realtimeSearchParam } from './realtime-search-param.mjs'
const rootDomain = 'pureskill.gg'
export const handler = (event, context, callback) => {
const request = event.Records[0].cf.request
const originHeader = request.headers.origin
const method = request.method
const headers = getHeaders(originHeader, method)
const { response } = event.Records[0].cf
for (const k of Object.keys(corsHeaders())) delete response.headers[k]
for (const [k, v] of Object.entries(headers)) response.headers[k] = v
if (isInvalidOriginForRealtimeReq(request, originHeader)) {
response.status = 403
response.statusDescription = 'Forbidden'
}
return callback(null, response)
}
const getHeaders = (originHeader, method) => {
const hasOrigin = originHeader && originHeader[0]
if (!hasOrigin) return commonHeaders
const origin = originHeader[0].value
if (!isValidOrigin(origin)) return commonHeaders
return {
...commonHeaders,
...corsHeaders(origin),
...(method === 'OPTIONS' ? cacheHeaders : {})
}
}
const isInvalidOriginForRealtimeReq = (request, originHeader) => {
const searchParams = new URLSearchParams(request.querystring)
if (!isRealtimeReq(searchParams)) return false
const hasOrigin = originHeader && originHeader[0]
if (!hasOrigin) return true
const origin = originHeader[0].value
if (isValidOrigin(origin)) return false
return true
}
const isRealtimeReq = (searchParams) => {
const param = searchParams.get(realtimeSearchParam)
return param && param.toString() === 'true'
}
const isValidOrigin = (origin) => {
try {
const url = new URL(origin)
if (url.protocol !== 'https:') return false
if (url.hostname === rootDomain) return true
if (url.hostname.endsWith(`.${rootDomain}`)) return true
return false
} catch {
return false
}
}
const cacheHeaders = {
'cache-control': [
{
key: 'Cache-Control',
value: 'max-age=86400'
}
]
}
const corsHeaders = (origin) => ({
'access-control-allow-origin': [
{
key: 'Access-Control-Allow-Origin',
value: origin
}
],
'access-control-allow-credentials': [
{
key: 'Access-Control-Allow-Credentials',
value: 'true'
}
],
'access-control-allow-methods': [
{
key: 'Access-Control-Allow-Methods',
value: ['DELETE', 'GET', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT'].join(
', '
)
}
]
})
const commonHeaders = {
vary: [
{
key: 'Vary',
value: 'Origin'
}
],
'access-control-max-age': [
{
key: 'Access-Control-Max-Age',
value: '86400'
}
],
'referrer-policy': [
{
key: 'Referrer-Policy',
value: 'same-origin'
}
],
'strict-transport-security': [
{
key: 'Strict-Transport-Security',
value: 'max-age=63072000; includeSubdomains; preload'
}
],
'x-content-type-options': [
{
key: 'X-Content-Type-Options',
value: 'nosniff'
}
],
'x-frame-options': [
{
key: 'X-Frame-Options',
value: 'DENY'
}
],
'x-xss-protection': [
{
key: 'X-XSS-Protection',
value: '1; mode=block'
}
]
}
export const realtimeSearchParam = 'ws'
@razor-x
Copy link
Author

razor-x commented Jul 26, 2021

Relevant versions:

{
  "@apollo/client": "^3.3.12",
  "apollo-cache-persist": "^0.1.1",
  "graphql": "^15.0.0",
  "subscriptions-transport-ws": "^0.9.16"
}

The WebSocket integration is based off of https://gist.github.com/zachboyd/f5630736b0a5a9b627d61bfd25299c90 and the custom domain support is from awslabs/aws-mobile-appsync-sdk-js#517 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment