-
-
Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
#!/usr/bin/env ipython -i | |
import datetime | |
import json | |
from typing import Optional | |
import sqlalchemy as sa | |
from sqlalchemy.orm import declarative_base, sessionmaker | |
from sqlalchemy.dialects.postgresql import JSONB | |
from pydantic import BaseModel, Field, parse_obj_as | |
from pydantic.json import pydantic_encoder | |
# -------------------------------------------------------------------------------------- | |
# Define pydantic-alchemy specific types (once per application) | |
# -------------------------------------------------------------------------------------- | |
class PydanticType(sa.types.TypeDecorator): | |
"""Pydantic type. | |
SAVING: | |
- Uses SQLAlchemy JSON type under the hood. | |
- Acceps the pydantic model and converts it to a dict on save. | |
- SQLAlchemy engine JSON-encodes the dict to a string. | |
RETRIEVING: | |
- Pulls the string from the database. | |
- SQLAlchemy engine JSON-decodes the string to a dict. | |
- Uses the dict to create a pydantic model. | |
""" | |
# If you work with PostgreSQL, you can consider using | |
# sqlalchemy.dialects.postgresql.JSONB instead of a | |
# generic sa.types.JSON | |
# | |
# Ref: https://www.postgresql.org/docs/13/datatype-json.html | |
impl = sa.types.JSON | |
def __init__(self, pydantic_type): | |
super().__init__() | |
self.pydantic_type = pydantic_type | |
def load_dialect_impl(self, dialect): | |
# Use JSONB for PostgreSQL and JSON for other databases. | |
if dialect.name == "postgresql": | |
return dialect.type_descriptor(JSONB()) | |
else: | |
return dialect.type_descriptor(sa.JSON()) | |
def process_bind_param(self, value, dialect): | |
return value.dict() if value else None | |
# If you use FasAPI, you can replace the line above with their jsonable_encoder(). | |
# E.g., | |
# from fastapi.encoders import jsonable_encoder | |
# return jsonable_encoder(value) if value else None | |
def process_result_value(self, value, dialect): | |
return parse_obj_as(self.pydantic_type, value) if value else None | |
def json_serializer(*args, **kwargs) -> str: | |
return json.dumps(*args, default=pydantic_encoder, **kwargs) | |
# -------------------------------------------------------------------------------------- | |
# Configure SQLAlchemy engine, session and declarative base (once per application) | |
# The key is to define json_serializer while creating the engine. | |
# -------------------------------------------------------------------------------------- | |
engine = sa.create_engine("sqlite:///:memory:", json_serializer=json_serializer) | |
Session = sessionmaker(bind=engine, expire_on_commit=False, future=True) | |
Base = declarative_base() | |
# -------------------------------------------------------------------------------------- | |
# Define your Pydantic and SQLAlchemy models (as many as needed) | |
# -------------------------------------------------------------------------------------- | |
class UserSettings(BaseModel): | |
notify_at: datetime.datetime = Field(default_factory=datetime.datetime.now) | |
class User(Base): | |
__tablename__ = "users" | |
id: int = sa.Column(sa.Integer, primary_key=True) | |
name: str = sa.Column(sa.String, doc="User name", comment="User name") | |
settings: Optional[UserSettings] = sa.Column(PydanticType(UserSettings), nullable=True) | |
# -------------------------------------------------------------------------------------- | |
# Create tables (once per application) | |
# -------------------------------------------------------------------------------------- | |
Base.metadata.create_all(engine) | |
# -------------------------------------------------------------------------------------- | |
# Usage example (we use 2.0 querying style with selects) | |
# Ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#querying-2-0-style | |
# -------------------------------------------------------------------------------------- | |
session = Session() | |
user = User(name="user", settings=UserSettings()) | |
session.add(user) | |
session.commit() | |
same_user = session.execute(sa.select(User)).scalars().first() |
Hey,@etoiledemer
You're right. It's not going to work out of the box with Alembic. For myself, I modify the migration script manually after creating it.
As you asked this question, I decided to see if there's a better solution and came up with the following snippet. Here I replace pydantic types with native SQLAlchemy JSON.
# file: env.py
def render_item(type_, obj, autogen_context):
"""Apply custom rendering for PydanticType."""
if type_ == "type" and isinstance(obj, PydanticType):
return "sa.JSON()"
return False
...
context.configure(
...
render_item=render_item,
)
This will generate
...
sa.Column('settings', sa.JSON(), nullable=True),
...
Ref: Affecting the Rendering of Types Themselves from the Alembic documentation.
I've made a Python package SQLAlchemy-Nested-Mutable with inspiration from this work.
@wonderbeyond, wow, this looks great! Thanks.
To avoid a deprecation warning as of Pydantic 2.5 you will need to use the following
def process_result_value(self, value, dialect):
return self.pydantic_type(**value) if value else None
got a problem
if I doselect(table.c.settings['a'])
, it still try to parse the a value as the whole pydantic model :(
Is it possible for SQLAlchemy to detect changes in PydanticType field? When I change the field of the pydantic model, I have to manually call function flag_modified
to make SQLAlchemy flush the change.
is it compatible with Alembic autogeneration of schema?
for me, alembic generated
without the
UserSettings
stuff