diff --git a/netbox/secrets/api/views.py b/netbox/secrets/api/views.py index 6a3593003..1f683cd09 100644 --- a/netbox/secrets/api/views.py +++ b/netbox/secrets/api/views.py @@ -8,7 +8,7 @@ from rest_framework.authentication import BasicAuthentication, SessionAuthentica from rest_framework.permissions import IsAuthenticated from rest_framework.renderers import JSONRenderer from rest_framework.response import Response -from rest_framework.viewsets import ViewSet, ModelViewSet +from rest_framework.viewsets import GenericViewSet, ModelViewSet, ViewSet from extras.api.renderers import FormlessBrowsableAPIRenderer, FreeRADIUSClientsRenderer from secrets.exceptions import InvalidSessionKey @@ -50,34 +50,37 @@ class SecretViewSet(WritableSerializerMixin, ModelViewSet): filter_class = SecretFilter # DRF's BrowsableAPIRenderer can't support passing the secret key as a header, so we disable it. renderer_classes = [FormlessBrowsableAPIRenderer, JSONRenderer, FreeRADIUSClientsRenderer] - # Enabled BasicAuthentication for testing (until we have TokenAuthentication implemented) - authentication_classes = [BasicAuthentication, SessionAuthentication] - permission_classes = [IsAuthenticated] - def _read_session_key(self, request): + master_key = None - # Check for a session key provided as a cookie or header + def initial(self, request, *args, **kwargs): + + super(SecretViewSet, self).initial(request, *args, **kwargs) + + # Read session key from HTTP cookie or header if it has been provided. The session key must be provided in order + # to encrypt/decrypt secrets. if 'session_key' in request.COOKIES: - return base64.b64decode(request.COOKIES['session_key']) + session_key = base64.b64decode(request.COOKIES['session_key']) elif 'HTTP_X_SESSION_KEY' in request.META: - return base64.b64decode(request.META['HTTP_X_SESSION_KEY']) - return None + session_key = base64.b64decode(request.META['HTTP_X_SESSION_KEY']) + else: + session_key = None + + # Attempt to retrieve the master key for encryption/decryption if a session key has been provided. + if session_key is not None: + try: + sk = SessionKey.objects.get(userkey__user=request.user) + self.master_key = sk.get_master_key(session_key) + except (SessionKey.DoesNotExist, InvalidSessionKey): + return HttpResponseBadRequest("Invalid session key.") def retrieve(self, request, *args, **kwargs): secret = self.get_object() - session_key = self._read_session_key(request) - # Retrieve session key cipher (if any) for the current user - if session_key is not None: - try: - sk = SessionKey.objects.get(userkey__user=request.user) - master_key = sk.get_master_key(session_key) - secret.decrypt(master_key) - except SessionKey.DoesNotExist: - return HttpResponseBadRequest("No active session key for current user.") - except InvalidSessionKey: - return HttpResponseBadRequest("Invalid session key.") + # Attempt to decrypt the secret if the master key is known + if self.master_key is not None: + secret.decrypt(self.master_key) serializer = self.get_serializer(secret) return Response(serializer.data) @@ -86,29 +89,19 @@ class SecretViewSet(WritableSerializerMixin, ModelViewSet): queryset = self.filter_queryset(self.get_queryset()) - # Attempt to retrieve the master key for decryption - session_key = self._read_session_key(request) - master_key = None - if session_key is not None: - try: - sk = SessionKey.objects.get(user=request.user) - master_key = sk.get_master_key(session_key) - except SessionKey.DoesNotExist: - return HttpResponseBadRequest("No active session key for current user.") - except InvalidSessionKey: - return HttpResponseBadRequest("Invalid session key.") - - # Pagination page = self.paginate_queryset(queryset) if page is not None: - secrets = [] - if master_key is not None: + + # Attempt to decrypt all secrets if the master key is known + if self.master_key is not None: + secrets = [] for secret in page: - secret.decrypt(master_key) + secret.decrypt(self.master_key) secrets.append(secret) serializer = self.get_serializer(secrets, many=True) else: serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) serializer = self.get_serializer(queryset, many=True)