-
-
Save mazulo/ef953616c4877e26c319f6c4f86e57b7 to your computer and use it in GitHub Desktop.
Optimize Django Rest Framework model views queries.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.db import ProgrammingError, models | |
from django.db.models.constants import LOOKUP_SEP | |
from django.db.models.query import normalize_prefetch_lookups | |
from rest_framework import serializers | |
from rest_framework.utils import model_meta | |
class OptimizeModelViewSetMetaclass(type): | |
""" | |
This metaclass optimizes the REST API view queryset using `prefetch_related` and `select_related` | |
if the `serializer_class` is an instance of `serializers.ModelSerializer`. | |
It determines the `ForeignKey`, `OneToOneField`, and `ManyToMany` fields declared on the model | |
Serializer class to be added to `prefetch_related` and `select_related` calls. | |
""" | |
@staticmethod | |
def get_many_to_many_rel(info, meta_fields): | |
many_to_many_fields = [field_name for field_name, relation_info in info.relations.items() | |
if relation_info.to_many and field_name in meta_fields] | |
return [lookup for lookup in meta_fields if lookup in many_to_many_fields] | |
@staticmethod | |
def get_lookups(fields, strict=False): | |
field_lookups = [(lookup.split(LOOKUP_SEP, 1)[0], lookup) for lookup in fields] | |
if strict: | |
field_lookups = [f for f in field_lookups if LOOKUP_SEP in f[1]] | |
return field_lookups | |
@staticmethod | |
def get_many_to_one_rel(info, meta_fields): | |
many_to_one_fields = [field_name for field_name, relation_info in info.forward_relations.items() | |
if isinstance(relation_info[0], models.ForeignKey) and field_name in meta_fields] | |
return [lookup for lookup in meta_fields if lookup in many_to_one_fields] | |
@staticmethod | |
def get_one_to_one_or_one_to_many_rel(info, meta_fields): | |
return [field_name for field_name, relation_info in info.forward_relations.items() | |
if field_name in meta_fields and not relation_info.to_many] | |
def __new__(cls, name, bases, attrs): | |
serializer_class = attrs.get('serializer_class', None) | |
queryset = attrs.get('queryset') | |
many_to_many_fields = [] | |
many_to_one_fields = [] | |
one_to_one_or_one_to_many_fields = [] | |
related_fields = [] | |
info = None | |
if serializer_class and not issubclass(serializer_class, serializers.ModelSerializer): | |
return super(OptimizeModelViewSetMetaclass, cls).__new__(cls, name, bases, attrs) | |
if serializer_class and issubclass(serializer_class, serializers.ModelSerializer): | |
model_meta_fields = serializer_class._declared_fields.keys() | |
many_to_many_fields.extend( | |
field_name for field_name in model_meta_fields | |
if isinstance(serializer_class._declared_fields[field_name], serializers.ManyRelatedField) | |
) | |
many_to_one_fields.extend( | |
field_name for field_name in model_meta_fields | |
if isinstance(serializer_class._declared_fields[field_name], serializers.PrimaryKeyRelatedField) | |
) | |
one_to_one_or_one_to_many_fields.extend( | |
field_name for field_name in model_meta_fields | |
if isinstance(serializer_class._declared_fields[field_name], serializers.RelatedField) | |
) | |
if hasattr(serializer_class.Meta, 'model'): | |
model = serializer_class.Meta.model | |
info = model_meta.get_field_info(model) | |
meta_fields = list(serializer_class.Meta.fields) | |
many_to_many_fields.extend(meta_fields) | |
many_to_one_fields.extend(meta_fields) | |
one_to_one_or_one_to_many_fields.extend(meta_fields) | |
if info is not None: | |
many_to_many_fields = cls.get_many_to_many_rel(info, set(many_to_many_fields)) | |
many_to_one_fields = cls.get_many_to_one_rel(info, set(many_to_one_fields)) | |
one_to_one_or_one_to_many_fields = cls.get_one_to_one_or_one_to_many_rel(info, set(one_to_one_or_one_to_many_fields)) | |
try: | |
if queryset is not None: | |
if many_to_many_fields: | |
queryset = queryset.prefetch_related(*normalize_prefetch_lookups(many_to_many_fields)) | |
if one_to_one_or_one_to_many_fields: | |
queryset = queryset.select_related(*one_to_one_or_one_to_many_fields) | |
attrs['queryset'] = queryset.all() | |
except ProgrammingError: | |
pass | |
return super(OptimizeModelViewSetMetaclass, cls).__new__(cls, name, bases, attrs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment