Last active
October 22, 2024 07:35
-
-
Save imankulov/4051b7805ad737ace7d8de3d3f934d6b to your computer and use it in GitHub Desktop.
Using pydantic models as SQLAlchemy JSON fields (convert beween JSON and pydantic.BaseModel subclasses)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env ipython -i | |
import datetime | |
import json | |
from typing import Optional | |
import sqlalchemy as sa | |
from sqlalchemy.orm import declarative_base, sessionmaker | |
from sqlalchemy.dialects.postgresql import JSONB | |
from pydantic import BaseModel, Field, parse_obj_as | |
from pydantic.json import pydantic_encoder | |
# -------------------------------------------------------------------------------------- | |
# Define pydantic-alchemy specific types (once per application) | |
# -------------------------------------------------------------------------------------- | |
class PydanticType(sa.types.TypeDecorator): | |
"""Pydantic type. | |
SAVING: | |
- Uses SQLAlchemy JSON type under the hood. | |
- Acceps the pydantic model and converts it to a dict on save. | |
- SQLAlchemy engine JSON-encodes the dict to a string. | |
RETRIEVING: | |
- Pulls the string from the database. | |
- SQLAlchemy engine JSON-decodes the string to a dict. | |
- Uses the dict to create a pydantic model. | |
""" | |
# If you work with PostgreSQL, you can consider using | |
# sqlalchemy.dialects.postgresql.JSONB instead of a | |
# generic sa.types.JSON | |
# | |
# Ref: https://www.postgresql.org/docs/13/datatype-json.html | |
impl = sa.types.JSON | |
def __init__(self, pydantic_type): | |
super().__init__() | |
self.pydantic_type = pydantic_type | |
def load_dialect_impl(self, dialect): | |
# Use JSONB for PostgreSQL and JSON for other databases. | |
if dialect.name == "postgresql": | |
return dialect.type_descriptor(JSONB()) | |
else: | |
return dialect.type_descriptor(sa.JSON()) | |
def process_bind_param(self, value, dialect): | |
return value.dict() if value else None | |
# If you use FasAPI, you can replace the line above with their jsonable_encoder(). | |
# E.g., | |
# from fastapi.encoders import jsonable_encoder | |
# return jsonable_encoder(value) if value else None | |
def process_result_value(self, value, dialect): | |
return parse_obj_as(self.pydantic_type, value) if value else None | |
def json_serializer(*args, **kwargs) -> str: | |
return json.dumps(*args, default=pydantic_encoder, **kwargs) | |
# -------------------------------------------------------------------------------------- | |
# Configure SQLAlchemy engine, session and declarative base (once per application) | |
# The key is to define json_serializer while creating the engine. | |
# -------------------------------------------------------------------------------------- | |
engine = sa.create_engine("sqlite:///:memory:", json_serializer=json_serializer) | |
Session = sessionmaker(bind=engine, expire_on_commit=False, future=True) | |
Base = declarative_base() | |
# -------------------------------------------------------------------------------------- | |
# Define your Pydantic and SQLAlchemy models (as many as needed) | |
# -------------------------------------------------------------------------------------- | |
class UserSettings(BaseModel): | |
notify_at: datetime.datetime = Field(default_factory=datetime.datetime.now) | |
class User(Base): | |
__tablename__ = "users" | |
id: int = sa.Column(sa.Integer, primary_key=True) | |
name: str = sa.Column(sa.String, doc="User name", comment="User name") | |
settings: Optional[UserSettings] = sa.Column(PydanticType(UserSettings), nullable=True) | |
# -------------------------------------------------------------------------------------- | |
# Create tables (once per application) | |
# -------------------------------------------------------------------------------------- | |
Base.metadata.create_all(engine) | |
# -------------------------------------------------------------------------------------- | |
# Usage example (we use 2.0 querying style with selects) | |
# Ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#querying-2-0-style | |
# -------------------------------------------------------------------------------------- | |
session = Session() | |
user = User(name="user", settings=UserSettings()) | |
session.add(user) | |
session.commit() | |
same_user = session.execute(sa.select(User)).scalars().first() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is it possible for SQLAlchemy to detect changes in PydanticType field? When I change the field of the pydantic model, I have to manually call function
flag_modified
to make SQLAlchemy flush the change.