-
-
Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
#!/usr/bin/env ipython -i | |
import datetime | |
import json | |
from typing import Optional | |
import sqlalchemy as sa | |
from sqlalchemy.orm import declarative_base, sessionmaker | |
from sqlalchemy.dialects.postgresql import JSONB | |
from pydantic import BaseModel, Field, parse_obj_as | |
from pydantic.json import pydantic_encoder | |
# -------------------------------------------------------------------------------------- | |
# Define pydantic-alchemy specific types (once per application) | |
# -------------------------------------------------------------------------------------- | |
class PydanticType(sa.types.TypeDecorator): | |
"""Pydantic type. | |
SAVING: | |
- Uses SQLAlchemy JSON type under the hood. | |
- Acceps the pydantic model and converts it to a dict on save. | |
- SQLAlchemy engine JSON-encodes the dict to a string. | |
RETRIEVING: | |
- Pulls the string from the database. | |
- SQLAlchemy engine JSON-decodes the string to a dict. | |
- Uses the dict to create a pydantic model. | |
""" | |
# If you work with PostgreSQL, you can consider using | |
# sqlalchemy.dialects.postgresql.JSONB instead of a | |
# generic sa.types.JSON | |
# | |
# Ref: https://www.postgresql.org/docs/13/datatype-json.html | |
impl = sa.types.JSON | |
def __init__(self, pydantic_type): | |
super().__init__() | |
self.pydantic_type = pydantic_type | |
def load_dialect_impl(self, dialect): | |
# Use JSONB for PostgreSQL and JSON for other databases. | |
if dialect.name == "postgresql": | |
return dialect.type_descriptor(JSONB()) | |
else: | |
return dialect.type_descriptor(sa.JSON()) | |
def process_bind_param(self, value, dialect): | |
return value.dict() if value else None | |
# If you use FasAPI, you can replace the line above with their jsonable_encoder(). | |
# E.g., | |
# from fastapi.encoders import jsonable_encoder | |
# return jsonable_encoder(value) if value else None | |
def process_result_value(self, value, dialect): | |
return parse_obj_as(self.pydantic_type, value) if value else None | |
def json_serializer(*args, **kwargs) -> str: | |
return json.dumps(*args, default=pydantic_encoder, **kwargs) | |
# -------------------------------------------------------------------------------------- | |
# Configure SQLAlchemy engine, session and declarative base (once per application) | |
# The key is to define json_serializer while creating the engine. | |
# -------------------------------------------------------------------------------------- | |
engine = sa.create_engine("sqlite:///:memory:", json_serializer=json_serializer) | |
Session = sessionmaker(bind=engine, expire_on_commit=False, future=True) | |
Base = declarative_base() | |
# -------------------------------------------------------------------------------------- | |
# Define your Pydantic and SQLAlchemy models (as many as needed) | |
# -------------------------------------------------------------------------------------- | |
class UserSettings(BaseModel): | |
notify_at: datetime.datetime = Field(default_factory=datetime.datetime.now) | |
class User(Base): | |
__tablename__ = "users" | |
id: int = sa.Column(sa.Integer, primary_key=True) | |
name: str = sa.Column(sa.String, doc="User name", comment="User name") | |
settings: Optional[UserSettings] = sa.Column(PydanticType(UserSettings), nullable=True) | |
# -------------------------------------------------------------------------------------- | |
# Create tables (once per application) | |
# -------------------------------------------------------------------------------------- | |
Base.metadata.create_all(engine) | |
# -------------------------------------------------------------------------------------- | |
# Usage example (we use 2.0 querying style with selects) | |
# Ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#querying-2-0-style | |
# -------------------------------------------------------------------------------------- | |
session = Session() | |
user = User(name="user", settings=UserSettings()) | |
session.add(user) | |
session.commit() | |
same_user = session.execute(sa.select(User)).scalars().first() |
To avoid a deprecation warning as of Pydantic 2.5 you will need to use the following
def process_result_value(self, value, dialect):
return self.pydantic_type(**value) if value else None
got a problem
if I doselect(table.c.settings['a'])
, it still try to parse the a value as the whole pydantic model :(
Is it possible for SQLAlchemy to detect changes in PydanticType field? When I change the field of the pydantic model, I have to manually call function flag_modified
to make SQLAlchemy flush the change.
@a1d4r I have the same problem on my end. Did you end up finding a solution ?
@a1d4r I have the same problem on my end. Did you end up finding a solution ?
Nope, I call flag_modified
every time I change the model. For example:
from datetime import UTC, datetime
from sqlalchemy.orm.attributes import flag_modified
user.settings.notify_at = datetime.now(UTC)
flag_modified(user.settings, "notify_at")
@a1d4r Alright, thank you for the quick response. I am working on something that might automate this if I am able to make it so. I want to avoid to do what you did
By the way, what implementation did you use for the PydanticJSONB ?
I did this:
class PydanticJSONB(TypeDecorator):
impl = JSONB
def __init__(self, model_type: Any, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_type = model_type
self._type = get_origin(model_type)
if self._type is list:
self._item_type = get_args(model_type)[0]
elif self._type is dict:
self._item_type = get_args(model_type)[1]
else:
self._item_type = model_type
self._adapter = TypeAdapter(self.model_type)
def process_bind_param(self, value: Any, dialect: Any) -> Any:
if value is None:
return None
if self._type is list:
if not isinstance(value, list):
raise TypeError(f"Expected list of {self._item_type}")
return [item.model_dump() if isinstance(item, self._item_type) else item for item in value]
elif self._type is dict:
if not isinstance(value, dict):
raise TypeError(f"Expected dict of {self._item_type}")
return {k: item.model_dump() if isinstance(item, self._item_type) else item for k, item in value.items()}
else:
if isinstance(value, self.model_type):
return value.model_dump()
return value
def process_result_value(self, value: Any, dialect: Any) -> Any:
if value is not None:
return self._adapter.validate_python(value)
if self._type is list:
return []
elif self._type is dict:
return {}
else:
return None
Might release this and more stuff in some kind of sqlmodel-utils
package on pypi one day (Because I use SQLModel, but it works with sqlalchemy as well)
@Seluj78 Here is my implementation:
https://gist.github.com/a1d4r/100b06239925a414446305c81433cc88
Basically, the same as the original one, but with typing.
Example:
data: Mapped[ReportData] = mapped_column(PydanticType(ReportData))
@a1d4r I DID IT !
Here's the code. More changes might be needed to support more complex types (Like List[PydanticModel]
or Dict[str, MyModel]
are untested right now).
But it works ! It flags correctly the pydanticjsonb columns as changed !!
import pydantic
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.orm.attributes import flag_modified
# Where `PydanticJSONB` is my implementation, see previous comments
def flag_pydantic_changes(target):
inspector = inspect(target)
mapper = inspector.mapper
for attr in inspector.attrs:
key = attr.key
prop = mapper.attrs.get(key)
# Skip non-ColumnProperty attributes
if not isinstance(prop, ColumnProperty):
continue
# Check if any column in this property is PydanticJSONB
is_pydantic_jsonb = any(
isinstance(col.type, PydanticJSONB)
for col in prop.columns
)
if is_pydantic_jsonb:
hist = attr.history
original_dict = hist.unchanged[0] if hist.unchanged else None
if issubclass(attr.value.__class__, pydantic.BaseModel):
current_dict = attr.value.model_dump()
else:
current_dict = attr.value
if original_dict != current_dict:
flag_modified(target, key)
@event.listens_for(_BaseModel, "before_update")
def auto_flag_modified(mapper, connection, target):
flag_pydantic_changes(target)
MODELS = [
Users,
# Add your models here
]
for model in MODELS:
event.listen(model, "before_update", auto_flag_modified)
I'm quite happy and I will keep working on this and I will probably end up publishing this to github at some point.
@Seluj78 Awesome work! I will check it later. By the way, feel free to contact me if you need any help.
@wonderbeyond, wow, this looks great! Thanks.