1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

Merge pull request #4733 from netbox-community/4730-api-test-permissions

Closes #4730: Update REST API tests to enforce ObjectPermissions
This commit is contained in:
Jeremy Stretch
2020-06-08 16:53:22 -04:00
committed by GitHub
20 changed files with 455 additions and 309 deletions

View File

@ -58,6 +58,7 @@ class ProviderTest(APIViewTestCases.APIViewTestCase):
) )
Graph.objects.bulk_create(graphs) Graph.objects.bulk_create(graphs)
self.add_permissions('circuits.view_provider')
url = reverse('circuits-api:provider-graphs', kwargs={'pk': provider.pk}) url = reverse('circuits-api:provider-graphs', kwargs={'pk': provider.pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)

View File

@ -106,6 +106,7 @@ class SiteTest(APIViewTestCases.APIViewTestCase):
) )
Graph.objects.bulk_create(graphs) Graph.objects.bulk_create(graphs)
self.add_permissions('dcim.view_site')
url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.first().pk}) url = reverse('dcim-api:site-graphs', kwargs={'pk': Site.objects.first().pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -245,6 +246,7 @@ class RackTest(APIViewTestCases.APIViewTestCase):
def test_get_elevation_rack_units(self): def test_get_elevation_rack_units(self):
rack = Rack.objects.first() rack = Rack.objects.first()
self.add_permissions('dcim.view_rack')
url = '{}?q=3'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk})) url = '{}?q=3'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk}))
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -270,6 +272,7 @@ class RackTest(APIViewTestCases.APIViewTestCase):
GET a single rack elevation. GET a single rack elevation.
""" """
rack = Rack.objects.first() rack = Rack.objects.first()
self.add_permissions('dcim.view_rack')
url = reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk}) url = reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -280,6 +283,7 @@ class RackTest(APIViewTestCases.APIViewTestCase):
GET a single rack elevation in SVG format. GET a single rack elevation in SVG format.
""" """
rack = Rack.objects.first() rack = Rack.objects.first()
self.add_permissions('dcim.view_rack')
url = '{}?render=svg'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk})) url = '{}?render=svg'.format(reverse('dcim-api:rack-elevation', kwargs={'pk': rack.pk}))
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -784,6 +788,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
) )
Graph.objects.bulk_create(graphs) Graph.objects.bulk_create(graphs)
self.add_permissions('dcim.view_device')
url = reverse('dcim-api:device-graphs', kwargs={'pk': Device.objects.first().pk}) url = reverse('dcim-api:device-graphs', kwargs={'pk': Device.objects.first().pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -794,6 +799,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
""" """
Check that config context data is included by default in the devices list. Check that config context data is included by default in the devices list.
""" """
self.add_permissions('dcim.view_device')
url = reverse('dcim-api:device-list') + '?slug=device-with-context-data' url = reverse('dcim-api:device-list') + '?slug=device-with-context-data'
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -803,6 +809,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
""" """
Check that config context data can be excluded by passing ?exclude=config_context. Check that config context data can be excluded by passing ?exclude=config_context.
""" """
self.add_permissions('dcim.view_device')
url = reverse('dcim-api:device-list') + '?exclude=config_context' url = reverse('dcim-api:device-list') + '?exclude=config_context'
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -820,6 +827,7 @@ class DeviceTest(APIViewTestCases.APIViewTestCase):
'name': device.name, 'name': device.name,
} }
self.add_permissions('dcim.add_device')
url = reverse('dcim-api:device-list') url = reverse('dcim-api:device-list')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
@ -878,6 +886,7 @@ class ConsolePortTest(APIViewTestCases.APIViewTestCase):
cable = Cable(termination_a=consoleport, termination_b=consoleserverport, label='Cable 1') cable = Cable(termination_a=consoleport, termination_b=consoleserverport, label='Cable 1')
cable.save() cable.save()
self.add_permissions('dcim.view_consoleport')
url = reverse('dcim-api:consoleport-trace', kwargs={'pk': consoleport.pk}) url = reverse('dcim-api:consoleport-trace', kwargs={'pk': consoleport.pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -941,6 +950,7 @@ class ConsoleServerPortTest(APIViewTestCases.APIViewTestCase):
cable = Cable(termination_a=consoleserverport, termination_b=consoleport, label='Cable 1') cable = Cable(termination_a=consoleserverport, termination_b=consoleport, label='Cable 1')
cable.save() cable.save()
self.add_permissions('dcim.view_consoleserverport')
url = reverse('dcim-api:consoleserverport-trace', kwargs={'pk': consoleserverport.pk}) url = reverse('dcim-api:consoleserverport-trace', kwargs={'pk': consoleserverport.pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -1004,6 +1014,7 @@ class PowerPortTest(APIViewTestCases.APIViewTestCase):
cable = Cable(termination_a=powerport, termination_b=poweroutlet, label='Cable 1') cable = Cable(termination_a=powerport, termination_b=poweroutlet, label='Cable 1')
cable.save() cable.save()
self.add_permissions('dcim.view_powerport')
url = reverse('dcim-api:powerport-trace', kwargs={'pk': powerport.pk}) url = reverse('dcim-api:powerport-trace', kwargs={'pk': powerport.pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -1067,6 +1078,7 @@ class PowerOutletTest(APIViewTestCases.APIViewTestCase):
cable = Cable(termination_a=poweroutlet, termination_b=powerport, label='Cable 1') cable = Cable(termination_a=poweroutlet, termination_b=powerport, label='Cable 1')
cable.save() cable.save()
self.add_permissions('dcim.view_poweroutlet')
url = reverse('dcim-api:poweroutlet-trace', kwargs={'pk': poweroutlet.pk}) url = reverse('dcim-api:poweroutlet-trace', kwargs={'pk': poweroutlet.pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -1143,6 +1155,7 @@ class InterfaceTest(APIViewTestCases.APIViewTestCase):
) )
Graph.objects.bulk_create(graphs) Graph.objects.bulk_create(graphs)
self.add_permissions('dcim.view_interface')
url = reverse('dcim-api:interface-graphs', kwargs={'pk': Interface.objects.first().pk}) url = reverse('dcim-api:interface-graphs', kwargs={'pk': Interface.objects.first().pk})
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -1446,6 +1459,7 @@ class ConnectionTest(APITestCase):
'termination_b_id': consoleserverport1.pk, 'termination_b_id': consoleserverport1.pk,
} }
self.add_permissions('dcim.add_cable')
url = reverse('dcim-api:cable-list') url = reverse('dcim-api:cable-list')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
@ -1484,6 +1498,7 @@ class ConnectionTest(APITestCase):
device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2 device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2
) )
self.add_permissions('dcim.add_cable')
url = reverse('dcim-api:cable-list') url = reverse('dcim-api:cable-list')
cables = [ cables = [
# Console port to panel1 front # Console port to panel1 front
@ -1539,6 +1554,7 @@ class ConnectionTest(APITestCase):
'termination_b_id': poweroutlet1.pk, 'termination_b_id': poweroutlet1.pk,
} }
self.add_permissions('dcim.add_cable')
url = reverse('dcim-api:cable-list') url = reverse('dcim-api:cable-list')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
@ -1574,6 +1590,7 @@ class ConnectionTest(APITestCase):
'termination_b_id': interface2.pk, 'termination_b_id': interface2.pk,
} }
self.add_permissions('dcim.add_cable')
url = reverse('dcim-api:cable-list') url = reverse('dcim-api:cable-list')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
@ -1612,6 +1629,7 @@ class ConnectionTest(APITestCase):
device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2 device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2
) )
self.add_permissions('dcim.add_cable')
url = reverse('dcim-api:cable-list') url = reverse('dcim-api:cable-list')
cables = [ cables = [
# Interface1 to panel1 front # Interface1 to panel1 front
@ -1676,6 +1694,7 @@ class ConnectionTest(APITestCase):
'termination_b_id': circuittermination1.pk, 'termination_b_id': circuittermination1.pk,
} }
self.add_permissions('dcim.add_cable')
url = reverse('dcim-api:cable-list') url = reverse('dcim-api:cable-list')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
@ -1723,6 +1742,7 @@ class ConnectionTest(APITestCase):
device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2 device=self.panel2, name='Test Front Port 2', type=PortTypeChoices.TYPE_8P8C, rear_port=rearport2
) )
self.add_permissions('dcim.add_cable')
url = reverse('dcim-api:cable-list') url = reverse('dcim-api:cable-list')
cables = [ cables = [
# Interface to panel1 front # Interface to panel1 front
@ -1826,6 +1846,9 @@ class VirtualChassisTest(APIViewTestCases.APIViewTestCase):
Device(name='Device 7', device_type=devicetype, device_role=devicerole, site=site), Device(name='Device 7', device_type=devicetype, device_role=devicerole, site=site),
Device(name='Device 8', device_type=devicetype, device_role=devicerole, site=site), Device(name='Device 8', device_type=devicetype, device_role=devicerole, site=site),
Device(name='Device 9', device_type=devicetype, device_role=devicerole, site=site), Device(name='Device 9', device_type=devicetype, device_role=devicerole, site=site),
Device(name='Device 10', device_type=devicetype, device_role=devicerole, site=site),
Device(name='Device 11', device_type=devicetype, device_role=devicerole, site=site),
Device(name='Device 12', device_type=devicetype, device_role=devicerole, site=site),
) )
Device.objects.bulk_create(devices) Device.objects.bulk_create(devices)
@ -1839,16 +1862,19 @@ class VirtualChassisTest(APIViewTestCases.APIViewTestCase):
) )
Interface.objects.bulk_create(interfaces) Interface.objects.bulk_create(interfaces)
# Create two VirtualChassis with three members each # Create three VirtualChassis with three members each
virtual_chassis = ( virtual_chassis = (
VirtualChassis(master=devices[0], domain='domain-1'), VirtualChassis(master=devices[0], domain='domain-1'),
VirtualChassis(master=devices[3], domain='domain-2'), VirtualChassis(master=devices[3], domain='domain-2'),
VirtualChassis(master=devices[6], domain='domain-3'),
) )
VirtualChassis.objects.bulk_create(virtual_chassis) VirtualChassis.objects.bulk_create(virtual_chassis)
Device.objects.filter(pk=devices[1].pk).update(virtual_chassis=virtual_chassis[0], vc_position=2) Device.objects.filter(pk=devices[1].pk).update(virtual_chassis=virtual_chassis[0], vc_position=2)
Device.objects.filter(pk=devices[2].pk).update(virtual_chassis=virtual_chassis[0], vc_position=3) Device.objects.filter(pk=devices[2].pk).update(virtual_chassis=virtual_chassis[0], vc_position=3)
Device.objects.filter(pk=devices[4].pk).update(virtual_chassis=virtual_chassis[1], vc_position=2) Device.objects.filter(pk=devices[4].pk).update(virtual_chassis=virtual_chassis[1], vc_position=2)
Device.objects.filter(pk=devices[5].pk).update(virtual_chassis=virtual_chassis[1], vc_position=3) Device.objects.filter(pk=devices[5].pk).update(virtual_chassis=virtual_chassis[1], vc_position=3)
Device.objects.filter(pk=devices[7].pk).update(virtual_chassis=virtual_chassis[2], vc_position=2)
Device.objects.filter(pk=devices[8].pk).update(virtual_chassis=virtual_chassis[2], vc_position=3)
cls.update_data = { cls.update_data = {
'master': devices[1].pk, 'master': devices[1].pk,
@ -1857,17 +1883,17 @@ class VirtualChassisTest(APIViewTestCases.APIViewTestCase):
cls.create_data = [ cls.create_data = [
{ {
'master': devices[6].pk, 'master': devices[9].pk,
'domain': 'domain-3',
},
{
'master': devices[7].pk,
'domain': 'domain-4', 'domain': 'domain-4',
}, },
{ {
'master': devices[8].pk, 'master': devices[10].pk,
'domain': 'domain-5', 'domain': 'domain-5',
}, },
{
'master': devices[11].pk,
'domain': 'domain-6',
},
] ]

View File

@ -232,6 +232,8 @@ class Graph(models.Model):
verbose_name='Link URL' verbose_name='Link URL'
) )
objects = RestrictedQuerySet.as_manager()
class Meta: class Meta:
ordering = ('type', 'weight', 'name', 'pk') # (type, weight, name) may be non-unique ordering = ('type', 'weight', 'name', 'pk') # (type, weight, name) may be non-unique
@ -299,6 +301,8 @@ class ExportTemplate(models.Model):
help_text='Extension to append to the rendered filename' help_text='Extension to append to the rendered filename'
) )
objects = RestrictedQuerySet.as_manager()
class Meta: class Meta:
ordering = ['content_type', 'name'] ordering = ['content_type', 'name']
unique_together = [ unique_together = [

View File

@ -295,6 +295,7 @@ class CreatedUpdatedFilterTest(APITestCase):
) )
def test_get_rack_created(self): def test_get_rack_created(self):
self.add_permissions('dcim.view_rack')
url = reverse('dcim-api:rack-list') url = reverse('dcim-api:rack-list')
response = self.client.get('{}?created=2001-02-03'.format(url), **self.header) response = self.client.get('{}?created=2001-02-03'.format(url), **self.header)
@ -302,6 +303,7 @@ class CreatedUpdatedFilterTest(APITestCase):
self.assertEqual(response.data['results'][0]['id'], self.rack2.pk) self.assertEqual(response.data['results'][0]['id'], self.rack2.pk)
def test_get_rack_created_gte(self): def test_get_rack_created_gte(self):
self.add_permissions('dcim.view_rack')
url = reverse('dcim-api:rack-list') url = reverse('dcim-api:rack-list')
response = self.client.get('{}?created__gte=2001-02-04'.format(url), **self.header) response = self.client.get('{}?created__gte=2001-02-04'.format(url), **self.header)
@ -309,6 +311,7 @@ class CreatedUpdatedFilterTest(APITestCase):
self.assertEqual(response.data['results'][0]['id'], self.rack1.pk) self.assertEqual(response.data['results'][0]['id'], self.rack1.pk)
def test_get_rack_created_lte(self): def test_get_rack_created_lte(self):
self.add_permissions('dcim.view_rack')
url = reverse('dcim-api:rack-list') url = reverse('dcim-api:rack-list')
response = self.client.get('{}?created__lte=2001-02-04'.format(url), **self.header) response = self.client.get('{}?created__lte=2001-02-04'.format(url), **self.header)
@ -316,6 +319,7 @@ class CreatedUpdatedFilterTest(APITestCase):
self.assertEqual(response.data['results'][0]['id'], self.rack2.pk) self.assertEqual(response.data['results'][0]['id'], self.rack2.pk)
def test_get_rack_last_updated(self): def test_get_rack_last_updated(self):
self.add_permissions('dcim.view_rack')
url = reverse('dcim-api:rack-list') url = reverse('dcim-api:rack-list')
response = self.client.get('{}?last_updated=2001-02-03%2001:02:03.000004'.format(url), **self.header) response = self.client.get('{}?last_updated=2001-02-03%2001:02:03.000004'.format(url), **self.header)
@ -323,6 +327,7 @@ class CreatedUpdatedFilterTest(APITestCase):
self.assertEqual(response.data['results'][0]['id'], self.rack2.pk) self.assertEqual(response.data['results'][0]['id'], self.rack2.pk)
def test_get_rack_last_updated_gte(self): def test_get_rack_last_updated_gte(self):
self.add_permissions('dcim.view_rack')
url = reverse('dcim-api:rack-list') url = reverse('dcim-api:rack-list')
response = self.client.get('{}?last_updated__gte=2001-02-04%2001:02:03.000004'.format(url), **self.header) response = self.client.get('{}?last_updated__gte=2001-02-04%2001:02:03.000004'.format(url), **self.header)
@ -330,6 +335,7 @@ class CreatedUpdatedFilterTest(APITestCase):
self.assertEqual(response.data['results'][0]['id'], self.rack1.pk) self.assertEqual(response.data['results'][0]['id'], self.rack1.pk)
def test_get_rack_last_updated_lte(self): def test_get_rack_last_updated_lte(self):
self.add_permissions('dcim.view_rack')
url = reverse('dcim-api:rack-list') url = reverse('dcim-api:rack-list')
response = self.client.get('{}?last_updated__lte=2001-02-04%2001:02:03.000004'.format(url), **self.header) response = self.client.get('{}?last_updated__lte=2001-02-04%2001:02:03.000004'.format(url), **self.header)

View File

@ -4,7 +4,6 @@ from rest_framework import status
from dcim.models import Site from dcim.models import Site
from extras.choices import * from extras.choices import *
from extras.constants import *
from extras.models import CustomField, CustomFieldValue, ObjectChange from extras.models import CustomField, CustomFieldValue, ObjectChange
from utilities.testing import APITestCase from utilities.testing import APITestCase
@ -26,7 +25,6 @@ class ChangeLogTest(APITestCase):
cf.obj_type.set([ct]) cf.obj_type.set([ct])
def test_create_object(self): def test_create_object(self):
data = { data = {
'name': 'Test Site 1', 'name': 'Test Site 1',
'slug': 'test-site-1', 'slug': 'test-site-1',
@ -37,10 +35,10 @@ class ChangeLogTest(APITestCase):
'bar', 'foo' 'bar', 'foo'
], ],
} }
self.assertEqual(ObjectChange.objects.count(), 0) self.assertEqual(ObjectChange.objects.count(), 0)
url = reverse('dcim-api:site-list') url = reverse('dcim-api:site-list')
self.add_permissions('dcim.add_site')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
@ -55,7 +53,6 @@ class ChangeLogTest(APITestCase):
self.assertListEqual(sorted(oc.object_data['tags']), data['tags']) self.assertListEqual(sorted(oc.object_data['tags']), data['tags'])
def test_update_object(self): def test_update_object(self):
site = Site(name='Test Site 1', slug='test-site-1') site = Site(name='Test Site 1', slug='test-site-1')
site.save() site.save()
@ -69,10 +66,10 @@ class ChangeLogTest(APITestCase):
'abc', 'xyz' 'abc', 'xyz'
], ],
} }
self.assertEqual(ObjectChange.objects.count(), 0) self.assertEqual(ObjectChange.objects.count(), 0)
self.add_permissions('dcim.change_site')
url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk}) url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
response = self.client.put(url, data, format='json', **self.header) response = self.client.put(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK) self.assertHttpStatus(response, status.HTTP_200_OK)
@ -87,7 +84,6 @@ class ChangeLogTest(APITestCase):
self.assertListEqual(sorted(oc.object_data['tags']), data['tags']) self.assertListEqual(sorted(oc.object_data['tags']), data['tags'])
def test_delete_object(self): def test_delete_object(self):
site = Site( site = Site(
name='Test Site 1', name='Test Site 1',
slug='test-site-1' slug='test-site-1'
@ -99,12 +95,11 @@ class ChangeLogTest(APITestCase):
obj=site, obj=site,
value='ABC' value='ABC'
) )
self.assertEqual(ObjectChange.objects.count(), 0) self.assertEqual(ObjectChange.objects.count(), 0)
self.add_permissions('dcim.delete_site')
url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk}) url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
response = self.client.delete(url, **self.header)
response = self.client.delete(url, **self.header)
self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
self.assertEqual(Site.objects.count(), 0) self.assertEqual(Site.objects.count(), 0)

View File

@ -182,8 +182,9 @@ class CustomFieldAPITest(APITestCase):
Validate that custom fields are present on an object even if it has no values defined. Validate that custom fields are present on an object even if it has no values defined.
""" """
url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[0].pk}) url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[0].pk})
response = self.client.get(url, **self.header) self.add_permissions('dcim.view_site')
response = self.client.get(url, **self.header)
self.assertEqual(response.data['name'], self.sites[0].name) self.assertEqual(response.data['name'], self.sites[0].name)
self.assertEqual(response.data['custom_fields'], { self.assertEqual(response.data['custom_fields'], {
'text_field': None, 'text_field': None,
@ -201,10 +202,10 @@ class CustomFieldAPITest(APITestCase):
site2_cfvs = { site2_cfvs = {
cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all() cfv.field.name: cfv.value for cfv in self.sites[1].custom_field_values.all()
} }
url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk}) url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk})
response = self.client.get(url, **self.header) self.add_permissions('dcim.view_site')
response = self.client.get(url, **self.header)
self.assertEqual(response.data['name'], self.sites[1].name) self.assertEqual(response.data['name'], self.sites[1].name)
self.assertEqual(response.data['custom_fields']['text_field'], site2_cfvs['text_field']) self.assertEqual(response.data['custom_fields']['text_field'], site2_cfvs['text_field'])
self.assertEqual(response.data['custom_fields']['number_field'], site2_cfvs['number_field']) self.assertEqual(response.data['custom_fields']['number_field'], site2_cfvs['number_field'])
@ -221,8 +222,9 @@ class CustomFieldAPITest(APITestCase):
'name': 'Site 3', 'name': 'Site 3',
'slug': 'site-3', 'slug': 'site-3',
} }
url = reverse('dcim-api:site-list') url = reverse('dcim-api:site-list')
self.add_permissions('dcim.add_site')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
@ -263,8 +265,9 @@ class CustomFieldAPITest(APITestCase):
'choice_field': self.cf_select_choice2.pk, 'choice_field': self.cf_select_choice2.pk,
}, },
} }
url = reverse('dcim-api:site-list') url = reverse('dcim-api:site-list')
self.add_permissions('dcim.add_site')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
@ -309,8 +312,9 @@ class CustomFieldAPITest(APITestCase):
'slug': 'site-5', 'slug': 'site-5',
}, },
) )
url = reverse('dcim-api:site-list') url = reverse('dcim-api:site-list')
self.add_permissions('dcim.add_site')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(len(response.data), len(data)) self.assertEqual(len(response.data), len(data))
@ -367,8 +371,9 @@ class CustomFieldAPITest(APITestCase):
'custom_fields': custom_field_data, 'custom_fields': custom_field_data,
}, },
) )
url = reverse('dcim-api:site-list') url = reverse('dcim-api:site-list')
self.add_permissions('dcim.add_site')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(len(response.data), len(data)) self.assertEqual(len(response.data), len(data))
@ -410,8 +415,9 @@ class CustomFieldAPITest(APITestCase):
'number_field': 1234, 'number_field': 1234,
}, },
} }
url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk}) url = reverse('dcim-api:site-detail', kwargs={'pk': self.sites[1].pk})
self.add_permissions('dcim.change_site')
response = self.client.patch(url, data, format='json', **self.header) response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK) self.assertHttpStatus(response, status.HTTP_200_OK)

View File

@ -15,16 +15,15 @@ class TaggedItemTest(APITestCase):
super().setUp() super().setUp()
def test_create_tagged_item(self): def test_create_tagged_item(self):
data = { data = {
'name': 'Test Site', 'name': 'Test Site',
'slug': 'test-site', 'slug': 'test-site',
'tags': ['Foo', 'Bar', 'Baz'] 'tags': ['Foo', 'Bar', 'Baz']
} }
url = reverse('dcim-api:site-list') url = reverse('dcim-api:site-list')
response = self.client.post(url, data, format='json', **self.header) self.add_permissions('dcim.add_site')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(sorted(response.data['tags']), sorted(data['tags'])) self.assertEqual(sorted(response.data['tags']), sorted(data['tags']))
site = Site.objects.get(pk=response.data['id']) site = Site.objects.get(pk=response.data['id'])
@ -32,20 +31,18 @@ class TaggedItemTest(APITestCase):
self.assertEqual(sorted(tags), sorted(data['tags'])) self.assertEqual(sorted(tags), sorted(data['tags']))
def test_update_tagged_item(self): def test_update_tagged_item(self):
site = Site.objects.create( site = Site.objects.create(
name='Test Site', name='Test Site',
slug='test-site' slug='test-site'
) )
site.tags.add('Foo', 'Bar', 'Baz') site.tags.add('Foo', 'Bar', 'Baz')
data = { data = {
'tags': ['Foo', 'Bar', 'New Tag'] 'tags': ['Foo', 'Bar', 'New Tag']
} }
self.add_permissions('dcim.change_site')
url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk}) url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
response = self.client.patch(url, data, format='json', **self.header)
response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK) self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(sorted(response.data['tags']), sorted(data['tags'])) self.assertEqual(sorted(response.data['tags']), sorted(data['tags']))
site = Site.objects.get(pk=response.data['id']) site = Site.objects.get(pk=response.data['id'])

View File

@ -42,13 +42,13 @@ class WebhookTest(APITestCase):
webhook.obj_type.set([site_ct]) webhook.obj_type.set([site_ct])
def test_enqueue_webhook_create(self): def test_enqueue_webhook_create(self):
# Create an object via the REST API # Create an object via the REST API
data = { data = {
'name': 'Test Site', 'name': 'Test Site',
'slug': 'test-site', 'slug': 'test-site',
} }
url = reverse('dcim-api:site-list') url = reverse('dcim-api:site-list')
self.add_permissions('dcim.add_site')
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(Site.objects.count(), 1) self.assertEqual(Site.objects.count(), 1)
@ -62,14 +62,13 @@ class WebhookTest(APITestCase):
self.assertEqual(job.args[3], ObjectChangeActionChoices.ACTION_CREATE) self.assertEqual(job.args[3], ObjectChangeActionChoices.ACTION_CREATE)
def test_enqueue_webhook_update(self): def test_enqueue_webhook_update(self):
site = Site.objects.create(name='Site 1', slug='site-1')
# Update an object via the REST API # Update an object via the REST API
site = Site.objects.create(name='Site 1', slug='site-1')
data = { data = {
'comments': 'Updated the site', 'comments': 'Updated the site',
} }
url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk}) url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
self.add_permissions('dcim.change_site')
response = self.client.patch(url, data, format='json', **self.header) response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK) self.assertHttpStatus(response, status.HTTP_200_OK)
@ -82,11 +81,10 @@ class WebhookTest(APITestCase):
self.assertEqual(job.args[3], ObjectChangeActionChoices.ACTION_UPDATE) self.assertEqual(job.args[3], ObjectChangeActionChoices.ACTION_UPDATE)
def test_enqueue_webhook_delete(self): def test_enqueue_webhook_delete(self):
site = Site.objects.create(name='Site 1', slug='site-1')
# Delete an object via the REST API # Delete an object via the REST API
site = Site.objects.create(name='Site 1', slug='site-1')
url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk}) url = reverse('dcim-api:site-detail', kwargs={'pk': site.pk})
self.add_permissions('dcim.delete_site')
response = self.client.delete(url, **self.header) response = self.client.delete(url, **self.header)
self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)

View File

@ -176,6 +176,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
Prefix.objects.create(prefix=IPNetwork('192.0.2.64/26')) Prefix.objects.create(prefix=IPNetwork('192.0.2.64/26'))
Prefix.objects.create(prefix=IPNetwork('192.0.2.192/27')) Prefix.objects.create(prefix=IPNetwork('192.0.2.192/27'))
url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk}) url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk})
self.add_permissions('ipam.view_prefix')
# Retrieve all available IPs # Retrieve all available IPs
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -190,6 +191,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
vrf = VRF.objects.create(name='Test VRF 1', rd='1234') vrf = VRF.objects.create(name='Test VRF 1', rd='1234')
prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/28'), vrf=vrf, is_pool=True) prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/28'), vrf=vrf, is_pool=True)
url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk}) url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk})
self.add_permissions('ipam.add_prefix')
# Create four available prefixes with individual requests # Create four available prefixes with individual requests
prefixes_to_be_created = [ prefixes_to_be_created = [
@ -225,6 +227,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
""" """
prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/28'), is_pool=True) prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/28'), is_pool=True)
url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk}) url = reverse('ipam-api:prefix-available-prefixes', kwargs={'pk': prefix.pk})
self.add_permissions('ipam.view_prefix', 'ipam.add_prefix')
# Try to create five /30s (only four are available) # Try to create five /30s (only four are available)
data = [ data = [
@ -240,6 +243,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
# Verify that no prefixes were created (the entire /28 is still available) # Verify that no prefixes were created (the entire /28 is still available)
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(response.data[0]['prefix'], '192.0.2.0/28') self.assertEqual(response.data[0]['prefix'], '192.0.2.0/28')
# Create four /30s in a single request # Create four /30s in a single request
@ -253,6 +257,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
""" """
prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/29'), is_pool=True) prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/29'), is_pool=True)
url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk}) url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk})
self.add_permissions('ipam.view_prefix', 'ipam.view_ipaddress')
# Retrieve all available IPs # Retrieve all available IPs
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
@ -271,6 +276,8 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
vrf = VRF.objects.create(name='Test VRF 1', rd='1234') vrf = VRF.objects.create(name='Test VRF 1', rd='1234')
prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/30'), vrf=vrf, is_pool=True) prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/30'), vrf=vrf, is_pool=True)
url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk}) url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk})
# TODO: ipam.add_prefix should not be required
self.add_permissions('ipam.add_prefix', 'ipam.add_ipaddress')
# Create all four available IPs with individual requests # Create all four available IPs with individual requests
for i in range(1, 5): for i in range(1, 5):
@ -293,6 +300,8 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
""" """
prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/29'), is_pool=True) prefix = Prefix.objects.create(prefix=IPNetwork('192.0.2.0/29'), is_pool=True)
url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk}) url = reverse('ipam-api:prefix-available-ips', kwargs={'pk': prefix.pk})
# TODO: ipam.add_prefix, ipam.view_prefix should not be required
self.add_permissions('ipam.add_prefix', 'ipam.view_prefix', 'ipam.view_ipaddress', 'ipam.add_ipaddress')
# Try to create nine IPs (only eight are available) # Try to create nine IPs (only eight are available)
data = [{'description': 'Test IP {}'.format(i)} for i in range(1, 10)] # 9 IPs data = [{'description': 'Test IP {}'.format(i)} for i in range(1, 10)] # 9 IPs
@ -302,6 +311,7 @@ class PrefixTest(APIViewTestCases.APIViewTestCase):
# Verify that no IPs were created (eight are still available) # Verify that no IPs were created (eight are still available)
response = self.client.get(url, **self.header) response = self.client.get(url, **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(len(response.data), 8) self.assertEqual(len(response.data), 8)
# Create all eight available IPs in a single request # Create all eight available IPs in a single request
@ -411,6 +421,7 @@ class VLANTest(APIViewTestCases.APIViewTestCase):
vlan = VLAN.objects.first() vlan = VLAN.objects.first()
Prefix.objects.create(prefix=IPNetwork('192.0.2.0/24'), vlan=vlan) Prefix.objects.create(prefix=IPNetwork('192.0.2.0/24'), vlan=vlan)
self.add_permissions('ipam.delete_vlan')
url = reverse('ipam-api:vlan-detail', kwargs={'pk': vlan.pk}) url = reverse('ipam-api:vlan-detail', kwargs={'pk': vlan.pk})
with disable_warnings('django.request'): with disable_warnings('django.request'):
response = self.client.delete(url, **self.header) response = self.client.delete(url, **self.header)

View File

@ -11,7 +11,7 @@ from dcim.models import Site
from ipam.choices import PrefixStatusChoices from ipam.choices import PrefixStatusChoices
from ipam.models import Prefix from ipam.models import Prefix
from users.models import ObjectPermission, Token from users.models import ObjectPermission, Token
from utilities.testing.testcases import TestCase from utilities.testing import TestCase
class ExternalAuthenticationTestCase(TestCase): class ExternalAuthenticationTestCase(TestCase):

View File

@ -29,7 +29,6 @@ class SecretRoleViewSet(ModelViewSet):
secret_count=Count('secrets') secret_count=Count('secrets')
) )
serializer_class = serializers.SecretRoleSerializer serializer_class = serializers.SecretRoleSerializer
permission_classes = [IsAuthenticated]
filterset_class = filters.SecretRoleFilterSet filterset_class = filters.SecretRoleFilterSet

View File

@ -1,16 +1,17 @@
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from django.contrib.contenttypes.models import ContentType
from rest_framework import serializers
from utilities.api import WritableNestedSerializer from users.models import ObjectPermission
from utilities.api import ContentTypeField, WritableNestedSerializer
_all_ = [ __all__ = [
'NestedGroupSerializer',
'NestedObjectPermissionSerializer',
'NestedUserSerializer', 'NestedUserSerializer',
] ]
#
# Groups and users
#
class NestedGroupSerializer(WritableNestedSerializer): class NestedGroupSerializer(WritableNestedSerializer):
class Meta: class Meta:
@ -23,3 +24,22 @@ class NestedUserSerializer(WritableNestedSerializer):
class Meta: class Meta:
model = User model = User
fields = ['id', 'username'] fields = ['id', 'username']
class NestedObjectPermissionSerializer(WritableNestedSerializer):
object_types = ContentTypeField(
queryset=ContentType.objects.all(),
many=True
)
groups = serializers.SerializerMethodField(read_only=True)
users = serializers.SerializerMethodField(read_only=True)
class Meta:
model = ObjectPermission
fields = ['id', 'object_types', 'groups', 'users', 'actions']
def get_groups(self, obj):
return [g.name for g in obj.groups.all()]
def get_users(self, obj):
return [u.username for u in obj.users.all()]

View File

@ -1,3 +1,4 @@
from django.contrib.auth.models import Group, User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from users.models import ObjectPermission from users.models import ObjectPermission

View File

@ -10,6 +10,7 @@ from django.db.models.signals import post_save
from django.dispatch import receiver from django.dispatch import receiver
from django.utils import timezone from django.utils import timezone
from utilities.querysets import RestrictedQuerySet
from utilities.utils import flatten_dict from utilities.utils import flatten_dict
@ -262,6 +263,8 @@ class ObjectPermission(models.Model):
help_text="Queryset filter matching the applicable objects of the selected type(s)" help_text="Queryset filter matching the applicable objects of the selected type(s)"
) )
objects = RestrictedQuerySet.as_manager()
class Meta: class Meta:
verbose_name = "Permission" verbose_name = "Permission"

View File

@ -1,10 +1,9 @@
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.urls import reverse from django.urls import reverse
from rest_framework import status
from users.models import ObjectPermission from users.models import ObjectPermission
from utilities.testing import APITestCase from utilities.testing import APIViewTestCases, APITestCase
class AppTest(APITestCase): class AppTest(APITestCase):
@ -17,7 +16,9 @@ class AppTest(APITestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
class ObjectPermissionTest(APITestCase): class ObjectPermissionTest(APIViewTestCases.APIViewTestCase):
model = ObjectPermission
brief_fields = ['actions', 'groups', 'id', 'object_types', 'users']
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -48,43 +49,7 @@ class ObjectPermissionTest(APITestCase):
objectpermission.groups.add(groups[i]) objectpermission.groups.add(groups[i])
objectpermission.users.add(users[i]) objectpermission.users.add(users[i])
def test_get_objectpermission(self): cls.create_data = [
objectpermission = ObjectPermission.objects.first()
url = reverse('users-api:objectpermission-detail', kwargs={'pk': objectpermission.pk})
response = self.client.get(url, **self.header)
self.assertEqual(response.data['id'], objectpermission.pk)
def test_list_objectpermissions(self):
url = reverse('users-api:objectpermission-list')
response = self.client.get(url, **self.header)
self.assertEqual(response.data['count'], ObjectPermission.objects.count())
def test_create_objectpermission(self):
data = {
'object_types': ['dcim.site'],
'groups': [Group.objects.first().pk],
'users': [User.objects.first().pk],
'actions': ['view', 'add', 'change', 'delete'],
'constraints': {'name': 'TEST4'},
}
url = reverse('users-api:objectpermission-list')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(ObjectPermission.objects.count(), 4)
objectpermission = ObjectPermission.objects.get(pk=response.data['id'])
self.assertEqual(objectpermission.groups.first().pk, data['groups'][0])
self.assertEqual(objectpermission.users.first().pk, data['users'][0])
self.assertEqual(objectpermission.actions, data['actions'])
self.assertEqual(objectpermission.constraints, data['constraints'])
def test_create_objectpermission_bulk(self):
groups = Group.objects.all()[:3]
users = User.objects.all()[:3]
data = [
{ {
'object_types': ['dcim.site'], 'object_types': ['dcim.site'],
'groups': [groups[0].pk], 'groups': [groups[0].pk],
@ -107,38 +72,3 @@ class ObjectPermissionTest(APITestCase):
'constraints': {'name': 'TEST6'}, 'constraints': {'name': 'TEST6'},
}, },
] ]
url = reverse('users-api:objectpermission-list')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(ObjectPermission.objects.count(), 6)
def test_update_objectpermission(self):
objectpermission = ObjectPermission.objects.first()
data = {
'object_types': ['dcim.site', 'dcim.device'],
'groups': [g.pk for g in Group.objects.all()[:2]],
'users': [u.pk for u in User.objects.all()[:2]],
'actions': ['view'],
'constraints': {'name': 'TEST'},
}
url = reverse('users-api:objectpermission-detail', kwargs={'pk': objectpermission.pk})
response = self.client.put(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(ObjectPermission.objects.count(), 3)
objectpermission = ObjectPermission.objects.get(pk=response.data['id'])
self.assertEqual(objectpermission.groups.first().pk, data['groups'][0])
self.assertEqual(objectpermission.users.first().pk, data['users'][0])
self.assertEqual(objectpermission.actions, data['actions'])
self.assertEqual(objectpermission.constraints, data['constraints'])
def test_delete_objectpermission(self):
objectpermission = ObjectPermission.objects.first()
url = reverse('users-api:objectpermission-detail', kwargs={'pk': objectpermission.pk})
response = self.client.delete(url, **self.header)
self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
self.assertEqual(ObjectPermission.objects.count(), 2)

View File

@ -1,2 +1,3 @@
from .testcases import * from .api import *
from .utils import * from .utils import *
from .views import *

View File

@ -0,0 +1,282 @@
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.urls import reverse
from django.test import override_settings
from rest_framework import status
from rest_framework.test import APIClient
from users.models import ObjectPermission, Token
from .utils import disable_warnings
from .views import TestCase
__all__ = (
'APITestCase',
'APIViewTestCases',
)
#
# REST API Tests
#
class APITestCase(TestCase):
client_class = APIClient
model = None
def setUp(self):
"""
Create a superuser and token for API calls.
"""
# Create the test user and assign permissions
self.user = User.objects.create_user(username='testuser')
self.add_permissions(*self.user_permissions)
self.token = Token.objects.create(user=self.user)
self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
def _get_detail_url(self, instance):
viewname = f'{instance._meta.app_label}-api:{instance._meta.model_name}-detail'
return reverse(viewname, kwargs={'pk': instance.pk})
def _get_list_url(self):
viewname = f'{self.model._meta.app_label}-api:{self.model._meta.model_name}-list'
return reverse(viewname)
class APIViewTestCases:
class GetObjectViewTestCase(APITestCase):
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_get_object_anonymous(self):
"""
GET a single object as an unauthenticated user.
"""
url = self._get_detail_url(self.model.objects.first())
response = self.client.get(url, **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_get_object_without_permission(self):
"""
GET a single object as an authenticated user without the required permission.
"""
url = self._get_detail_url(self.model.objects.first())
# Try GET without permission
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_get_object(self):
"""
GET a single object as an authenticated user with permission to view the object.
"""
self.assertGreaterEqual(self.model.objects.count(), 2,
f"Test requires the creation of at least two {self.model} instances")
instance1, instance2 = self.model.objects.all()[:2]
# Add object-level permission
obj_perm = ObjectPermission(
constraints={'pk': instance1.pk},
actions=['view']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
# Try GET to permitted object
url = self._get_detail_url(instance1)
self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_200_OK)
# Try GET to non-permitted object
url = self._get_detail_url(instance2)
self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_404_NOT_FOUND)
class ListObjectsViewTestCase(APITestCase):
brief_fields = []
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_list_objects_anonymous(self):
"""
GET a list of objects as an unauthenticated user.
"""
url = self._get_list_url()
response = self.client.get(url, **self.header)
self.assertEqual(len(response.data['results']), self.model.objects.count())
self.assertHttpStatus(response, status.HTTP_200_OK)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_list_objects_brief(self):
"""
GET a list of objects using the "brief" parameter as an unauthenticated user.
"""
url = f'{self._get_list_url()}?brief=1'
response = self.client.get(url, **self.header)
self.assertEqual(len(response.data['results']), self.model.objects.count())
self.assertEqual(sorted(response.data['results'][0]), self.brief_fields)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_list_objects_without_permission(self):
"""
GET a list of objects as an authenticated user without the required permission.
"""
url = self._get_list_url()
# Try GET without permission
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.get(url, **self.header), status.HTTP_403_FORBIDDEN)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_list_objects(self):
"""
GET a list of objects as an authenticated user with permission to view the objects.
"""
self.assertGreaterEqual(self.model.objects.count(), 3,
f"Test requires the creation of at least three {self.model} instances")
instance1, instance2 = self.model.objects.all()[:2]
# Add object-level permission
obj_perm = ObjectPermission(
constraints={'pk__in': [instance1.pk, instance2.pk]},
actions=['view']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
# Try GET to permitted objects
response = self.client.get(self._get_list_url(), **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(len(response.data['results']), 2)
class CreateObjectViewTestCase(APITestCase):
create_data = []
def test_create_object_without_permission(self):
"""
POST a single object without permission.
"""
url = self._get_list_url()
# Try POST without permission
with disable_warnings('django.request'):
response = self.client.post(url, self.create_data[0], format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
def test_create_object(self):
"""
POST a single object with permission.
"""
# Add object-level permission
obj_perm = ObjectPermission(
actions=['add']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
initial_count = self.model.objects.count()
response = self.client.post(self._get_list_url(), self.create_data[0], format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(self.model.objects.count(), initial_count + 1)
self.assertInstanceEqual(self.model.objects.get(pk=response.data['id']), self.create_data[0], api=True)
def test_bulk_create_objects(self):
"""
POST a set of objects in a single request.
"""
# Add object-level permission
obj_perm = ObjectPermission(
actions=['add']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
initial_count = self.model.objects.count()
response = self.client.post(self._get_list_url(), self.create_data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(len(response.data), len(self.create_data))
self.assertEqual(self.model.objects.count(), initial_count + len(self.create_data))
for i, obj in enumerate(response.data):
self.assertInstanceEqual(self.model.objects.get(pk=obj['id']), self.create_data[i], api=True)
class UpdateObjectViewTestCase(APITestCase):
update_data = {}
def test_update_object_without_permission(self):
"""
PATCH a single object without permission.
"""
url = self._get_detail_url(self.model.objects.first())
update_data = self.update_data or getattr(self, 'create_data')[0]
# Try PATCH without permission
with disable_warnings('django.request'):
response = self.client.patch(url, update_data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
def test_update_object(self):
"""
PATCH a single object identified by its numeric ID.
"""
instance = self.model.objects.first()
url = self._get_detail_url(instance)
update_data = self.update_data or getattr(self, 'create_data')[0]
# Add object-level permission
obj_perm = ObjectPermission(
actions=['change']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
response = self.client.patch(url, update_data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
instance.refresh_from_db()
self.assertInstanceEqual(instance, self.update_data, api=True)
class DeleteObjectViewTestCase(APITestCase):
def test_delete_object_without_permission(self):
"""
DELETE a single object without permission.
"""
url = self._get_detail_url(self.model.objects.first())
# Try DELETE without permission
with disable_warnings('django.request'):
response = self.client.delete(url, **self.header)
self.assertHttpStatus(response, status.HTTP_403_FORBIDDEN)
def test_delete_object(self):
"""
DELETE a single object identified by its numeric ID.
"""
instance = self.model.objects.first()
url = self._get_detail_url(instance)
# Add object-level permission
obj_perm = ObjectPermission(
actions=['delete']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ContentType.objects.get_for_model(self.model))
response = self.client.delete(url, **self.header)
self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
self.assertFalse(self.model.objects.filter(pk=instance.pk).exists())
class APIViewTestCase(
GetObjectViewTestCase,
ListObjectsViewTestCase,
CreateObjectViewTestCase,
UpdateObjectViewTestCase,
DeleteObjectViewTestCase
):
pass

View File

@ -1,18 +1,24 @@
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db.models import ForeignKey, ManyToManyField
from django.forms.models import model_to_dict from django.forms.models import model_to_dict
from django.test import Client, TestCase as _TestCase, override_settings from django.test import Client, TestCase as _TestCase, override_settings
from django.urls import reverse, NoReverseMatch from django.urls import reverse, NoReverseMatch
from netaddr import IPNetwork from netaddr import IPNetwork
from rest_framework import status
from rest_framework.test import APIClient
from users.models import ObjectPermission, Token from users.models import ObjectPermission
from utilities.permissions import resolve_permission_ct from utilities.permissions import resolve_permission_ct
from .utils import disable_warnings, post_data from .utils import disable_warnings, post_data
__all__ = (
'TestCase',
'ModelViewTestCase',
'ViewTestCases',
)
class TestCase(_TestCase): class TestCase(_TestCase):
user_permissions = () user_permissions = ()
@ -78,12 +84,15 @@ class TestCase(_TestCase):
if api: if api:
# Replace ContentType numeric IDs with <app_label>.<model> # Replace ContentType numeric IDs with <app_label>.<model>
if type(getattr(instance, key)) is ContentType: field = instance._meta.get_field(key)
if type(field) is ForeignKey and field.related_model is ContentType:
ct = ContentType.objects.get(pk=value) ct = ContentType.objects.get(pk=value)
model_dict[key] = f'{ct.app_label}.{ct.model}' model_dict[key] = f'{ct.app_label}.{ct.model}'
elif type(field) is ManyToManyField and field.related_model is ContentType:
model_dict[key] = [f'{ct.app_label}.{ct.model}' for ct in value]
# Convert IPNetwork instances to strings # Convert IPNetwork instances to strings
if type(value) is IPNetwork: elif type(value) is IPNetwork:
model_dict[key] = str(value) model_dict[key] = str(value)
# Omit any dictionary keys which are not instance attributes # Omit any dictionary keys which are not instance attributes
@ -202,13 +211,6 @@ class ViewTestCases:
# Try GET to non-permitted object # Try GET to non-permitted object
self.assertHttpStatus(self.client.get(instance2.get_absolute_url()), 404) self.assertHttpStatus(self.client.get(instance2.get_absolute_url()), 404)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])
def test_get_object_anonymous(self):
# Make the request as an unauthenticated user
self.client.logout()
response = self.client.get(self.model.objects.first().get_absolute_url())
self.assertHttpStatus(response, 200)
class CreateObjectViewTestCase(ModelViewTestCase): class CreateObjectViewTestCase(ModelViewTestCase):
""" """
Create a single new instance. Create a single new instance.
@ -799,129 +801,3 @@ class ViewTestCases:
TestCase suitable for testing device component models (ConsolePorts, Interfaces, etc.) TestCase suitable for testing device component models (ConsolePorts, Interfaces, etc.)
""" """
maxDiff = None maxDiff = None
#
# REST API Tests
#
class APITestCase(TestCase):
client_class = APIClient
model = None
def setUp(self):
"""
Create a superuser and token for API calls.
"""
self.user = User.objects.create(username='testuser', is_superuser=True)
self.token = Token.objects.create(user=self.user)
self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
def _get_detail_url(self, instance):
viewname = f'{instance._meta.app_label}-api:{instance._meta.model_name}-detail'
return reverse(viewname, kwargs={'pk': instance.pk})
def _get_list_url(self):
viewname = f'{self.model._meta.app_label}-api:{self.model._meta.model_name}-list'
return reverse(viewname)
class APIViewTestCases:
class GetObjectViewTestCase(APITestCase):
def test_get_object(self):
"""
GET a single object identified by its numeric ID.
"""
instance = self.model.objects.first()
url = self._get_detail_url(instance)
response = self.client.get(url, **self.header)
self.assertEqual(response.data['id'], instance.pk)
class ListObjectsViewTestCase(APITestCase):
brief_fields = []
def test_list_objects(self):
"""
GET a list of objects.
"""
url = self._get_list_url()
response = self.client.get(url, **self.header)
self.assertEqual(len(response.data['results']), self.model.objects.count())
def test_list_objects_brief(self):
"""
GET a list of objects using the "brief" parameter.
"""
url = f'{self._get_list_url()}?brief=1'
response = self.client.get(url, **self.header)
self.assertEqual(len(response.data['results']), self.model.objects.count())
self.assertEqual(sorted(response.data['results'][0]), self.brief_fields)
class CreateObjectViewTestCase(APITestCase):
create_data = []
def test_create_object(self):
"""
POST a single object.
"""
initial_count = self.model.objects.count()
url = self._get_list_url()
response = self.client.post(url, self.create_data[0], format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(self.model.objects.count(), initial_count + 1)
self.assertInstanceEqual(self.model.objects.get(pk=response.data['id']), self.create_data[0], api=True)
def test_bulk_create_object(self):
"""
POST a set of objects in a single request.
"""
initial_count = self.model.objects.count()
url = self._get_list_url()
response = self.client.post(url, self.create_data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(self.model.objects.count(), initial_count + len(self.create_data))
class UpdateObjectViewTestCase(APITestCase):
update_data = {}
def test_update_object(self):
"""
PATCH a single object identified by its numeric ID.
"""
instance = self.model.objects.first()
url = self._get_detail_url(instance)
update_data = self.update_data or getattr(self, 'create_data')[0]
response = self.client.patch(url, update_data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
instance.refresh_from_db()
self.assertInstanceEqual(instance, self.update_data, api=True)
class DeleteObjectViewTestCase(APITestCase):
def test_delete_object(self):
"""
DELETE a single object identified by its numeric ID.
"""
instance = self.model.objects.first()
url = self._get_detail_url(instance)
response = self.client.delete(url, **self.header)
self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
self.assertFalse(self.model.objects.filter(pk=instance.pk).exists())
class APIViewTestCase(
GetObjectViewTestCase,
ListObjectsViewTestCase,
CreateObjectViewTestCase,
UpdateObjectViewTestCase,
DeleteObjectViewTestCase
):
pass

View File

@ -18,7 +18,6 @@ class WritableNestedSerializerTest(APITestCase):
""" """
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.region_a = Region.objects.create(name='Region A', slug='region-a') self.region_a = Region.objects.create(name='Region A', slug='region-a')
@ -26,39 +25,36 @@ class WritableNestedSerializerTest(APITestCase):
self.site2 = Site.objects.create(region=self.region_a, name='Site 2', slug='site-2') self.site2 = Site.objects.create(region=self.region_a, name='Site 2', slug='site-2')
def test_related_by_pk(self): def test_related_by_pk(self):
data = { data = {
'vid': 100, 'vid': 100,
'name': 'Test VLAN 100', 'name': 'Test VLAN 100',
'site': self.site1.pk, 'site': self.site1.pk,
} }
url = reverse('ipam-api:vlan-list') url = reverse('ipam-api:vlan-list')
response = self.client.post(url, data, format='json', **self.header) self.add_permissions('ipam.add_vlan')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(response.data['site']['id'], self.site1.pk) self.assertEqual(response.data['site']['id'], self.site1.pk)
vlan = VLAN.objects.get(pk=response.data['id']) vlan = VLAN.objects.get(pk=response.data['id'])
self.assertEqual(vlan.site, self.site1) self.assertEqual(vlan.site, self.site1)
def test_related_by_pk_no_match(self): def test_related_by_pk_no_match(self):
data = { data = {
'vid': 100, 'vid': 100,
'name': 'Test VLAN 100', 'name': 'Test VLAN 100',
'site': 999, 'site': 999,
} }
url = reverse('ipam-api:vlan-list') url = reverse('ipam-api:vlan-list')
self.add_permissions('ipam.add_vlan')
with disable_warnings('django.request'): with disable_warnings('django.request'):
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
self.assertEqual(VLAN.objects.count(), 0) self.assertEqual(VLAN.objects.count(), 0)
self.assertTrue(response.data['site'][0].startswith("Related object not found")) self.assertTrue(response.data['site'][0].startswith("Related object not found"))
def test_related_by_attributes(self): def test_related_by_attributes(self):
data = { data = {
'vid': 100, 'vid': 100,
'name': 'Test VLAN 100', 'name': 'Test VLAN 100',
@ -66,17 +62,16 @@ class WritableNestedSerializerTest(APITestCase):
'name': 'Site 1' 'name': 'Site 1'
}, },
} }
url = reverse('ipam-api:vlan-list') url = reverse('ipam-api:vlan-list')
response = self.client.post(url, data, format='json', **self.header) self.add_permissions('ipam.add_vlan')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(response.data['site']['id'], self.site1.pk) self.assertEqual(response.data['site']['id'], self.site1.pk)
vlan = VLAN.objects.get(pk=response.data['id']) vlan = VLAN.objects.get(pk=response.data['id'])
self.assertEqual(vlan.site, self.site1) self.assertEqual(vlan.site, self.site1)
def test_related_by_attributes_no_match(self): def test_related_by_attributes_no_match(self):
data = { data = {
'vid': 100, 'vid': 100,
'name': 'Test VLAN 100', 'name': 'Test VLAN 100',
@ -84,17 +79,16 @@ class WritableNestedSerializerTest(APITestCase):
'name': 'Site X' 'name': 'Site X'
}, },
} }
url = reverse('ipam-api:vlan-list') url = reverse('ipam-api:vlan-list')
self.add_permissions('ipam.add_vlan')
with disable_warnings('django.request'): with disable_warnings('django.request'):
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
self.assertEqual(VLAN.objects.count(), 0) self.assertEqual(VLAN.objects.count(), 0)
self.assertTrue(response.data['site'][0].startswith("Related object not found")) self.assertTrue(response.data['site'][0].startswith("Related object not found"))
def test_related_by_attributes_multiple_matches(self): def test_related_by_attributes_multiple_matches(self):
data = { data = {
'vid': 100, 'vid': 100,
'name': 'Test VLAN 100', 'name': 'Test VLAN 100',
@ -104,27 +98,26 @@ class WritableNestedSerializerTest(APITestCase):
}, },
}, },
} }
url = reverse('ipam-api:vlan-list') url = reverse('ipam-api:vlan-list')
self.add_permissions('ipam.add_vlan')
with disable_warnings('django.request'): with disable_warnings('django.request'):
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
self.assertEqual(VLAN.objects.count(), 0) self.assertEqual(VLAN.objects.count(), 0)
self.assertTrue(response.data['site'][0].startswith("Multiple objects match")) self.assertTrue(response.data['site'][0].startswith("Multiple objects match"))
def test_related_by_invalid(self): def test_related_by_invalid(self):
data = { data = {
'vid': 100, 'vid': 100,
'name': 'Test VLAN 100', 'name': 'Test VLAN 100',
'site': 'XXX', 'site': 'XXX',
} }
url = reverse('ipam-api:vlan-list') url = reverse('ipam-api:vlan-list')
self.add_permissions('ipam.add_vlan')
with disable_warnings('django.request'): with disable_warnings('django.request'):
response = self.client.post(url, data, format='json', **self.header) response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
self.assertEqual(VLAN.objects.count(), 0) self.assertEqual(VLAN.objects.count(), 0)

View File

@ -164,10 +164,10 @@ class VirtualMachineTest(APIViewTestCases.APIViewTestCase):
Check that config context data is included by default in the virtual machines list. Check that config context data is included by default in the virtual machines list.
""" """
virtualmachine = VirtualMachine.objects.first() virtualmachine = VirtualMachine.objects.first()
url = reverse('virtualization-api:virtualmachine-list') url = '{}?id={}'.format(reverse('virtualization-api:virtualmachine-list'), virtualmachine.pk)
url = '{}?id={}'.format(url, virtualmachine.pk) self.add_permissions('virtualization.view_virtualmachine')
response = self.client.get(url, **self.header)
response = self.client.get(url, **self.header)
self.assertEqual(response.data['results'][0].get('config_context', {}).get('A'), 1) self.assertEqual(response.data['results'][0].get('config_context', {}).get('A'), 1)
def test_config_context_excluded(self): def test_config_context_excluded(self):
@ -175,8 +175,9 @@ class VirtualMachineTest(APIViewTestCases.APIViewTestCase):
Check that config context data can be excluded by passing ?exclude=config_context. Check that config context data can be excluded by passing ?exclude=config_context.
""" """
url = reverse('virtualization-api:virtualmachine-list') + '?exclude=config_context' url = reverse('virtualization-api:virtualmachine-list') + '?exclude=config_context'
response = self.client.get(url, **self.header) self.add_permissions('virtualization.view_virtualmachine')
response = self.client.get(url, **self.header)
self.assertFalse('config_context' in response.data['results'][0]) self.assertFalse('config_context' in response.data['results'][0])
def test_unique_name_per_cluster_constraint(self): def test_unique_name_per_cluster_constraint(self):
@ -188,8 +189,9 @@ class VirtualMachineTest(APIViewTestCases.APIViewTestCase):
'cluster': Cluster.objects.first().pk, 'cluster': Cluster.objects.first().pk,
} }
url = reverse('virtualization-api:virtualmachine-list') url = reverse('virtualization-api:virtualmachine-list')
response = self.client.post(url, data, format='json', **self.header) self.add_permissions('virtualization.add_virtualmachine')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)
@ -224,39 +226,38 @@ class InterfaceTest(APITestCase):
self.vlan3 = VLAN.objects.create(name="Test VLAN 3", vid=3) self.vlan3 = VLAN.objects.create(name="Test VLAN 3", vid=3)
def test_get_interface(self): def test_get_interface(self):
url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk}) url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk})
response = self.client.get(url, **self.header) self.add_permissions('dcim.view_interface')
response = self.client.get(url, **self.header)
self.assertEqual(response.data['name'], self.interface1.name) self.assertEqual(response.data['name'], self.interface1.name)
def test_list_interfaces(self): def test_list_interfaces(self):
url = reverse('virtualization-api:interface-list') url = reverse('virtualization-api:interface-list')
response = self.client.get(url, **self.header) self.add_permissions('dcim.view_interface')
response = self.client.get(url, **self.header)
self.assertEqual(response.data['count'], 3) self.assertEqual(response.data['count'], 3)
def test_list_interfaces_brief(self): def test_list_interfaces_brief(self):
url = reverse('virtualization-api:interface-list') url = reverse('virtualization-api:interface-list')
response = self.client.get('{}?brief=1'.format(url), **self.header) self.add_permissions('dcim.view_interface')
response = self.client.get('{}?brief=1'.format(url), **self.header)
self.assertEqual( self.assertEqual(
sorted(response.data['results'][0]), sorted(response.data['results'][0]),
['id', 'name', 'url', 'virtual_machine'] ['id', 'name', 'url', 'virtual_machine']
) )
def test_create_interface(self): def test_create_interface(self):
data = { data = {
'virtual_machine': self.virtualmachine.pk, 'virtual_machine': self.virtualmachine.pk,
'name': 'Test Interface 4', 'name': 'Test Interface 4',
} }
url = reverse('virtualization-api:interface-list') url = reverse('virtualization-api:interface-list')
response = self.client.post(url, data, format='json', **self.header) self.add_permissions('dcim.add_interface')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(Interface.objects.count(), 4) self.assertEqual(Interface.objects.count(), 4)
interface4 = Interface.objects.get(pk=response.data['id']) interface4 = Interface.objects.get(pk=response.data['id'])
@ -264,7 +265,6 @@ class InterfaceTest(APITestCase):
self.assertEqual(interface4.name, data['name']) self.assertEqual(interface4.name, data['name'])
def test_create_interface_with_802_1q(self): def test_create_interface_with_802_1q(self):
data = { data = {
'virtual_machine': self.virtualmachine.pk, 'virtual_machine': self.virtualmachine.pk,
'name': 'Test Interface 4', 'name': 'Test Interface 4',
@ -272,10 +272,10 @@ class InterfaceTest(APITestCase):
'untagged_vlan': self.vlan3.id, 'untagged_vlan': self.vlan3.id,
'tagged_vlans': [self.vlan1.id, self.vlan2.id], 'tagged_vlans': [self.vlan1.id, self.vlan2.id],
} }
url = reverse('virtualization-api:interface-list') url = reverse('virtualization-api:interface-list')
response = self.client.post(url, data, format='json', **self.header) self.add_permissions('dcim.add_interface')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(Interface.objects.count(), 4) self.assertEqual(Interface.objects.count(), 4)
self.assertEqual(response.data['virtual_machine']['id'], data['virtual_machine']) self.assertEqual(response.data['virtual_machine']['id'], data['virtual_machine'])
@ -284,7 +284,6 @@ class InterfaceTest(APITestCase):
self.assertEqual([v['id'] for v in response.data['tagged_vlans']], data['tagged_vlans']) self.assertEqual([v['id'] for v in response.data['tagged_vlans']], data['tagged_vlans'])
def test_create_interface_bulk(self): def test_create_interface_bulk(self):
data = [ data = [
{ {
'virtual_machine': self.virtualmachine.pk, 'virtual_machine': self.virtualmachine.pk,
@ -299,10 +298,10 @@ class InterfaceTest(APITestCase):
'name': 'Test Interface 6', 'name': 'Test Interface 6',
}, },
] ]
url = reverse('virtualization-api:interface-list') url = reverse('virtualization-api:interface-list')
response = self.client.post(url, data, format='json', **self.header) self.add_permissions('dcim.add_interface')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(Interface.objects.count(), 6) self.assertEqual(Interface.objects.count(), 6)
self.assertEqual(response.data[0]['name'], data[0]['name']) self.assertEqual(response.data[0]['name'], data[0]['name'])
@ -310,7 +309,6 @@ class InterfaceTest(APITestCase):
self.assertEqual(response.data[2]['name'], data[2]['name']) self.assertEqual(response.data[2]['name'], data[2]['name'])
def test_create_interface_802_1q_bulk(self): def test_create_interface_802_1q_bulk(self):
data = [ data = [
{ {
'virtual_machine': self.virtualmachine.pk, 'virtual_machine': self.virtualmachine.pk,
@ -334,10 +332,10 @@ class InterfaceTest(APITestCase):
'tagged_vlans': [self.vlan1.id], 'tagged_vlans': [self.vlan1.id],
}, },
] ]
url = reverse('virtualization-api:interface-list') url = reverse('virtualization-api:interface-list')
response = self.client.post(url, data, format='json', **self.header) self.add_permissions('dcim.add_interface')
response = self.client.post(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(Interface.objects.count(), 6) self.assertEqual(Interface.objects.count(), 6)
for i in range(0, 3): for i in range(0, 3):
@ -346,24 +344,23 @@ class InterfaceTest(APITestCase):
self.assertEqual(response.data[i]['untagged_vlan']['id'], data[i]['untagged_vlan']) self.assertEqual(response.data[i]['untagged_vlan']['id'], data[i]['untagged_vlan'])
def test_update_interface(self): def test_update_interface(self):
data = { data = {
'virtual_machine': self.virtualmachine.pk, 'virtual_machine': self.virtualmachine.pk,
'name': 'Test Interface X', 'name': 'Test Interface X',
} }
url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk}) url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk})
response = self.client.put(url, data, format='json', **self.header) self.add_permissions('dcim.change_interface')
response = self.client.put(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK) self.assertHttpStatus(response, status.HTTP_200_OK)
self.assertEqual(Interface.objects.count(), 3) self.assertEqual(Interface.objects.count(), 3)
interface1 = Interface.objects.get(pk=response.data['id']) interface1 = Interface.objects.get(pk=response.data['id'])
self.assertEqual(interface1.name, data['name']) self.assertEqual(interface1.name, data['name'])
def test_delete_interface(self): def test_delete_interface(self):
url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk}) url = reverse('virtualization-api:interface-detail', kwargs={'pk': self.interface1.pk})
response = self.client.delete(url, **self.header) self.add_permissions('dcim.delete_interface')
response = self.client.delete(url, **self.header)
self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT) self.assertHttpStatus(response, status.HTTP_204_NO_CONTENT)
self.assertEqual(Interface.objects.count(), 2) self.assertEqual(Interface.objects.count(), 2)