Last active
September 18, 2024 16:28
-
-
Save luisdelatorre012/91fc64876b92e17b4cc50531825a7936 to your computer and use it in GitHub Desktop.
json schema dynamic class creation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import logging | |
import os | |
from typing import Any, Dict, Optional, Type, List, Tuple | |
import jsonref # Library for resolving JSON schema references | |
from jsonschema import Draft7Validator | |
from sqlalchemy import ( | |
JSON, | |
Boolean, | |
Column, | |
DateTime, | |
Enum, | |
Float, | |
ForeignKey, | |
Integer, | |
String, | |
CheckConstraint, | |
) | |
from sqlalchemy.exc import InvalidRequestError | |
from sqlalchemy.inspection import inspect | |
from sqlalchemy.orm import ( | |
relationship, | |
registry, | |
) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class NamingConvention: | |
""" | |
Class to manage naming conventions for tables and columns. | |
""" | |
@staticmethod | |
def table_name(name: str) -> str: | |
return name # Keep original capitalization | |
@staticmethod | |
def column_name(name: str) -> str: | |
return name # Keep original capitalization | |
# Create a SQLAlchemy registry | |
mapper_registry = registry() | |
# Generate the base class from the registry | |
Base = mapper_registry.generate_base() | |
def load_schema(file_path: str) -> Dict[str, Any]: | |
""" | |
Load and validate the JSON schema from the given file path. | |
""" | |
if not os.path.exists(file_path): | |
logger.error(f"Schema file not found: {file_path}") | |
raise FileNotFoundError(f"Schema file not found: {file_path}") | |
with open(file_path, 'r') as f: | |
try: | |
schema = json.load(f) | |
# Validate the schema against the Draft7 meta-schema | |
Draft7Validator.check_schema(schema) | |
logger.info(f"Schema loaded and validated successfully from {file_path}") | |
resolved_schema = resolve_refs(schema) | |
return resolved_schema | |
except json.JSONDecodeError as e: | |
logger.error(f"Invalid JSON schema: {e}") | |
raise | |
except Exception as e: | |
logger.error(f"Schema validation error: {e}") | |
raise | |
def resolve_refs(schema: Dict[str, Any], base_uri: str = '') -> Dict[str, Any]: | |
""" | |
Resolve $ref references within the JSON schema using the jsonref library. | |
""" | |
try: | |
resolved_schema = jsonref.replace_refs(schema, base_uri=base_uri) | |
logger.info("Resolved all $ref references in the schema.") | |
return resolved_schema | |
except Exception as e: | |
logger.error(f"Error resolving references: {e}") | |
raise | |
def determine_column_type(name: str, details: Dict[str, Any]) -> Any: | |
""" | |
Determine the SQLAlchemy column type based on the field name and details. | |
""" | |
schema_type = details.get('type', 'string') | |
format_ = details.get('format', '') | |
content_media_type = details.get('contentMediaType', '') | |
type_mapping = { | |
'integer': Integer, | |
'number': Float, | |
'string': String, | |
'boolean': Boolean, | |
'object': JSON, | |
'array': JSON, | |
} | |
format_mapping = { | |
'date-time': DateTime, | |
'email': String(255), | |
'uuid': String(36), | |
} | |
# Heuristic based type determination based on field name | |
name_lower = name.lower() | |
if any(keyword in name_lower for keyword in ['weight', 'total', 'quantity', 'amount', 'price']): | |
return Float | |
elif 'time' in name_lower or name_lower.endswith('date') or schema_type == 'datetime': | |
return DateTime | |
if content_media_type == 'application/json': | |
return JSON | |
elif format_ in format_mapping: | |
return format_mapping[format_] | |
elif schema_type in type_mapping: | |
if schema_type == 'string': | |
# Assign a default length for string types without specific formats | |
return String(255) | |
return type_mapping[schema_type] | |
else: | |
logger.warning(f"Unknown type '{schema_type}' for field '{name}', defaulting to String(255)") | |
return String(255) # Default to String with length 255 | |
def create_column(name: str, details: Dict[str, Any], required: bool) -> Tuple[Optional[Column], List[CheckConstraint]]: | |
""" | |
Create a SQLAlchemy Column based on the field details. | |
""" | |
if details.get('type') == 'array' and 'items' in details: | |
# Arrays will be handled as separate tables | |
return None, [] | |
elif details.get('type') == 'object': | |
# Nested objects will be handled separately | |
return None, [] | |
else: | |
column_type = determine_column_type(name, details) | |
column_kwargs: Dict[str, Any] = {'nullable': not required} | |
check_constraints: List[CheckConstraint] = [] | |
# Handle default values | |
if 'default' in details: | |
column_kwargs['default'] = details['default'] | |
# Handle Enum types | |
if 'enum' in details: | |
column_type = Enum(*details['enum'], name=f"{name}_enum") | |
# Add constraints based on schema | |
if 'minimum' in details: | |
check_constraints.append(CheckConstraint(f"{name} >= {details['minimum']}", name=f"ck_{name}_min")) | |
if 'maximum' in details: | |
check_constraints.append(CheckConstraint(f"{name} <= {details['maximum']}", name=f"ck_{name}_max")) | |
if 'pattern' in details: | |
# Patterns can't be enforced at the database level; consider application-level validation | |
logger.warning(f"Pattern constraint for '{name}' cannot be enforced at the DB level") | |
# Define the column | |
try: | |
column = Column(column_type, **column_kwargs) | |
logger.debug(f"Created column '{name}' with type '{column_type}' and kwargs {column_kwargs}") | |
return column, check_constraints | |
except Exception as e: | |
logger.error(f"Error creating column '{name}': {e}") | |
raise | |
def handle_complex_types(details: Dict[str, Any], prop: str): | |
""" | |
Handle complex schema types like 'oneOf', 'anyOf', 'allOf'. | |
""" | |
for key in ['oneOf', 'anyOf', 'allOf']: | |
if key in details: | |
logger.warning(f"{key} encountered in property '{prop}'. This feature is not fully supported.") | |
# Implement handling logic as per requirements | |
def schema_model(schema: Dict[str, Any]) -> Tuple[Dict[str, Column], List[CheckConstraint]]: | |
""" | |
Function to generate class attributes and constraints based on the JSON schema. | |
""" | |
attrs = {} | |
constraints: List[CheckConstraint] = [] | |
properties = schema.get('properties', {}) | |
required_fields = schema.get('required', []) | |
for prop, details in properties.items(): | |
required = prop in required_fields | |
handle_complex_types(details, prop) | |
column, checks = create_column(prop, details, required) | |
if column is not None: | |
column_name = NamingConvention.column_name(prop) | |
# Avoid overwriting existing columns | |
if column_name in attrs: | |
column_name = f"{column_name}_field" | |
attrs[column_name] = column # Use Column directly | |
constraints.extend(checks) | |
return attrs, constraints | |
def create_models_from_schema(schema: Dict[str, Any], model_name: str = None, schema_name: str = None) -> Dict[str, Type[Base]]: | |
""" | |
Dynamically create SQLAlchemy models based on the JSON schema. | |
""" | |
models: Dict[str, Type[Base]] = {} | |
schema_type = schema.get('type') | |
if schema_type == 'object' and 'properties' in schema: | |
if model_name is None: | |
model_name = "TopLevelModel" | |
class_name = model_name | |
table_name = NamingConvention.table_name(model_name) | |
# Prepare attributes for the class | |
attrs = { | |
'__tablename__': table_name, | |
'id': Column(Integer, primary_key=True, autoincrement=True), | |
} | |
# Add columns and constraints from schema | |
column_attrs, constraints = schema_model(schema) | |
attrs.update(column_attrs) | |
# Dynamically create the class | |
DynamicModel = type(class_name, (Base,), attrs) | |
# Handle __table_args__ | |
table_args = [] | |
table_kwargs = {} | |
if constraints: | |
table_args.extend(constraints) | |
if schema_name: | |
table_kwargs['schema'] = schema_name | |
if table_kwargs: | |
setattr(DynamicModel, '__table_args__', (tuple(table_args), table_kwargs)) | |
elif table_args: | |
setattr(DynamicModel, '__table_args__', tuple(table_args)) | |
models[class_name] = DynamicModel | |
logger.info(f"Model '{class_name}' created with table name '{table_name}'") | |
# Handle nested objects and arrays | |
create_nested_models(class_name, schema, models, parent_model=DynamicModel, schema_name=schema_name) | |
elif 'properties' in schema: | |
# The schema has properties but not 'type': 'object' at the top level | |
for prop_name, prop_schema in schema['properties'].items(): | |
prop_type = prop_schema.get('type') | |
required_fields = schema.get('required', []) | |
required = prop_name in required_fields # Use 'required' variable here if needed | |
if prop_type == 'object': | |
sub_models = create_models_from_schema(prop_schema, model_name=prop_name.capitalize(), schema_name=schema_name) | |
models.update(sub_models) | |
else: | |
# Handle other types if necessary | |
pass | |
else: | |
# Handle other cases | |
pass | |
return models | |
def create_nested_models( | |
parent_prop: str, | |
schema: Dict[str, Any], | |
models: Dict[str, Type[Base]], | |
parent_model: Type[Base], | |
schema_name: str = None, | |
): | |
""" | |
Recursively create models for nested objects and arrays. | |
""" | |
schema_type = schema.get('type') | |
if schema_type == 'object': | |
properties = schema.get('properties', {}) | |
required_fields = schema.get('required', []) | |
for prop_name, prop_schema in properties.items(): | |
required = prop_name in required_fields | |
handle_complex_types(prop_schema, prop_name) | |
prop_type = prop_schema.get('type') | |
if prop_type == 'object': | |
# Create model for the nested object | |
class_name = f"{parent_prop}{prop_name.capitalize()}Model" | |
table_name = NamingConvention.table_name(f"{parent_prop}_{prop_name}") | |
# Prepare foreign key column name based on parent table | |
foreign_key_column = f"{parent_model.__tablename__}_id" | |
# Prepare foreign key target, including schema if specified | |
if schema_name: | |
foreign_key_target = f"{schema_name}.{parent_model.__tablename__}.id" | |
else: | |
foreign_key_target = f"{parent_model.__tablename__}.id" | |
# Prepare attributes for the nested class | |
attrs = { | |
'__tablename__': table_name, | |
'id': Column(Integer, primary_key=True, autoincrement=True), | |
foreign_key_column: Column( | |
ForeignKey(foreign_key_target), | |
nullable=not required # Set nullable based on required | |
), | |
} | |
# Add columns and constraints from nested schema | |
column_attrs, constraints = schema_model(prop_schema) | |
attrs.update(column_attrs) | |
# Dynamically create the nested class | |
NestedModel = type(class_name, (Base,), attrs) | |
# Handle __table_args__ | |
table_args = [] | |
table_kwargs = {} | |
if constraints: | |
table_args.extend(constraints) | |
if schema_name: | |
table_kwargs['schema'] = schema_name | |
if table_kwargs: | |
setattr(NestedModel, '__table_args__', (tuple(table_args), table_kwargs)) | |
elif table_args: | |
setattr(NestedModel, '__table_args__', tuple(table_args)) | |
models[f"{parent_prop}_{prop_name}"] = NestedModel | |
logger.info(f"Nested model '{class_name}' created with table name '{table_name}'") | |
# Add relationship to parent model | |
setattr( | |
parent_model, | |
NamingConvention.column_name(prop_name), | |
relationship( | |
NestedModel, | |
back_populates='parent', | |
uselist=False, | |
cascade='all, delete-orphan', | |
) | |
) | |
# Add back_populates in the NestedModel | |
setattr( | |
NestedModel, | |
'parent', | |
relationship( | |
parent_model, | |
back_populates=NamingConvention.column_name(prop_name), | |
) | |
) | |
# Recursively process the nested object | |
create_nested_models(f"{parent_prop}_{prop_name}", prop_schema, models, NestedModel, schema_name=schema_name) | |
elif prop_type == 'array': | |
# Handle array property | |
items_schema = prop_schema.get('items', {}) | |
items_type = items_schema.get('type') | |
class_name = f"{parent_prop}{prop_name.capitalize()}Model" | |
table_name = NamingConvention.table_name(f"{parent_prop}_{prop_name}") | |
# Prepare foreign key column name based on parent table | |
foreign_key_column = f"{parent_model.__tablename__}_id" | |
# Prepare foreign key target, including schema if specified | |
if schema_name: | |
foreign_key_target = f"{schema_name}.{parent_model.__tablename__}.id" | |
else: | |
foreign_key_target = f"{parent_model.__tablename__}.id" | |
# Prepare attributes for the array item class | |
attrs = { | |
'__tablename__': table_name, | |
'id': Column(Integer, primary_key=True, autoincrement=True), | |
foreign_key_column: Column( | |
ForeignKey(foreign_key_target), | |
nullable=False # For arrays, foreign key should not be nullable | |
), | |
} | |
constraints = [] | |
# Handle items of the array | |
if items_type == 'object': | |
column_attrs, item_constraints = schema_model(items_schema) | |
attrs.update(column_attrs) | |
constraints.extend(item_constraints) | |
elif items_type in ['string', 'number', 'integer', 'boolean']: | |
# Handle array of primitives | |
item_type = determine_column_type(prop_name, items_schema) | |
attrs['value'] = Column(item_type, nullable=False) | |
else: | |
logger.warning(f"Unhandled array item type '{items_type}' for field '{prop_name}'.") | |
# Dynamically create the array item class | |
ArrayModel = type(class_name, (Base,), attrs) | |
# Handle __table_args__ | |
table_args = [] | |
table_kwargs = {} | |
if constraints: | |
table_args.extend(constraints) | |
if schema_name: | |
table_kwargs['schema'] = schema_name | |
if table_kwargs: | |
setattr(ArrayModel, '__table_args__', (tuple(table_args), table_kwargs)) | |
elif table_args: | |
setattr(ArrayModel, '__table_args__', tuple(table_args)) | |
models[f"{parent_prop}_{prop_name}"] = ArrayModel | |
logger.info(f"Array item model '{class_name}' created with table name '{table_name}'") | |
# Add relationship to parent model | |
setattr( | |
parent_model, | |
NamingConvention.column_name(prop_name), | |
relationship( | |
ArrayModel, | |
back_populates='parent', | |
cascade='all, delete-orphan', | |
) | |
) | |
# Add back_populates in the ArrayModel | |
setattr( | |
ArrayModel, | |
'parent', | |
relationship( | |
parent_model, | |
back_populates=NamingConvention.column_name(prop_name), | |
) | |
) | |
# Recursively process the items if they are objects | |
if items_type == 'object': | |
create_nested_models(f"{parent_prop}_{prop_name}", items_schema, models, ArrayModel, schema_name=schema_name) | |
else: | |
# Primitive types are handled in schema_model | |
pass | |
elif schema_type == 'array': | |
# The schema itself is an array | |
# Similar adjustments as above | |
pass | |
else: | |
# Handle other types if necessary | |
pass | |
def format_column_type(column_type): | |
""" | |
Helper function to format SQLAlchemy column types as per user specifications. | |
""" | |
if isinstance(column_type, String): | |
return f"VARCHAR(length={column_type.length})" | |
elif isinstance(column_type, Integer): | |
return "INTEGER()" | |
elif isinstance(column_type, Float): | |
return "FLOAT()" | |
elif isinstance(column_type, DateTime): | |
return "DATETIME()" | |
elif isinstance(column_type, JSON): | |
return "JSON()" | |
elif isinstance(column_type, Boolean): | |
return "BOOLEAN()" | |
elif isinstance(column_type, Enum): | |
return f"ENUM({', '.join(column_type.enums)})" | |
else: | |
return str(column_type) | |
def print_model_info(models: Dict[str, Type[Base]]) -> None: | |
""" | |
Print information about the generated models. | |
""" | |
for name, model in models.items(): | |
print(f"Model: {model.__name__}") | |
print(f"Tablename: {model.__tablename__}") | |
# Print schema if available | |
schema = '' | |
if hasattr(model, '__table_args__') and model.__table_args__: | |
# __table_args__ can be a tuple of constraints and a dict of kwargs | |
for arg in model.__table_args__: | |
if isinstance(arg, dict): | |
schema = arg.get('schema', '') | |
if schema: | |
print(f"Schema: {schema}") | |
print("Columns:") | |
for column_name, column in inspect(model).columns.items(): | |
# Convert column type to string with desired formatting | |
column_type_str = format_column_type(column.type) | |
# Handle 'check' constraints if present | |
if column.constraints: | |
for constraint in column.constraints: | |
if isinstance(constraint, CheckConstraint): | |
check_str = constraint.sqltext | |
column_type_str += f", check=\"{check_str}\"" | |
print(f" {column_name}: {column_type_str}") | |
print("Relationships:") | |
for rel_name, rel in inspect(model).relationships.items(): | |
related_model = rel.mapper.class_.__name__ | |
print(f" {rel_name}: Relationship({related_model})") | |
print("-" * 40) | |
def generate_models_code(models: Dict[str, Type[Base]]) -> str: | |
""" | |
Generate Python code for SQLAlchemy models. | |
""" | |
lines = [ | |
"from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, ForeignKey, Enum, JSON, CheckConstraint", | |
"from sqlalchemy.orm import relationship, declarative_base", | |
"", | |
"Base = declarative_base()", | |
"", | |
] | |
for name, model in models.items(): | |
class_def = f"class {model.__name__}(Base):" | |
lines.append(class_def) | |
indent = " " | |
lines.append(f"{indent}__tablename__ = '{model.__tablename__}'") | |
# Handle __table_args__ if CheckConstraints or schema are present | |
table_args_list = [] | |
table_kwargs = {} | |
if hasattr(model, '__table_args__') and model.__table_args__: | |
for arg in model.__table_args__: | |
if isinstance(arg, CheckConstraint): | |
table_args_list.append(f"CheckConstraint('{arg.sqltext}', name='{arg.name}')") | |
elif isinstance(arg, dict): | |
table_kwargs.update(arg) | |
if table_args_list or table_kwargs: | |
args_str = ", ".join(table_args_list) | |
if table_kwargs: | |
kwargs_str = ", ".join(f"'{k}': '{v}'" for k, v in table_kwargs.items()) | |
if args_str: | |
lines.append(f"{indent}__table_args__ = ({args_str}, {{ {kwargs_str} }})") | |
else: | |
lines.append(f"{indent}__table_args__ = ({{ {kwargs_str} }},)") | |
else: | |
lines.append(f"{indent}__table_args__ = ({args_str},)") | |
# Generate columns | |
for column_name, column in inspect(model).columns.items(): | |
# Determine column type string | |
if isinstance(column.type, Enum): | |
enum_values = ', '.join(f"'{e}'" for e in column.type.enums) | |
column_type_str = f"Enum({enum_values}, name='{column.type.name}')" | |
elif isinstance(column.type, String): | |
column_type_str = f"String(length={column.type.length})" | |
elif isinstance(column.type, JSON): | |
column_type_str = "JSON()" | |
elif isinstance(column.type, DateTime): | |
column_type_str = "DateTime" | |
elif isinstance(column.type, Float): | |
column_type_str = "Float" | |
elif isinstance(column.type, Integer): | |
column_type_str = "Integer" | |
elif isinstance(column.type, Boolean): | |
column_type_str = "Boolean" | |
else: | |
column_type_str = "String" | |
# Handle primary key and nullable | |
pk = "primary_key=True" if column.primary_key else "" | |
nullable = "nullable=False" if not column.nullable else "nullable=True" | |
# Handle ForeignKey | |
fk = "" | |
if column.foreign_keys: | |
foreign_key = next(iter(column.foreign_keys)) | |
fk = f"ForeignKey('{foreign_key.target_fullname}')" | |
# Handle default | |
default = f"default={column.default.arg}" if column.default else "" | |
# Build list of arguments, skipping empty strings | |
args = [column_type_str] | |
if fk: | |
args.append(fk) | |
if pk: | |
args.append(pk) | |
if nullable: | |
args.append(nullable) | |
if default: | |
args.append(default) | |
# Join arguments with ', ' | |
args_joined = ", ".join(args) | |
# Combine all parts | |
column_line = f"{indent}{column_name} = Column({args_joined})" | |
lines.append(column_line) | |
# Handle relationships | |
for rel_name, rel in inspect(model).relationships.items(): | |
rel_line = f"{indent}{rel_name} = relationship('{rel.mapper.class_.__name__}', back_populates='{rel.back_populates}')" | |
lines.append(rel_line) | |
lines.append("") # Add an empty line after each class | |
return '\n'.join(lines) | |
def generate_models_py(models: Dict[str, Type[Base]], output_file: str = 'models.py') -> None: | |
""" | |
Generate a Python file with SQLAlchemy model definitions. | |
""" | |
code = generate_models_code(models) | |
with open(output_file, 'w') as f: | |
f.write(code) | |
logger.info(f"SQLAlchemy models have been written to '{output_file}'.") | |
def generate_models_from_schema_py(schema_file: str, output_file: str = 'models.py', model_name: str = None, schema_name: str = None) -> None: | |
""" | |
Load schema, create models, print model info, and generate a .py file with models. | |
""" | |
try: | |
schema = load_schema(schema_file) | |
models = create_models_from_schema(schema, model_name=model_name, schema_name=schema_name) | |
print_model_info(models) | |
generate_models_py(models, output_file) | |
except InvalidRequestError as e: | |
logger.error(f"SQLAlchemy InvalidRequestError: {e}") | |
except Exception as e: | |
logger.error(f"An unexpected error occurred: {e}") | |
def create_engine_and_create_tables(engine_url: str) -> Any: | |
""" | |
Create a SQLAlchemy engine and create all tables. | |
""" | |
from sqlalchemy import create_engine | |
engine = create_engine(engine_url) | |
Base.metadata.create_all(engine) | |
logger.info(f"Database tables created using engine: {engine_url}") | |
return engine | |
def main(): | |
""" | |
Main function to load the schema, create models, display their information, | |
and write the models to a .py file. | |
""" | |
schema_file = 'your_schema.json' # Replace with your schema file | |
output_file = 'models.py' | |
top_level_model_name = 'YourModelName' # Set your top-level model name here | |
schema_name = 'your_schema_name' # Set your desired schema name here | |
try: | |
generate_models_from_schema_py(schema_file, output_file, model_name=top_level_model_name, schema_name=schema_name) | |
# Optionally, create the database tables | |
# Uncomment and set your database URL | |
# engine_url = 'sqlite:///my_database.db' | |
# engine = create_engine_and_create_tables(engine_url) | |
except InvalidRequestError as e: | |
logger.error(f"SQLAlchemy InvalidRequestError: {e}") | |
except Exception as e: | |
logger.error(f"An unexpected error occurred: {e}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment