Skip to content

Instantly share code, notes, and snippets.

@danielpowell4
Last active July 19, 2025 06:40
Show Gist options
  • Save danielpowell4/a0c6a98e7de5f63dffb86ec05d34dab9 to your computer and use it in GitHub Desktop.
Save danielpowell4/a0c6a98e7de5f63dffb86ec05d34dab9 to your computer and use it in GitHub Desktop.
RLS Extension for Prisma with wrapping jest test
import { Prisma, PrismaClient } from '@prisma/client';
import get from 'lodash.get';
describe('prisma extensions', () => {
it('RLS extension sets app policy helpers', async () => {
// #region setup
const userId = 'user_123';
const organizationId = 'org_456';
const membershipRole = 'admin';
const setRlsSql = Prisma.sql`SELECT set_rls_app_context(${userId}, ${organizationId}, ${membershipRole});`;
const selectRlsSql = Prisma.sql`
SELECT
current_setting('app.current_user_id', true) AS current_user_id,
current_setting('app.current_organization_id', true) AS current_organization_id,
current_setting('app.current_user_role', true) AS current_user_role,
current_user AS current_db_role;
`;
const expectedRes = [
{
current_db_role: 'app',
current_organization_id: organizationId,
current_user_id: userId,
current_user_role: membershipRole,
},
];
// #endregion
let rlsFromModelOperation; // defined as extension runs
let overrideCount = 0;
// #region define extension
const rlsExtension = Prisma.defineExtension((prismaExtension) =>
prismaExtension.$extends({
client: {
$transaction: (async (...txnParams: Parameters<typeof prismaExtension.$transaction>) => {
overrideCount++;
const [txnOps, options] = txnParams;
if (Array.isArray(txnOps)) {
// Inject as the first step of a batch operation
const [_rlsHelper, ...result] = await prismaExtension.$transaction(
[prismaExtension.$executeRaw(setRlsSql), ...txnOps],
options,
);
return result;
}
// Inject as the first step of an interactive transaction
return prismaExtension.$transaction(async (transaction) => {
await transaction.$executeRaw(setRlsSql);
return txnOps(transaction);
}, options);
}) as typeof prismaExtension.$transaction,
},
name: 'test-rls-client',
query: {
$allModels: {
async $allOperations({ args, query, ...rest }) {
overrideCount++;
const existingTxn = get(rest, ['__internalParams', 'transaction']);
if (existingTxn) return query(args); // handled via 'client' override
const [_setRls, result, roleHelper] = await prismaExtension.$transaction([
prismaExtension.$executeRaw(setRlsSql),
query(args),
prismaExtension.$queryRaw(selectRlsSql),
]);
rlsFromModelOperation = roleHelper;
return result;
},
},
async $allOperations({ args, model, query, ...rest }) {
if (model) return query(args); // no override needed... should be handled in $allModels above
// NOTE: inside a raw SQL operation like like '$queryRaw' or '$executeRaw'
const existingTxn = get(rest, ['__internalParams', 'transaction']);
if (existingTxn) return query(args); // already handled via client override
overrideCount++;
const [_setRls, result] = await prismaExtension.$transaction([
prismaExtension.$executeRaw(setRlsSql),
query(args),
]);
return result;
},
},
}),
);
// #endregion
const testClient = new PrismaClient().$extends(rlsExtension);
await testClient.project.findFirst();
expect(overrideCount).toEqual(1);
expect(rlsFromModelOperation).toMatchObject(expectedRes);
const interactiveOp = await testClient.$transaction(async (transaction) => {
// #region extra operations to ensure only set RLS 1x
await transaction.$queryRaw(selectRlsSql);
await transaction.project.findFirst();
// #endregion
return transaction.$queryRaw(selectRlsSql);
});
expect(overrideCount).toEqual(2);
expect(interactiveOp).toMatchObject(expectedRes);
const [batchOp] = await testClient.$transaction([testClient.$queryRaw(selectRlsSql)]);
expect(overrideCount).toEqual(3);
expect(batchOp).toMatchObject(expectedRes);
});
});
type BuildPrismaClientOptions = Pick<PrismaClientOptions, 'datasourceUrl' | 'datasources' | 'errorFormat'>;
type RLSAuthContext = { membershipRole?: string; organizationId?: string; userId?: string };
const buildPlainPrismaClient = (options?: BuildPrismaClientOptions): PrismaClient => {
if (globalThis.globalPrisma || global.globalPrisma) {
return (globalThis.globalPrisma || global.globalPrisma)!;
}
return new PrismaClient({ log: LOG_LEVELS, ...options }) as PrismaClient;
};
if (process.env.NODE_ENV !== 'production' || IS_PLAYWRIGHT) {
globalThis.globalPrisma = globalThis.globalPrisma ?? buildPlainPrismaClient();
}
const getExtendedPrismaClient = (options?: BuildPrismaClientOptions) =>
buildPlainPrismaClient(options).$extends({ name: 'general-client' }) as PrismaClient;
/** RLS extends client and must be last */
export function getRlsPrismaClient(
ctx: RLSAuthContext | null,
options?: BuildPrismaClientOptions,
extensions?: ReturnType<typeof Prisma.defineExtension>[],
) {
// use jestPrisma if in test land https://github.com/Quramy/jest-prisma
if (typeof jestPrisma !== 'undefined' && Boolean(jestPrisma?.client)) {
return jestPrisma.client as PrismaClient;
}
let extendedPrisma = getExtendedPrismaClient(options);
if (extensions?.length) {
for (const extension of extensions) {
extendedPrisma = extendedPrisma.$extends(extension) as unknown as PrismaClient;
}
}
if (ctx == null) return extendedPrisma;
// #region define extension
const userId = ctx.userId ?? '';
const organizationId = ctx.organizationId ?? '';
const membershipRole = ctx.membershipRole ?? '';
const setRlsSql = Prisma.sql`SELECT set_rls_app_context(${userId}, ${organizationId}, ${membershipRole});`;
/** NOTE: if making any adjustments, ensure test coverage in src/server/utils/prisma.test.ts */
const rlsExtension = Prisma.defineExtension((prismaExtension) =>
prismaExtension.$extends({
client: {
$transaction: (async (...txnParams: Parameters<typeof prismaExtension.$transaction>) => {
const [txnOps, txnOptions] = txnParams;
if (Array.isArray(txnOps)) {
// Inject as the first step of a batch operation
const [_rlsHelper, ...result] = await prismaExtension.$transaction(
[prismaExtension.$executeRaw``, ...txnOps],
txnOptions,
);
return result;
}
// Inject as the first step of an interactive transaction
return prismaExtension.$transaction(async (transaction) => {
await transaction.$executeRaw(setRlsSql);
return txnOps(transaction);
}, txnOptions);
}) as typeof prismaExtension.$transaction,
},
name: 'prisma-rls-client',
query: {
$allModels: {
async $allOperations({ args, query, ...rest }) {
const existingTxn = get(rest, ['__internalParams', 'transaction']);
if (existingTxn) return query(args); // handled via 'client' override
const [_setRls, result] = await prismaExtension.$transaction([
prismaExtension.$executeRaw(setRlsSql),
query(args),
]);
return result;
},
},
async $allOperations({ args, model, query, ...rest }) {
if (model) return query(args); // no override needed... should be handled in $allModels above
// NOTE: inside a raw SQL operation like like '$queryRaw' or '$executeRaw'
const existingTxn = get(rest, ['__internalParams', 'transaction']);
if (existingTxn) return query(args); // already handled via client override
const [_setRls, result] = await prismaExtension.$transaction([
prismaExtension.$executeRaw(setRlsSql),
query(args),
]);
return result;
},
},
}),
);
// #endregion
return extendedPrisma.$extends(rlsExtension) as unknown as PrismaClient;
}
-- postgres function that sets app context for RLS policies
CREATE OR REPLACE FUNCTION set_rls_app_context(
p_user_id TEXT,
p_organization_id TEXT,
p_user_role TEXT
) RETURNS VOID AS $$
BEGIN
PERFORM set_config('app.current_user_id', p_user_id, TRUE);
PERFORM set_config('app.current_organization_id', p_organization_id, TRUE);
PERFORM set_config('app.current_user_role', p_user_role, TRUE);
SET LOCAL ROLE app;
END;
$$ LANGUAGE plpgsql;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment