Created
December 1, 2018 23:07
-
-
Save kalda341/8a894e057dd7b6c86d5558ac567e9539 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from rest_framework_json_api import serializers | |
from rest_framework_json_api.utils import get_included_serializers | |
from rest_framework.relations import ManyRelatedField | |
from django.core.exceptions import FieldDoesNotExist | |
class ModelSerializer(serializers.ModelSerializer): | |
extra_prefetches = {} | |
def get_extra_prefetches(self): | |
return self.extra_prefetches | |
@classmethod | |
def get_prefetches(cls, max_depth=2): | |
# Unfortunately there's no way of accessing fields without creating an instance | |
instance = cls() | |
# Will be used for figuring out prefetches for includes | |
included_serializers = get_included_serializers(instance) | |
# Include manually defined prefetches | |
prefetches = copy.deepcopy(instance.get_extra_prefetches()) | |
# All is a special case - things in all will be prefetched regarless of includes | |
prefetches['__all__'] = all = prefetches.get('__all__', []) | |
model = cls.Meta.model | |
for field_name, field in instance.fields.items(): | |
# The field name may not have the same name as the equivalent field on the model | |
model_field_name = getattr(field, 'source', field_name) | |
# Ensure the field exists on the model. If it doesn't, then there's nothing to prefetch! | |
try: | |
model._meta.get_field(model_field_name) | |
except FieldDoesNotExist: | |
continue | |
# Many related fields always need a prefetch so that we can show the ids | |
if isinstance(field, ManyRelatedField): | |
if model_field_name not in all: | |
all.append(field_name) | |
# Include the prefetches from included serializers when they are requested | |
if field_name in included_serializers: | |
serializer = included_serializers[field_name] | |
related_prefetches = cls.get_prefetches_for_include( | |
field_name, | |
model_field_name, | |
serializer, | |
max_depth, | |
) | |
# Don't overwrite manually specified prefetches | |
for key, value in related_prefetches.items(): | |
if key not in prefetches: | |
prefetches[key] = value | |
return prefetches | |
@classmethod | |
def get_prefetches_for_include(cls, field_name, model_field_name, serializer, max_depth): | |
prefetches = {} | |
if not issubclass(serializer, ModelSerializer) or max_depth == 0: | |
return prefetches | |
# Include the field itself in the prefetches | |
prefetches[field_name] = [model_field_name] | |
for prefetch_name, values in serializer.get_prefetches(max_depth=max_depth - 1).items(): | |
if prefetch_name == '__all__': | |
prefetch_name = field_name | |
else: | |
prefetch_name = '.'.join([field_name, prefetch_name]) | |
prefetches[prefetch_name] = (prefetches.get(prefetch_name, []) + | |
['__'.join([model_field_name, x]) for x in values]) | |
return prefetches |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment