Skip to content

Instantly share code, notes, and snippets.

@imankulov
Last active October 22, 2024 07:35
Show Gist options
  • Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
Using pydantic models as SQLAlchemy JSON fields (convert beween JSON and pydantic.BaseModel subclasses)
#!/usr/bin/env ipython -i
import datetime
import json
from typing import Optional
import sqlalchemy as sa
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB
from pydantic import BaseModel, Field, parse_obj_as
from pydantic.json import pydantic_encoder
# --------------------------------------------------------------------------------------
# Define pydantic-alchemy specific types (once per application)
# --------------------------------------------------------------------------------------
class PydanticType(sa.types.TypeDecorator):
"""Pydantic type.
SAVING:
- Uses SQLAlchemy JSON type under the hood.
- Acceps the pydantic model and converts it to a dict on save.
- SQLAlchemy engine JSON-encodes the dict to a string.
RETRIEVING:
- Pulls the string from the database.
- SQLAlchemy engine JSON-decodes the string to a dict.
- Uses the dict to create a pydantic model.
"""
# If you work with PostgreSQL, you can consider using
# sqlalchemy.dialects.postgresql.JSONB instead of a
# generic sa.types.JSON
#
# Ref: https://www.postgresql.org/docs/13/datatype-json.html
impl = sa.types.JSON
def __init__(self, pydantic_type):
super().__init__()
self.pydantic_type = pydantic_type
def load_dialect_impl(self, dialect):
# Use JSONB for PostgreSQL and JSON for other databases.
if dialect.name == "postgresql":
return dialect.type_descriptor(JSONB())
else:
return dialect.type_descriptor(sa.JSON())
def process_bind_param(self, value, dialect):
return value.dict() if value else None
# If you use FasAPI, you can replace the line above with their jsonable_encoder().
# E.g.,
# from fastapi.encoders import jsonable_encoder
# return jsonable_encoder(value) if value else None
def process_result_value(self, value, dialect):
return parse_obj_as(self.pydantic_type, value) if value else None
def json_serializer(*args, **kwargs) -> str:
return json.dumps(*args, default=pydantic_encoder, **kwargs)
# --------------------------------------------------------------------------------------
# Configure SQLAlchemy engine, session and declarative base (once per application)
# The key is to define json_serializer while creating the engine.
# --------------------------------------------------------------------------------------
engine = sa.create_engine("sqlite:///:memory:", json_serializer=json_serializer)
Session = sessionmaker(bind=engine, expire_on_commit=False, future=True)
Base = declarative_base()
# --------------------------------------------------------------------------------------
# Define your Pydantic and SQLAlchemy models (as many as needed)
# --------------------------------------------------------------------------------------
class UserSettings(BaseModel):
notify_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
class User(Base):
__tablename__ = "users"
id: int = sa.Column(sa.Integer, primary_key=True)
name: str = sa.Column(sa.String, doc="User name", comment="User name")
settings: Optional[UserSettings] = sa.Column(PydanticType(UserSettings), nullable=True)
# --------------------------------------------------------------------------------------
# Create tables (once per application)
# --------------------------------------------------------------------------------------
Base.metadata.create_all(engine)
# --------------------------------------------------------------------------------------
# Usage example (we use 2.0 querying style with selects)
# Ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#querying-2-0-style
# --------------------------------------------------------------------------------------
session = Session()
user = User(name="user", settings=UserSettings())
session.add(user)
session.commit()
same_user = session.execute(sa.select(User)).scalars().first()
@Filimoa
Copy link

Filimoa commented Jan 26, 2024

To avoid a deprecation warning as of Pydantic 2.5 you will need to use the following

def process_result_value(self, value, dialect):
    return self.pydantic_type(**value) if value else None

@DeoLeung
Copy link

got a problem

if I doselect(table.c.settings['a']), it still try to parse the a value as the whole pydantic model :(

@a1d4r
Copy link

a1d4r commented Jun 21, 2024

Is it possible for SQLAlchemy to detect changes in PydanticType field? When I change the field of the pydantic model, I have to manually call function flag_modified to make SQLAlchemy flush the change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment