Last active
September 21, 2024 10:34
-
-
Save florimondmanca/36d304e50f9cda79fe9b12f0b3349cbb to your computer and use it in GitHub Desktop.
DRF-inspired ModelSerializer implementation backed by Pydantic and Tortoise ORM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import decimal | |
import json | |
from typing import Any, Optional, Type, Union, Dict, Tuple | |
from pydantic import BaseModel, ValidationError, create_model | |
from tortoise import fields as tortoise_fields | |
from tortoise.fields import Field | |
from tortoise.models import Model as TortoiseModel | |
Data = Union[list, dict] | |
PydanticDefinition = Union[Any, Tuple[Type, Any]] | |
Model = Any | |
_TORTOISE_FIELD_TO_PYDANTIC_TYPE = { | |
tortoise_fields.IntField: int, | |
tortoise_fields.BigIntField: int, | |
tortoise_fields.SmallIntField: int, | |
tortoise_fields.CharField: str, | |
tortoise_fields.TextField: str, | |
tortoise_fields.BooleanField: bool, | |
tortoise_fields.DecimalField: decimal.Decimal, | |
tortoise_fields.DatetimeField: datetime.datetime, | |
tortoise_fields.DateField: datetime.date, | |
tortoise_fields.TimeDeltaField: datetime.timedelta, | |
tortoise_fields.FloatField: float, | |
tortoise_fields.JSONField: Union[dict, list], | |
} | |
class BaseModelSerializer: | |
def __init__(self, instance: Model = None, data: Data = None): | |
self.instance = instance | |
self.initial_data = data | |
self._validated_data: Optional[dict] = None | |
self._data: Optional[dict] = None | |
self._errors: Optional[list] = None | |
def to_internal_value(self, data: Data) -> Data: | |
raise NotImplementedError | |
def to_representation(self, instance: Any) -> Data: | |
raise NotImplementedError | |
def run_validation(self, data: Data) -> Data: | |
return self.to_internal_value(data) | |
def is_valid(self, raise_exceptions: bool = False) -> bool: | |
assert self.initial_data is not None, ( | |
"Cannot call `.is_valid()` as no `data=` keyword argument " | |
"was passed when instanciating the serializer." | |
) | |
if self._validated_data is None: | |
try: | |
self._validated_data = self.run_validation(self.initial_data) | |
except ValidationError as exc: | |
self._validated_data = {} | |
self._errors = exc.raw_errors | |
else: | |
self._errors = [] | |
if self._errors and raise_exceptions: | |
raise ValidationError(self._errors) | |
return not bool(self._errors) | |
@property | |
def validated_data(self) -> dict: | |
assert ( | |
self._validated_data is not None | |
), "You must call `.is_valid()` before accessing `.validated_data`." | |
return self._validated_data | |
@property | |
def data(self): | |
assert ( | |
self.initial_data is not None and self._validated_data is not None | |
), ( | |
"When a serializer is passed a `data` keyword argument you " | |
"must call `.is_valid()` before attempting to access the " | |
"serialized `.data` representation.\n" | |
"You should either call `.is_valid()` first, " | |
"or access `.initial_data` instead." | |
) | |
if self._data is None: | |
if self.instance is not None and self._errors is None: | |
self._data = self.to_representation(self.instance) | |
elif self._validated_data is not None and self._errors is not None: | |
self._data = self.to_representation(self.instance) | |
else: | |
self._data = self.initial_data | |
return self._data | |
async def create(self, validated_data: dict) -> Model: | |
raise NotImplementedError | |
async def update(self, instance: Model, validated_data: dict) -> Model: | |
raise NotImplementedError | |
async def save(self, **kwargs: Any) -> Model: | |
validated_data = {**self.validated_data, **kwargs} | |
if self.instance is None: | |
self.instance = await self.create(validated_data) | |
else: | |
self.instance = await self.update(self.instance, validated_data) | |
return self.instance | |
class PydanticMixin: | |
"""Serialize/deserialize using Pydantic.""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._read_schema: Optional[Type[BaseModel]] = None | |
self._write_schema: Optional[Type[BaseModel]] = None | |
def get_declared_fields(self) -> set: | |
return set(getattr(self, "__annotations__", [])) | |
def get_field_names(self) -> set: | |
try: | |
excluded_fields = set(self.Meta.exclude) | |
except AttributeError: | |
excluded_fields = set() | |
return self.get_declared_fields().difference(excluded_fields) | |
@property | |
def read_schema(self) -> Type[BaseModel]: | |
if self._read_schema is None: | |
self._read_schema = self._get_schema("read") | |
return self._read_schema | |
@property | |
def write_schema(self) -> Type[BaseModel]: | |
if self._write_schema is None: | |
self._write_schema = self._get_schema("write") | |
return self._write_schema | |
def build_field( | |
self, field_name: str, operation: str | |
) -> Optional[PydanticDefinition]: | |
raise NotImplementedError | |
def _get_schema(self, operation: str) -> BaseModel: | |
definitions: Dict[str, PydanticDefinition] = {} | |
for field_name in self.get_field_names(): | |
definition = self.build_field(field_name, operation) | |
if definition is None: | |
continue | |
definitions[field_name] = definition | |
return create_model(self._model_class.__name__, **definitions) | |
def to_internal_value(self, data: dict) -> dict: | |
"""Convert a native dict of values to a Python dict of values.""" | |
schema = self.write_schema | |
try: | |
obj = schema(**data) | |
except ValidationError as exc: | |
raise exc from None | |
else: | |
return obj.dict() | |
def to_representation(self, instance: Union[Model, dict]) -> dict: | |
"""Convert a Python object to a native dict of values.""" | |
schema = self.read_schema | |
data = { | |
attr: ( | |
instance[attr] | |
if isinstance(instance, dict) | |
else getattr(instance, attr) | |
) | |
for attr in schema.__fields__.keys() | |
} | |
return json.loads(schema(**data).json()) | |
class TortoiseMixin: | |
"""Create, update and save instances using Tortoise ORM.""" | |
validated_data: dict | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
try: | |
self._model_class: Type[TortoiseModel] = self.Meta.model | |
except AttributeError as exc: | |
raise AttributeError( | |
"Model class could not be retrieved from `.Meta.model`." | |
) from exc | |
async def create(self, validated_data: dict) -> TortoiseModel: | |
return await self._model_class.create(**validated_data) | |
async def update( | |
self, instance: TortoiseModel, validated_data: dict | |
) -> TortoiseModel: | |
for attr, value in validated_data.items(): | |
setattr(instance, attr, value) | |
await instance.save() | |
return instance | |
class ModelSerializer(PydanticMixin, TortoiseMixin, BaseModelSerializer): | |
"""A concrete model serializer class backed by Pydantic and Tortoise.""" | |
def get_declared_fields(self) -> set: | |
return ( | |
super().get_declared_fields().union(self._model_class._meta.fields) | |
) | |
def build_field( | |
self, field_name: str, operation: str | |
) -> Optional[PydanticDefinition]: | |
try: | |
# Tortoise field | |
field: Field = self._model_class._meta.fields_map[field_name] | |
except KeyError: | |
# Property. Ignore during write. | |
if operation == "write": | |
return None | |
assert self.instance is not None | |
return getattr(self.instance, field_name) | |
else: | |
read_only = any( | |
( | |
field.pk, | |
getattr(field, "auto_now", False), | |
getattr(field, "auto_now_add", False), | |
) | |
) | |
if operation == "write" and read_only: | |
return | |
write_only = False # Not implemented yet | |
if operation == "read" and write_only: | |
return | |
type_ = _TORTOISE_FIELD_TO_PYDANTIC_TYPE[field.__class__] | |
value = ... # NOTE: `...` (ellipsis) means "required" for Pydantic. | |
if field.default is not None: | |
assert not callable( | |
field.default | |
), "Callable default values are not supported yet." | |
value = field.default | |
# Consider `None` to be a default only if the field is nullable. | |
elif field.null: | |
value = None | |
return (type_, value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment