diff --git a/peeringdb_server/rest.py b/peeringdb_server/rest.py index 2791fd36..f383c9ff 100644 --- a/peeringdb_server/rest.py +++ b/peeringdb_server/rest.py @@ -25,6 +25,7 @@ import unidecode from django.apps import apps from django.conf import settings from django.core.exceptions import FieldError, ObjectDoesNotExist, ValidationError +from django.contrib.auth.models import AnonymousUser from django.db import connection, transaction from django.db.models import DateTimeField from django.utils import timezone @@ -175,9 +176,11 @@ class client_check: def __call__(self, fn): compat_check = self.compat_check + auth_check = self.auth_check def wrapped(self, request, *args, **kwargs): try: + auth_check(request) compat_check(request) except ValueError as exc: return Response( @@ -210,6 +213,19 @@ class client_check: """Return the max supported version for the specified backend.""" return self.backends.get(backend, {}).get("max") + def auth_check(self, request): + for header in request.META.keys(): + if header.startswith("HTTP_AUTH") and header != "HTTP_AUTHORIZATION": + if "HTTP_AUTHORIZATION" not in request.META: + raise ValueError("Malformed authorization header") + break + + if "HTTP_AUTHORIZATION" in request.META: + permission_holder = get_permission_holder_from_request(request) + + if isinstance(permission_holder, AnonymousUser): + raise ValueError("Unknown authorization method") + def client_info(self, request): """ Parse the useragent in the request and return client version diff --git a/tests/test_api_compat.py b/tests/test_api_compat.py index 2fd7a1ec..e5395ce7 100644 --- a/tests/test_api_compat.py +++ b/tests/test_api_compat.py @@ -10,6 +10,10 @@ from .util import ClientCase class TestAPIClientCompat(ClientCase): + + expected_unknown_auth_method_err_str = "Unknown authorization method" + expected_malformed_auth_header_err_str = "Malformed authorization header" + @classmethod def setUpTestData(cls): super().setUpTestData() @@ -96,3 +100,18 @@ class TestAPIClientCompat(ClientCase): self._compat("0.6", "0.6.1", False) self._compat("0.6.1", "0.6", False) self._compat(None, None, False) + + def test_auth_header(self): + + # this should return 400 with an unknown authorization method message + r = self.client.get("/api/net", HTTP_AUTHORIZATION="apikey deadbeef") + content = json.loads(r.content) + + assert content["meta"]["error"] == self.expected_unknown_auth_method_err_str + + # this should return 400 with an malformed authorization header message + r = self.client.get("/api/net", HTTP_AUTHORIZATIONS="apikey deadbeef") + content = json.loads(r.content) + + assert content["meta"]["error"] == self.expected_malformed_auth_header_err_str +