mirror of
https://github.com/netbox-community/netbox.git
synced 2024-05-10 07:54:54 +00:00
Add support for DRF token authentication
This commit is contained in:
20
netbox/netbox/graphql/views.py
Normal file
20
netbox/netbox/graphql/views.py
Normal file
@ -0,0 +1,20 @@
|
||||
from graphene_django.views import GraphQLView as GraphQLView_
|
||||
from rest_framework.decorators import authentication_classes, permission_classes, api_view
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
|
||||
class GraphQLView(GraphQLView_):
|
||||
"""
|
||||
Extends grpahene_django's GraphQLView to support DRF's token-based authentication.
|
||||
"""
|
||||
@classmethod
|
||||
def as_view(cls, *args, **kwargs):
|
||||
view = super(GraphQLView, cls).as_view(*args, **kwargs)
|
||||
|
||||
# Apply DRF permission and authentication classes
|
||||
view = permission_classes((IsAuthenticated,))(view)
|
||||
view = authentication_classes(api_settings.DEFAULT_AUTHENTICATION_CLASSES)(view)
|
||||
view = api_view(['GET', 'POST'])(view)
|
||||
|
||||
return view
|
@ -24,7 +24,8 @@ class LoginRequiredMiddleware(object):
|
||||
if settings.LOGIN_REQUIRED and not request.user.is_authenticated:
|
||||
# Determine exempt paths
|
||||
exempt_paths = [
|
||||
reverse('api-root')
|
||||
reverse('api-root'),
|
||||
reverse('graphql'),
|
||||
]
|
||||
if settings.METRICS_ENABLED:
|
||||
exempt_paths.append(reverse('prometheus-django-metrics'))
|
||||
|
@ -4,11 +4,11 @@ from django.urls import path, re_path
|
||||
from django.views.static import serve
|
||||
from drf_yasg import openapi
|
||||
from drf_yasg.views import get_schema_view
|
||||
from graphene_django.views import GraphQLView
|
||||
|
||||
from extras.plugins.urls import plugin_admin_patterns, plugin_patterns, plugin_api_patterns
|
||||
from netbox.api.views import APIRootView, StatusView
|
||||
from netbox.graphql.schema import schema
|
||||
from netbox.graphql.views import GraphQLView
|
||||
from netbox.views import HomeView, StaticMediaFailureView, SearchView
|
||||
from users.views import LoginView, LogoutView
|
||||
from .admin import admin_site
|
||||
|
@ -425,6 +425,7 @@ class APIViewTestCases:
|
||||
|
||||
class GraphQLTestCase(APITestCase):
|
||||
|
||||
@override_settings(LOGIN_REQUIRED=True)
|
||||
def test_graphql_get_object(self):
|
||||
url = reverse('graphql')
|
||||
object_type = self.model._meta.verbose_name.replace(' ', '_')
|
||||
@ -441,11 +442,21 @@ class APIViewTestCases:
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
|
||||
|
||||
# Add object-level permission
|
||||
obj_perm = ObjectPermission(
|
||||
name='Test permission',
|
||||
actions=['view']
|
||||
)
|
||||
obj_perm.save()
|
||||
obj_perm.users.add(self.user)
|
||||
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
|
||||
|
||||
response = self.client.post(url, data={'query': query}, **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
|
||||
@override_settings(LOGIN_REQUIRED=True)
|
||||
def test_graphql_list_objects(self):
|
||||
url = reverse('graphql')
|
||||
object_type = self.model._meta.verbose_name_plural.replace(' ', '_')
|
||||
@ -461,10 +472,20 @@ class APIViewTestCases:
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(url, data={'query': query}), status.HTTP_403_FORBIDDEN)
|
||||
|
||||
# Add object-level permission
|
||||
obj_perm = ObjectPermission(
|
||||
name='Test permission',
|
||||
actions=['view']
|
||||
)
|
||||
obj_perm.save()
|
||||
obj_perm.users.add(self.user)
|
||||
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
|
||||
|
||||
response = self.client.post(url, data={'query': query}, **self.header)
|
||||
self.assertHttpStatus(response, status.HTTP_200_OK)
|
||||
data = json.loads(response.content)
|
||||
self.assertNotIn('errors', data)
|
||||
self.assertGreater(len(data['data'][object_type]), 0)
|
||||
|
||||
class APIViewTestCase(
|
||||
GetObjectViewTestCase,
|
||||
|
Reference in New Issue
Block a user