8000 Demo broken pagination · encode/django-rest-framework@2375f6c · GitHub
[go: up one dir, main page]

Skip to content

Commit 2375f6c

Browse files
committed
Demo broken pagination
1 parent 0323d6f commit 2375f6c

File tree

2 files changed

+206
-1
lines changed

2 files changed

+206
-1
lines changed

tests/models.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class RESTFrameworkModel(models.Model):
1010
"""
1111

1212
class Meta:
13-
app_label = 'tests'
13+
app_label = "tests"
1414
abstract = True
1515

1616

@@ -119,3 +119,11 @@ class OneToOnePKSource(RESTFrameworkModel):
119119
target = models.OneToOneField(
120120
OneToOneTarget, primary_key=True,
121121
related_name='required_source', on_delete=models.CASCADE)
122+
123+
124+
class ExamplePaginationModel(models.Model):
125+
# Don't use an auto field because we can't reset
126+
# sequences and that's needed for this test
127+
id = models.IntegerField(primary_key=True)
128+
field = models.IntegerField()
129+
timestamp = models.IntegerField()

tests/test_cursor_pagination.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import base64
2+
import itertools
3+
import re
4+
from base64 import b64encode
5+
from urllib import parse
6+
7+
import pytest
8+
from django.db import models
9+
from rest_framework import generics
10+
from rest_framework.pagination import Cursor, CursorPagination
11+
from rest_framework.filters import OrderingFilter
12+
from rest_framework.permissions import AllowAny
13+
from rest_framework.serializers import ModelSerializer
14+
from rest_framework.test import APIRequestFactory
15+
from .models import ExamplePaginationModel
16+
17+
18+
factory = APIRequestFactory()
19+
20+
21+
class SerializerCls(ModelSerializer):
22+
class Meta:
23+
model = ExamplePaginationModel
24+
fields = "__all__"
25+
26+
27+
def create_cursor(offset, reverse, position):
28+
# Taken from rest_framework.pagination
29+
cursor = Cursor(offset=offset, reverse=reverse, position=position)
30+
31+
tokens = {}
32+
if cursor.offset != 0:
33+
tokens["o"] = str(cursor.offset)
34+
if cursor.reverse:
35+
tokens["r"] = "1"
36+
if cursor.position is not None:
37+
tokens["p"] = cursor.position
38+
39+
querystring = parse.urlencode(tokens, doseq=True)
40+
return b64encode(querystring.encode("ascii")).decode("ascii")
41+
42+
43+
def decode_cursor(response):
44+
45+
links = {
46+
'next': response.data.get('next'),
47+
'prev': response.data.get('prev'),
48+
}
49+
50+
cursors = {}
51+
52+
for rel, link in links.items():
53+
if link:
54+
# Don't hate my laziness - copied from an IPDB prompt
55+
cursor_dict = dict(
56+
parse.parse_qsl(
57+
base64.decodebytes(
58+
(parse.parse_qs(parse.urlparse(link).query)["cursor"][0]).encode()
59+
)
60+
)
61+
)
62+
63+
offset = cursor_dict.get(b"o", 0)
64+
if offset:
65+
offset = int(offset)
66+
67+
reverse = cursor_dict.get(b"r", False)
68+
if reverse:
69+
reverse = int(reverse)
70+
71+
position = cursor_dict.get(b"p", None)
72+
73+
cursors[rel] = Cursor(
74+
offset=offset,
75+
reverse=reverse,
76+
position=position,
77+
)
78+
79+
return type(
80+
"prev_next_stuct",
81+
(object,),
82+
{"next": cursors.get("next"), "prev": cursors.get("previous")},
83+
)
84+
85+
86+
@pytest.mark.django_db
87+
@pytest.mark.parametrize("page_size,offset", [
88+
(6, 2), (2, 6), (5, 3), (3, 5), (5, 5)
89+
],
90+
ids=[
91+
'page_size_divisor_of_offset',
92+
'page_size_multiple_of_offset',
93+
'page_size_uneven_divisor_of_offset',
94+
'page_size_uneven_multiple_of_offset',
95+
'page_size_same_as_offset',
96+
]
97+
)
98+
def test_filtered_items_are_paginated(page_size, offset):
99+
100+
PaginationCls = type('PaginationCls', (CursorPagination,), dict(
101+
page_size=page_size,
102+
offset_cutoff=offset,
103+
max_page_size=20,
104+
))
105+
106+
example_models = []
107+
108+
for id_, (field_1, field_2) in enumerate(
109+
itertools.product(range(1, 11), range(1, 3))
110+
):
111+
# field_1 is a unique range from 1-10 inclusive
112+
# field_2 is the 'timestamp' field. 1 or 2
113+
example_models.append(
114+
ExamplePaginationModel(
115+
# manual primary key
116+
id=id_ + 1,
117+
field=field_1,
118+
timestamp=field_2,
119+
)
120+
)
121+
122+
ExamplePaginationModel.objects.bulk_create(example_models)
123+
124+
view = generics.ListAPIView.as_view(
125+
serializer_class=SerializerCls,
126+
queryset=ExamplePaginationModel.objects.all(),
127+
pagination_class=PaginationCls,
128+
permission_classes=(AllowAny,),
129+
filter_backends=[OrderingFilter],
130+
)
131+
132+
def _request(offset, reverse, position):
133+
return view(
134+
factory.get(
135+
"/",
136+
{
137+
PaginationCls.cursor_query_param: create_cursor(
138+
offset, reverse, position
139+
),
140+
"ordering": "timestamp,id",
141+
},
142+
)
143+
)
144+
145+
# This is the result we would expect
146+
expected_result = list(
147+
ExamplePaginationModel.objects.order_by("timestamp", "id").values(
148+
"timestamp",
149+
"id",
150+
"field",
151+
)
152+
)
153+
assert expected_result == [
154+
{"field": 1, "id": 1, "timestamp": 1},
155+
{"field": 2, "id": 3, "timestamp": 1},
156+
{"field": 3, "id": 5, "timestamp": 1},
157+
{"field": 4, "id": 7, "timestamp": 1},
158+
{"field": 5, "id": 9, "timestamp": 1},
159+
{"field": 6, "id": 11, "timestamp": 1},
160+
{"field": 7, "id": 13, "timestamp": 1},
161+
{"field": 8, "id": 15, "timestamp": 1},
162+
{"field": 9, "id": 17, "timestamp": 1},
163+
{"field": 10, "id": 19, "timestamp": 1},
164+
{"field": 1, "id": 2, "timestamp": 2},
165+
{"field": 2, "id": 4, "timestamp": 2},
166+
{"field": 3, "id": 6, "timestamp": 2},
167+
{"field": 4, "id": 8, "timestamp": 2},
168+
{"field": 5, "id": 10, "timestamp": 2},
169+
{"field": 6, "id": 12, "timestamp": 2},
170+
{"field": 7, "id": 14, "timestamp": 2},
171+
{"field": 8, "id": 16, "timestamp": 2},
172+
{"field": 9, "id": 18, "timestamp": 2},
173+
{"field": 10, "id": 20, "timestamp": 2},
174+
]
175+
176+
response = _request(0, False, None)
177+
next_cursor = decode_cursor(response).next
178+
position = 0
179+
180+
while next_cursor:
181+
assert (
182+
expected_result[position : position + len(response.data['results'])] == response.data['results']
183+
)
184+
position += len(response.data['results'])
185+
response = _request(*next_cursor)
186+
next_cursor = decode_cursor(response).next
187+
188+
prev_cursor = decode_cursor(response).prev
189+
position = 20
190+
191+
while prev_cursor:
192+
assert (
193+
expected_result[position - len(response.data['results']) : position] == response.data['results']
194+
)
195+
position -= len(response.data['results'])
196+
response = _request(*prev_cursor)
197+
prev_cursor = decode_cursor(response).prev

0 commit comments

Comments
 (0)
0