From d5675a5d4ae2ce435dc972417020527764730ae6 Mon Sep 17 00:00:00 2001 From: jeremystretch Date: Fri, 25 Jun 2021 09:13:08 -0400 Subject: [PATCH] Add support for DRF token authentication --- netbox/netbox/graphql/views.py | 20 ++++++++++++++++++++ netbox/netbox/middleware.py | 3 ++- netbox/netbox/urls.py | 2 +- netbox/utilities/testing/api.py | 21 +++++++++++++++++++++ 4 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 netbox/netbox/graphql/views.py diff --git a/netbox/netbox/graphql/views.py b/netbox/netbox/graphql/views.py new file mode 100644 index 000000000..1cde56cd6 --- /dev/null +++ b/netbox/netbox/graphql/views.py @@ -0,0 +1,20 @@ +from graphene_django.views import GraphQLView as GraphQLView_ +from rest_framework.decorators import authentication_classes, permission_classes, api_view +from rest_framework.permissions import IsAuthenticated +from rest_framework.settings import api_settings + + +class GraphQLView(GraphQLView_): + """ + Extends grpahene_django's GraphQLView to support DRF's token-based authentication. + """ + @classmethod + def as_view(cls, *args, **kwargs): + view = super(GraphQLView, cls).as_view(*args, **kwargs) + + # Apply DRF permission and authentication classes + view = permission_classes((IsAuthenticated,))(view) + view = authentication_classes(api_settings.DEFAULT_AUTHENTICATION_CLASSES)(view) + view = api_view(['GET', 'POST'])(view) + + return view diff --git a/netbox/netbox/middleware.py b/netbox/netbox/middleware.py index d3b3dae40..ef50edc4a 100644 --- a/netbox/netbox/middleware.py +++ b/netbox/netbox/middleware.py @@ -24,7 +24,8 @@ class LoginRequiredMiddleware(object): if settings.LOGIN_REQUIRED and not request.user.is_authenticated: # Determine exempt paths exempt_paths = [ - reverse('api-root') + reverse('api-root'), + reverse('graphql'), ] if settings.METRICS_ENABLED: exempt_paths.append(reverse('prometheus-django-metrics')) diff --git a/netbox/netbox/urls.py b/netbox/netbox/urls.py index 4f1ec38d2..06e1eee06 100644 --- a/netbox/netbox/urls.py +++ b/netbox/netbox/urls.py @@ -4,11 +4,11 @@ from django.urls import path, re_path from django.views.static import serve from drf_yasg import openapi from drf_yasg.views import get_schema_view -from graphene_django.views import GraphQLView from extras.plugins.urls import plugin_admin_patterns, plugin_patterns, plugin_api_patterns from netbox.api.views import APIRootView, StatusView from netbox.graphql.schema import schema +from netbox.graphql.views import GraphQLView from netbox.views import HomeView, StaticMediaFailureView, SearchView from users.views import LoginView, LogoutView from .admin import admin_site diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index 1a9414dc6..ad14c2fdc 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -425,6 +425,7 @@ class APIViewTestCases: class GraphQLTestCase(APITestCase): + @override_settings(LOGIN_REQUIRED=True) def test_graphql_get_object(self): url = reverse('graphql') object_type = self.model._meta.verbose_name.replace(' ', '_') @@ -441,11 +442,21 @@ class APIViewTestCases: with disable_warnings('django.request'): self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN) + # Add object-level permission + obj_perm = ObjectPermission( + name='Test permission', + actions=['view'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + response = self.client.post(url, data={'query': query}, **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) data = json.loads(response.content) self.assertNotIn('errors', data) + @override_settings(LOGIN_REQUIRED=True) def test_graphql_list_objects(self): url = reverse('graphql') object_type = self.model._meta.verbose_name_plural.replace(' ', '_') @@ -461,10 +472,20 @@ class APIViewTestCases: with disable_warnings('django.request'): self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN) + # Add object-level permission + obj_perm = ObjectPermission( + name='Test permission', + actions=['view'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ContentType.objects.get_for_model(self.model)) + response = self.client.post(url, data={'query': query}, **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) data = json.loads(response.content) self.assertNotIn('errors', data) + self.assertGreater(len(data['data'][object_type]), 0) class APIViewTestCase( GetObjectViewTestCase,