Skip to content

Instantly share code, notes, and snippets.

@florimondmanca
Last active September 21, 2024 10:34
Show Gist options
  • Save florimondmanca/36d304e50f9cda79fe9b12f0b3349cbb to your computer and use it in GitHub Desktop.
Save florimondmanca/36d304e50f9cda79fe9b12f0b3349cbb to your computer and use it in GitHub Desktop.
DRF-inspired ModelSerializer implementation backed by Pydantic and Tortoise ORM
import datetime
import decimal
import json
from typing import Any, Optional, Type, Union, Dict, Tuple
from pydantic import BaseModel, ValidationError, create_model
from tortoise import fields as tortoise_fields
from tortoise.fields import Field
from tortoise.models import Model as TortoiseModel
Data = Union[list, dict]
PydanticDefinition = Union[Any, Tuple[Type, Any]]
Model = Any
_TORTOISE_FIELD_TO_PYDANTIC_TYPE = {
tortoise_fields.IntField: int,
tortoise_fields.BigIntField: int,
tortoise_fields.SmallIntField: int,
tortoise_fields.CharField: str,
tortoise_fields.TextField: str,
tortoise_fields.BooleanField: bool,
tortoise_fields.DecimalField: decimal.Decimal,
tortoise_fields.DatetimeField: datetime.datetime,
tortoise_fields.DateField: datetime.date,
tortoise_fields.TimeDeltaField: datetime.timedelta,
tortoise_fields.FloatField: float,
tortoise_fields.JSONField: Union[dict, list],
}
class BaseModelSerializer:
def __init__(self, instance: Model = None, data: Data = None):
self.instance = instance
self.initial_data = data
self._validated_data: Optional[dict] = None
self._data: Optional[dict] = None
self._errors: Optional[list] = None
def to_internal_value(self, data: Data) -> Data:
raise NotImplementedError
def to_representation(self, instance: Any) -> Data:
raise NotImplementedError
def run_validation(self, data: Data) -> Data:
return self.to_internal_value(data)
def is_valid(self, raise_exceptions: bool = False) -> bool:
assert self.initial_data is not None, (
"Cannot call `.is_valid()` as no `data=` keyword argument "
"was passed when instanciating the serializer."
)
if self._validated_data is None:
try:
self._validated_data = self.run_validation(self.initial_data)
except ValidationError as exc:
self._validated_data = {}
self._errors = exc.raw_errors
else:
self._errors = []
if self._errors and raise_exceptions:
raise ValidationError(self._errors)
return not bool(self._errors)
@property
def validated_data(self) -> dict:
assert (
self._validated_data is not None
), "You must call `.is_valid()` before accessing `.validated_data`."
return self._validated_data
@property
def data(self):
assert (
self.initial_data is not None and self._validated_data is not None
), (
"When a serializer is passed a `data` keyword argument you "
"must call `.is_valid()` before attempting to access the "
"serialized `.data` representation.\n"
"You should either call `.is_valid()` first, "
"or access `.initial_data` instead."
)
if self._data is None:
if self.instance is not None and self._errors is None:
self._data = self.to_representation(self.instance)
elif self._validated_data is not None and self._errors is not None:
self._data = self.to_representation(self.instance)
else:
self._data = self.initial_data
return self._data
async def create(self, validated_data: dict) -> Model:
raise NotImplementedError
async def update(self, instance: Model, validated_data: dict) -> Model:
raise NotImplementedError
async def save(self, **kwargs: Any) -> Model:
validated_data = {**self.validated_data, **kwargs}
if self.instance is None:
self.instance = await self.create(validated_data)
else:
self.instance = await self.update(self.instance, validated_data)
return self.instance
class PydanticMixin:
"""Serialize/deserialize using Pydantic."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._read_schema: Optional[Type[BaseModel]] = None
self._write_schema: Optional[Type[BaseModel]] = None
def get_declared_fields(self) -> set:
return set(getattr(self, "__annotations__", []))
def get_field_names(self) -> set:
try:
excluded_fields = set(self.Meta.exclude)
except AttributeError:
excluded_fields = set()
return self.get_declared_fields().difference(excluded_fields)
@property
def read_schema(self) -> Type[BaseModel]:
if self._read_schema is None:
self._read_schema = self._get_schema("read")
return self._read_schema
@property
def write_schema(self) -> Type[BaseModel]:
if self._write_schema is None:
self._write_schema = self._get_schema("write")
return self._write_schema
def build_field(
self, field_name: str, operation: str
) -> Optional[PydanticDefinition]:
raise NotImplementedError
def _get_schema(self, operation: str) -> BaseModel:
definitions: Dict[str, PydanticDefinition] = {}
for field_name in self.get_field_names():
definition = self.build_field(field_name, operation)
if definition is None:
continue
definitions[field_name] = definition
return create_model(self._model_class.__name__, **definitions)
def to_internal_value(self, data: dict) -> dict:
"""Convert a native dict of values to a Python dict of values."""
schema = self.write_schema
try:
obj = schema(**data)
except ValidationError as exc:
raise exc from None
else:
return obj.dict()
def to_representation(self, instance: Union[Model, dict]) -> dict:
"""Convert a Python object to a native dict of values."""
schema = self.read_schema
data = {
attr: (
instance[attr]
if isinstance(instance, dict)
else getattr(instance, attr)
)
for attr in schema.__fields__.keys()
}
return json.loads(schema(**data).json())
class TortoiseMixin:
"""Create, update and save instances using Tortoise ORM."""
validated_data: dict
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
try:
self._model_class: Type[TortoiseModel] = self.Meta.model
except AttributeError as exc:
raise AttributeError(
"Model class could not be retrieved from `.Meta.model`."
) from exc
async def create(self, validated_data: dict) -> TortoiseModel:
return await self._model_class.create(**validated_data)
async def update(
self, instance: TortoiseModel, validated_data: dict
) -> TortoiseModel:
for attr, value in validated_data.items():
setattr(instance, attr, value)
await instance.save()
return instance
class ModelSerializer(PydanticMixin, TortoiseMixin, BaseModelSerializer):
"""A concrete model serializer class backed by Pydantic and Tortoise."""
def get_declared_fields(self) -> set:
return (
super().get_declared_fields().union(self._model_class._meta.fields)
)
def build_field(
self, field_name: str, operation: str
) -> Optional[PydanticDefinition]:
try:
# Tortoise field
field: Field = self._model_class._meta.fields_map[field_name]
except KeyError:
# Property. Ignore during write.
if operation == "write":
return None
assert self.instance is not None
return getattr(self.instance, field_name)
else:
read_only = any(
(
field.pk,
getattr(field, "auto_now", False),
getattr(field, "auto_now_add", False),
)
)
if operation == "write" and read_only:
return
write_only = False # Not implemented yet
if operation == "read" and write_only:
return
type_ = _TORTOISE_FIELD_TO_PYDANTIC_TYPE[field.__class__]
value = ... # NOTE: `...` (ellipsis) means "required" for Pydantic.
if field.default is not None:
assert not callable(
field.default
), "Callable default values are not supported yet."
value = field.default
# Consider `None` to be a default only if the field is nullable.
elif field.null:
value = None
return (type_, value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment