Skip to content

Instantly share code, notes, and snippets.

@mazulo
Forked from jackton1/drf_optimize.py
Created August 29, 2023 12:58
Show Gist options
  • Save mazulo/ef953616c4877e26c319f6c4f86e57b7 to your computer and use it in GitHub Desktop.
Save mazulo/ef953616c4877e26c319f6c4f86e57b7 to your computer and use it in GitHub Desktop.
Optimize Django Rest Framework model views queries.
from django.db import ProgrammingError, models
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query import normalize_prefetch_lookups
from rest_framework import serializers
from rest_framework.utils import model_meta
class OptimizeModelViewSetMetaclass(type):
"""
This metaclass optimizes the REST API view queryset using `prefetch_related` and `select_related`
if the `serializer_class` is an instance of `serializers.ModelSerializer`.
It determines the `ForeignKey`, `OneToOneField`, and `ManyToMany` fields declared on the model
Serializer class to be added to `prefetch_related` and `select_related` calls.
"""
@staticmethod
def get_many_to_many_rel(info, meta_fields):
many_to_many_fields = [field_name for field_name, relation_info in info.relations.items()
if relation_info.to_many and field_name in meta_fields]
return [lookup for lookup in meta_fields if lookup in many_to_many_fields]
@staticmethod
def get_lookups(fields, strict=False):
field_lookups = [(lookup.split(LOOKUP_SEP, 1)[0], lookup) for lookup in fields]
if strict:
field_lookups = [f for f in field_lookups if LOOKUP_SEP in f[1]]
return field_lookups
@staticmethod
def get_many_to_one_rel(info, meta_fields):
many_to_one_fields = [field_name for field_name, relation_info in info.forward_relations.items()
if isinstance(relation_info[0], models.ForeignKey) and field_name in meta_fields]
return [lookup for lookup in meta_fields if lookup in many_to_one_fields]
@staticmethod
def get_one_to_one_or_one_to_many_rel(info, meta_fields):
return [field_name for field_name, relation_info in info.forward_relations.items()
if field_name in meta_fields and not relation_info.to_many]
def __new__(cls, name, bases, attrs):
serializer_class = attrs.get('serializer_class', None)
queryset = attrs.get('queryset')
many_to_many_fields = []
many_to_one_fields = []
one_to_one_or_one_to_many_fields = []
related_fields = []
info = None
if serializer_class and not issubclass(serializer_class, serializers.ModelSerializer):
return super(OptimizeModelViewSetMetaclass, cls).__new__(cls, name, bases, attrs)
if serializer_class and issubclass(serializer_class, serializers.ModelSerializer):
model_meta_fields = serializer_class._declared_fields.keys()
many_to_many_fields.extend(
field_name for field_name in model_meta_fields
if isinstance(serializer_class._declared_fields[field_name], serializers.ManyRelatedField)
)
many_to_one_fields.extend(
field_name for field_name in model_meta_fields
if isinstance(serializer_class._declared_fields[field_name], serializers.PrimaryKeyRelatedField)
)
one_to_one_or_one_to_many_fields.extend(
field_name for field_name in model_meta_fields
if isinstance(serializer_class._declared_fields[field_name], serializers.RelatedField)
)
if hasattr(serializer_class.Meta, 'model'):
model = serializer_class.Meta.model
info = model_meta.get_field_info(model)
meta_fields = list(serializer_class.Meta.fields)
many_to_many_fields.extend(meta_fields)
many_to_one_fields.extend(meta_fields)
one_to_one_or_one_to_many_fields.extend(meta_fields)
if info is not None:
many_to_many_fields = cls.get_many_to_many_rel(info, set(many_to_many_fields))
many_to_one_fields = cls.get_many_to_one_rel(info, set(many_to_one_fields))
one_to_one_or_one_to_many_fields = cls.get_one_to_one_or_one_to_many_rel(info, set(one_to_one_or_one_to_many_fields))
try:
if queryset is not None:
if many_to_many_fields:
queryset = queryset.prefetch_related(*normalize_prefetch_lookups(many_to_many_fields))
if one_to_one_or_one_to_many_fields:
queryset = queryset.select_related(*one_to_one_or_one_to_many_fields)
attrs['queryset'] = queryset.all()
except ProgrammingError:
pass
return super(OptimizeModelViewSetMetaclass, cls).__new__(cls, name, bases, attrs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment