diff --git a/netbox/netbox/api/exceptions.py b/netbox/netbox/api/exceptions.py index 8c62eee4c..f552b06b5 100644 --- a/netbox/netbox/api/exceptions.py +++ b/netbox/netbox/api/exceptions.py @@ -8,3 +8,7 @@ class ServiceUnavailable(APIException): class SerializerNotFound(Exception): pass + + +class GraphQLTypeNotFound(Exception): + pass diff --git a/netbox/utilities/api.py b/netbox/utilities/api.py index 09cc7004b..b4bde9b53 100644 --- a/netbox/utilities/api.py +++ b/netbox/utilities/api.py @@ -7,7 +7,7 @@ from django.urls import reverse from rest_framework import status from rest_framework.utils import formatting -from netbox.api.exceptions import SerializerNotFound +from netbox.api.exceptions import GraphQLTypeNotFound, SerializerNotFound from .utils import dynamic_import @@ -24,10 +24,22 @@ def get_serializer_for_model(model, prefix=''): return dynamic_import(serializer_name) except AttributeError: raise SerializerNotFound( - "Could not determine serializer for {}.{} with prefix '{}'".format(app_name, model_name, prefix) + f"Could not determine serializer for {app_name}.{model_name} with prefix '{prefix}'" ) +def get_graphql_type_for_model(model): + """ + Return the GraphQL type class for the given model. + """ + app_name, model_name = model._meta.label.split('.') + class_name = f'{app_name}.graphql.types.{model_name}Type' + try: + return dynamic_import(class_name) + except AttributeError: + raise GraphQLTypeNotFound(f"Could not find GraphQL type for {app_name}.{model_name}") + + def is_api_request(request): """ Return True of the request is being made via the REST API. diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index 2549492c4..fd18259d1 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -5,12 +5,14 @@ from django.contrib.auth.models import User from django.contrib.contenttypes.models import ContentType from django.urls import reverse from django.test import override_settings +from graphene.types.dynamic import Dynamic from rest_framework import status from rest_framework.test import APIClient from extras.choices import ObjectChangeActionChoices from extras.models import ObjectChange from users.models import ObjectPermission, Token +from utilities.api import get_graphql_type_for_model from .base import ModelTestCase from .utils import disable_warnings @@ -431,18 +433,42 @@ class APIViewTestCases: self.model._meta.verbose_name_plural.lower().replace(' ', '_')) return getattr(self, 'graphql_base_name', self.model._meta.verbose_name.lower().replace(' ', '_')) + def _build_query(self, name, **filters): + type_class = get_graphql_type_for_model(self.model) + if filters: + filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items()) + filter_string = f'({filter_string})' + else: + filter_string = '' + + # Compile list of fields to include + fields_string = '' + for field_name, field in type_class._meta.fields.items(): + # TODO: Omit "hidden" fields from GraphQL types + if field_name.startswith('_'): + continue + if type(field) is Dynamic: + # Dynamic fields must specify a subselection + fields_string += f'{field_name} {{ id }}\n' + else: + fields_string += f'{field_name}\n' + + query = f""" + {{ + {name}{filter_string} {{ + {fields_string} + }} + }} + """ + + return query + @override_settings(LOGIN_REQUIRED=True) def test_graphql_get_object(self): url = reverse('graphql') object_type = self._get_graphql_base_name() object_id = self._get_queryset().first().pk - query = f""" - {{ - {object_type}(id:{object_id}) {{ - id - }} - }} - """ + query = self._build_query(object_type, id=object_id) # Non-authenticated requests should fail with disable_warnings('django.request'): @@ -466,13 +492,7 @@ class APIViewTestCases: def test_graphql_list_objects(self): url = reverse('graphql') object_type = self._get_graphql_base_name(plural=True) - query = f""" - {{ - {object_type} {{ - id - }} - }} - """ + query = self._build_query(object_type) # Non-authenticated requests should fail with disable_warnings('django.request'):