pagination.py 4.06 KB
Newer Older
1 2 3
import sys

from django.core.paginator import Paginator
4
from django.db.models.query import RawQuerySet
5
from django.db.models import sql
6
from django.utils.functional import cached_property
7
from rest_framework.pagination import PageNumberPagination
8 9 10 11 12 13 14 15 16 17 18 19 20
from rest_framework.response import Response


class CustomPaginatorClass(Paginator):
    """Use a large number to make sure that all results can be shown"""
    @cached_property
    def count(self):
        return sys.maxsize


class LargeTablePagination(PageNumberPagination):
    """Use this paginator class to avoid large table count query"""
    django_paginator_class = CustomPaginatorClass
carlosribas's avatar
carlosribas committed
21
    page_size_query_param = 'page_size'
22 23 24 25 26 27 28

    def get_paginated_response(self, data):
        return Response({
            'next': self.get_next_link(),
            'previous': self.get_previous_link(),
            'results': data
        })
29 30 31


class Pagination(PageNumberPagination):
32
    """
33
    DRF pagination_class, you use it by saying:
34

35
    class MyView(GenericAPIView):
36 37 38 39 40
        pagination_class = Pagination
    """
    page_size_query_param = 'page_size'


41 42 43 44 45 46
class PaginatedRawQuerySet(RawQuerySet):
    """
    Replacement for a RawQuerySet that handles pagination, stolen from:

    https://stackoverflow.com/questions/32191853/best-way-to-paginate-a-raw-sql-query-in-a-django-rest-listapi-view/43921793#43921793
    https://gist.github.com/eltongo/d3e6bdef17b0b14384ba38edc76f25f6
carlosribas's avatar
carlosribas committed
47 48

    Stopped working after Django migration, but will keep it just in case.
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    """
    def __init__(self, raw_query, **kwargs):
        super(PaginatedRawQuerySet, self).__init__(raw_query, **kwargs)
        self.original_raw_query = raw_query
        self._result_cache = None

    def __getitem__(self, k):
        """
        Retrieves an item or slice from the set of results.
        """
        if not isinstance(k, (slice, int,)):
            raise TypeError
        assert ((not isinstance(k, slice) and (k >= 0)) or
                (isinstance(k, slice) and (k.start is None or k.start >= 0) and
                 (k.stop is None or k.stop >= 0))), \
            "Negative indexing is not supported."

        if self._result_cache is not None:
            return self._result_cache[k]

        if isinstance(k, slice):
            qs = self._clone()
            if k.start is not None:
                start = int(k.start)
            else:
                start = None
            if k.stop is not None:
                stop = int(k.stop)
            else:
                stop = None
            qs.set_limits(start, stop)
            return qs

        qs = self._clone()
        qs.set_limits(k, k + 1)
        return list(qs)[0]

    def __iter__(self):
        self._fetch_all()
        return iter(self._result_cache)

    def count(self):
        if self._result_cache is not None:
            return len(self._result_cache)

94
        return self.query.get_count(using=self.db)  # Originally was: return self.model.objects.count()
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

    def set_limits(self, start, stop):
        limit_offset = ''

        new_params = tuple()
        if start is None:
            start = 0
        elif start > 0:
            new_params += (start,)
            limit_offset = ' OFFSET %s'
        if stop is not None:
            new_params = (stop - start,) + new_params
            limit_offset = 'LIMIT %s' + limit_offset

        self.params = self.params + new_params
        self.raw_query = self.original_raw_query + limit_offset
        self.query = sql.RawQuery(sql=self.raw_query, using=self.db, params=self.params)

    def _fetch_all(self):
        if self._result_cache is None:
            self._result_cache = list(super(PaginatedRawQuerySet, self).__iter__())

    def __repr__(self):
        return '<%s: %s>' % (self.__class__.__name__, self.model.__name__)

    def __len__(self):
        self._fetch_all()
        return len(self._result_cache)

    def _clone(self):
        clone = self.__class__(raw_query=self.raw_query, model=self.model, using=self._db, hints=self._hints,
                               query=self.query, params=self.params, translations=self.translations)
        return clone