This gist demonstrates a custom SQLAlchemy column type that allows you to save and load Pydantic models directly in your SQLAlchemy database. It simplifies the process of storing and retrieving complex pydantic data models in your database without manual conversion to JSON or dictionaries.
kindly Check Git Repo or follow steps bellow
- Building custom Column Type
- Building a nested data structure using Pydantic model
- Building a Sqlalchemy Table models using 1 and 2
- Insert and select Script for testing
from pydantic import BaseModel, TypeAdapter # pydantic version > 2.0.0
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.types import JSON, TypeDecorator
# from pydantic import parse_obj_as # pydantic version <2.0.0
class PydanticColumn(TypeDecorator):
"""
PydanticColumn type.
* for custom column type implementation check https://docs.sqlalchemy.org/en/20/core/custom_types.html
* Uses sqlalchemy.dialects.postgresql.JSONB if dialects == postgresql else generic sqlalchemy.types.JSON
* Save:
- Acceps the pydantic model and converts it to a dict on save.
- SQLAlchemy engine JSON-encodes the dict to a string.
* Load:
- Pulls the string from the database.
- SQLAlchemy engine JSON-decodes the string to a dict.
- Uses the dict to create a pydantic model.
"""
impl = JSON
cache_ok = True
def __init__(self, pydantic_type):
super().__init__()
if not issubclass(pydantic_type, BaseModel):
raise ValueError("Column Type Should be Pydantic Class")
self.pydantic_type = pydantic_type
def load_dialect_impl(self, dialect):
# Use JSONB for PostgreSQL and JSON for other databases.
if dialect.name == "postgresql":
return dialect.type_descriptor(JSONB())
else:
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
# return value.dict() if value else None # pydantic <2.0.0
return value.model_dump() if value else None
def process_result_value(self, value, dialect):
# return parse_obj_as(self.pydantic_type, value) if value else None # pydantic < 2.0.0
return TypeAdapter(self.pydantic_type).validate_python(value)
from enum import Enum
from typing import List, Literal, Optional
from pydantic import BaseModel, Field
class ProjecType(str, Enum):
DEPLOYABLE = "DEPLOYABLE"
INMEMORY = "INMEMORY"
SINGLESHOT = "SINGLESHOT"
class ProjecStatus(str, Enum):
INIT = "INIT"
DATALOAD = "DEPLOYABLE"
PREPROCESS = "PREPROCESS"
POSTPROCESS = "POSTPROCESS"
class Dtypes(str, Enum):
INTERGER = "integers"
FLOAT = "float"
BOOLEAN = "bool"
CATEGORICAL = "categorical"
DATE = "date"
class ColumnType(str, Enum):
FEATURES = "features"
TARGET = "target"
INDEX = "index"
UNIQUEID = "unique-id"
class ImputationScheme(str, Enum):
MEAN = "mean"
MEDIAN = "median"
MODE = "mode"
VALUE = "value"
class ColumnsDescription(BaseModel):
name: str
col_type: ColumnType
dtype: Dtypes
mean: Optional[float] = Field(default=None)
median: Optional[float] = Field(default=None)
mode: Optional[float] = Field(default=None)
null_count: Optional[int] = Field(default=None)
unique_valus: Optional[int] = Field(default=None)
imputation_scheme: ImputationScheme
class DatasetDescriptor(BaseModel):
row_count: int = Field(ge=0)
cloumns_info: List[ColumnsDescription] = Field(default_factory=list)
duplicate_row_count: int = Field(default=0, ge=0)
duplicate_columns: List[str] = Field(default_factory=list)
outlier_count: int = Field(ge=0)
is_imbalance: Optional[bool] = Field(default=None)
from typing import Optional
from sqlalchemy import Enum, MetaData, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from src.custom_pydantic_column import PydanticColumn # custom_column.py
from src.pydantic_model import DatasetDescriptor, ProjecStatus, ProjecType # pydantic_model.py
meta = MetaData()
class Base(DeclarativeBase):
metadata = meta
class Projects(Base):
__tablename__ = "projects"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
title: Mapped[str] = mapped_column(String(100), nullable=False)
descriptions: Mapped[Optional[str]] = mapped_column(
String(100), nullable=True, default=None
)
ptype: Mapped[ProjecType] = mapped_column(Enum(ProjecType), nullable=False)
status: Mapped[ProjecStatus] = mapped_column(Enum(ProjecStatus), nullable=False)
dataset_info: Mapped[Optional[DatasetDescriptor]] = mapped_column(
PydanticColumn(DatasetDescriptor), nullable=True
)
import os
import logging
from typing import Dict, Optional
from sqlalchemy import URL, create_engine, insert, select
from sqlalchemy.orm import Session, sessionmaker
from src.orms import Projects
from src.pydantic_model import (ColumnsDescription, ColumnType,
DatasetDescriptor, Dtypes, ImputationScheme,
ProjecStatus, ProjecType)
logger = logging.getLogger(__name__)
def insert_project(db: Session, object_in: Dict) -> Optional[Projects]:
query_stmt = insert(Projects).values(**object_in).returning(Projects)
status, data = None, None
try:
status = db.execute(query_stmt)
data = status.scalars().one()
except Exception as err:
db.rollback()
logger.error(f"Session id {id(db)}| Error while Insert, Error {err}")
data = None
else:
logger.info(f"Session id {id(db)}| Sucessfully inserted, id {data.id}")
finally:
return data
def select_project(db: Session, project_id: int) -> Optional[Projects]:
query_stmt = select(Projects).where(Projects.id == project_id)
status, data = None, None
try:
status = db.execute(query_stmt)
data = status.scalars().one()
except Exception as err:
db.rollback()
logger.error(f"Session id {id(db)}| Error while Insert, Error {err}")
data = None
else:
logger.info(f"Session id {id(db)}| Sucessfully Selected, id {data.id}")
finally:
return data
if __name__ == "__main__":
dataset = DatasetDescriptor(
row_count=100,
duplicate_columns=[],
duplicate_row_count=0,
outlier_count=0,
is_imbalance=False,
cloumns_info=[
ColumnsDescription(
name="job_type",
col_type=ColumnType.FEATURES,
dtype=Dtypes.CATEGORICAL,
mean=None,
median=None,
mode=None,
null_count=0,
unique_valus=5,
imputation_scheme=ImputationScheme.MODE,
)
],
)
logger.info(f'dataset {dataset.model_dump()}')
object_in = {
"title": "Example Project 1",
"descriptions": "Example Project 1 description",
"ptype": ProjecType.DEPLOYABLE,
"status": ProjecStatus.INIT,
"dataset_info": dataset,
}
DB_PATH = os.path.join(os.curdir, "test_sqlite3.db")
db_url = URL.create(drivername="sqlite", database=DB_PATH)
logger.info(f"DB URL {db_url}")
engine = create_engine(db_url, echo=False)
Session = sessionmaker(engine)
with Session.begin() as session:
data = insert_project(db=session, object_in=object_in)
# logger.info(f'data {data}')
project = select_project(db=session, project_id=data.id)
logger.info(f'dataset_info {project.ptype}, {project.dataset_info}')
Kindly check git_repo