Skip to content

Instantly share code, notes, and snippets.

@lsmith77
Created February 20, 2026 12:35
Show Gist options
  • Select an option

  • Save lsmith77/7990945f35e0cad13c8ce66fabc1b0e4 to your computer and use it in GitHub Desktop.

Select an option

Save lsmith77/7990945f35e0cad13c8ce66fabc1b0e4 to your computer and use it in GitHub Desktop.
sqlalchemy plugin for Zenstack
import {
getAllFields,
isRelationshipField,
isComputedField,
getAttribute,
hasAttribute,
getStringLiteral,
getAllAttributes,
getAttributeArgLiteral,
getLiteral,
} from '@zenstackhq/language/utils';
import {
DataModelAttribute,
DataFieldAttribute,
} from '@zenstackhq/language/ast';
import {
Model,
DataModel,
Enum,
DataField,
ArrayExpr,
Expression,
} from '@zenstackhq/language/ast';
// Extracts items from a ZenStack array literal expression
function arrayLiteralItems(expr: ArrayExpr | Expression): Expression[] {
if (expr && Array.isArray((expr as ArrayExpr).items))
return (expr as ArrayExpr).items;
return [];
}
// Extracts field name from a reference expression item
function extractFieldName(item: Expression): string {
const refItem = item as ReferenceItem;
return refItem.$refText || refItem.target?.ref?.name || refItem.value || '';
}
// --- Types ---
interface PluginOptions {
output?: string;
}
interface PluginContext {
model: Model;
pluginOptions: PluginOptions;
}
interface ReferenceItem {
$refText?: string;
target?: { $refText?: string; ref?: { name?: string } };
ref?: { name?: string };
value?: string;
}
interface AttributeArgument {
name?: string;
value: Expression;
}
interface FieldTypeNode {
type?: string | { name?: string };
reference?: { ref?: { name?: string } };
name?: string;
optional?: boolean;
}
// --- Constants ---
const pythonKeywords = new Set([
'False',
'None',
'True',
'and',
'as',
'assert',
'async',
'await',
'break',
'class',
'continue',
'def',
'del',
'elif',
'else',
'except',
'finally',
'for',
'from',
'global',
'if',
'import',
'in',
'is',
'lambda',
'nonlocal',
'not',
'or',
'pass',
'raise',
'return',
'try',
'while',
'with',
'yield',
]);
// --- Type Mapping ---
const typeMap: Record<string, string> = {
String: 'String',
Int: 'Integer',
Float: 'Float',
Boolean: 'Boolean',
DateTime: 'DateTime',
BigInt: 'BigInteger',
Json: 'JSON',
};
// --- Helper Functions ---
function getFieldTypeName(field: DataField): string {
const type = field.type as FieldTypeNode | undefined;
if (!type) return '';
// Handle string type (builtin types like 'String', 'DateTime', etc.)
if (typeof type === 'string') {
return type;
}
// Handle nested type structure (type.type)
if (type.type) {
if (typeof type.type === 'string') {
return type.type;
}
if (type.type.name) {
return type.type.name;
}
}
// Handle reference type (e.g., custom models or enums)
if (type.reference?.ref?.name) {
return type.reference.ref.name;
}
// Handle type with a direct name property
if (type.name && typeof type.name === 'string') {
return type.name;
}
return '';
}
function getPyFieldName(field: { name: string }): string {
if (pythonKeywords.has(field.name)) {
return field.name + '_';
}
return field.name;
}
// Extracts map name from attribute
function extractMapName(
attr: DataFieldAttribute | DataModelAttribute | undefined,
): string | undefined {
return attr ? getAttributeArgLiteral<string>(attr, 'map') : undefined;
}
// Builds a constraint string with fields and optional name.
// Index(name, *columns) puts the name first (nameFirst=true).
// UniqueConstraint(*columns, name=name) uses a keyword argument (default).
function buildConstraintStr(
ConstraintType: string,
fields: string | string[],
mapName?: string,
nameFirst = false,
): string {
const fieldStr =
typeof fields === 'string'
? `'${fields}'`
: fields.map((f) => `'${f}'`).join(', ');
if (nameFirst && mapName) {
return `${ConstraintType}('${mapName}', ${fieldStr})`;
}
let str = `${ConstraintType}(${fieldStr}`;
if (mapName) str += `, name='${mapName}'`;
str += ')';
return str;
}
function buildSchemaData(ctx: PluginContext) {
const config = ctx.pluginOptions || {};
const schema = ctx.model;
if (!schema) {
throw new Error(
'No model/schema found in plugin context. Available keys: ' +
Object.keys(ctx).join(', '),
);
}
const decls = Array.isArray(schema.declarations) ? schema.declarations : [];
const enumNames = new Set(
decls.filter((d): d is Enum => d.$type === 'Enum').map((d) => d.name),
);
const models = decls.filter(
(decl): decl is DataModel => decl.$type === 'DataModel',
);
return { config, decls, enumNames, models };
}
type FieldMeta = {
name: string;
type: string;
isPK: boolean;
isUnique: boolean;
isFK: boolean;
isEnum: boolean;
isRelation: boolean;
relation?: {
target: string;
/** Schema-qualified table name of the target, e.g. "simap.institution" */
targetTableName?: string;
fkFields: string[];
refFields: string[];
backPopulates?: string;
fkNames?: string[];
onDelete?: string;
onUpdate?: string;
};
isOptional: boolean;
isComputed: boolean;
field: DataField;
// New fields for defaults and updates
default?: {
type: 'now' | 'autoincrement' | 'value';
value?: string;
};
isUpdatedAt?: boolean;
dbPrecision?: number;
};
type ModelMeta = {
name: string;
tableName: string;
schema?: string;
fields: FieldMeta[];
indexes: { fields: string[]; name?: string }[];
uniques: { fields: string[]; name?: string }[];
fieldUniques: { field: string; name?: string }[];
};
function collectModelMeta(model: DataModel, enumNames: Set<string>): ModelMeta {
// 1. Table name and schema
let tableName = model.name.toLowerCase();
let schema: string | undefined = undefined;
for (const attr of getAllAttributes(model)) {
if (attr.decl?.$refText === '@@map' && attr.args?.[0]) {
tableName = getStringLiteral(attr.args[0].value) || tableName;
}
if (attr.decl?.$refText === '@@schema' && attr.args?.[0]) {
schema = getStringLiteral(attr.args[0].value);
}
}
// 2. Indexes and uniques
const indexes: { fields: string[]; name?: string }[] = [];
const uniques: { fields: string[]; name?: string }[] = [];
const fieldUniques: { field: string; name?: string }[] = [];
for (const attr of getAllAttributes(model)) {
if (
(attr.decl?.$refText === '@@index' ||
attr.decl?.$refText === '@@unique') &&
attr.args?.[0]
) {
const fields = arrayLiteralItems(attr.args[0].value).map(
extractFieldName,
);
const mapName = extractMapName(attr);
const constraint = { fields, name: mapName };
if (attr.decl?.$refText === '@@index') {
indexes.push(constraint);
} else {
uniques.push(constraint);
}
}
}
// Enhanced relation emission: uselist and back_populates
const allFields = getAllFields(model, false, new Set());
// Field-level @unique
for (const field of allFields) {
const uniqueAttr = Array.isArray(field.attributes)
? field.attributes.find((attr) => attr.decl?.$refText === '@unique')
: undefined;
if (uniqueAttr) {
const mapName = extractMapName(uniqueAttr);
fieldUniques.push({ field: field.name, name: mapName });
}
}
const fields: FieldMeta[] = [];
for (const field of allFields) {
const type = getFieldTypeName(field);
const isEnum = enumNames.has(type);
const isRelation = isRelationshipField(field);
const isPK = hasAttribute(field, '@id');
const isUnique = hasAttribute(field, '@unique');
const isOptional = field.type?.optional === true;
const isComputed = isComputedField(field);
let isFK = false;
let relation: FieldMeta['relation'] = undefined;
let isUpdatedAt = false;
let defaultValue: FieldMeta['default'] = undefined;
let dbPrecision: number | undefined = undefined;
// Check for @updatedAt and @default
if (Array.isArray(field.attributes)) {
for (const attr of field.attributes) {
if (attr.decl?.$refText === '@updatedAt') {
isUpdatedAt = true;
}
if (attr.decl?.$refText === '@default') {
// Extracts default value/function from @default attribute
const defaultArg = attr.args?.[0];
if (defaultArg && defaultArg.value) {
const defaultExpr = defaultArg.value;
// Resolves function name from the default expression
let funcName: string | undefined;
const funcNode = defaultExpr as {
function?: { name?: string; ref?: { name?: string } };
};
const funcRef = funcNode.function;
if (funcRef?.name) {
funcName = funcRef.name;
} else if (funcRef?.ref?.name) {
funcName = funcRef.ref.name;
} else if (typeof funcRef === 'string') {
funcName = funcRef;
}
if (funcName === 'now') {
defaultValue = { type: 'now' };
} else if (funcName === 'autoincrement') {
defaultValue = { type: 'autoincrement' };
} else if (funcName) {
defaultValue = { type: 'value', value: funcName };
}
}
}
// Extracts precision value from @db.Timestamp attribute
if (attr.decl?.$refText === '@db.Timestamp' && attr.args?.[0]) {
// Extracts precision value from @db.Timestamp attribute using utility
const precValue = getLiteral<number>(attr.args[0].value);
if (typeof precValue === 'number') {
dbPrecision = precValue;
}
}
}
}
// Initializes relation metadata with target model reference
if (isRelation) {
relation = {
target: type,
fkFields: [],
refFields: [],
};
}
// Extracts foreign key and relationship configuration from @relation attributes
if (isRelation && Array.isArray(field.attributes)) {
for (const attr of field.attributes) {
if (attr.decl?.$refText === '@relation' && attr.args) {
let fkFields: string[] = [];
let refFields: string[] = [];
for (const arg of attr.args) {
if (arg.name === 'fields') {
const items = arrayLiteralItems(arg.value);
fkFields = items.map(extractFieldName);
}
if (arg.name === 'references') {
const items = arrayLiteralItems(arg.value);
refFields = items.map(extractFieldName);
}
}
if (fkFields.length > 0) {
isFK = true;
if (relation) {
relation.fkFields = fkFields;
relation.refFields = refFields;
}
}
// Extracts referential action rules and custom constraint names
let onDelete: string | undefined;
let onUpdate: string | undefined;
let fkNames: string[] = [];
for (const arg of attr.args as AttributeArgument[]) {
if (arg.name === 'onDelete') {
onDelete = getStringLiteral(arg.value);
}
if (arg.name === 'onUpdate') {
onUpdate = getStringLiteral(arg.value);
}
if (arg.name === 'map') {
fkNames = [getStringLiteral(arg.value) || ''];
}
}
if (relation) {
if (fkNames.length > 0) relation.fkNames = fkNames;
if (onDelete) relation.onDelete = onDelete;
if (onUpdate) relation.onUpdate = onUpdate;
}
}
}
}
fields.push({
name: field.name,
type,
isPK,
isUnique,
isFK,
isEnum,
isRelation,
relation,
isOptional,
isComputed,
field,
isUpdatedAt,
default: defaultValue,
dbPrecision,
});
}
return {
name: model.name,
tableName,
schema,
fields,
indexes,
uniques,
fieldUniques,
};
}
function emitEnums(decls: (DataModel | Enum)[]): string[] {
const lines: string[] = [];
for (const decl of decls) {
if (decl.$type === 'Enum') {
const enumDecl = decl as Enum;
lines.push('', `class ${enumDecl.name}(str, enum.Enum):`);
for (const opt of enumDecl.fields) {
let value = (opt as any).value ?? opt.name;
if (typeof value === 'object' && value !== null && 'value' in value) {
const valNode = value as { value?: string };
value = valNode.value || opt.name;
}
lines.push(` ${opt.name} = '${value}'`);
}
}
}
// Adds spacing after enum definitions
if (lines.length > 0) {
lines.push('');
}
return lines;
}
// Emits SQLAlchemy model class definition with table configuration
function emitModelMeta(meta: ModelMeta): string[] {
const lines: string[] = [];
lines.push(`class ${meta.name}(Base):`);
lines.push(` __tablename__ = '${meta.tableName}'`);
let tableArgs: string[] = [];
// Defines indexes for the table
// Index(name, *columns) — name must come first as a positional argument
tableArgs.push(
...meta.indexes.map((index) =>
buildConstraintStr('Index', index.fields, index.name, true),
),
);
// Defines model-level unique constraints
tableArgs.push(
...meta.uniques.map((unique) =>
buildConstraintStr('UniqueConstraint', unique.fields, unique.name),
),
);
// Defines unique constraints for individual fields
tableArgs.push(
...meta.fieldUniques.map((unique) =>
buildConstraintStr('UniqueConstraint', unique.field, unique.name),
),
);
// Assigns the database schema for the table
if (meta.schema) {
tableArgs.push(`{'schema': '${meta.schema}'}`);
}
if (tableArgs.length > 0) {
if (tableArgs.length === 1 && tableArgs[0].startsWith('{')) {
// Formats single constraint as tuple with trailing comma
lines.push(` __table_args__ = (${tableArgs[0]},)`);
} else {
// Formats multiple constraints as tuple
lines.push(` __table_args__ = (${tableArgs.join(', ')})`);
}
}
// Identifies FK fields that will be emitted as part of relation definitions
const fkFieldsInRelations = new Set<string>();
for (const field of meta.fields) {
if (
field.isRelation &&
field.relation &&
field.relation.fkFields.length > 0
) {
// Tracks all FK columns that belong to relation fields
for (const fkField of field.relation.fkFields) {
fkFieldsInRelations.add(fkField);
}
}
}
// Generates column and relationship definitions for all model fields
for (const field of meta.fields) {
if (field.isComputed) continue;
// Generates column definitions for scalar and enum fields
if (!field.isRelation) {
// Omits FK columns that are part of relation definitions
if (fkFieldsInRelations.has(field.name)) continue;
// Support @db.Text and @db.VarChar
let saType: string;
if (getAttribute(field.field, '@db.Text')) {
saType = 'Text';
} else if (getAttribute(field.field, '@db.VarChar')) {
saType = 'String';
} else if (field.type === 'DateTime') {
saType = field.dbPrecision
? `TIMESTAMP(precision=${field.dbPrecision}, timezone=False)`
: 'DateTime';
} else {
saType = field.isEnum
? `Enum(${field.type})`
: typeMap[field.type] || 'String';
}
const colArgs = [saType];
if (field.isPK) colArgs.push('primary_key=True');
if (
field.isUnique &&
!meta.fieldUniques.some((fu) => fu.field === field.name)
) {
colArgs.push('unique=True');
}
if (field.isOptional) colArgs.push('nullable=True');
// Applies server-side default values from @default attribute
if (field.default) {
if (field.default.type === 'now') {
colArgs.push('server_default=func.now()');
} else if (field.default.type === 'autoincrement') {
// Autoincrement is handled implicitly by primary key definition
}
}
// Applies auto-update on timestamp fields with @updatedAt
if (field.isUpdatedAt) {
colArgs.push('onupdate=func.now()');
}
// Preserves Python column name when it differs from database column name
const pyFieldName = getPyFieldName(field);
if (pyFieldName !== field.name) {
colArgs.push(`name='${field.name}'`);
}
lines.push(` ${pyFieldName} = Column(${colArgs.join(', ')})`);
} else {
// Generates relationship definitions with optional FK columns
if (field.relation && field.relation.fkFields.length > 0) {
for (let i = 0; i < field.relation.fkFields.length; i++) {
const fkField = field.relation.fkFields[i];
const refField = field.relation.refFields[i] || 'id';
const fkTable =
field.relation.targetTableName ??
field.relation.target.toLowerCase();
let fkStr = `'${fkTable}.${refField}'`;
let fkKwargs: string[] = [];
if (field.relation.fkNames && field.relation.fkNames[i]) {
fkKwargs.push(`name='${field.relation.fkNames[i]}'`);
}
if (field.relation.onDelete) {
fkKwargs.push(`ondelete='${field.relation.onDelete}'`);
}
if (field.relation.onUpdate) {
fkKwargs.push(`onupdate='${field.relation.onUpdate}'`);
}
if (fkKwargs.length > 0) {
fkStr += ', ' + fkKwargs.join(', ');
}
// FK columns that also serve as the primary key must carry primary_key=True
const fkFieldMeta = meta.fields.find((f) => f.name === fkField);
const fkColArgs = [`Integer`, `ForeignKey(${fkStr})`];
if (fkFieldMeta?.isPK) fkColArgs.push('primary_key=True');
lines.push(
` ${getPyFieldName({ name: fkField })} = Column(${fkColArgs.join(', ')})`,
);
}
}
// Configures relationship cardinality and bidirectional navigation
let relArgs: string[] = [];
if (field.field.type?.array) relArgs.push('uselist=True');
else relArgs.push('uselist=False');
let backPopulates = '';
if (field.relation && field.relation.backPopulates) {
backPopulates = `, back_populates='${field.relation.backPopulates}'`;
}
lines.push(
` ${getPyFieldName(field)} = relationship('${field.relation ? field.relation.target : field.type}'${relArgs.length ? ', ' + relArgs.join(', ') : ''}${backPopulates})`,
);
}
}
lines.push('');
return lines;
}
// Generates SQLAlchemy models from ZenStack schema
async function generate(ctx: PluginContext) {
const { config, decls, enumNames, models } = buildSchemaData(ctx);
const modelMetas = models.map((m) => collectModelMeta(m, enumNames));
// Establishes bidirectional relationship mappings and resolves schema-qualified
// FK table names for every relation field.
for (const meta of modelMetas) {
for (const field of meta.fields) {
if (field.isRelation && field.relation) {
const targetMeta = modelMetas.find(
(m) => m.name === field.relation!.target,
);
if (targetMeta) {
// Build the schema-qualified table reference used in ForeignKey(...)
field.relation.targetTableName = targetMeta.schema
? `${targetMeta.schema}.${targetMeta.tableName}`
: targetMeta.tableName;
// Identify the reciprocal field for back_populates
if (!field.relation.backPopulates) {
const backField = targetMeta.fields.find(
(f) =>
f.isRelation && f.relation && f.relation.target === meta.name,
);
if (backField) {
field.relation.backPopulates = getPyFieldName(backField);
}
}
}
}
}
}
// Determines if func import is needed for server-side defaults and updates
const needsFuncImport = modelMetas.some((meta) =>
meta.fields.some((f) => f.default || f.isUpdatedAt),
);
const codeLines: string[] = [];
const needsTimestampImport = modelMetas.some((meta) =>
meta.fields.some((f) => f.type === 'DateTime' && f.dbPrecision),
);
codeLines.push(
'import enum',
'from sqlalchemy import Column, Integer, String, Float, Boolean, DateTime, ForeignKey, Enum, JSON, BigInteger, Text, Index, UniqueConstraint' +
(needsFuncImport ? ', func' : ''),
...(needsTimestampImport
? ['from sqlalchemy.dialects.postgresql import TIMESTAMP']
: []),
'from sqlalchemy.orm import relationship',
'from sqlalchemy.ext.declarative import declarative_base',
'',
'Base = declarative_base()',
'',
);
const enumAndModelDecls = decls.filter(
(d): d is DataModel | Enum => d.$type === 'DataModel' || d.$type === 'Enum',
);
codeLines.push(...emitEnums(enumAndModelDecls));
for (const meta of modelMetas) {
codeLines.push(...emitModelMeta(meta));
}
const fs = await import('fs');
const path = await import('path');
const outputPath = config.output || 'models.py';
const outputDir = path.dirname(outputPath);
if (outputDir && outputDir !== '.') {
fs.mkdirSync(outputDir, { recursive: true });
}
fs.writeFileSync(outputPath, codeLines.join('\n'));
}
const plugin = {
name: 'SQLAlchemy',
generate,
};
export default plugin;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment