Created
March 31, 2020 08:36
-
-
Save c1ay/86940b5effb446d4bf1ff39ee7c6b1fa to your computer and use it in GitHub Desktop.
rest frame work 动态返回字段
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
from collections import OrderedDict | |
from django.db import models | |
from django.db.models.fields import related_descriptors | |
from rest_framework import serializers | |
from rest_framework.fields import SkipField | |
from rest_framework.relations import PKOnlyObject | |
class DynamicFieldModelSerializer(serializers.ModelSerializer): | |
include_parameters = 'include' | |
def get_fields_for_serializer(self, parameter_name: typing.List[str] = None) -> typing.Union[typing.Set, None]: | |
include_name = parameter_name or self.include_parameters | |
if include_name in self.context: | |
fields = set(self.context[include_name]) | |
elif "request" in self.context and include_name in self.context["request"].query_params: | |
fields = set(self.context['request'].query_params[include_name].split(',')) | |
else: | |
fields = None | |
return fields | |
def to_representation(self, instance): | |
ret = OrderedDict() | |
fields = self._readable_fields | |
include_fields = self.get_fields_for_serializer() | |
for field in fields: | |
if include_fields and field.field_name not in include_fields: | |
continue | |
try: | |
attribute = field.get_attribute(instance) | |
except SkipField: | |
continue | |
check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute | |
if check_for_none is None: | |
ret[field.field_name] = None | |
else: | |
# pass child serializer fields | |
ret[field.field_name] = field.to_representation(attribute) | |
return ret | |
class AutoOptimizationMixin: | |
def get_queryset(self): | |
serializer = self.get_serializer_class()(context=self.get_serializer_context()) | |
include_fields = set() | |
if isinstance(serializer, DynamicFieldModelSerializer): | |
include_fields = serializer.get_fields_for_serializer() | |
selected_related, prefetch_related = auto_optimization('', serializer, include_fields) | |
queryset = super().get_queryset() | |
if selected_related: | |
queryset = queryset.select_related(*list(selected_related)) | |
if prefetch_related: | |
queryset = queryset.prefetch_related(*list(prefetch_related)) | |
return queryset | |
def auto_optimization(prefix: str, serializer: 'DynamicFieldModelSerializer', include_fields: typing.Iterable[str]) -> ( | |
typing.Set, typing.Set): | |
if not hasattr(serializer, 'Meta') or not hasattr(serializer.Meta, 'model'): | |
return | |
model_class = serializer.Meta.model | |
select_related = set() | |
prefetch_related = set() | |
for field_name, field in serializer.fields.items(): | |
if field_name not in include_fields: | |
continue | |
if isinstance(field, serializers.Serializer): | |
if '.' not in field.source and hasattr(model_class, field.source): | |
if isinstance(getattr(model_class, field.source), related_descriptors.ManyToManyDescriptor): | |
prefetch_related.add(prefix + field.source) | |
else: | |
select_related.add(prefix + field.source) | |
elif '.' in field.source: | |
field_name = field.source.split('.', 1)[0] | |
# 检查是否可以外键查询 | |
if isinstance(getattr(model_class, field_name, None), models.ForeignKey): | |
select_related.add(prefix + field_name) | |
return select_related, prefetch_related |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment