diff --git a/shynet/api/mixins.py b/shynet/api/mixins.py index f6c8dba..5ecd19b 100644 --- a/shynet/api/mixins.py +++ b/shynet/api/mixins.py @@ -1,7 +1,10 @@ -from django.http import JsonResponse +from http import HTTPStatus + +from django.contrib.auth import get_user_model from django.contrib.auth.models import AnonymousUser +from django.http import JsonResponse -from core.models import User +User = get_user_model() class ApiTokenRequiredMixin: @@ -11,13 +14,13 @@ def _get_user_by_token(self, request): return AnonymousUser() token = token.split(" ")[1] - user = User.objects.filter(api_token=token).first() - - return user if user else AnonymousUser() + user: User = User.objects.filter(api_token=token).first() + return user or AnonymousUser() def dispatch(self, request, *args, **kwargs): request.user = self._get_user_by_token(request) - if not request.user.is_authenticated: - return JsonResponse(data={}, status=403) - - return super().dispatch(request, *args, **kwargs) + return ( + super().dispatch(request, *args, **kwargs) + if request.user.is_authenticated + else JsonResponse(data={}, status=HTTPStatus.FORBIDDEN) + ) diff --git a/shynet/api/tests.py b/shynet/api/tests.py deleted file mode 100644 index 7ce503c..0000000 --- a/shynet/api/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/shynet/api/tests/__init__.py b/shynet/api/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shynet/api/tests/test_mixins.py b/shynet/api/tests/test_mixins.py new file mode 100644 index 0000000..901d3bc --- /dev/null +++ b/shynet/api/tests/test_mixins.py @@ -0,0 +1,77 @@ +from http import HTTPStatus + +from django.test import TestCase, RequestFactory +from django.views import View + +from api.mixins import ApiTokenRequiredMixin +from core.factories import UserFactory +from core.models import _default_api_token, Service + + +class TestApiTokenRequiredMixin(TestCase): + class DummyView(ApiTokenRequiredMixin, View): + model = Service + template_name = "dashboard/pages/service.html" + + def setUp(self): + super().setUp() + self.user = UserFactory() + self.request = RequestFactory().get("/fake-path") + + # Setup request and view. + self.factory = RequestFactory() + self.view = self.DummyView() + + def test_get_user_by_token_without_authorization_token(self): + """ + GIVEN: A request without Authorization header + WHEN: get_user_by_token is called + THEN: It should return AnonymousUser + """ + user = self.view._get_user_by_token(self.request) + + self.assertEqual(user.is_anonymous, True) + + def test_get_user_by_token_with_invalid_authorization_token(self): + """ + GIVEN: A request with invalid Authorization header + WHEN: get_user_by_token is called + THEN: It should return AnonymousUser + """ + self.request.META["HTTP_AUTHORIZATION"] = "Bearer invalid-token" + user = self.view._get_user_by_token(self.request) + + self.assertEqual(user.is_anonymous, True) + + def test_get_user_by_token_with_invalid_token(self): + """ + GIVEN: A request with invalid token + WHEN: get_user_by_token is called + THEN: It should return AnonymousUser + """ + self.request.META["HTTP_AUTHORIZATION"] = f"Token {_default_api_token()}" + user = self.view._get_user_by_token(self.request) + + self.assertEqual(user.is_anonymous, True) + + def test_get_user_by_token_with_valid_token(self): + """ + GIVEN: A request with valid token + WHEN: get_user_by_token is called + THEN: It should return the user + """ + self.request.META["HTTP_AUTHORIZATION"] = f"Token {self.user.api_token}" + user = self.view._get_user_by_token(self.request) + + self.assertEqual(user, self.user) + + def test_dispatch_with_unauthenticated_user(self): + """ + GIVEN: A request with unauthenticated user + WHEN: dispatch is called + THEN: It should return 403 + """ + self.request.META["HTTP_AUTHORIZATION"] = f"Token {_default_api_token()}" + response = self.view.dispatch(self.request) + + self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) diff --git a/shynet/api/tests/test_views.py b/shynet/api/tests/test_views.py new file mode 100644 index 0000000..3e1c0ad --- /dev/null +++ b/shynet/api/tests/test_views.py @@ -0,0 +1,79 @@ +import json +from http import HTTPStatus + +from django.contrib.auth import get_user_model +from django.test import TestCase, RequestFactory +from django.urls import reverse + +from api.views import DashboardApiView +from core.factories import UserFactory, ServiceFactory +from core.models import Service + +User = get_user_model() + + +class TestDashboardApiView(TestCase): + def setUp(self) -> None: + super().setUp() + self.user: User = UserFactory() + self.service_1: Service = ServiceFactory(owner=self.user) + self.service_2: Service = ServiceFactory(owner=self.user) + self.url = reverse("api:services") + self.factory = RequestFactory() + + def test_get_with_unauthenticated_user(self): + """ + GIVEN: An unauthenticated user + WHEN: The user makes a GET request to the dashboard API view + THEN: It should return 403 + """ + response = self.client.get(self.url) + self.assertEqual(response.status_code, HTTPStatus.FORBIDDEN) + + def test_get_returns_400(self): + """ + GIVEN: An authenticated user + WHEN: The user makes a GET request to the dashboard API view with an invalid date format + THEN: It should return 400 + """ + request = self.factory.get(self.url, {"startDate": "01/01/2000"}) + request.META["HTTP_AUTHORIZATION"] = f"Token {self.user.api_token}" + + response = DashboardApiView.as_view()(request) + self.assertEqual(response.status_code, HTTPStatus.BAD_REQUEST) + + data = json.loads(response.content) + self.assertEqual(data["error"], "Invalid date format. Use YYYY-MM-DD.") + + def test_get_with_authenticated_user(self): + """ + GIVEN: An authenticated user + WHEN: The user makes a GET request to the dashboard API view + THEN: It should return 200 + """ + request = self.factory.get(self.url) + request.META["HTTP_AUTHORIZATION"] = f"Token {self.user.api_token}" + + response = DashboardApiView.as_view()(request) + self.assertEqual(response.status_code, HTTPStatus.OK) + + data = json.loads(response.content) + self.assertEqual(len(data["services"]), 2) + + def test_get_with_service_uuid(self): + """ + GIVEN: An authenticated user + WHEN: The user makes a GET request to the dashboard API view with a service UUID + THEN: It should return 200 and a single service + """ + request = self.factory.get(self.url, {"uuid": str(self.service_1.uuid)}) + request.META["HTTP_AUTHORIZATION"] = f"Token {self.user.api_token}" + + response = DashboardApiView.as_view()(request) + self.assertEqual(response.status_code, HTTPStatus.OK) + + data = json.loads(response.content) + self.assertEqual(len(data["services"]), 1) + self.assertEqual(data["services"][0]["uuid"], str(self.service_1.uuid)) + self.assertEqual(data["services"][0]["name"], str(self.service_1.name)) + diff --git a/shynet/api/views.py b/shynet/api/views.py index 44ef6dc..9af2816 100644 --- a/shynet/api/views.py +++ b/shynet/api/views.py @@ -1,54 +1,46 @@ -import uuid -from django.http import JsonResponse +from http import HTTPStatus + from django.db.models import Q from django.db.models.query import QuerySet +from django.http import JsonResponse from django.views.generic import View -from dashboard.mixins import DateRangeMixin from core.models import Service - +from core.utils import is_valid_uuid +from dashboard.mixins import DateRangeMixin from .mixins import ApiTokenRequiredMixin -def is_valid_uuid(value): - try: - uuid.UUID(value) - return True - except ValueError: - return False - - class DashboardApiView(ApiTokenRequiredMixin, DateRangeMixin, View): def get(self, request, *args, **kwargs): - services = Service.objects.filter( - Q(owner=request.user) | Q(collaborators__in=[request.user]) - ).distinct() + services = Service.objects.filter(Q(owner=request.user) | Q(collaborators__in=[request.user])).distinct() - uuid = request.GET.get("uuid") - if uuid and is_valid_uuid(uuid): - services = services.filter(uuid=uuid) + uuid_ = request.GET.get("uuid") + if uuid_ and is_valid_uuid(uuid_): + services = services.filter(uuid=uuid_) try: start = self.get_start_date() end = self.get_end_date() except ValueError: - return JsonResponse(status=400, data={"error": "Invalid date format"}) + return JsonResponse(status=HTTPStatus.BAD_REQUEST, data={"error": "Invalid date format. Use YYYY-MM-DD."}) + service: Service services_data = [ { - "name": s.name, - "uuid": s.uuid, - "link": s.link, - "stats": s.get_core_stats(start, end), + "name": service.name, + "uuid": service.uuid, + "link": service.link, + "stats": service.get_core_stats(start, end), } - for s in services + for service in services ] services_data = self._convert_querysets_to_lists(services_data) return JsonResponse(data={"services": services_data}) - def _convert_querysets_to_lists(self, services_data): + def _convert_querysets_to_lists(self, services_data: list[dict]) -> list[dict]: for service_data in services_data: for key, value in service_data["stats"].items(): if isinstance(value, QuerySet): diff --git a/shynet/core/utils.py b/shynet/core/utils.py new file mode 100644 index 0000000..f845a99 --- /dev/null +++ b/shynet/core/utils.py @@ -0,0 +1,10 @@ +import uuid + + +def is_valid_uuid(value: str) -> bool: + """Check if a string is a valid UUID.""" + try: + uuid.UUID(value) + return True + except ValueError: + return False