Last active
October 13, 2024 12:22
-
-
Save bendavis78/b7caaf55062911872613 to your computer and use it in GitHub Desktop.
Example implementation of randomized pagination in django and django-rest-framework
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Adds a `seed` paramter to DRF's `next` and `prev` pagination urls | |
""" | |
from rest_framework import serializers | |
from rest_framework import pagination | |
from rest_framework.templatetags.rest_framework import replace_query_param | |
from . import utils | |
class PageSeedFieldMixin: | |
seed_field = 'seed' | |
def to_representation(self, value): | |
url = super().to_representation(value) | |
if not url: | |
return None | |
seed = utils.encode_float(self.context.get('seed')) | |
return replace_query_param(url, self.seed_field, seed) | |
class NextPageSeedField(PageSeedFieldMixin, pagination.NextPageField): | |
pass | |
class PreviousPageSeedField(PageSeedFieldMixin, pagination.PreviousPageField): | |
pass | |
class PaginationSeedSerializer(pagination.PaginationSerializer): | |
next = NextPageSeedField(source='*') | |
previous = PreviousPageSeedField(source='*') | |
seed = serializers.SerializerMethodField() | |
def get_seed(self, object): | |
return utils.encode_float(self.context.get('seed')).decode() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
`SeededQuerySet` adds the `set_seed` function which prepends a setseed() call | |
to the SQL query (postgres only). | |
See http://www.postgresql.org/docs/8.3/static/sql-set.html for more info. | |
""" | |
from django.db import models | |
from django.db import connections | |
from django.db.models.sql import Query | |
from django.db.models.sql.compiler import SQLCompiler | |
class SeededSQLCompiler(SQLCompiler): | |
def as_sql(self, *args, **kwargs): | |
sql, params = super().as_sql() | |
if self.query.seed and '?' in self.query.order_by: | |
# seed must be a float between 0 and 1 | |
seed = float(self.query.seed) | |
if not 0 < seed < 1: | |
raise ValueError("Invalid seed value: " + seed) | |
sql = 'SELECT setseed(%s); ' + sql | |
params += ('{:0.52f}'.format(seed),) | |
return sql, params | |
class SeededQuery(Query): | |
seed = None | |
def clone(self, *args, **kwargs): | |
kwargs['seed'] = self.seed | |
return super().clone(*args, **kwargs) | |
def get_compiler(self, using=None, connection=None): | |
if using is None and connection is None: | |
raise ValueError("Need either using or connection") | |
if using: | |
connection = connections[using] | |
# Check that the compiler will be able to execute the query | |
for alias, aggregate in self.aggregate_select.items(): | |
connection.ops.check_aggregate_support(aggregate) | |
return SeededSQLCompiler(self, connection, using) | |
class SeededQuerySet(models.QuerySet): | |
def __init__(self, model=None, query=None, using=None, hints=None): | |
query = query or SeededQuery(model) | |
super().__init__(model, query, using, hints) | |
def set_seed(self, seed): | |
self.query.seed = seed | |
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Utility functions for encoding integers and floats into short ASCII strings | |
(ideal for URL parameters) | |
""" | |
import struct | |
import string | |
BASE_ALPH = tuple(string.ascii_letters + string.digits) | |
BASE_DICT = dict((c, v) for v, c in enumerate(BASE_ALPH)) | |
def decode_int(encoded): | |
num = 0 | |
for char in encoded: | |
num = num * len(BASE_ALPH) + BASE_DICT[char] | |
return num | |
def encode_int(num): | |
encoding = '' | |
while num: | |
num, rem = divmod(num, len(BASE_ALPH)) | |
encoding = BASE_ALPH[rem] + encoding | |
return encoding.encode('ascii') | |
def encode_float(num): | |
b = struct.pack('>d', num) | |
return encode_int(int.from_bytes(b, 'big')) | |
def decode_float(encoded): | |
b = decode_int(encoded).to_bytes(8, 'big') | |
return struct.unpack('>d', b)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from rest_framework.response import Response | |
from . import models | |
from . import seed_pagination | |
class EntryViewSet(BaseViewSet): | |
queryset = models.Entry.objects.all() | |
serializer_class = serializers.EntrySerializer | |
paginate_by = 10 | |
paginate_by_param = 'page_size' | |
max_paginate_by = 100 | |
pagination_serializer_class = seed_pagination.PaginationSeedSerializer | |
def list(self, request, *args, **kwargs): | |
qs = self.filter_queryset(self.get_queryset()) | |
# randomize queryset based on given seed | |
seed = self.request.query_params.get('seed', random.random()) | |
qs = qs.set_seed(self.seed).order_by('?') | |
page = self.paginate_queryset(qs) | |
if page is not None: | |
serializer = self.get_pagination_serializer(page) | |
else: | |
serializer = self.get_serializer(qs, many=True) | |
return Response(serializer.data) | |
def get_serializer_context(self): | |
context = super().get_serializer_context() | |
context.update({ | |
'seed': self.seed | |
}) | |
return context | |
@property | |
def seed(self): | |
field = self.pagination_serializer_class._declared_fields['next'] | |
page = self.request.query_params.get(field.page_field) | |
seed = self.request.query_params.get(field.seed_field) | |
if not seed and (not page or page == '1'): | |
return random.random() | |
return seed and utils.decode_float(seed) or None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment