Skip to content

Instantly share code, notes, and snippets.

@kalda341
Created December 1, 2018 23:07
Show Gist options
  • Save kalda341/8a894e057dd7b6c86d5558ac567e9539 to your computer and use it in GitHub Desktop.
Save kalda341/8a894e057dd7b6c86d5558ac567e9539 to your computer and use it in GitHub Desktop.
import copy
from rest_framework_json_api import serializers
from rest_framework_json_api.utils import get_included_serializers
from rest_framework.relations import ManyRelatedField
from django.core.exceptions import FieldDoesNotExist
class ModelSerializer(serializers.ModelSerializer):
extra_prefetches = {}
def get_extra_prefetches(self):
return self.extra_prefetches
@classmethod
def get_prefetches(cls, max_depth=2):
# Unfortunately there's no way of accessing fields without creating an instance
instance = cls()
# Will be used for figuring out prefetches for includes
included_serializers = get_included_serializers(instance)
# Include manually defined prefetches
prefetches = copy.deepcopy(instance.get_extra_prefetches())
# All is a special case - things in all will be prefetched regarless of includes
prefetches['__all__'] = all = prefetches.get('__all__', [])
model = cls.Meta.model
for field_name, field in instance.fields.items():
# The field name may not have the same name as the equivalent field on the model
model_field_name = getattr(field, 'source', field_name)
# Ensure the field exists on the model. If it doesn't, then there's nothing to prefetch!
try:
model._meta.get_field(model_field_name)
except FieldDoesNotExist:
continue
# Many related fields always need a prefetch so that we can show the ids
if isinstance(field, ManyRelatedField):
if model_field_name not in all:
all.append(field_name)
# Include the prefetches from included serializers when they are requested
if field_name in included_serializers:
serializer = included_serializers[field_name]
related_prefetches = cls.get_prefetches_for_include(
field_name,
model_field_name,
serializer,
max_depth,
)
# Don't overwrite manually specified prefetches
for key, value in related_prefetches.items():
if key not in prefetches:
prefetches[key] = value
return prefetches
@classmethod
def get_prefetches_for_include(cls, field_name, model_field_name, serializer, max_depth):
prefetches = {}
if not issubclass(serializer, ModelSerializer) or max_depth == 0:
return prefetches
# Include the field itself in the prefetches
prefetches[field_name] = [model_field_name]
for prefetch_name, values in serializer.get_prefetches(max_depth=max_depth - 1).items():
if prefetch_name == '__all__':
prefetch_name = field_name
else:
prefetch_name = '.'.join([field_name, prefetch_name])
prefetches[prefetch_name] = (prefetches.get(prefetch_name, []) +
['__'.join([model_field_name, x]) for x in values])
return prefetches
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment