Created
January 25, 2023 21:55
-
-
Save kurtbrose/dc11fd9149f63159008f829d519ce517 to your computer and use it in GitHub Desktop.
Helper for converting type annotations into sqlalchemy columns
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This module is a helper for converting a dataclass-like annotation class into a sqlalchemy ORM. | |
""" | |
from dataclasses import dataclass | |
from datetime import datetime | |
import enum | |
import functools | |
import inspect | |
import re | |
import types | |
import typing | |
from sqlalchemy import Column, Integer, String, ForeignKey, Enum, Boolean, DateTime | |
from sqlalchemy.orm import relationship | |
_TYPE_MAP = { | |
int: Integer, | |
str: String, | |
bool: Boolean, | |
datetime: DateTime, | |
} | |
def _to_snake_case(camel_case): | |
"""CamelCase -> camel_case; SnakeCase -> snake_case""" | |
return re.sub(r'(?<!^)(?=[A-Z])', '_', camel_case).lower() | |
def _is_optional(annotation): | |
"""Check if an annotation was declared as Optional[type].""" | |
if typing.get_origin(annotation) in (typing.Union, types.UnionType): | |
args = typing.get_args(annotation) | |
if len(args) == 2 and args[-1] is None.__class__: | |
return True, args[0] | |
raise ValueError(f"unsupported type annotation: {annotation}") | |
return False, annotation | |
def _pytype2sqla_cols(name, type_) -> dict: | |
""" | |
Given a python type annotation, convert it to a dict to be added to a sqlalchemy orm class. | |
Does not attempt to support all possible type annotations, only a reasonable subset. | |
""" | |
is_optional, type_ = _is_optional(type_) | |
col_kwargs = dict(nullable=is_optional) | |
if type_ in _TYPE_MAP: | |
col_type = _TYPE_MAP[type_] | |
return {name: Column(col_type, **col_kwargs)} | |
if isinstance(type_, typing.ForwardRef): # sometimes the strings are wrapped in ForwardRef | |
type_ = type_.__forward_arg__ | |
if isinstance(type_, str): # string referencing a type | |
model_ref_name = type_ | |
table_ref_name = _to_snake_case(model_ref_name) + "s" | |
id_col = Column(Integer, ForeignKey(f"{table_ref_name}.id"), index=True, **col_kwargs) | |
rel = relationship(model_ref_name, foreign_keys=[id_col]) | |
return {name + "_id": id_col, name: rel} | |
if isinstance(type_, type) and issubclass(type_, enum.Enum): | |
return {name: Column(Enum(type_), **col_kwargs)} | |
raise ValueError(f"unsupported type annotation: {type_}") | |
def auto_orm(cls: type) -> type: | |
""" | |
Given a dataclass-like annotation, returns a sqlalchemy ORM model | |
""" | |
annotations = inspect.get_annotations(cls) | |
if not annotations: | |
raise ValueError(f"{cls} does not define any columns!") | |
body = dict( | |
__original_class__=cls, | |
__tablename__=_to_snake_case(cls.__name__) + "s", | |
id=Column(Integer, primary_key=True) | |
) | |
for name, type_ in annotations.items(): | |
body.update(_pytype2sqla_cols(name, type_)) | |
return type(cls.__name__, cls.__bases__, body) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment