Created
November 10, 2024 09:58
-
-
Save ikupenov/26f3775821c05f17b6f8b7a037fb2c7a to your computer and use it in GitHub Desktop.
Drizzle RLS DB with Policies
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// schema/entities/base.ts | |
export const getBaseEntityProps = () => ({ | |
id: uuid("id") | |
.default(sql`gen_random_uuid()`) | |
.primaryKey(), | |
createdAt: timestamp("created_at", { precision: 3 }).notNull().defaultNow(), | |
}); | |
export const getOwnedBaseEntityProps = () => ({ | |
...getBaseEntityProps(), | |
ownerId: varchar("owner_id", { length: 256 }) | |
.notNull() | |
.default("user_00000000000000000000000000"), | |
}); | |
export const getOptionalOwnedBaseEntityProps = () => ({ | |
...getBaseEntityProps(), | |
ownerId: varchar("owner_id", { length: 256 }), | |
}); | |
export const getOrgOwnedBaseEntityProps = () => ({ | |
...getBaseEntityProps(), | |
organizationId: varchar("organization_id", { length: 256 }) | |
.notNull() | |
.default("org_00000000000000000000000000"), | |
}); | |
export const getOptionalOrgOwnedBaseEntityProps = () => ({ | |
...getBaseEntityProps(), | |
organizationId: varchar("organization_id", { length: 256 }), | |
}); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { | |
and, | |
type DBQueryConfig, | |
type SQLWrapper, | |
type Table, | |
} from "drizzle-orm"; | |
import { drizzle } from "drizzle-orm/postgres-js"; | |
import postgres, { type Sql } from "postgres"; | |
import { type AnyArgs, isNil } from "@sheetah/common"; | |
import { | |
type DbClient, | |
type DeleteArgs, | |
type DeleteFn, | |
type FindArgs, | |
type FindFn, | |
type FromArgs, | |
type FromFn, | |
type JoinArgs, | |
type JoinFn, | |
type RlsDbClient, | |
type SetArgs, | |
type SetFn, | |
type TransactionArgs, | |
type TransactionFn, | |
type UpdateArgs, | |
type ValuesArgs, | |
type ValuesFn, | |
type WhereArgs, | |
type WhereFn, | |
} from "./db-client.types"; | |
import { getPolicy, type Principal } from "./orm"; | |
import { schema } from "./schema"; | |
import { | |
type getOrgOwnedBaseEntityProps, | |
type getOwnedBaseEntityProps, | |
} from "./schema/entities"; | |
export const connectDb = (connectionString: string) => { | |
return postgres(connectionString); | |
}; | |
export const createDbClient = (client: Sql): DbClient => { | |
return drizzle(client, { schema }); | |
}; | |
export const createRlsDbClient = ( | |
client: Sql, | |
principal: Principal, | |
): RlsDbClient => { | |
const db = createDbClient(client); | |
const ownerIdColumn: keyof ReturnType<typeof getOwnedBaseEntityProps> = | |
"ownerId" as const; | |
const orgIdColumn: keyof ReturnType<typeof getOrgOwnedBaseEntityProps> = | |
"organizationId" as const; | |
interface InvokeContext { | |
path?: string[]; | |
fnPath?: { name: string; args: unknown[] }[]; | |
} | |
interface InterceptFn { | |
invoke: (...args: unknown[]) => unknown; | |
name: string; | |
args: unknown[]; | |
} | |
interface OverrideFn { | |
pattern: string | string[]; | |
action: () => unknown; | |
} | |
const intercept = (fn: InterceptFn, context: InvokeContext = {}) => { | |
const { path = [], fnPath = [] } = context; | |
const pathAsString = path.join("."); | |
const matchPath = (pattern: string) => { | |
return new RegExp( | |
`^${pattern.replace(/\./g, "\\.").replace(/\*/g, ".*")}$`, | |
).test(pathAsString); | |
}; | |
const overrides: OverrideFn[] = [ | |
{ | |
pattern: ["db.execute", "db.*.execute", "tx.execute", "tx.*.execute"], | |
action: () => { | |
throw new Error("'execute' in rls DB is not allowed"); | |
}, | |
}, | |
{ | |
pattern: [ | |
"db.query.findMany", | |
"db.query.*.findMany", | |
"db.query.findFirst", | |
"db.query.*.findFirst", | |
"tx.query.findMany", | |
"tx.query.*.findMany", | |
"tx.query.findFirst", | |
"tx.query.*.findFirst", | |
], | |
action: () => { | |
const findFn = fn.invoke as FindFn; | |
const findArgs = fn.args as FindArgs; | |
const tableIndex = path.findIndex((x) => x === "query") + 1; | |
const tableProperty = path[tableIndex]! as keyof typeof db.query; | |
// eslint-disable-next-line import/namespace | |
const table = schema[tableProperty]; | |
const policy = getPolicy(table); | |
if (policy) { | |
let [config] = findArgs; | |
if (config?.where) { | |
config = { | |
...config, | |
where: and( | |
policy(principal, { db, schema }), | |
config.where as SQLWrapper, | |
), | |
}; | |
} | |
if (!config?.where) { | |
config = { | |
...config, | |
where: policy(principal, { db, schema }), | |
}; | |
} | |
if (config.with) { | |
const processWithConfig = ( | |
// @ts-expect-error -- config is always defined at this point | |
withConfig: typeof config.with, | |
): DBQueryConfig["with"] => { | |
return ( | |
Object.keys(withConfig) as (keyof typeof withConfig)[] | |
).reduce<DBQueryConfig["with"]>((acc, key) => { | |
const value = withConfig[key] as | |
| true | |
| null | |
| DBQueryConfig<"many">; | |
if (value === true) { | |
return { | |
...acc, | |
[key]: { | |
where: (...args) => { | |
const [_, __, table] = args as unknown as [ | |
(typeof args)[0], | |
(typeof args)[1], | |
Table, | |
]; | |
const policy = getPolicy(table); | |
return policy | |
? policy(principal, { db, schema }) | |
: undefined; | |
}, | |
}, | |
}; | |
} | |
if (typeof value === "object" && value !== null) { | |
return { | |
...acc, | |
[key]: { | |
...value, | |
with: value.with | |
? processWithConfig(value.with) | |
: undefined, | |
where: (...args) => { | |
const [_, __, table] = args as unknown as [ | |
(typeof args)[0], | |
(typeof args)[1], | |
Table, | |
]; | |
const policy = getPolicy(table as unknown as Table); | |
return policy | |
? and( | |
policy(principal, { db, schema }), | |
typeof value.where === "function" | |
? value.where(...args) | |
: value.where, | |
) | |
: typeof value.where === "function" | |
? value.where(...args) | |
: value.where; | |
}, | |
}, | |
}; | |
} | |
return { ...acc, [key]: value }; | |
}, {}); | |
}; | |
config = { | |
...config, | |
with: processWithConfig(config.with), | |
}; | |
} | |
return findFn(...([config] as FindArgs)); | |
} | |
return findFn(...findArgs); | |
}, | |
}, | |
{ | |
pattern: ["db.*.from", "tx.*.from"], | |
action: () => { | |
const fromFn = fn.invoke as FromFn; | |
const fromArgs = fn.args as FromArgs; | |
const [table] = fromArgs as [Table]; | |
const policy = getPolicy(table); | |
if (policy) { | |
return fromFn(...fromArgs).where(policy(principal, { db, schema })); | |
} | |
return fromFn(...fromArgs); | |
}, | |
}, | |
{ | |
pattern: [ | |
"db.*.from.where", | |
"db.*.from.*.where", | |
"tx.*.from.where", | |
"tx.*.from.*.where", | |
], | |
action: () => { | |
const whereFn = fn.invoke as WhereFn; | |
const whereArgs = fn.args as WhereArgs; | |
const [table] = fnPath.findLast((x) => x.name === "from") | |
?.args as FromArgs as [Table]; | |
const policy = getPolicy(table); | |
if (policy) { | |
const [whereFilter] = whereArgs as [SQLWrapper]; | |
return whereFn(and(policy(principal, { db, schema }), whereFilter)); | |
} | |
return whereFn(...whereArgs); | |
}, | |
}, | |
{ | |
pattern: [ | |
"db.*.leftJoin", | |
"db.*.rightJoin", | |
"db.*.innerJoin", | |
"db.*.fullJoin", | |
"tx.*.leftJoin", | |
"tx.*.rightJoin", | |
"tx.*.innerJoin", | |
"tx.*.fullJoin", | |
], | |
action: () => { | |
const joinFn = fn.invoke as JoinFn; | |
const joinArgs = fn.args as JoinArgs; | |
const [table, joinOptions] = joinArgs as unknown as [ | |
Table, | |
SQLWrapper, | |
]; | |
const policy = getPolicy(table); | |
if (policy) { | |
return joinFn( | |
table, | |
and(policy(principal, { db, schema }), joinOptions), | |
); | |
} | |
return joinFn(...joinArgs); | |
}, | |
}, | |
{ | |
pattern: ["db.insert.values", "tx.insert.values"], | |
action: () => { | |
const valuesFn = fn.invoke as ValuesFn; | |
const valuesArgs = fn.args as ValuesArgs; | |
let [valuesToInsert] = valuesArgs; | |
if (!Array.isArray(valuesToInsert)) { | |
valuesToInsert = [valuesToInsert]; | |
} | |
// TODO: Extract that as an onInsert hook? | |
const valuesToInsertWithOwner = valuesToInsert.map((value) => ({ | |
...value, | |
...(isNil(principal.orgId) | |
? {} | |
: { [orgIdColumn]: principal.orgId }), | |
...(isNil(principal.userId) | |
? {} | |
: { [ownerIdColumn]: principal.userId }), | |
})); | |
return valuesFn(valuesToInsertWithOwner); | |
}, | |
}, | |
{ | |
pattern: ["db.update.set", "tx.update.set"], | |
action: () => { | |
const setFn = fn.invoke as SetFn; | |
const setArgs = fn.args as SetArgs; | |
const [table] = fnPath.findLast((x) => x.name === "update") | |
?.args as UpdateArgs as [Table]; | |
const policy = getPolicy(table); | |
if (policy) { | |
return setFn(...setArgs).where(policy(principal, { db, schema })); | |
} | |
return setFn(...setArgs); | |
}, | |
}, | |
{ | |
pattern: [ | |
"db.update.where", | |
"db.update.*.where", | |
"tx.update.where", | |
"tx.update.*.where", | |
], | |
action: () => { | |
const whereFn = fn.invoke as WhereFn; | |
const whereArgs = fn.args as WhereArgs; | |
const [table] = [...fnPath].reverse().find((x) => x.name === "update") | |
?.args as UpdateArgs as [Table]; | |
const policy = getPolicy(table); | |
if (policy) { | |
const [whereFilter] = whereArgs as [SQLWrapper]; | |
return whereFn(and(policy(principal, { db, schema }), whereFilter)); | |
} | |
return whereFn(...whereArgs); | |
}, | |
}, | |
{ | |
pattern: ["db.delete", "tx.delete"], | |
action: () => { | |
const deleteFn = fn.invoke as DeleteFn; | |
const deleteArgs = fn.args as DeleteArgs; | |
const [table] = deleteArgs as [Table]; | |
const policy = getPolicy(table); | |
if (policy) { | |
return deleteFn(...deleteArgs).where( | |
policy(principal, { db, schema }), | |
); | |
} | |
return deleteFn(...deleteArgs); | |
}, | |
}, | |
{ | |
pattern: [ | |
"db.delete.where", | |
"db.delete.*.where", | |
"tx.delete.where", | |
"tx.delete.*.where", | |
], | |
action: () => { | |
const whereFn = fn.invoke as WhereFn; | |
const whereArgs = fn.args as WhereArgs; | |
const [table] = fnPath.findLast((x) => x.name === "delete") | |
?.args as DeleteArgs as [Table]; | |
const policy = getPolicy(table); | |
if (policy) { | |
const [whereOptions] = whereArgs as [SQLWrapper]; | |
return whereFn( | |
and(policy(principal, { db, schema }), whereOptions), | |
); | |
} | |
return whereFn(...whereArgs); | |
}, | |
}, | |
{ | |
pattern: "db.transaction", | |
action: () => { | |
const transactionFn = fn.invoke as TransactionFn; | |
const transactionArgs = fn.args as TransactionArgs; | |
const [callback, ...restArgs] = transactionArgs; | |
const nextCallback: typeof callback = async (...args) => { | |
const [tx] = args; | |
const rlsTx = createInterceptProxy(tx, { path: ["tx"] }); | |
return callback(rlsTx); | |
}; | |
return transactionFn(nextCallback, ...restArgs); | |
}, | |
}, | |
]; | |
const fnOverride = overrides.find(({ pattern, action }) => { | |
if (Array.isArray(pattern) && pattern.some(matchPath)) { | |
return action; | |
} | |
if (typeof pattern === "string" && matchPath(pattern)) { | |
return action; | |
} | |
return null; | |
})?.action; | |
return fnOverride ? fnOverride() : fn.invoke(...fn.args); | |
}; | |
const createInterceptProxy = <T extends object>( | |
target: T, | |
context: InvokeContext = {}, | |
): T => { | |
const { path = [], fnPath = [] } = context; | |
return new Proxy<T>(target, { | |
get: (innerTarget, innerTargetProp, innerTargetReceiver) => { | |
const currentPath = path.concat(innerTargetProp.toString()); | |
const innerTargetPropValue = Reflect.get( | |
innerTarget, | |
innerTargetProp, | |
innerTargetReceiver, | |
); | |
if (typeof innerTargetPropValue === "function") { | |
return (...args: AnyArgs) => { | |
const currentFnPath = [ | |
...fnPath, | |
{ name: innerTargetProp.toString(), args }, | |
]; | |
const result = intercept( | |
{ | |
invoke: innerTargetPropValue.bind( | |
innerTarget, | |
) as InterceptFn["invoke"], | |
name: innerTargetProp.toString(), | |
args, | |
}, | |
{ path: currentPath, fnPath: currentFnPath }, | |
); | |
if ( | |
typeof result === "object" && | |
result !== null && | |
!Array.isArray(result) | |
) { | |
return createInterceptProxy(result, { | |
path: currentPath, | |
fnPath: currentFnPath, | |
}); | |
} | |
return result; | |
}; | |
} else if ( | |
typeof innerTargetPropValue === "object" && | |
innerTargetPropValue !== null && | |
!Array.isArray(innerTargetPropValue) | |
) { | |
// wrap nested objects in a proxy as well | |
return createInterceptProxy(innerTargetPropValue, { | |
path: currentPath, | |
fnPath, | |
}); | |
} | |
return innerTargetPropValue; | |
}, | |
}); | |
}; | |
return createInterceptProxy(db, { path: ["db"] }); | |
}; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { type drizzle } from "drizzle-orm/postgres-js"; | |
import { type schema } from "./schema"; | |
declare const db: ReturnType<typeof drizzle<typeof schema>>; | |
export type DbClient = typeof db; | |
export type DbSchema = typeof schema; | |
export type DbTableName = keyof DbSchema; | |
export type DbTable = DbSchema[DbTableName]; | |
export type RlsDbClient = Omit<DbClient, "execute">; | |
export type FindFn<K extends keyof typeof db.query = keyof typeof db.query> = ( | |
...args: | |
| Parameters<(typeof db.query)[K]["findFirst"]> | |
| Parameters<(typeof db.query)[K]["findMany"]> | |
) => | |
| ReturnType<(typeof db.query)[K]["findFirst"]> | |
| ReturnType<(typeof db.query)[K]["findMany"]>; | |
export type FindArgs<K extends keyof typeof db.query = keyof typeof db.query> = | |
Parameters<FindFn<K>>; | |
export type TransactionFn = typeof db.transaction; | |
export type TransactionArgs = Parameters<TransactionFn>; | |
export type SelectFn = typeof db.select; | |
export type SelectArgs = Parameters<SelectFn>; | |
export type FromFn = ReturnType<SelectFn>["from"]; | |
export type FromArgs = Parameters<FromFn>; | |
export type WhereFn = ReturnType<FromFn>["where"]; | |
export type WhereArgs = Parameters<WhereFn>; | |
export type JoinFn = ReturnType<FromFn>["leftJoin"]; | |
export type JoinArgs = Parameters<JoinFn>; | |
export type InsertFn = typeof db.insert; | |
export type InsertArgs = Parameters<InsertFn>; | |
export type ValuesFn = ReturnType<InsertFn>["values"]; | |
export type ValuesArgs = Parameters<ValuesFn>; | |
export type UpdateFn = typeof db.update; | |
export type UpdateArgs = Parameters<UpdateFn>; | |
export type SetFn = ReturnType<UpdateFn>["set"]; | |
export type SetArgs = Parameters<SetFn>; | |
export type DeleteFn = typeof db.delete; | |
export type DeleteArgs = Parameters<DeleteFn>; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// schema/entities/example-entity.ts | |
import { and, eq, isNotNull, or, sql } from "drizzle-orm"; | |
import { pgTable, text, uuid } from "drizzle-orm/pg-core"; | |
import { hasRole } from "@sheetah/common"; | |
import { policy } from "@sheetah/db/orm"; | |
import { | |
getOptionalOrgOwnedBaseEntityProps, | |
getOwnedBaseEntityProps, | |
} from "./base"; | |
export const entities = pgTable("entity", { | |
...getOwnedBaseEntityProps(), | |
...getOptionalOrgOwnedBaseEntityProps(), | |
description: text("description"), | |
transactionId: uuid("transaction_id").unique().notNull(), | |
categoryId: uuid("category_id"), | |
taxRateId: uuid("tax_rate_id"), | |
}); | |
policy(entities, ({ userId, orgId, role }) => { | |
return ( | |
or( | |
userId ? eq(entities.ownerId, userId) : sql`false`, | |
orgId && hasRole(role, ["org:admin"]) | |
? and( | |
isNotNull(entities.organizationId), | |
eq(entities.organizationId, orgId), | |
) | |
: sql`false`, | |
) ?? sql`false` | |
); | |
}); | |
export type Entity = typeof expenses.$inferSelect; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// orm/policy.ts | |
import { getTableName, isTable, type SQL, type Table } from "drizzle-orm"; | |
import { type AuthorizationRole, isString } from "@sheetah/common"; | |
import { type DbClient, type DbSchema } from "@sheetah/db/db-client.types"; | |
export interface Principal { | |
userId: string | null; | |
orgId: string | null; | |
role: AuthorizationRole | null; | |
} | |
export interface PolicyOptions { | |
db: DbClient; | |
schema: DbSchema; | |
} | |
// TODO: Add policies per method: select, insert, update, delete | |
export type Policy = (principal: Principal, options: PolicyOptions) => SQL; | |
const policies = new Map<string, Policy>(); | |
export const policy = (table: Table | string, fn: Policy) => { | |
if (isTable(table)) { | |
policies.set(getTableName(table), fn); | |
} | |
if (isString(table)) { | |
policies.set(table, fn); | |
} | |
}; | |
export const getPolicy = (table: Table | string) => { | |
if (isTable(table)) { | |
return policies.get(getTableName(table)); | |
} | |
if (isString(table)) { | |
return policies.get(table); | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You're welcome :).
AnyArgs
is justany[]
: