forked from encode/django-rest-framework
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_throttling.py
451 lines (349 loc) · 14.1 KB
/
test_throttling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
"""
Tests for the throttling implementations in the permissions module.
"""
from __future__ import unicode_literals
import pytest
from django.contrib.auth.models import User
from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured
from django.http import HttpRequest
from django.test import TestCase
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.throttling import (
AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle,
UserRateThrottle
)
from rest_framework.views import APIView
class User3SecRateThrottle(UserRateThrottle):
rate = '3/sec'
scope = 'seconds'
class User3MinRateThrottle(UserRateThrottle):
rate = '3/min'
scope = 'minutes'
class NonTimeThrottle(BaseThrottle):
def allow_request(self, request, view):
if not hasattr(self.__class__, 'called'):
self.__class__.called = True
return True
return False
class MockView(APIView):
throttle_classes = (User3SecRateThrottle,)
def get(self, request):
return Response('foo')
class MockView_MinuteThrottling(APIView):
throttle_classes = (User3MinRateThrottle,)
def get(self, request):
return Response('foo')
class MockView_NonTimeThrottling(APIView):
throttle_classes = (NonTimeThrottle,)
def get(self, request):
return Response('foo')
class ThrottlingTests(TestCase):
def setUp(self):
"""
Reset the cache so that no throttles will be active
"""
cache.clear()
self.factory = APIRequestFactory()
def test_requests_are_throttled(self):
"""
Ensure request rate is limited
"""
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
assert response.status_code == 429
def set_throttle_timer(self, view, value):
"""
Explicitly set the timer, overriding time.time()
"""
view.throttle_classes[0].timer = lambda self: value
def test_request_throttling_expires(self):
"""
Ensure request rate is limited for a limited duration only
"""
self.set_throttle_timer(MockView, 0)
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
assert response.status_code == 429
# Advance the timer by one second
self.set_throttle_timer(MockView, 1)
response = MockView.as_view()(request)
assert response.status_code == 200
def ensure_is_throttled(self, view, expect):
request = self.factory.get('/')
request.user = User.objects.create(username='a')
for dummy in range(3):
view.as_view()(request)
request.user = User.objects.create(username='b')
response = view.as_view()(request)
assert response.status_code == expect
def test_request_throttling_is_per_user(self):
"""
Ensure request rate is only limited per user, not globally for
PerUserThrottles
"""
self.ensure_is_throttled(MockView, 200)
def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an Retry-After field with status and next attributes
set properly.
"""
request = self.factory.get('/')
for timer, expect in expected_headers:
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
if expect is not None:
assert response['Retry-After'] == expect
else:
assert not'Retry-After' in response
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(
MockView, (
(0, None),
(0, None),
(0, None),
(0, '1')
)
)
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(
MockView_MinuteThrottling, (
(0, None),
(0, None),
(0, None),
(0, '60')
)
)
def test_next_rate_remains_constant_if_followed(self):
"""
If a client follows the recommended next request rate,
the throttling rate should stay constant.
"""
self.ensure_response_header_contains_proper_throttle_field(
MockView_MinuteThrottling, (
(0, None),
(20, None),
(40, None),
(60, None),
(80, None)
)
)
def test_non_time_throttle(self):
"""
Ensure for second based throttles.
"""
request = self.factory.get('/')
self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
response = MockView_NonTimeThrottling.as_view()(request)
self.assertFalse('Retry-After' in response)
self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
response = MockView_NonTimeThrottling.as_view()(request)
self.assertFalse('Retry-After' in response)
class ScopedRateThrottleTests(TestCase):
"""
Tests for ScopedRateThrottle.
"""
def setUp(self):
self.throttle = ScopedRateThrottle()
class XYScopedRateThrottle(ScopedRateThrottle):
TIMER_SECONDS = 0
THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
def timer(self):
return self.TIMER_SECONDS
class XView(APIView):
throttle_classes = (XYScopedRateThrottle,)
throttle_scope = 'x'
def get(self, request):
return Response('x')
class YView(APIView):
throttle_classes = (XYScopedRateThrottle,)
throttle_scope = 'y'
def get(self, request):
return Response('y')
class UnscopedView(APIView):
throttle_classes = (XYScopedRateThrottle,)
def get(self, request):
return Response('y')
self.throttle_class = XYScopedRateThrottle
self.factory = APIRequestFactory()
self.x_view = XView.as_view()
self.y_view = YView.as_view()
self.unscoped_view = UnscopedView.as_view()
def increment_timer(self, seconds=1):
self.throttle_class.TIMER_SECONDS += seconds
def test_scoped_rate_throttle(self):
request = self.factory.get('/')
# Should be able to hit x view 3 times per minute.
response = self.x_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.x_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.x_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.x_view(request)
assert response.status_code == 429
# Should be able to hit y view 1 time per minute.
self.increment_timer()
response = self.y_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.y_view(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(55)
# Should still be able to hit x view 3 times per minute.
response = self.x_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.x_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.x_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.x_view(request)
assert response.status_code == 429
# Should still be able to hit y view 1 time per minute.
self.increment_timer()
response = self.y_view(request)
assert response.status_code == 200
self.increment_timer()
response = self.y_view(request)
assert response.status_code == 429
def test_unscoped_view_not_throttled(self):
request = self.factory.get('/')
for idx in range(10):
self.increment_timer()
response = self.unscoped_view(request)
assert response.status_code == 200
def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self):
class DummyView(object):
throttle_scope = 'user'
request = Request(HttpRequest())
user = User.objects.create(username='test')
force_authenticate(request, user)
request.user = user
self.throttle.allow_request(request, DummyView())
cache_key = self.throttle.get_cache_key(request, view=DummyView())
assert cache_key == 'throttle_user_%s' % user.pk
class XffTestingBase(TestCase):
def setUp(self):
class Throttle(ScopedRateThrottle):
THROTTLE_RATES = {'test_limit': '1/day'}
TIMER_SECONDS = 0
def timer(self):
return self.TIMER_SECONDS
class View(APIView):
throttle_classes = (Throttle,)
throttle_scope = 'test_limit'
def get(self, request):
return Response('test_limit')
cache.clear()
self.throttle = Throttle()
self.view = View.as_view()
self.request = APIRequestFactory().get('/some_uri')
self.request.META['REMOTE_ADDR'] = '3.3.3.3'
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2'
def config_proxy(self, num_proxies):
setattr(api_settings, 'NUM_PROXIES', num_proxies)
class IdWithXffBasicTests(XffTestingBase):
def test_accepts_request_under_limit(self):
self.config_proxy(0)
assert self.view(self.request).status_code == 200
def test_denies_request_over_limit(self):
self.config_proxy(0)
self.view(self.request)
assert self.view(self.request).status_code == 429
class XffSpoofingTests(XffTestingBase):
def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
self.config_proxy(1)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
assert self.view(self.request).status_code == 429
def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
self.config_proxy(2)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
assert self.view(self.request).status_code == 429
class XffUniqueMachinesTest(XffTestingBase):
def test_unique_clients_are_counted_independently_with_one_proxy(self):
self.config_proxy(1)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
assert self.view(self.request).status_code == 200
def test_unique_clients_are_counted_independently_with_two_proxies(self):
self.config_proxy(2)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
assert self.view(self.request).status_code == 200
class BaseThrottleTests(TestCase):
def test_allow_request_raises_not_implemented_error(self):
with pytest.raises(NotImplementedError):
BaseThrottle().allow_request(request={}, view={})
class SimpleRateThrottleTests(TestCase):
def setUp(self):
SimpleRateThrottle.scope = 'anon'
def test_get_rate_raises_error_if_scope_is_missing(self):
throttle = SimpleRateThrottle()
with pytest.raises(ImproperlyConfigured):
throttle.scope = None
throttle.get_rate()
def test_throttle_raises_error_if_rate_is_missing(self):
SimpleRateThrottle.scope = 'invalid scope'
with pytest.raises(ImproperlyConfigured):
SimpleRateThrottle()
def test_parse_rate_returns_tuple_with_none_if_rate_not_provided(self):
rate = SimpleRateThrottle().parse_rate(None)
assert rate == (None, None)
def test_allow_request_returns_true_if_rate_is_none(self):
assert SimpleRateThrottle().allow_request(request={}, view={}) is True
def test_get_cache_key_raises_not_implemented_error(self):
with pytest.raises(NotImplementedError):
SimpleRateThrottle().get_cache_key({}, {})
def test_allow_request_returns_true_if_key_is_none(self):
throttle = SimpleRateThrottle()
throttle.rate = 'some rate'
throttle.get_cache_key = lambda *args: None
assert throttle.allow_request(request={}, view={}) is True
def test_wait_returns_correct_waiting_time_without_history(self):
throttle = SimpleRateThrottle()
throttle.num_requests = 1
throttle.duration = 60
throttle.history = []
waiting_time = throttle.wait()
assert isinstance(waiting_time, float)
assert waiting_time == 30.0
def test_wait_returns_none_if_there_are_no_available_requests(self):
throttle = SimpleRateThrottle()
throttle.num_requests = 1
throttle.duration = 60
throttle.now = throttle.timer()
throttle.history = [throttle.timer() for _ in range(3)]
assert throttle.wait() is None
class AnonRateThrottleTests(TestCase):
def setUp(self):
self.throttle = AnonRateThrottle()
def test_authenticated_user_not_affected(self):
request = Request(HttpRequest())
user = User.objects.create(username='test')
force_authenticate(request, user)
request.user = user
assert self.throttle.get_cache_key(request, view={}) is None
def test_get_cache_key_returns_correct_value(self):
request = Request(HttpRequest())
cache_key = self.throttle.get_cache_key(request, view={})
assert cache_key == 'throttle_anon_None'