1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

Add support for DRF token authentication

This commit is contained in:
jeremystretch
2021-06-25 09:13:08 -04:00
parent 91d39cc0c0
commit d5675a5d4a
4 changed files with 44 additions and 2 deletions

View File

@ -0,0 +1,20 @@
from graphene_django.views import GraphQLView as GraphQLView_
from rest_framework.decorators import authentication_classes, permission_classes, api_view
from rest_framework.permissions import IsAuthenticated
from rest_framework.settings import api_settings
class GraphQLView(GraphQLView_):
"""
Extends grpahene_django's GraphQLView to support DRF's token-based authentication.
"""
@classmethod
def as_view(cls, *args, **kwargs):
view = super(GraphQLView, cls).as_view(*args, **kwargs)
# Apply DRF permission and authentication classes
view = permission_classes((IsAuthenticated,))(view)
view = authentication_classes(api_settings.DEFAULT_AUTHENTICATION_CLASSES)(view)
view = api_view(['GET', 'POST'])(view)
return view

View File

@ -24,7 +24,8 @@ class LoginRequiredMiddleware(object):
if settings.LOGIN_REQUIRED and not request.user.is_authenticated: if settings.LOGIN_REQUIRED and not request.user.is_authenticated:
# Determine exempt paths # Determine exempt paths
exempt_paths = [ exempt_paths = [
reverse('api-root') reverse('api-root'),
reverse('graphql'),
] ]
if settings.METRICS_ENABLED: if settings.METRICS_ENABLED:
exempt_paths.append(reverse('prometheus-django-metrics')) exempt_paths.append(reverse('prometheus-django-metrics'))

View File

@ -4,11 +4,11 @@ from django.urls import path, re_path
from django.views.static import serve from django.views.static import serve
from drf_yasg import openapi from drf_yasg import openapi
from drf_yasg.views import get_schema_view from drf_yasg.views import get_schema_view
from graphene_django.views import GraphQLView
from extras.plugins.urls import plugin_admin_patterns, plugin_patterns, plugin_api_patterns from extras.plugins.urls import plugin_admin_patterns, plugin_patterns, plugin_api_patterns
from netbox.api.views import APIRootView, StatusView from netbox.api.views import APIRootView, StatusView
from netbox.graphql.schema import schema from netbox.graphql.schema import schema
from netbox.graphql.views import GraphQLView
from netbox.views import HomeView, StaticMediaFailureView, SearchView from netbox.views import HomeView, StaticMediaFailureView, SearchView
from users.views import LoginView, LogoutView from users.views import LoginView, LogoutView
from .admin import admin_site from .admin import admin_site

View File

@ -425,6 +425,7 @@ class APIViewTestCases:
class GraphQLTestCase(APITestCase): class GraphQLTestCase(APITestCase):
@override_settings(LOGIN_REQUIRED=True)
def test_graphql_get_object(self): def test_graphql_get_object(self):
url = reverse('graphql') url = reverse('graphql')
object_type = self.model._meta.verbose_name.replace(' ', '_') object_type = self.model._meta.verbose_name.replace(' ', '_')
@ -441,11 +442,21 @@ class APIViewTestCases:
with disable_warnings('django.request'): with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN) self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
# Add object-level permission
obj_perm = ObjectPermission(
name='Test permission',
actions=['view']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
response = self.client.post(url, data={'query': query}, **self.header) response = self.client.post(url, data={'query': query}, **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK) self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content) data = json.loads(response.content)
self.assertNotIn('errors', data) self.assertNotIn('errors', data)
@override_settings(LOGIN_REQUIRED=True)
def test_graphql_list_objects(self): def test_graphql_list_objects(self):
url = reverse('graphql') url = reverse('graphql')
object_type = self.model._meta.verbose_name_plural.replace(' ', '_') object_type = self.model._meta.verbose_name_plural.replace(' ', '_')
@ -461,10 +472,20 @@ class APIViewTestCases:
with disable_warnings('django.request'): with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN) self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
# Add object-level permission
obj_perm = ObjectPermission(
name='Test permission',
actions=['view']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
response = self.client.post(url, data={'query': query}, **self.header) response = self.client.post(url, data={'query': query}, **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK) self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content) data = json.loads(response.content)
self.assertNotIn('errors', data) self.assertNotIn('errors', data)
self.assertGreater(len(data['data'][object_type]), 0)
class APIViewTestCase( class APIViewTestCase(
GetObjectViewTestCase, GetObjectViewTestCase,