Skip to content

Commit

Permalink
Merge pull request encode#4779 from auvipy/pyts0
Browse files Browse the repository at this point in the history
converted throttling tests asserts to pytest
  • Loading branch information
jpadilla authored Jan 4, 2017
2 parents c524925 + 6ca7f76 commit 559a0a8
Showing 1 changed file with 25 additions and 26 deletions.
51 changes: 25 additions & 26 deletions tests/test_throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_requests_are_throttled(self):
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(429, response.status_code)
assert response.status_code == 429

def set_throttle_timer(self, view, value):
"""
Expand All @@ -87,13 +87,13 @@ def test_request_throttling_expires(self):
request = self.factory.get('/')
for dummy in range(4):
response = MockView.as_view()(request)
self.assertEqual(429, response.status_code)
assert response.status_code == 429

# Advance the timer by one second
self.set_throttle_timer(MockView, 1)

response = MockView.as_view()(request)
self.assertEqual(200, response.status_code)
assert response.status_code == 200

def ensure_is_throttled(self, view, expect):
request = self.factory.get('/')
Expand All @@ -102,7 +102,7 @@ def ensure_is_throttled(self, view, expect):
view.as_view()(request)
request.user = User.objects.create(username='b')
response = view.as_view()(request)
self.assertEqual(expect, response.status_code)
assert response.status_code == expect

def test_request_throttling_is_per_user(self):
"""
Expand All @@ -121,9 +121,9 @@ def ensure_response_header_contains_proper_throttle_field(self, view, expected_h
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
if expect is not None:
self.assertEqual(response['Retry-After'], expect)
assert response['Retry-After'] == expect
else:
self.assertFalse('Retry-After' in response)
assert not'Retry-After' in response

def test_seconds_fields(self):
"""
Expand Down Expand Up @@ -230,64 +230,63 @@ def test_scoped_rate_throttle(self):

# Should be able to hit x view 3 times per minute.
response = self.x_view(request)
self.assertEqual(200, response.status_code)
assert response.status_code == 200

self.increment_timer()
response = self.x_view(request)
self.assertEqual(200, response.status_code)
assert response.status_code == 200

self.increment_timer()
response = self.x_view(request)
self.assertEqual(200, response.status_code)

assert response.status_code == 200
self.increment_timer()
response = self.x_view(request)
self.assertEqual(429, response.status_code)
assert response.status_code == 429

# Should be able to hit y view 1 time per minute.
self.increment_timer()
response = self.y_view(request)
self.assertEqual(200, response.status_code)
assert response.status_code == 200

self.increment_timer()
response = self.y_view(request)
self.assertEqual(429, response.status_code)
assert response.status_code == 429

# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(55)

# Should still be able to hit x view 3 times per minute.
response = self.x_view(request)
self.assertEqual(200, response.status_code)
assert response.status_code == 200

self.increment_timer()
response = self.x_view(request)
self.assertEqual(200, response.status_code)
assert response.status_code == 200

self.increment_timer()
response = self.x_view(request)
self.assertEqual(200, response.status_code)
assert response.status_code == 200

self.increment_timer()
response = self.x_view(request)
self.assertEqual(429, response.status_code)
assert response.status_code == 429

# Should still be able to hit y view 1 time per minute.
self.increment_timer()
response = self.y_view(request)
self.assertEqual(200, response.status_code)
assert response.status_code == 200

self.increment_timer()
response = self.y_view(request)
self.assertEqual(429, response.status_code)
assert response.status_code == 429

def test_unscoped_view_not_throttled(self):
request = self.factory.get('/')

for idx in range(10):
self.increment_timer()
response = self.unscoped_view(request)
self.assertEqual(200, response.status_code)
assert response.status_code == 200


class XffTestingBase(TestCase):
Expand Down Expand Up @@ -321,37 +320,37 @@ def config_proxy(self, num_proxies):
class IdWithXffBasicTests(XffTestingBase):
def test_accepts_request_under_limit(self):
self.config_proxy(0)
self.assertEqual(200, self.view(self.request).status_code)
assert self.view(self.request).status_code == 200

def test_denies_request_over_limit(self):
self.config_proxy(0)
self.view(self.request)
self.assertEqual(429, self.view(self.request).status_code)
assert self.view(self.request).status_code == 429


class XffSpoofingTests(XffTestingBase):
def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
self.config_proxy(1)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
self.assertEqual(429, self.view(self.request).status_code)
assert self.view(self.request).status_code == 429

def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
self.config_proxy(2)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
self.assertEqual(429, self.view(self.request).status_code)
assert self.view(self.request).status_code == 429


class XffUniqueMachinesTest(XffTestingBase):
def test_unique_clients_are_counted_independently_with_one_proxy(self):
self.config_proxy(1)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
self.assertEqual(200, self.view(self.request).status_code)
assert self.view(self.request).status_code == 200

def test_unique_clients_are_counted_independently_with_two_proxies(self):
self.config_proxy(2)
self.view(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
self.assertEqual(200, self.view(self.request).status_code)
assert self.view(self.request).status_code == 200

0 comments on commit 559a0a8

Please sign in to comment.