1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

Extended GraphQL tests to include all fields

This commit is contained in:
jeremystretch
2021-06-29 11:20:54 -04:00
parent bd1e019a42
commit 7deabfe9cd
3 changed files with 52 additions and 16 deletions

View File

@ -8,3 +8,7 @@ class ServiceUnavailable(APIException):
class SerializerNotFound(Exception): class SerializerNotFound(Exception):
pass pass
class GraphQLTypeNotFound(Exception):
pass

View File

@ -7,7 +7,7 @@ from django.urls import reverse
from rest_framework import status from rest_framework import status
from rest_framework.utils import formatting from rest_framework.utils import formatting
from netbox.api.exceptions import SerializerNotFound from netbox.api.exceptions import GraphQLTypeNotFound, SerializerNotFound
from .utils import dynamic_import from .utils import dynamic_import
@ -24,10 +24,22 @@ def get_serializer_for_model(model, prefix=''):
return dynamic_import(serializer_name) return dynamic_import(serializer_name)
except AttributeError: except AttributeError:
raise SerializerNotFound( raise SerializerNotFound(
"Could not determine serializer for {}.{} with prefix '{}'".format(app_name, model_name, prefix) f"Could not determine serializer for {app_name}.{model_name} with prefix '{prefix}'"
) )
def get_graphql_type_for_model(model):
"""
Return the GraphQL type class for the given model.
"""
app_name, model_name = model._meta.label.split('.')
class_name = f'{app_name}.graphql.types.{model_name}Type'
try:
return dynamic_import(class_name)
except AttributeError:
raise GraphQLTypeNotFound(f"Could not find GraphQL type for {app_name}.{model_name}")
def is_api_request(request): def is_api_request(request):
""" """
Return True of the request is being made via the REST API. Return True of the request is being made via the REST API.

View File

@ -5,12 +5,14 @@ from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.urls import reverse from django.urls import reverse
from django.test import override_settings from django.test import override_settings
from graphene.types.dynamic import Dynamic
from rest_framework import status from rest_framework import status
from rest_framework.test import APIClient from rest_framework.test import APIClient
from extras.choices import ObjectChangeActionChoices from extras.choices import ObjectChangeActionChoices
from extras.models import ObjectChange from extras.models import ObjectChange
from users.models import ObjectPermission, Token from users.models import ObjectPermission, Token
from utilities.api import get_graphql_type_for_model
from .base import ModelTestCase from .base import ModelTestCase
from .utils import disable_warnings from .utils import disable_warnings
@ -431,18 +433,42 @@ class APIViewTestCases:
self.model._meta.verbose_name_plural.lower().replace(' ', '_')) self.model._meta.verbose_name_plural.lower().replace(' ', '_'))
return getattr(self, 'graphql_base_name', self.model._meta.verbose_name.lower().replace(' ', '_')) return getattr(self, 'graphql_base_name', self.model._meta.verbose_name.lower().replace(' ', '_'))
def _build_query(self, name, **filters):
type_class = get_graphql_type_for_model(self.model)
if filters:
filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
filter_string = f'({filter_string})'
else:
filter_string = ''
# Compile list of fields to include
fields_string = ''
for field_name, field in type_class._meta.fields.items():
# TODO: Omit "hidden" fields from GraphQL types
if field_name.startswith('_'):
continue
if type(field) is Dynamic:
# Dynamic fields must specify a subselection
fields_string += f'{field_name} {{ id }}\n'
else:
fields_string += f'{field_name}\n'
query = f"""
{{
{name}{filter_string} {{
{fields_string}
}}
}}
"""
return query
@override_settings(LOGIN_REQUIRED=True) @override_settings(LOGIN_REQUIRED=True)
def test_graphql_get_object(self): def test_graphql_get_object(self):
url = reverse('graphql') url = reverse('graphql')
object_type = self._get_graphql_base_name() object_type = self._get_graphql_base_name()
object_id = self._get_queryset().first().pk object_id = self._get_queryset().first().pk
query = f""" query = self._build_query(object_type, id=object_id)
{{
{object_type}(id:{object_id}) {{
id
}}
}}
"""
# Non-authenticated requests should fail # Non-authenticated requests should fail
with disable_warnings('django.request'): with disable_warnings('django.request'):
@ -466,13 +492,7 @@ class APIViewTestCases:
def test_graphql_list_objects(self): def test_graphql_list_objects(self):
url = reverse('graphql') url = reverse('graphql')
object_type = self._get_graphql_base_name(plural=True) object_type = self._get_graphql_base_name(plural=True)
query = f""" query = self._build_query(object_type)
{{
{object_type} {{
id
}}
}}
"""
# Non-authenticated requests should fail # Non-authenticated requests should fail
with disable_warnings('django.request'): with disable_warnings('django.request'):