Last active
July 19, 2025 06:40
-
-
Save danielpowell4/a0c6a98e7de5f63dffb86ec05d34dab9 to your computer and use it in GitHub Desktop.
RLS Extension for Prisma with wrapping jest test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Prisma, PrismaClient } from '@prisma/client'; | |
import get from 'lodash.get'; | |
describe('prisma extensions', () => { | |
it('RLS extension sets app policy helpers', async () => { | |
// #region setup | |
const userId = 'user_123'; | |
const organizationId = 'org_456'; | |
const membershipRole = 'admin'; | |
const setRlsSql = Prisma.sql`SELECT set_rls_app_context(${userId}, ${organizationId}, ${membershipRole});`; | |
const selectRlsSql = Prisma.sql` | |
SELECT | |
current_setting('app.current_user_id', true) AS current_user_id, | |
current_setting('app.current_organization_id', true) AS current_organization_id, | |
current_setting('app.current_user_role', true) AS current_user_role, | |
current_user AS current_db_role; | |
`; | |
const expectedRes = [ | |
{ | |
current_db_role: 'app', | |
current_organization_id: organizationId, | |
current_user_id: userId, | |
current_user_role: membershipRole, | |
}, | |
]; | |
// #endregion | |
let rlsFromModelOperation; // defined as extension runs | |
let overrideCount = 0; | |
// #region define extension | |
const rlsExtension = Prisma.defineExtension((prismaExtension) => | |
prismaExtension.$extends({ | |
client: { | |
$transaction: (async (...txnParams: Parameters<typeof prismaExtension.$transaction>) => { | |
overrideCount++; | |
const [txnOps, options] = txnParams; | |
if (Array.isArray(txnOps)) { | |
// Inject as the first step of a batch operation | |
const [_rlsHelper, ...result] = await prismaExtension.$transaction( | |
[prismaExtension.$executeRaw(setRlsSql), ...txnOps], | |
options, | |
); | |
return result; | |
} | |
// Inject as the first step of an interactive transaction | |
return prismaExtension.$transaction(async (transaction) => { | |
await transaction.$executeRaw(setRlsSql); | |
return txnOps(transaction); | |
}, options); | |
}) as typeof prismaExtension.$transaction, | |
}, | |
name: 'test-rls-client', | |
query: { | |
$allModels: { | |
async $allOperations({ args, query, ...rest }) { | |
overrideCount++; | |
const existingTxn = get(rest, ['__internalParams', 'transaction']); | |
if (existingTxn) return query(args); // handled via 'client' override | |
const [_setRls, result, roleHelper] = await prismaExtension.$transaction([ | |
prismaExtension.$executeRaw(setRlsSql), | |
query(args), | |
prismaExtension.$queryRaw(selectRlsSql), | |
]); | |
rlsFromModelOperation = roleHelper; | |
return result; | |
}, | |
}, | |
async $allOperations({ args, model, query, ...rest }) { | |
if (model) return query(args); // no override needed... should be handled in $allModels above | |
// NOTE: inside a raw SQL operation like like '$queryRaw' or '$executeRaw' | |
const existingTxn = get(rest, ['__internalParams', 'transaction']); | |
if (existingTxn) return query(args); // already handled via client override | |
overrideCount++; | |
const [_setRls, result] = await prismaExtension.$transaction([ | |
prismaExtension.$executeRaw(setRlsSql), | |
query(args), | |
]); | |
return result; | |
}, | |
}, | |
}), | |
); | |
// #endregion | |
const testClient = new PrismaClient().$extends(rlsExtension); | |
await testClient.project.findFirst(); | |
expect(overrideCount).toEqual(1); | |
expect(rlsFromModelOperation).toMatchObject(expectedRes); | |
const interactiveOp = await testClient.$transaction(async (transaction) => { | |
// #region extra operations to ensure only set RLS 1x | |
await transaction.$queryRaw(selectRlsSql); | |
await transaction.project.findFirst(); | |
// #endregion | |
return transaction.$queryRaw(selectRlsSql); | |
}); | |
expect(overrideCount).toEqual(2); | |
expect(interactiveOp).toMatchObject(expectedRes); | |
const [batchOp] = await testClient.$transaction([testClient.$queryRaw(selectRlsSql)]); | |
expect(overrideCount).toEqual(3); | |
expect(batchOp).toMatchObject(expectedRes); | |
}); | |
}); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type BuildPrismaClientOptions = Pick<PrismaClientOptions, 'datasourceUrl' | 'datasources' | 'errorFormat'>; | |
type RLSAuthContext = { membershipRole?: string; organizationId?: string; userId?: string }; | |
const buildPlainPrismaClient = (options?: BuildPrismaClientOptions): PrismaClient => { | |
if (globalThis.globalPrisma || global.globalPrisma) { | |
return (globalThis.globalPrisma || global.globalPrisma)!; | |
} | |
return new PrismaClient({ log: LOG_LEVELS, ...options }) as PrismaClient; | |
}; | |
if (process.env.NODE_ENV !== 'production' || IS_PLAYWRIGHT) { | |
globalThis.globalPrisma = globalThis.globalPrisma ?? buildPlainPrismaClient(); | |
} | |
const getExtendedPrismaClient = (options?: BuildPrismaClientOptions) => | |
buildPlainPrismaClient(options).$extends({ name: 'general-client' }) as PrismaClient; | |
/** RLS extends client and must be last */ | |
export function getRlsPrismaClient( | |
ctx: RLSAuthContext | null, | |
options?: BuildPrismaClientOptions, | |
extensions?: ReturnType<typeof Prisma.defineExtension>[], | |
) { | |
// use jestPrisma if in test land https://github.com/Quramy/jest-prisma | |
if (typeof jestPrisma !== 'undefined' && Boolean(jestPrisma?.client)) { | |
return jestPrisma.client as PrismaClient; | |
} | |
let extendedPrisma = getExtendedPrismaClient(options); | |
if (extensions?.length) { | |
for (const extension of extensions) { | |
extendedPrisma = extendedPrisma.$extends(extension) as unknown as PrismaClient; | |
} | |
} | |
if (ctx == null) return extendedPrisma; | |
// #region define extension | |
const userId = ctx.userId ?? ''; | |
const organizationId = ctx.organizationId ?? ''; | |
const membershipRole = ctx.membershipRole ?? ''; | |
const setRlsSql = Prisma.sql`SELECT set_rls_app_context(${userId}, ${organizationId}, ${membershipRole});`; | |
/** NOTE: if making any adjustments, ensure test coverage in src/server/utils/prisma.test.ts */ | |
const rlsExtension = Prisma.defineExtension((prismaExtension) => | |
prismaExtension.$extends({ | |
client: { | |
$transaction: (async (...txnParams: Parameters<typeof prismaExtension.$transaction>) => { | |
const [txnOps, txnOptions] = txnParams; | |
if (Array.isArray(txnOps)) { | |
// Inject as the first step of a batch operation | |
const [_rlsHelper, ...result] = await prismaExtension.$transaction( | |
[prismaExtension.$executeRaw``, ...txnOps], | |
txnOptions, | |
); | |
return result; | |
} | |
// Inject as the first step of an interactive transaction | |
return prismaExtension.$transaction(async (transaction) => { | |
await transaction.$executeRaw(setRlsSql); | |
return txnOps(transaction); | |
}, txnOptions); | |
}) as typeof prismaExtension.$transaction, | |
}, | |
name: 'prisma-rls-client', | |
query: { | |
$allModels: { | |
async $allOperations({ args, query, ...rest }) { | |
const existingTxn = get(rest, ['__internalParams', 'transaction']); | |
if (existingTxn) return query(args); // handled via 'client' override | |
const [_setRls, result] = await prismaExtension.$transaction([ | |
prismaExtension.$executeRaw(setRlsSql), | |
query(args), | |
]); | |
return result; | |
}, | |
}, | |
async $allOperations({ args, model, query, ...rest }) { | |
if (model) return query(args); // no override needed... should be handled in $allModels above | |
// NOTE: inside a raw SQL operation like like '$queryRaw' or '$executeRaw' | |
const existingTxn = get(rest, ['__internalParams', 'transaction']); | |
if (existingTxn) return query(args); // already handled via client override | |
const [_setRls, result] = await prismaExtension.$transaction([ | |
prismaExtension.$executeRaw(setRlsSql), | |
query(args), | |
]); | |
return result; | |
}, | |
}, | |
}), | |
); | |
// #endregion | |
return extendedPrisma.$extends(rlsExtension) as unknown as PrismaClient; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- postgres function that sets app context for RLS policies | |
CREATE OR REPLACE FUNCTION set_rls_app_context( | |
p_user_id TEXT, | |
p_organization_id TEXT, | |
p_user_role TEXT | |
) RETURNS VOID AS $$ | |
BEGIN | |
PERFORM set_config('app.current_user_id', p_user_id, TRUE); | |
PERFORM set_config('app.current_organization_id', p_organization_id, TRUE); | |
PERFORM set_config('app.current_user_role', p_user_role, TRUE); | |
SET LOCAL ROLE app; | |
END; | |
$$ LANGUAGE plpgsql; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment