-
-
Save aj07mm/ced0e99ae590eb711bdf4a75e2ac49ef to your computer and use it in GitHub Desktop.
OrderingFilter for Django Rest Framework
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# PYTHON IMPORTS | |
from django.utils import six | |
from django.core.exceptions import ImproperlyConfigured | |
# REST IMPORTS | |
from rest_framework.settings import api_settings | |
from rest_framework.filters import BaseFilterBackend | |
class OrderingFilter(BaseFilterBackend): | |
# The URL query parameter used for the ordering. | |
ordering_param = api_settings.ORDERING_PARAM | |
ordering_fields = None | |
def get_ordering(self, request, queryset, view): | |
""" | |
Ordering is set by a comma delimited ?ordering=... query parameter. | |
The `ordering` query parameter can be overridden by setting | |
the `ordering_param` value on the OrderingFilter or by | |
specifying an `ORDERING_PARAM` value in the API settings. | |
""" | |
params = request.query_params.get(self.ordering_param) | |
if params: | |
fields = [param.strip() for param in params.split(',')] | |
ordering = self.remove_invalid_fields(queryset, fields, view) | |
if ordering: | |
return ordering | |
# No ordering was included, or all the ordering fields were invalid | |
return self.get_default_ordering(view) | |
def get_default_ordering(self, view): | |
ordering = getattr(view, 'ordering', None) | |
if isinstance(ordering, six.string_types): | |
return (ordering,) | |
return ordering | |
def remove_invalid_fields(self, queryset, fields, view): | |
ordering_fields = getattr(view, 'ordering_fields', self.ordering_fields) | |
if not ordering_fields == '__all__': | |
serializer_class = getattr(view, 'serializer_class', None) | |
if serializer_class is None: | |
serializer_class = view.get_serializer_class() | |
if serializer_class is None: | |
msg = ("Cannot use %s on a view which does not have either a " | |
"'serializer_class' or 'ordering_fields' attribute.") | |
raise ImproperlyConfigured(msg % self.__class__.__name__) | |
if ordering_fields is None: | |
# Default to allowing filtering on serializer field names (return field sources) | |
valid_fields = [ | |
(field.source, field_name) | |
for field_name, field in serializer_class().fields.items() | |
if not getattr(field, 'write_only', False) | |
] | |
return [term[0] for term in valid_fields if term[0] != "*"] | |
elif ordering_fields == '__all__': | |
# View explicitly allows filtering on any model field | |
valid_fields = [field.name for field in queryset.model._meta.fields] | |
valid_fields += queryset.query.aggregates.keys() | |
return [term for term in fields if term.lstrip('-') in valid_fields] | |
else: | |
# Allow filtering on defined field name (return field sources) | |
valid_fields = [ | |
(field.source, field_name) | |
for field_name, field in serializer_class().fields.items() | |
if not getattr(field, 'write_only', False) | |
] | |
return [term[0] for term in valid_fields if term[0] != "*" and term[1].lstrip('-') in fields] | |
def filter_queryset(self, request, queryset, view): | |
ordering = self.get_ordering(request, queryset, view) | |
if ordering: | |
return queryset.order_by(*ordering) | |
return queryset |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment