Last active
February 18, 2024 18:35
-
-
Save Tantas/b0e9f8a62807b3f8d5fcddda722c1492 to your computer and use it in GitHub Desktop.
SqlAlchemy Pydantic 2+ Json Column
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pydantic import BaseModel, TypeAdapter | |
from sqlalchemy import JSON, TypeDecorator | |
from sqlalchemy.orm import DeclarativeBase | |
from sqlalchemy.sql import elements | |
class PydanticJson(TypeDecorator): | |
impl = JSON | |
cache_ok = True | |
def __init__(self, model: type[BaseModel]): | |
super().__init__(none_as_null=True) | |
self.model = model | |
def _make_bind_processor(self, string_process, json_serializer): | |
if string_process: | |
def process(value): | |
if value is self.NULL: | |
value = None | |
elif isinstance(value, elements.Null) or ( | |
value is None and self.none_as_null | |
): | |
return None | |
serialized = json_serializer(value) | |
return string_process(serialized) | |
else: | |
def process(value): | |
if value is self.NULL: | |
value = None | |
elif isinstance(value, elements.Null) or ( | |
value is None and self.none_as_null | |
): | |
return None | |
return json_serializer(value) | |
return process | |
def bind_processor(self, dialect): | |
string_process = self._str_impl.bind_processor(dialect) | |
json_serializer = TypeAdapter(self.model).dump_json | |
return self._make_bind_processor(string_process, json_serializer) | |
def result_processor(self, dialect, coltype): | |
string_process = self._str_impl.result_processor(dialect, coltype) | |
json_deserializer = TypeAdapter(self.model).validate_json | |
def process(value): | |
if value is None: | |
return None | |
if string_process: | |
value = string_process(value) | |
return json_deserializer(value) | |
return process | |
""" | |
# Usage example. | |
from pydantic import BaseModel | |
from sqlalchemy.orm import DeclarativeBase | |
class Base(DeclarativeBase): | |
pass | |
class PydanticType(BaseModel): | |
field: str | |
class Entity(Base): | |
id: Mapped[int] = mapped_column(primary_key=True) | |
data: Mapped[PydanticType] = mapped_column(PydanticJson(PydanticType)) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment