-
-
Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
#!/usr/bin/env ipython -i | |
import datetime | |
import json | |
from typing import Optional | |
import sqlalchemy as sa | |
from sqlalchemy.orm import declarative_base, sessionmaker | |
from sqlalchemy.dialects.postgresql import JSONB | |
from pydantic import BaseModel, Field, parse_obj_as | |
from pydantic.json import pydantic_encoder | |
# -------------------------------------------------------------------------------------- | |
# Define pydantic-alchemy specific types (once per application) | |
# -------------------------------------------------------------------------------------- | |
class PydanticType(sa.types.TypeDecorator): | |
"""Pydantic type. | |
SAVING: | |
- Uses SQLAlchemy JSON type under the hood. | |
- Acceps the pydantic model and converts it to a dict on save. | |
- SQLAlchemy engine JSON-encodes the dict to a string. | |
RETRIEVING: | |
- Pulls the string from the database. | |
- SQLAlchemy engine JSON-decodes the string to a dict. | |
- Uses the dict to create a pydantic model. | |
""" | |
# If you work with PostgreSQL, you can consider using | |
# sqlalchemy.dialects.postgresql.JSONB instead of a | |
# generic sa.types.JSON | |
# | |
# Ref: https://www.postgresql.org/docs/13/datatype-json.html | |
impl = sa.types.JSON | |
def __init__(self, pydantic_type): | |
super().__init__() | |
self.pydantic_type = pydantic_type | |
def load_dialect_impl(self, dialect): | |
# Use JSONB for PostgreSQL and JSON for other databases. | |
if dialect.name == "postgresql": | |
return dialect.type_descriptor(JSONB()) | |
else: | |
return dialect.type_descriptor(sa.JSON()) | |
def process_bind_param(self, value, dialect): | |
return value.dict() if value else None | |
# If you use FasAPI, you can replace the line above with their jsonable_encoder(). | |
# E.g., | |
# from fastapi.encoders import jsonable_encoder | |
# return jsonable_encoder(value) if value else None | |
def process_result_value(self, value, dialect): | |
return parse_obj_as(self.pydantic_type, value) if value else None | |
def json_serializer(*args, **kwargs) -> str: | |
return json.dumps(*args, default=pydantic_encoder, **kwargs) | |
# -------------------------------------------------------------------------------------- | |
# Configure SQLAlchemy engine, session and declarative base (once per application) | |
# The key is to define json_serializer while creating the engine. | |
# -------------------------------------------------------------------------------------- | |
engine = sa.create_engine("sqlite:///:memory:", json_serializer=json_serializer) | |
Session = sessionmaker(bind=engine, expire_on_commit=False, future=True) | |
Base = declarative_base() | |
# -------------------------------------------------------------------------------------- | |
# Define your Pydantic and SQLAlchemy models (as many as needed) | |
# -------------------------------------------------------------------------------------- | |
class UserSettings(BaseModel): | |
notify_at: datetime.datetime = Field(default_factory=datetime.datetime.now) | |
class User(Base): | |
__tablename__ = "users" | |
id: int = sa.Column(sa.Integer, primary_key=True) | |
name: str = sa.Column(sa.String, doc="User name", comment="User name") | |
settings: Optional[UserSettings] = sa.Column(PydanticType(UserSettings), nullable=True) | |
# -------------------------------------------------------------------------------------- | |
# Create tables (once per application) | |
# -------------------------------------------------------------------------------------- | |
Base.metadata.create_all(engine) | |
# -------------------------------------------------------------------------------------- | |
# Usage example (we use 2.0 querying style with selects) | |
# Ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#querying-2-0-style | |
# -------------------------------------------------------------------------------------- | |
session = Session() | |
user = User(name="user", settings=UserSettings()) | |
session.add(user) | |
session.commit() | |
same_user = session.execute(sa.select(User)).scalars().first() |
Hey,@etoiledemer
You're right. It's not going to work out of the box with Alembic. For myself, I modify the migration script manually after creating it.
As you asked this question, I decided to see if there's a better solution and came up with the following snippet. Here I replace pydantic types with native SQLAlchemy JSON.
# file: env.py
def render_item(type_, obj, autogen_context):
"""Apply custom rendering for PydanticType."""
if type_ == "type" and isinstance(obj, PydanticType):
return "sa.JSON()"
return False
...
context.configure(
...
render_item=render_item,
)
This will generate
...
sa.Column('settings', sa.JSON(), nullable=True),
...
Ref: Affecting the Rendering of Types Themselves from the Alembic documentation.
I've made a Python package SQLAlchemy-Nested-Mutable with inspiration from this work.
@wonderbeyond, wow, this looks great! Thanks.
To avoid a deprecation warning as of Pydantic 2.5 you will need to use the following
def process_result_value(self, value, dialect):
return self.pydantic_type(**value) if value else None
got a problem
if I doselect(table.c.settings['a'])
, it still try to parse the a value as the whole pydantic model :(
Is it possible for SQLAlchemy to detect changes in PydanticType field? When I change the field of the pydantic model, I have to manually call function flag_modified
to make SQLAlchemy flush the change.
@a1d4r I have the same problem on my end. Did you end up finding a solution ?
@a1d4r I have the same problem on my end. Did you end up finding a solution ?
Nope, I call flag_modified
every time I change the model. For example:
from datetime import UTC, datetime
from sqlalchemy.orm.attributes import flag_modified
user.settings.notify_at = datetime.now(UTC)
flag_modified(user.settings, "notify_at")
@a1d4r Alright, thank you for the quick response. I am working on something that might automate this if I am able to make it so. I want to avoid to do what you did
By the way, what implementation did you use for the PydanticJSONB ?
I did this:
class PydanticJSONB(TypeDecorator):
impl = JSONB
def __init__(self, model_type: Any, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_type = model_type
self._type = get_origin(model_type)
if self._type is list:
self._item_type = get_args(model_type)[0]
elif self._type is dict:
self._item_type = get_args(model_type)[1]
else:
self._item_type = model_type
self._adapter = TypeAdapter(self.model_type)
def process_bind_param(self, value: Any, dialect: Any) -> Any:
if value is None:
return None
if self._type is list:
if not isinstance(value, list):
raise TypeError(f"Expected list of {self._item_type}")
return [item.model_dump() if isinstance(item, self._item_type) else item for item in value]
elif self._type is dict:
if not isinstance(value, dict):
raise TypeError(f"Expected dict of {self._item_type}")
return {k: item.model_dump() if isinstance(item, self._item_type) else item for k, item in value.items()}
else:
if isinstance(value, self.model_type):
return value.model_dump()
return value
def process_result_value(self, value: Any, dialect: Any) -> Any:
if value is not None:
return self._adapter.validate_python(value)
if self._type is list:
return []
elif self._type is dict:
return {}
else:
return None
Might release this and more stuff in some kind of sqlmodel-utils
package on pypi one day (Because I use SQLModel, but it works with sqlalchemy as well)
@Seluj78 Here is my implementation:
https://gist.github.com/a1d4r/100b06239925a414446305c81433cc88
Basically, the same as the original one, but with typing.
Example:
data: Mapped[ReportData] = mapped_column(PydanticType(ReportData))
@a1d4r I DID IT !
Here's the code. More changes might be needed to support more complex types (Like List[PydanticModel]
or Dict[str, MyModel]
are untested right now).
But it works ! It flags correctly the pydanticjsonb columns as changed !!
import pydantic
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.orm.attributes import flag_modified
# Where `PydanticJSONB` is my implementation, see previous comments
def flag_pydantic_changes(target):
inspector = inspect(target)
mapper = inspector.mapper
for attr in inspector.attrs:
key = attr.key
prop = mapper.attrs.get(key)
# Skip non-ColumnProperty attributes
if not isinstance(prop, ColumnProperty):
continue
# Check if any column in this property is PydanticJSONB
is_pydantic_jsonb = any(
isinstance(col.type, PydanticJSONB)
for col in prop.columns
)
if is_pydantic_jsonb:
hist = attr.history
original_dict = hist.unchanged[0] if hist.unchanged else None
if issubclass(attr.value.__class__, pydantic.BaseModel):
current_dict = attr.value.model_dump()
else:
current_dict = attr.value
if original_dict != current_dict:
flag_modified(target, key)
@event.listens_for(_BaseModel, "before_update")
def auto_flag_modified(mapper, connection, target):
flag_pydantic_changes(target)
MODELS = [
Users,
# Add your models here
]
for model in MODELS:
event.listen(model, "before_update", auto_flag_modified)
I'm quite happy and I will keep working on this and I will probably end up publishing this to github at some point.
@Seluj78 Awesome work! I will check it later. By the way, feel free to contact me if you need any help.
is it compatible with Alembic autogeneration of schema?
for me, alembic generated
without the
UserSettings
stuff