Skip to content

Instantly share code, notes, and snippets.

@imankulov
Last active March 10, 2025 10:05
Show Gist options
  • Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
Using pydantic models as SQLAlchemy JSON fields (convert beween JSON and pydantic.BaseModel subclasses)
#!/usr/bin/env ipython -i
import datetime
import json
from typing import Optional
import sqlalchemy as sa
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB
from pydantic import BaseModel, Field, parse_obj_as
from pydantic.json import pydantic_encoder
# --------------------------------------------------------------------------------------
# Define pydantic-alchemy specific types (once per application)
# --------------------------------------------------------------------------------------
class PydanticType(sa.types.TypeDecorator):
"""Pydantic type.
SAVING:
- Uses SQLAlchemy JSON type under the hood.
- Acceps the pydantic model and converts it to a dict on save.
- SQLAlchemy engine JSON-encodes the dict to a string.
RETRIEVING:
- Pulls the string from the database.
- SQLAlchemy engine JSON-decodes the string to a dict.
- Uses the dict to create a pydantic model.
"""
# If you work with PostgreSQL, you can consider using
# sqlalchemy.dialects.postgresql.JSONB instead of a
# generic sa.types.JSON
#
# Ref: https://www.postgresql.org/docs/13/datatype-json.html
impl = sa.types.JSON
def __init__(self, pydantic_type):
super().__init__()
self.pydantic_type = pydantic_type
def load_dialect_impl(self, dialect):
# Use JSONB for PostgreSQL and JSON for other databases.
if dialect.name == "postgresql":
return dialect.type_descriptor(JSONB())
else:
return dialect.type_descriptor(sa.JSON())
def process_bind_param(self, value, dialect):
return value.dict() if value else None
# If you use FasAPI, you can replace the line above with their jsonable_encoder().
# E.g.,
# from fastapi.encoders import jsonable_encoder
# return jsonable_encoder(value) if value else None
def process_result_value(self, value, dialect):
return parse_obj_as(self.pydantic_type, value) if value else None
def json_serializer(*args, **kwargs) -> str:
return json.dumps(*args, default=pydantic_encoder, **kwargs)
# --------------------------------------------------------------------------------------
# Configure SQLAlchemy engine, session and declarative base (once per application)
# The key is to define json_serializer while creating the engine.
# --------------------------------------------------------------------------------------
engine = sa.create_engine("sqlite:///:memory:", json_serializer=json_serializer)
Session = sessionmaker(bind=engine, expire_on_commit=False, future=True)
Base = declarative_base()
# --------------------------------------------------------------------------------------
# Define your Pydantic and SQLAlchemy models (as many as needed)
# --------------------------------------------------------------------------------------
class UserSettings(BaseModel):
notify_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
class User(Base):
__tablename__ = "users"
id: int = sa.Column(sa.Integer, primary_key=True)
name: str = sa.Column(sa.String, doc="User name", comment="User name")
settings: Optional[UserSettings] = sa.Column(PydanticType(UserSettings), nullable=True)
# --------------------------------------------------------------------------------------
# Create tables (once per application)
# --------------------------------------------------------------------------------------
Base.metadata.create_all(engine)
# --------------------------------------------------------------------------------------
# Usage example (we use 2.0 querying style with selects)
# Ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#querying-2-0-style
# --------------------------------------------------------------------------------------
session = Session()
user = User(name="user", settings=UserSettings())
session.add(user)
session.commit()
same_user = session.execute(sa.select(User)).scalars().first()
@etoiledemer
Copy link

is it compatible with Alembic autogeneration of schema?

for me, alembic generated

sa.Column('data', some.path.to.PydanticType(), nullable=True)

without the UserSettings stuff

@imankulov
Copy link
Author

Hey,@etoiledemer

You're right. It's not going to work out of the box with Alembic. For myself, I modify the migration script manually after creating it.

As you asked this question, I decided to see if there's a better solution and came up with the following snippet. Here I replace pydantic types with native SQLAlchemy JSON.

# file: env.py

def render_item(type_, obj, autogen_context):
    """Apply custom rendering for PydanticType."""
    if type_ == "type" and isinstance(obj, PydanticType):
        return "sa.JSON()"
    return False

...
context.configure(
    ...
    render_item=render_item,
)

This will generate

    ...
    sa.Column('settings', sa.JSON(), nullable=True),
    ...

Ref: Affecting the Rendering of Types Themselves from the Alembic documentation.

@wonderbeyond
Copy link

I've made a Python package SQLAlchemy-Nested-Mutable with inspiration from this work.

@imankulov
Copy link
Author

@wonderbeyond, wow, this looks great! Thanks.

@Filimoa
Copy link

Filimoa commented Jan 26, 2024

To avoid a deprecation warning as of Pydantic 2.5 you will need to use the following

def process_result_value(self, value, dialect):
    return self.pydantic_type(**value) if value else None

@DeoLeung
Copy link

got a problem

if I doselect(table.c.settings['a']), it still try to parse the a value as the whole pydantic model :(

@a1d4r
Copy link

a1d4r commented Jun 21, 2024

Is it possible for SQLAlchemy to detect changes in PydanticType field? When I change the field of the pydantic model, I have to manually call function flag_modified to make SQLAlchemy flush the change.

@Seluj78
Copy link

Seluj78 commented Mar 3, 2025

@a1d4r I have the same problem on my end. Did you end up finding a solution ?

@a1d4r
Copy link

a1d4r commented Mar 3, 2025

@a1d4r I have the same problem on my end. Did you end up finding a solution ?

Nope, I call flag_modified every time I change the model. For example:

from datetime import UTC, datetime
from sqlalchemy.orm.attributes import flag_modified

user.settings.notify_at = datetime.now(UTC)
flag_modified(user.settings, "notify_at")

@Seluj78
Copy link

Seluj78 commented Mar 3, 2025

@a1d4r Alright, thank you for the quick response. I am working on something that might automate this if I am able to make it so. I want to avoid to do what you did

By the way, what implementation did you use for the PydanticJSONB ?

I did this:

class PydanticJSONB(TypeDecorator):
    impl = JSONB

    def __init__(self, model_type: Any, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_type = model_type
        self._type = get_origin(model_type)
        if self._type is list:
            self._item_type = get_args(model_type)[0]
        elif self._type is dict:
            self._item_type = get_args(model_type)[1]
        else:
            self._item_type = model_type
        self._adapter = TypeAdapter(self.model_type)

    def process_bind_param(self, value: Any, dialect: Any) -> Any:
        if value is None:
            return None

        if self._type is list:
            if not isinstance(value, list):
                raise TypeError(f"Expected list of {self._item_type}")
            return [item.model_dump() if isinstance(item, self._item_type) else item for item in value]
        elif self._type is dict:
            if not isinstance(value, dict):
                raise TypeError(f"Expected dict of {self._item_type}")
            return {k: item.model_dump() if isinstance(item, self._item_type) else item for k, item in value.items()}
        else:
            if isinstance(value, self.model_type):
                return value.model_dump()
            return value

    def process_result_value(self, value: Any, dialect: Any) -> Any:
        if value is not None:
            return self._adapter.validate_python(value)
        if self._type is list:
            return []
        elif self._type is dict:
            return {}
        else:
            return None

Might release this and more stuff in some kind of sqlmodel-utils package on pypi one day (Because I use SQLModel, but it works with sqlalchemy as well)

@a1d4r
Copy link

a1d4r commented Mar 3, 2025

@Seluj78 Here is my implementation:
https://gist.github.com/a1d4r/100b06239925a414446305c81433cc88

Basically, the same as the original one, but with typing.

Example:

data: Mapped[ReportData] = mapped_column(PydanticType(ReportData))

@Seluj78
Copy link

Seluj78 commented Mar 3, 2025

@a1d4r I DID IT !

Here's the code. More changes might be needed to support more complex types (Like List[PydanticModel] or Dict[str, MyModel] are untested right now).

But it works ! It flags correctly the pydanticjsonb columns as changed !!

import pydantic
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.orm.attributes import flag_modified

# Where `PydanticJSONB` is my implementation, see previous comments

def flag_pydantic_changes(target):
    inspector = inspect(target)
    mapper = inspector.mapper

    for attr in inspector.attrs:
        key = attr.key
        prop = mapper.attrs.get(key)

        # Skip non-ColumnProperty attributes
        if not isinstance(prop, ColumnProperty):
            continue

        # Check if any column in this property is PydanticJSONB
        is_pydantic_jsonb = any(
            isinstance(col.type, PydanticJSONB)
            for col in prop.columns
        )

        if is_pydantic_jsonb:
            hist = attr.history
            original_dict = hist.unchanged[0] if hist.unchanged else None
            if issubclass(attr.value.__class__, pydantic.BaseModel):
                current_dict = attr.value.model_dump()
            else:
                current_dict = attr.value

            if original_dict != current_dict:
                flag_modified(target, key)


@event.listens_for(_BaseModel, "before_update")
def auto_flag_modified(mapper, connection, target):
    flag_pydantic_changes(target)


MODELS = [
    Users,
    # Add your models here
]


for model in MODELS:
    event.listen(model, "before_update", auto_flag_modified)

I'm quite happy and I will keep working on this and I will probably end up publishing this to github at some point.

@a1d4r
Copy link

a1d4r commented Mar 3, 2025

@Seluj78 Awesome work! I will check it later. By the way, feel free to contact me if you need any help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment