Skip to content

Instantly share code, notes, and snippets.

@fnimick
Last active August 5, 2024 19:53
Show Gist options
  • Save fnimick/9fce65218a15b8b0ef3eb4592c20a0e4 to your computer and use it in GitHub Desktop.
Save fnimick/9fce65218a15b8b0ef3eb4592c20a0e4 to your computer and use it in GitHub Desktop.
Prisma utilities to create RLS clients that support transactions (unlike the official recommendation / other lib which ignores them silently)
// Necessary due to https://github.com/prisma/prisma/issues/20678
// Modeled on prisma-extension-enable-supabase-row-level-security:
// https://github.com/dthyresson/prisma-extension-supabase-rls
// Modified for transaction support using two clients, one for non-transactional operations, one for
// transactional operations. inefficient, but seems to work. Please test yourself in your own
// environment! I am not responsible for any transactional behavior failures.
// I mostly used this by creating two clients and using a given one explicitly based on whether or
// not I needed to operate in a transaction. The auto-proxying was added just before we switched
// away from Prisma due to this and other issues (see: lack of native sql JOIN support)
// IMPORTANT NOTE: this will absolutely not work with nested transactions!
// ------------------------------------------------
// USAGE
// prisma-client.ts
import {
useRowLevelSecurity,
useRowLevelSecurityForTransactions,
type ClaimsFunction,
} from "./rls-extension";
const rlsDbUrlPrismaClient = new PrismaClient({
datasources: { db: { url: RLS_DB_URL } },
});
export function createExtendedClient(claimsFn: ClaimsFunction, transactionClient: any) {
return rlsDbUrlPrismaClient.$extends(useRowLevelSecurity({ claimsFn }, transactionClient));
}
export function createExtendedClientForTransaction(claimsFn: ClaimsFunction) {
return rlsDbUrlPrismaClient.$extends(useRowLevelSecurityForTransactions({ claimsFn }));
}
// somewhere in your request path - some-middleware.ts
import { createExtendedClient, createExtendedClientForTransaction } from "path-to-prisma-client";
let claimsFn: ClaimsFunction = undefined;
const session = getSession();
if (session) {
const parsedJwt = getJwtFromSession(session);
claimsFn = () => parsedJwt;
}
const rlsPrismaForTransactions = createExtendedClientForTransaction(claimsFn);
const rlsPrisma = createExtendedClient(claimsFn, rlsPrismaForTransactions);
req.rlsPrismaForTransactions = rlsPrismaForTransactions;
req.rlsPrisma = rlsPrisma;
// ------------------------------------------------
// ACTUAL EXTENSION CODE
// rls-extension.ts
import { Prisma } from "@prisma/client";
export type ClaimsFunction = undefined | (() => Record<string, unknown>);
export interface RowLevelSecurityOptions {
/**
* The client extension name
*/
name?: string;
/**
* The name of the Postgres setting to use for the claims, `request.jwt.claims` by default.
*
* @default 'request.jwt.claims'
*/
claimsSetting?: string;
/**
* A function that returns the JWT claims to use for the current request (i.e. for supabase,
* decoded from the supabase access token.)
*
* E.g.:
*
* {
* "aud": "authenticated", "exp": 1675711033, "sub": "00000000-0000-0000-0000-000000000000",
* "email": "[email protected]", "phone": "", "app_metadata": { "provider": "email", "providers":
* [ "email"
* ]
* },
* "user_metadata": {}, "role": "authenticated", "aal": "aal1", "amr": [
* {
* "method": "otp",
* "timestamp": 1675696651
* }
* ],
* "session_id": "000000000000-0000-0000-0000-000000000000"
* }
*/
claimsFn?: ClaimsFunction;
/**
* Log errors to the console.
* @default false
*/
logging?: boolean;
}
const defaultRowLevelSecurityOptions = {
name: "useRowLevelSecurity",
claimsSetting: "request.jwt.claims",
claimsFn: undefined,
logging: false,
} as const satisfies RowLevelSecurityOptions;
const defaultRowLevelSecurityForTransactionOptions = {
...defaultRowLevelSecurityOptions,
name: "useRowLevelSecurityForTransaction",
} as const satisfies RowLevelSecurityOptions;
/**
* Creates a Prisma client that runs all operations in a transaction in order to set RLS variables
* in the transaction settings. If a transaction is detected, proxies the query to the provided
* transactionClient, which should be constructed with `useRowLevelSecurityForTransactions`.
*/
export const useRowLevelSecurity = (
options: RowLevelSecurityOptions = defaultRowLevelSecurityOptions,
transactionClient: any,
) => {
const name = options.name || defaultRowLevelSecurityOptions.name;
const claimsFn = options.claimsFn || defaultRowLevelSecurityOptions.claimsFn;
const claimsSetting = options.claimsSetting || defaultRowLevelSecurityOptions.claimsSetting;
return Prisma.defineExtension((client) =>
client.$extends({
name,
query: {
$allModels: {
async $allOperations({ args, query }) {
const claims = claimsFn ? JSON.stringify(claimsFn() || {}) : "";
try {
const [, result] = await client.$transaction([
client.$executeRaw`SELECT set_config(${claimsSetting}, ${claims}, TRUE)`,
query(args),
]);
return result;
} catch (e) {
// TODO: sentry
if (options.logging) console.error(e);
throw e;
}
},
},
},
client: {
$transaction(...args: any[]) {
// NOTE: to disable auto-proxying, throw an error rather than calling the second client
return transactionClient.$transaction(...args);
// throw new Error("Transactions not supported on this client.");
},
} as { $transaction: (typeof client)["$transaction"] },
}),
);
};
/**
* Creates a Prisma client that sets RLS variables in the transaction settings before running the
* provided transaction. Should *ONLY* be used for transactions. All other operations will be run
* without RLS variables.
*/
export const useRowLevelSecurityForTransactions = (
options: RowLevelSecurityOptions = defaultRowLevelSecurityForTransactionOptions,
) => {
const name = options.name || defaultRowLevelSecurityForTransactionOptions.name;
const claimsFn = options.claimsFn || defaultRowLevelSecurityForTransactionOptions.claimsFn;
const claimsSetting =
options.claimsSetting || defaultRowLevelSecurityForTransactionOptions.claimsSetting;
return Prisma.defineExtension((client) =>
client.$extends({
name: name,
client: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
$transaction(first: any, ...rest: any) {
const claims = claimsFn ? JSON.stringify(claimsFn() || {}) : "";
if (Array.isArray(first)) {
// batch operation
return client.$transaction(
[client.$executeRaw`SELECT set_config(${claimsSetting}, ${claims}, TRUE)`, ...first],
...rest,
);
}
// function operation
return client.$transaction(
async (tx) => {
await tx.$executeRaw`SELECT set_config(${claimsSetting}, ${claims}, TRUE)`;
return first(tx);
},
...rest,
);
},
} as { $transaction: (typeof client)["$transaction"] },
}),
);
};
@fnimick
Copy link
Author

fnimick commented Aug 5, 2024

Please note that this is provided without warranty or support of any kind. This was my in-house attempt at fixing prisma/prisma#20678 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment