Last active
August 5, 2024 19:53
-
-
Save fnimick/9fce65218a15b8b0ef3eb4592c20a0e4 to your computer and use it in GitHub Desktop.
Prisma utilities to create RLS clients that support transactions (unlike the official recommendation / other lib which ignores them silently)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Necessary due to https://github.com/prisma/prisma/issues/20678 | |
// Modeled on prisma-extension-enable-supabase-row-level-security: | |
// https://github.com/dthyresson/prisma-extension-supabase-rls | |
// Modified for transaction support using two clients, one for non-transactional operations, one for | |
// transactional operations. inefficient, but seems to work. Please test yourself in your own | |
// environment! I am not responsible for any transactional behavior failures. | |
// I mostly used this by creating two clients and using a given one explicitly based on whether or | |
// not I needed to operate in a transaction. The auto-proxying was added just before we switched | |
// away from Prisma due to this and other issues (see: lack of native sql JOIN support) | |
// IMPORTANT NOTE: this will absolutely not work with nested transactions! | |
// ------------------------------------------------ | |
// USAGE | |
// prisma-client.ts | |
import { | |
useRowLevelSecurity, | |
useRowLevelSecurityForTransactions, | |
type ClaimsFunction, | |
} from "./rls-extension"; | |
const rlsDbUrlPrismaClient = new PrismaClient({ | |
datasources: { db: { url: RLS_DB_URL } }, | |
}); | |
export function createExtendedClient(claimsFn: ClaimsFunction, transactionClient: any) { | |
return rlsDbUrlPrismaClient.$extends(useRowLevelSecurity({ claimsFn }, transactionClient)); | |
} | |
export function createExtendedClientForTransaction(claimsFn: ClaimsFunction) { | |
return rlsDbUrlPrismaClient.$extends(useRowLevelSecurityForTransactions({ claimsFn })); | |
} | |
// somewhere in your request path - some-middleware.ts | |
import { createExtendedClient, createExtendedClientForTransaction } from "path-to-prisma-client"; | |
let claimsFn: ClaimsFunction = undefined; | |
const session = getSession(); | |
if (session) { | |
const parsedJwt = getJwtFromSession(session); | |
claimsFn = () => parsedJwt; | |
} | |
const rlsPrismaForTransactions = createExtendedClientForTransaction(claimsFn); | |
const rlsPrisma = createExtendedClient(claimsFn, rlsPrismaForTransactions); | |
req.rlsPrismaForTransactions = rlsPrismaForTransactions; | |
req.rlsPrisma = rlsPrisma; | |
// ------------------------------------------------ | |
// ACTUAL EXTENSION CODE | |
// rls-extension.ts | |
import { Prisma } from "@prisma/client"; | |
export type ClaimsFunction = undefined | (() => Record<string, unknown>); | |
export interface RowLevelSecurityOptions { | |
/** | |
* The client extension name | |
*/ | |
name?: string; | |
/** | |
* The name of the Postgres setting to use for the claims, `request.jwt.claims` by default. | |
* | |
* @default 'request.jwt.claims' | |
*/ | |
claimsSetting?: string; | |
/** | |
* A function that returns the JWT claims to use for the current request (i.e. for supabase, | |
* decoded from the supabase access token.) | |
* | |
* E.g.: | |
* | |
* { | |
* "aud": "authenticated", "exp": 1675711033, "sub": "00000000-0000-0000-0000-000000000000", | |
* "email": "[email protected]", "phone": "", "app_metadata": { "provider": "email", "providers": | |
* [ "email" | |
* ] | |
* }, | |
* "user_metadata": {}, "role": "authenticated", "aal": "aal1", "amr": [ | |
* { | |
* "method": "otp", | |
* "timestamp": 1675696651 | |
* } | |
* ], | |
* "session_id": "000000000000-0000-0000-0000-000000000000" | |
* } | |
*/ | |
claimsFn?: ClaimsFunction; | |
/** | |
* Log errors to the console. | |
* @default false | |
*/ | |
logging?: boolean; | |
} | |
const defaultRowLevelSecurityOptions = { | |
name: "useRowLevelSecurity", | |
claimsSetting: "request.jwt.claims", | |
claimsFn: undefined, | |
logging: false, | |
} as const satisfies RowLevelSecurityOptions; | |
const defaultRowLevelSecurityForTransactionOptions = { | |
...defaultRowLevelSecurityOptions, | |
name: "useRowLevelSecurityForTransaction", | |
} as const satisfies RowLevelSecurityOptions; | |
/** | |
* Creates a Prisma client that runs all operations in a transaction in order to set RLS variables | |
* in the transaction settings. If a transaction is detected, proxies the query to the provided | |
* transactionClient, which should be constructed with `useRowLevelSecurityForTransactions`. | |
*/ | |
export const useRowLevelSecurity = ( | |
options: RowLevelSecurityOptions = defaultRowLevelSecurityOptions, | |
transactionClient: any, | |
) => { | |
const name = options.name || defaultRowLevelSecurityOptions.name; | |
const claimsFn = options.claimsFn || defaultRowLevelSecurityOptions.claimsFn; | |
const claimsSetting = options.claimsSetting || defaultRowLevelSecurityOptions.claimsSetting; | |
return Prisma.defineExtension((client) => | |
client.$extends({ | |
name, | |
query: { | |
$allModels: { | |
async $allOperations({ args, query }) { | |
const claims = claimsFn ? JSON.stringify(claimsFn() || {}) : ""; | |
try { | |
const [, result] = await client.$transaction([ | |
client.$executeRaw`SELECT set_config(${claimsSetting}, ${claims}, TRUE)`, | |
query(args), | |
]); | |
return result; | |
} catch (e) { | |
// TODO: sentry | |
if (options.logging) console.error(e); | |
throw e; | |
} | |
}, | |
}, | |
}, | |
client: { | |
$transaction(...args: any[]) { | |
// NOTE: to disable auto-proxying, throw an error rather than calling the second client | |
return transactionClient.$transaction(...args); | |
// throw new Error("Transactions not supported on this client."); | |
}, | |
} as { $transaction: (typeof client)["$transaction"] }, | |
}), | |
); | |
}; | |
/** | |
* Creates a Prisma client that sets RLS variables in the transaction settings before running the | |
* provided transaction. Should *ONLY* be used for transactions. All other operations will be run | |
* without RLS variables. | |
*/ | |
export const useRowLevelSecurityForTransactions = ( | |
options: RowLevelSecurityOptions = defaultRowLevelSecurityForTransactionOptions, | |
) => { | |
const name = options.name || defaultRowLevelSecurityForTransactionOptions.name; | |
const claimsFn = options.claimsFn || defaultRowLevelSecurityForTransactionOptions.claimsFn; | |
const claimsSetting = | |
options.claimsSetting || defaultRowLevelSecurityForTransactionOptions.claimsSetting; | |
return Prisma.defineExtension((client) => | |
client.$extends({ | |
name: name, | |
client: { | |
// eslint-disable-next-line @typescript-eslint/no-explicit-any | |
$transaction(first: any, ...rest: any) { | |
const claims = claimsFn ? JSON.stringify(claimsFn() || {}) : ""; | |
if (Array.isArray(first)) { | |
// batch operation | |
return client.$transaction( | |
[client.$executeRaw`SELECT set_config(${claimsSetting}, ${claims}, TRUE)`, ...first], | |
...rest, | |
); | |
} | |
// function operation | |
return client.$transaction( | |
async (tx) => { | |
await tx.$executeRaw`SELECT set_config(${claimsSetting}, ${claims}, TRUE)`; | |
return first(tx); | |
}, | |
...rest, | |
); | |
}, | |
} as { $transaction: (typeof client)["$transaction"] }, | |
}), | |
); | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Please note that this is provided without warranty or support of any kind. This was my in-house attempt at fixing prisma/prisma#20678 .