diff --git a/netbox/dcim/tests/test_api.py b/netbox/dcim/tests/test_api.py index 4b711efe4..22fc0b71c 100644 --- a/netbox/dcim/tests/test_api.py +++ b/netbox/dcim/tests/test_api.py @@ -127,9 +127,6 @@ class SiteTest(APITestCase): self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2') self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='test-site-1') self.site2 = Site.objects.create(region=self.region1, name='Test Site 2', slug='test-site-2') - self.site3 = Site.objects.create(region=self.region2, name='Test Site 3', slug='test-site-3') - self.site_non_region1 = Site.objects.create(name='Test Site Null Region1', slug='test-site-no-region1') - self.site_non_region2 = Site.objects.create(name='Test Site Null Region2', slug='test-site-no-region2') def test_get_site(self): @@ -164,7 +161,7 @@ class SiteTest(APITestCase): url = reverse('dcim-api:site-list') response = self.client.get(url, **self.header) - self.assertEqual(response.data['count'], 5) + self.assertEqual(response.data['count'], 3) def test_list_sites_brief(self): @@ -176,20 +173,6 @@ class SiteTest(APITestCase): ['id', 'name', 'slug', 'url'] ) - def test_list_sites_null_region(self): - - url = reverse('dcim-api:site-list') - response = self.client.get('{}?region=null'.format(url), **self.header) - - self.assertEqual(response.data['count'], 2) - - def test_list_sites_multiple_regions(self): - - url = reverse('dcim-api:site-list') - response = self.client.get('{}?region=null®ion=test-region-1'.format(url), **self.header) - - self.assertEqual(response.data['count'], 4) - def test_create_site(self): data = { @@ -203,7 +186,7 @@ class SiteTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Site.objects.count(), 6) + self.assertEqual(Site.objects.count(), 4) site4 = Site.objects.get(pk=response.data['id']) self.assertEqual(site4.name, data['name']) self.assertEqual(site4.slug, data['slug']) @@ -236,7 +219,7 @@ class SiteTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Site.objects.count(), 8) + self.assertEqual(Site.objects.count(), 6) self.assertEqual(response.data[0]['name'], data[0]['name']) self.assertEqual(response.data[1]['name'], data[1]['name']) self.assertEqual(response.data[2]['name'], data[2]['name']) @@ -253,7 +236,7 @@ class SiteTest(APITestCase): response = self.client.put(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(Site.objects.count(), 5) + self.assertEqual(Site.objects.count(), 3) site1 = Site.objects.get(pk=response.data['id']) self.assertEqual(site1.name, data['name']) self.assertEqual(site1.slug, data['slug']) @@ -265,7 +248,7 @@ class SiteTest(APITestCase): response = self.client.delete(url, **self.header) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Site.objects.count(), 4) + self.assertEqual(Site.objects.count(), 2) class RackGroupTest(APITestCase): @@ -1769,8 +1752,7 @@ class DeviceTest(APITestCase): super().setUp() - region = Region.objects.create(name='Test Region', slug='test-region-1') - self.site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1') + self.site1 = Site.objects.create(name='Test Site 1', slug='test-site-1') self.site2 = Site.objects.create(name='Test Site 2', slug='test-site-2') manufacturer = Manufacturer.objects.create(name='Test Manufacturer 1', slug='test-manufacturer-1') self.devicetype1 = DeviceType.objects.create( @@ -1818,20 +1800,6 @@ class DeviceTest(APITestCase): 'B': 2 } ) - self.device_non_region1 = Device.objects.create( - device_type=self.devicetype1, - device_role=self.devicerole1, - name='Test Device Null Region1', - site=self.site2, - cluster=self.cluster1 - ) - self.device_non_region2 = Device.objects.create( - device_type=self.devicetype1, - device_role=self.devicerole1, - name='Test Device Null Region2', - site=self.site2, - cluster=self.cluster1 - ) def test_get_device(self): @@ -1847,7 +1815,7 @@ class DeviceTest(APITestCase): url = reverse('dcim-api:device-list') response = self.client.get(url, **self.header) - self.assertEqual(response.data['count'], 6) + self.assertEqual(response.data['count'], 4) def test_list_devices_brief(self): @@ -1859,20 +1827,6 @@ class DeviceTest(APITestCase): ['display_name', 'id', 'name', 'url'] ) - def test_list_devices_null_region(self): - - url = reverse('dcim-api:device-list') - response = self.client.get('{}?region=null'.format(url), **self.header) - - self.assertEqual(response.data['count'], 2) - - def test_list_devices_multiple_regions(self): - - url = reverse('dcim-api:device-list') - response = self.client.get('{}?region=null®ion=test-region-1'.format(url), **self.header) - - self.assertEqual(response.data['count'], 6) - def test_create_device(self): data = { @@ -1887,7 +1841,7 @@ class DeviceTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Device.objects.count(), 7) + self.assertEqual(Device.objects.count(), 5) device4 = Device.objects.get(pk=response.data['id']) self.assertEqual(device4.device_type_id, data['device_type']) self.assertEqual(device4.device_role_id, data['device_role']) @@ -1922,7 +1876,7 @@ class DeviceTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(Device.objects.count(), 9) + self.assertEqual(Device.objects.count(), 7) self.assertEqual(response.data[0]['name'], data[0]['name']) self.assertEqual(response.data[1]['name'], data[1]['name']) self.assertEqual(response.data[2]['name'], data[2]['name']) @@ -1946,7 +1900,7 @@ class DeviceTest(APITestCase): response = self.client.put(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(Device.objects.count(), 6) + self.assertEqual(Device.objects.count(), 4) device1 = Device.objects.get(pk=response.data['id']) self.assertEqual(device1.device_type_id, data['device_type']) self.assertEqual(device1.device_role_id, data['device_role']) @@ -1961,7 +1915,7 @@ class DeviceTest(APITestCase): response = self.client.delete(url, **self.header) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(Device.objects.count(), 5) + self.assertEqual(Device.objects.count(), 3) def test_config_context_included_by_default_in_list_view(self): diff --git a/netbox/utilities/tests/test_filters.py b/netbox/utilities/tests/test_filters.py new file mode 100644 index 000000000..513e11bca --- /dev/null +++ b/netbox/utilities/tests/test_filters.py @@ -0,0 +1,62 @@ +from django.conf import settings +from django.test import TestCase +import django_filters + +from dcim.models import Region, Site +from utilities.filters import TreeNodeMultipleChoiceFilter + + +class TreeNodeMultipleChoiceFilterTest(TestCase): + + class SiteFilterSet(django_filters.FilterSet): + region = TreeNodeMultipleChoiceFilter( + queryset=Region.objects.all(), + field_name='region__in', + to_field_name='slug', + ) + + def setUp(self): + + super().setUp() + + self.region1 = Region.objects.create(name='Test Region 1', slug='test-region-1') + self.region2 = Region.objects.create(name='Test Region 2', slug='test-region-2') + self.site1 = Site.objects.create(region=self.region1, name='Test Site 1', slug='test-site1') + self.site2 = Site.objects.create(region=self.region2, name='Test Site 2', slug='test-site2') + self.site3 = Site.objects.create(region=None, name='Test Site 3', slug='test-site3') + + self.queryset = Site.objects.all() + + def test_filter_single(self): + + kwargs = {'region': ['test-region-1']} + qs = self.SiteFilterSet(kwargs, self.queryset).qs + + self.assertEqual(qs.count(), 1) + self.assertEqual(qs[0], self.site1) + + def test_filter_multiple(self): + + kwargs = {'region': ['test-region-1', 'test-region-2']} + qs = self.SiteFilterSet(kwargs, self.queryset).qs + + self.assertEqual(qs.count(), 2) + self.assertEqual(qs[0], self.site1) + self.assertEqual(qs[1], self.site2) + + def test_filter_null(self): + + kwargs = {'region': [settings.FILTERS_NULL_CHOICE_VALUE]} + qs = self.SiteFilterSet(kwargs, self.queryset).qs + + self.assertEqual(qs.count(), 1) + self.assertEqual(qs[0], self.site3) + + def test_filter_combined(self): + + kwargs = {'region': ['test-region-1', settings.FILTERS_NULL_CHOICE_VALUE]} + qs = self.SiteFilterSet(kwargs, self.queryset).qs + + self.assertEqual(qs.count(), 2) + self.assertEqual(qs[0], self.site1) + self.assertEqual(qs[1], self.site3) diff --git a/netbox/virtualization/tests/test_api.py b/netbox/virtualization/tests/test_api.py index 7bbeccbdd..f1e372dd4 100644 --- a/netbox/virtualization/tests/test_api.py +++ b/netbox/virtualization/tests/test_api.py @@ -3,7 +3,7 @@ from netaddr import IPNetwork from rest_framework import status from dcim.constants import IFACE_TYPE_VIRTUAL, IFACE_MODE_TAGGED -from dcim.models import Interface, Region, Site +from dcim.models import Interface from ipam.models import IPAddress, VLAN from utilities.testing import APITestCase from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine @@ -330,14 +330,9 @@ class VirtualMachineTest(APITestCase): super().setUp() - region = Region.objects.create(name='Test Region 1', slug='test-region-1') - site1 = Site.objects.create(region=region, name='Test Site 1', slug='test-site-1') - site2 = Site.objects.create(name='Test Site 2', slug='test-site-2') - cluster_type = ClusterType.objects.create(name='Test Cluster Type 1', slug='test-cluster-type-1') cluster_group = ClusterGroup.objects.create(name='Test Cluster Group 1', slug='test-cluster-group-1') - self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group, site=site1) - self.cluster2 = Cluster.objects.create(name='Test Cluster 2', type=cluster_type, group=cluster_group, site=site2) + self.cluster1 = Cluster.objects.create(name='Test Cluster 1', type=cluster_type, group=cluster_group) self.virtualmachine1 = VirtualMachine.objects.create(name='Test Virtual Machine 1', cluster=self.cluster1) self.virtualmachine2 = VirtualMachine.objects.create(name='Test Virtual Machine 2', cluster=self.cluster1) @@ -350,8 +345,6 @@ class VirtualMachineTest(APITestCase): 'B': 2 } ) - self.virtualmachine_non_region1 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region1', cluster=self.cluster2) - self.virtualmachine_non_region2 = VirtualMachine.objects.create(name='Test Virtual Machine Null Region2', cluster=self.cluster2) def test_get_virtualmachine(self): @@ -365,7 +358,7 @@ class VirtualMachineTest(APITestCase): url = reverse('virtualization-api:virtualmachine-list') response = self.client.get(url, **self.header) - self.assertEqual(response.data['count'], 6) + self.assertEqual(response.data['count'], 4) def test_list_virtualmachines_brief(self): @@ -377,20 +370,6 @@ class VirtualMachineTest(APITestCase): ['id', 'name', 'url'] ) - def test_list_virtualmachines_null_region(self): - - url = reverse('virtualization-api:virtualmachine-list') - response = self.client.get('{}?region=null'.format(url), **self.header) - - self.assertEqual(response.data['count'], 2) - - def test_list_virtualmachines_multiple_regions(self): - - url = reverse('virtualization-api:virtualmachine-list') - response = self.client.get('{}?region=null®ion=test-region-1'.format(url), **self.header) - - self.assertEqual(response.data['count'], 6) - def test_create_virtualmachine(self): data = { @@ -402,7 +381,7 @@ class VirtualMachineTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(VirtualMachine.objects.count(), 7) + self.assertEqual(VirtualMachine.objects.count(), 5) virtualmachine4 = VirtualMachine.objects.get(pk=response.data['id']) self.assertEqual(virtualmachine4.name, data['name']) self.assertEqual(virtualmachine4.cluster.pk, data['cluster']) @@ -417,7 +396,7 @@ class VirtualMachineTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) - self.assertEqual(VirtualMachine.objects.count(), 6) + self.assertEqual(VirtualMachine.objects.count(), 4) def test_create_virtualmachine_bulk(self): @@ -440,7 +419,7 @@ class VirtualMachineTest(APITestCase): response = self.client.post(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_201_CREATED) - self.assertEqual(VirtualMachine.objects.count(), 9) + self.assertEqual(VirtualMachine.objects.count(), 7) self.assertEqual(response.data[0]['name'], data[0]['name']) self.assertEqual(response.data[1]['name'], data[1]['name']) self.assertEqual(response.data[2]['name'], data[2]['name']) @@ -451,9 +430,14 @@ class VirtualMachineTest(APITestCase): ip4_address = IPAddress.objects.create(address=IPNetwork('192.0.2.1/24'), interface=interface) ip6_address = IPAddress.objects.create(address=IPNetwork('2001:db8::1/64'), interface=interface) + cluster2 = Cluster.objects.create( + name='Test Cluster 2', + type=ClusterType.objects.first(), + group=ClusterGroup.objects.first() + ) data = { 'name': 'Test Virtual Machine X', - 'cluster': self.cluster2.pk, + 'cluster': cluster2.pk, 'primary_ip4': ip4_address.pk, 'primary_ip6': ip6_address.pk, } @@ -462,7 +446,7 @@ class VirtualMachineTest(APITestCase): response = self.client.put(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_200_OK) - self.assertEqual(VirtualMachine.objects.count(), 6) + self.assertEqual(VirtualMachine.objects.count(), 4) virtualmachine1 = VirtualMachine.objects.get(pk=response.data['id']) self.assertEqual(virtualmachine1.name, data['name']) self.assertEqual(virtualmachine1.cluster.pk, data['cluster']) @@ -475,7 +459,7 @@ class VirtualMachineTest(APITestCase): response = self.client.delete(url, **self.header) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) - self.assertEqual(VirtualMachine.objects.count(), 5) + self.assertEqual(VirtualMachine.objects.count(), 3) def test_config_context_included_by_default_in_list_view(self):