diff --git a/netbox/utilities/testing/testcases.py b/netbox/utilities/testing/testcases.py index d10bb025a..149ae8000 100644 --- a/netbox/utilities/testing/testcases.py +++ b/netbox/utilities/testing/testcases.py @@ -3,6 +3,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.forms.models import model_to_dict from django.test import Client, TestCase as _TestCase, override_settings from django.urls import reverse, NoReverseMatch +from rest_framework import status from rest_framework.test import APIClient from users.models import Token @@ -57,6 +58,34 @@ class TestCase(_TestCase): expected_status, response.status_code, getattr(response, 'data', 'No data') )) + def assertInstanceEqual(self, instance, data): + """ + Compare a model instance to a dictionary, checking that its attribute values match those specified + in the dictionary. + """ + model_dict = model_to_dict(instance, fields=data.keys()) + + for key in list(model_dict.keys()): + + # TODO: Differentiate between tags assigned to the instance and a M2M field for tags (ex: ConfigContext) + if key == 'tags': + model_dict[key] = ','.join(sorted([tag.name for tag in model_dict['tags']])) + + # Convert ManyToManyField to list of instance PKs + elif model_dict[key] and type(model_dict[key]) in (list, tuple) and hasattr(model_dict[key][0], 'pk'): + model_dict[key] = [obj.pk for obj in model_dict[key]] + + # Omit any dictionary keys which are not instance attributes + relevant_data = { + k: v for k, v in data.items() if hasattr(instance, k) + } + + self.assertDictEqual(model_dict, relevant_data) + + +# +# UI Tests +# class ModelViewTestCase(TestCase): """ @@ -104,42 +133,6 @@ class ModelViewTestCase(TestCase): else: raise Exception("Invalid action for URL resolution: {}".format(action)) - def assertInstanceEqual(self, instance, data): - """ - Compare a model instance to a dictionary, checking that its attribute values match those specified - in the dictionary. - """ - model_dict = model_to_dict(instance, fields=data.keys()) - - for key in list(model_dict.keys()): - - # TODO: Differentiate between tags assigned to the instance and a M2M field for tags (ex: ConfigContext) - if key == 'tags': - model_dict[key] = ','.join(sorted([tag.name for tag in model_dict['tags']])) - - # Convert ManyToManyField to list of instance PKs - elif model_dict[key] and type(model_dict[key]) in (list, tuple) and hasattr(model_dict[key][0], 'pk'): - model_dict[key] = [obj.pk for obj in model_dict[key]] - - # Omit any dictionary keys which are not instance attributes - relevant_data = { - k: v for k, v in data.items() if hasattr(instance, k) - } - - self.assertDictEqual(model_dict, relevant_data) - - -class APITestCase(TestCase): - client_class = APIClient - - def setUp(self): - """ - Create a superuser and token for API calls. - """ - self.user = User.objects.create(username='testuser', is_superuser=True) - self.token = Token.objects.create(user=self.user) - self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)} - class ViewTestCases: """ @@ -488,3 +481,129 @@ class ViewTestCases: TestCase suitable for testing device component models (ConsolePorts, Interfaces, etc.) """ maxDiff = None + + +# +# REST API Tests +# + +class APITestCase(TestCase): + client_class = APIClient + model = None + + def setUp(self): + """ + Create a superuser and token for API calls. + """ + self.user = User.objects.create(username='testuser', is_superuser=True) + self.token = Token.objects.create(user=self.user) + self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)} + + def _get_detail_url(self, instance): + viewname = f'{instance._meta.app_label}-api:{instance._meta.model_name}-detail' + return reverse(viewname, kwargs={'pk': instance.pk}) + + def _get_list_url(self): + viewname = f'{self.model._meta.app_label}-api:{self.model._meta.model_name}-list' + return reverse(viewname) + + +class APIViewTestCases: + + class GetObjectViewTestCase(APITestCase): + + def test_get_object(self): + """ + GET a single object identified by its numeric ID. + """ + instance = self.model.objects.first() + url = self._get_detail_url(instance) + response = self.client.get(url, **self.header) + + self.assertEqual(response.data['id'], instance.pk) + + class ListObjectsViewTestCase(APITestCase): + brief_fields = [] + + def test_list_objects(self): + """ + GET a list of objects. + """ + url = self._get_list_url() + response = self.client.get(url, **self.header) + + self.assertEqual(len(response.data['results']), self.model.objects.count()) + + def test_list_objects_brief(self): + """ + GET a list of objects using the "brief" parameter. + """ + url = f'{self._get_list_url()}?brief=1' + response = self.client.get(url, **self.header) + + self.assertEqual(len(response.data['results']), self.model.objects.count()) + self.assertEqual(sorted(response.data['results'][0]), self.brief_fields) + + class CreateObjectViewTestCase(APITestCase): + create_data = [] + + def test_create_object(self): + """ + POST a single object. + """ + initial_count = self.model.objects.count() + url = self._get_list_url() + response = self.client.post(url, self.create_data[0], format='json', **self.header) + + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(self.model.objects.count(), initial_count + 1) + self.assertInstanceEqual(self.model.objects.get(pk=response.data['id']), self.create_data[0]) + + def test_bulk_create_object(self): + """ + POST a set of objects in a single request. + """ + initial_count = self.model.objects.count() + url = self._get_list_url() + response = self.client.post(url, self.create_data, format='json', **self.header) + + self.assertHttpStatus(response, status.HTTP_201_CREATED) + self.assertEqual(self.model.objects.count(), initial_count + len(self.create_data)) + + class UpdateObjectViewTestCase(APITestCase): + update_data = {} + + def test_update_object(self): + """ + PATCH a single object identified by its numeric ID. + """ + instance = self.model.objects.first() + url = self._get_detail_url(instance) + update_data = self.update_data + response = self.client.patch(url, update_data, format='json', **self.header) + + self.assertHttpStatus(response, status.HTTP_200_OK) + instance.refresh_from_db() + self.assertInstanceEqual(instance, self.update_data) + + class DeleteObjectViewTestCase(APITestCase): + + def test_delete_object(self): + """ + DELETE a single object identified by its numeric ID. + """ + instance = self.model.objects.first() + url = self._get_detail_url(instance) + response = self.client.delete(url, **self.header) + + self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) + self.assertFalse(self.model.objects.filter(pk=instance.pk).exists()) + + class APIViewTestCase( + GetObjectViewTestCase, + ListObjectsViewTestCase, + CreateObjectViewTestCase, + UpdateObjectViewTestCase, + DeleteObjectViewTestCase + ): + pass