Skip to content

Instantly share code, notes, and snippets.

@mostlylikeable
Last active October 5, 2023 23:01
Show Gist options
  • Save mostlylikeable/38d20313af3bbd1c6df97c0f51ce2b8f to your computer and use it in GitHub Desktop.
Save mostlylikeable/38d20313af3bbd1c6df97c0f51ce2b8f to your computer and use it in GitHub Desktop.
import { Prisma } from '@prisma/client'
/**
** USAGE **
const withTx = async (fn: () => Promise<void>)): Promise<void> => {
// get the tx handle
const tx = await prisma.$begin()
// if fn throws an error, the tx will be rolled back automatically
await fn()
// fn was successful so rollback the tx
await tx.$rollback()
}
*/
const logger = console
export type FlatTransactionClient = Prisma.TransactionClient & {
$commit: () => Promise<void>
$rollback: () => Promise<void>
}
type PrismaTxCallback = <R = unknown>(
tx: Prisma.TransactionClient
) => Promise<R>
/** Symbol / error prisma uses to indicate rollback */
const ROLLBACK = { [Symbol.for('prisma.client.extension.rollback')]: true }
/**
* Provides the ability to manage prisma transactions outside a callback.
*
* This is useful during testing, when you might want to begin the transaction before each test,
* and then rollback after the tests completes so that you can maintain a clean test data set.
*
* @see https://github.com/prisma/prisma-client-extensions/blob/main/callback-free-itx/script.ts
*/
const flatTransactionExt: Partial<Prisma.Extension> = {
name: 'flat-transaction',
client: {
async $begin(): Promise<FlatTransactionClient> {
const ctx = Prisma.getExtensionContext(this)
if (!ctx.$transaction || typeof ctx.$transaction !== 'function') {
throw new TypeError('Client does not support transactions')
}
// Used to coordinate $transaction callback and to capture inner tx client and expose it to
// returned proxy
let setTxClient: (client: Prisma.TransactionClient) => void
const txClientPromise = new Promise<Prisma.TransactionClient>(
(resolve) => {
setTxClient = (tx) => resolve(tx)
}
)
// Control functions used to expose inner $transaction callback block, so proxy can manage it
// via "flat" functions
let commit: FlatTransactionClient['$commit']
let rollback: FlatTransactionClient['$rollback']
// The promise returned here is returned by the proxy on commit/rollback, and the transaction
// will resolve once this promise is resolved
const txPromise = ctx.$transaction(async (tx) => {
// Provide the tx client to promise above, so proxy below can access it.
setTxClient(tx)
// This promise will cause the $transaction to stay open until proxy below calls either
// commit or rollback.
return new Promise((resolve, reject) => {
commit = async () => {
logger.info('$commit: resolving tx promise')
resolve(undefined)
}
rollback = async () => {
logger.info('$rollback: rolling back tx promise')
reject(ROLLBACK)
}
})
})
// Handles execution of tx action and error handling
const actionHandler = (
action: () => Promise<void>
): (() => Promise<void>) => {
return async () => {
logger.info('actionHandler: executing')
await action()
await txPromise.catch((err: unknown) => {
logger.info('actionHandler: checking tx error')
// Don't rethrow rollback error, because we shouldn't need to propagate
// NOTE: This can't be caught by promise in $transaction, because we need the ROLLBACK
// error to trigger the transaction to rollback.
if (err !== ROLLBACK) throw err
logger.info('actionHandler: tx rolled back')
})
}
}
// Proxy to transaction client and add support for commit/rollback
const txClient: Prisma.TransactionClient = await txClientPromise
return new Proxy(txClient as FlatTransactionClient, {
get(target, prop) {
switch (prop) {
case '$commit':
logger.info('$commit: called')
return actionHandler(commit)
case '$rollback':
logger.info('$rollback: called')
return actionHandler(rollback)
case '$transaction': {
logger.info('$transaction: called')
// Prisma doesn't support nested transactions, so calling `$transaction` on the
// TransactionClient will fail. So, we just return a dummy function that will await
// all the promises passed to `$transaction` instead.
return async (
promisesOrFn: Promise<unknown>[] | PrismaTxCallback
) => {
if (typeof promisesOrFn === 'function') {
// Calling code passed a callback so give them the tx client
return await promisesOrFn(txClient)
}
// Calling code provided N promises to await in the tx
return Promise.all(promisesOrFn)
.then(() =>
logger.info(
`$transaction: resolved ${promisesOrFn.length} promises`
)
)
.catch((err) =>
logger.error('$transaction: error resolving promises', err)
)
}
}
default:
return target[prop as keyof typeof target]
}
},
})
},
},
}
export default flatTransactionExt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment