1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

Refactor API views

This commit is contained in:
jeremystretch
2022-03-09 11:09:06 -05:00
parent a11abf87ec
commit efd5a73a18
11 changed files with 329 additions and 323 deletions

View File

@ -4,7 +4,7 @@ from circuits import filtersets
from circuits.models import *
from dcim.api.views import PassThroughPortMixin
from extras.api.views import CustomFieldModelViewSet
from netbox.api.views import ModelViewSet
from netbox.api.viewsets import NetBoxModelViewSet
from utilities.utils import count_related
from . import serializers
@ -57,7 +57,7 @@ class CircuitViewSet(CustomFieldModelViewSet):
# Circuit Terminations
#
class CircuitTerminationViewSet(PassThroughPortMixin, ModelViewSet):
class CircuitTerminationViewSet(PassThroughPortMixin, NetBoxModelViewSet):
queryset = CircuitTermination.objects.prefetch_related(
'circuit', 'site', 'provider_network', 'cable'
)

View File

@ -19,7 +19,7 @@ from ipam.models import Prefix, VLAN
from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
from netbox.api.exceptions import ServiceUnavailable
from netbox.api.metadata import ContentTypeMetadata
from netbox.api.views import ModelViewSet
from netbox.api.viewsets import NetBoxModelViewSet
from netbox.config import get_config
from utilities.api import get_serializer_for_model
from utilities.utils import count_related
@ -250,7 +250,7 @@ class RackViewSet(CustomFieldModelViewSet):
# Rack reservations
#
class RackReservationViewSet(ModelViewSet):
class RackReservationViewSet(NetBoxModelViewSet):
queryset = RackReservation.objects.prefetch_related('rack', 'user', 'tenant')
serializer_class = serializers.RackReservationSerializer
filterset_class = filtersets.RackReservationFilterSet
@ -296,61 +296,61 @@ class ModuleTypeViewSet(CustomFieldModelViewSet):
# Device type components
#
class ConsolePortTemplateViewSet(ModelViewSet):
class ConsolePortTemplateViewSet(NetBoxModelViewSet):
queryset = ConsolePortTemplate.objects.prefetch_related('device_type__manufacturer')
serializer_class = serializers.ConsolePortTemplateSerializer
filterset_class = filtersets.ConsolePortTemplateFilterSet
class ConsoleServerPortTemplateViewSet(ModelViewSet):
class ConsoleServerPortTemplateViewSet(NetBoxModelViewSet):
queryset = ConsoleServerPortTemplate.objects.prefetch_related('device_type__manufacturer')
serializer_class = serializers.ConsoleServerPortTemplateSerializer
filterset_class = filtersets.ConsoleServerPortTemplateFilterSet
class PowerPortTemplateViewSet(ModelViewSet):
class PowerPortTemplateViewSet(NetBoxModelViewSet):
queryset = PowerPortTemplate.objects.prefetch_related('device_type__manufacturer')
serializer_class = serializers.PowerPortTemplateSerializer
filterset_class = filtersets.PowerPortTemplateFilterSet
class PowerOutletTemplateViewSet(ModelViewSet):
class PowerOutletTemplateViewSet(NetBoxModelViewSet):
queryset = PowerOutletTemplate.objects.prefetch_related('device_type__manufacturer')
serializer_class = serializers.PowerOutletTemplateSerializer
filterset_class = filtersets.PowerOutletTemplateFilterSet
class InterfaceTemplateViewSet(ModelViewSet):
class InterfaceTemplateViewSet(NetBoxModelViewSet):
queryset = InterfaceTemplate.objects.prefetch_related('device_type__manufacturer')
serializer_class = serializers.InterfaceTemplateSerializer
filterset_class = filtersets.InterfaceTemplateFilterSet
class FrontPortTemplateViewSet(ModelViewSet):
class FrontPortTemplateViewSet(NetBoxModelViewSet):
queryset = FrontPortTemplate.objects.prefetch_related('device_type__manufacturer')
serializer_class = serializers.FrontPortTemplateSerializer
filterset_class = filtersets.FrontPortTemplateFilterSet
class RearPortTemplateViewSet(ModelViewSet):
class RearPortTemplateViewSet(NetBoxModelViewSet):
queryset = RearPortTemplate.objects.prefetch_related('device_type__manufacturer')
serializer_class = serializers.RearPortTemplateSerializer
filterset_class = filtersets.RearPortTemplateFilterSet
class ModuleBayTemplateViewSet(ModelViewSet):
class ModuleBayTemplateViewSet(NetBoxModelViewSet):
queryset = ModuleBayTemplate.objects.prefetch_related('device_type__manufacturer')
serializer_class = serializers.ModuleBayTemplateSerializer
filterset_class = filtersets.ModuleBayTemplateFilterSet
class DeviceBayTemplateViewSet(ModelViewSet):
class DeviceBayTemplateViewSet(NetBoxModelViewSet):
queryset = DeviceBayTemplate.objects.prefetch_related('device_type__manufacturer')
serializer_class = serializers.DeviceBayTemplateSerializer
filterset_class = filtersets.DeviceBayTemplateFilterSet
class InventoryItemTemplateViewSet(ModelViewSet):
class InventoryItemTemplateViewSet(NetBoxModelViewSet):
queryset = InventoryItemTemplate.objects.prefetch_related('device_type__manufacturer', 'role')
serializer_class = serializers.InventoryItemTemplateSerializer
filterset_class = filtersets.InventoryItemTemplateFilterSet
@ -544,7 +544,7 @@ class ModuleViewSet(CustomFieldModelViewSet):
# Device components
#
class ConsolePortViewSet(PathEndpointMixin, ModelViewSet):
class ConsolePortViewSet(PathEndpointMixin, NetBoxModelViewSet):
queryset = ConsolePort.objects.prefetch_related(
'device', 'module__module_bay', '_path__destination', 'cable', '_link_peer', 'tags'
)
@ -553,7 +553,7 @@ class ConsolePortViewSet(PathEndpointMixin, ModelViewSet):
brief_prefetch_fields = ['device']
class ConsoleServerPortViewSet(PathEndpointMixin, ModelViewSet):
class ConsoleServerPortViewSet(PathEndpointMixin, NetBoxModelViewSet):
queryset = ConsoleServerPort.objects.prefetch_related(
'device', 'module__module_bay', '_path__destination', 'cable', '_link_peer', 'tags'
)
@ -562,7 +562,7 @@ class ConsoleServerPortViewSet(PathEndpointMixin, ModelViewSet):
brief_prefetch_fields = ['device']
class PowerPortViewSet(PathEndpointMixin, ModelViewSet):
class PowerPortViewSet(PathEndpointMixin, NetBoxModelViewSet):
queryset = PowerPort.objects.prefetch_related(
'device', 'module__module_bay', '_path__destination', 'cable', '_link_peer', 'tags'
)
@ -571,7 +571,7 @@ class PowerPortViewSet(PathEndpointMixin, ModelViewSet):
brief_prefetch_fields = ['device']
class PowerOutletViewSet(PathEndpointMixin, ModelViewSet):
class PowerOutletViewSet(PathEndpointMixin, NetBoxModelViewSet):
queryset = PowerOutlet.objects.prefetch_related(
'device', 'module__module_bay', '_path__destination', 'cable', '_link_peer', 'tags'
)
@ -580,7 +580,7 @@ class PowerOutletViewSet(PathEndpointMixin, ModelViewSet):
brief_prefetch_fields = ['device']
class InterfaceViewSet(PathEndpointMixin, ModelViewSet):
class InterfaceViewSet(PathEndpointMixin, NetBoxModelViewSet):
queryset = Interface.objects.prefetch_related(
'device', 'module__module_bay', 'parent', 'bridge', 'lag', '_path__destination', 'cable', '_link_peer',
'wireless_lans', 'untagged_vlan', 'tagged_vlans', 'vrf', 'ip_addresses', 'fhrp_group_assignments', 'tags'
@ -590,7 +590,7 @@ class InterfaceViewSet(PathEndpointMixin, ModelViewSet):
brief_prefetch_fields = ['device']
class FrontPortViewSet(PassThroughPortMixin, ModelViewSet):
class FrontPortViewSet(PassThroughPortMixin, NetBoxModelViewSet):
queryset = FrontPort.objects.prefetch_related(
'device__device_type__manufacturer', 'module__module_bay', 'rear_port', 'cable', 'tags'
)
@ -599,7 +599,7 @@ class FrontPortViewSet(PassThroughPortMixin, ModelViewSet):
brief_prefetch_fields = ['device']
class RearPortViewSet(PassThroughPortMixin, ModelViewSet):
class RearPortViewSet(PassThroughPortMixin, NetBoxModelViewSet):
queryset = RearPort.objects.prefetch_related(
'device__device_type__manufacturer', 'module__module_bay', 'cable', 'tags'
)
@ -608,21 +608,21 @@ class RearPortViewSet(PassThroughPortMixin, ModelViewSet):
brief_prefetch_fields = ['device']
class ModuleBayViewSet(ModelViewSet):
class ModuleBayViewSet(NetBoxModelViewSet):
queryset = ModuleBay.objects.prefetch_related('tags')
serializer_class = serializers.ModuleBaySerializer
filterset_class = filtersets.ModuleBayFilterSet
brief_prefetch_fields = ['device']
class DeviceBayViewSet(ModelViewSet):
class DeviceBayViewSet(NetBoxModelViewSet):
queryset = DeviceBay.objects.prefetch_related('installed_device', 'tags')
serializer_class = serializers.DeviceBaySerializer
filterset_class = filtersets.DeviceBayFilterSet
brief_prefetch_fields = ['device']
class InventoryItemViewSet(ModelViewSet):
class InventoryItemViewSet(NetBoxModelViewSet):
queryset = InventoryItem.objects.prefetch_related('device', 'manufacturer', 'tags')
serializer_class = serializers.InventoryItemSerializer
filterset_class = filtersets.InventoryItemFilterSet
@ -645,7 +645,7 @@ class InventoryItemRoleViewSet(CustomFieldModelViewSet):
# Cables
#
class CableViewSet(ModelViewSet):
class CableViewSet(NetBoxModelViewSet):
metadata_class = ContentTypeMetadata
queryset = Cable.objects.prefetch_related(
'termination_a', 'termination_b'
@ -658,7 +658,7 @@ class CableViewSet(ModelViewSet):
# Virtual chassis
#
class VirtualChassisViewSet(ModelViewSet):
class VirtualChassisViewSet(NetBoxModelViewSet):
queryset = VirtualChassis.objects.prefetch_related('tags').annotate(
member_count=count_related(Device, 'virtual_chassis')
)
@ -671,7 +671,7 @@ class VirtualChassisViewSet(ModelViewSet):
# Power panels
#
class PowerPanelViewSet(ModelViewSet):
class PowerPanelViewSet(NetBoxModelViewSet):
queryset = PowerPanel.objects.prefetch_related(
'site', 'location'
).annotate(

View File

@ -18,7 +18,7 @@ from extras.reports import get_report, get_reports, run_report
from extras.scripts import get_script, get_scripts, run_script
from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
from netbox.api.metadata import ContentTypeMetadata
from netbox.api.views import ModelViewSet
from netbox.api.viewsets import NetBoxModelViewSet
from utilities.exceptions import RQWorkerNotRunningException
from utilities.utils import copy_safe_request, count_related
from . import serializers
@ -58,7 +58,7 @@ class ConfigContextQuerySetMixin:
# Webhooks
#
class WebhookViewSet(ModelViewSet):
class WebhookViewSet(NetBoxModelViewSet):
metadata_class = ContentTypeMetadata
queryset = Webhook.objects.all()
serializer_class = serializers.WebhookSerializer
@ -69,14 +69,14 @@ class WebhookViewSet(ModelViewSet):
# Custom fields
#
class CustomFieldViewSet(ModelViewSet):
class CustomFieldViewSet(NetBoxModelViewSet):
metadata_class = ContentTypeMetadata
queryset = CustomField.objects.all()
serializer_class = serializers.CustomFieldSerializer
filterset_class = filtersets.CustomFieldFilterSet
class CustomFieldModelViewSet(ModelViewSet):
class CustomFieldModelViewSet(NetBoxModelViewSet):
"""
Include the applicable set of CustomFields in the ModelViewSet context.
"""
@ -98,7 +98,7 @@ class CustomFieldModelViewSet(ModelViewSet):
# Custom links
#
class CustomLinkViewSet(ModelViewSet):
class CustomLinkViewSet(NetBoxModelViewSet):
metadata_class = ContentTypeMetadata
queryset = CustomLink.objects.all()
serializer_class = serializers.CustomLinkSerializer
@ -109,7 +109,7 @@ class CustomLinkViewSet(ModelViewSet):
# Export templates
#
class ExportTemplateViewSet(ModelViewSet):
class ExportTemplateViewSet(NetBoxModelViewSet):
metadata_class = ContentTypeMetadata
queryset = ExportTemplate.objects.all()
serializer_class = serializers.ExportTemplateSerializer
@ -120,7 +120,7 @@ class ExportTemplateViewSet(ModelViewSet):
# Tags
#
class TagViewSet(ModelViewSet):
class TagViewSet(NetBoxModelViewSet):
queryset = Tag.objects.annotate(
tagged_items=count_related(TaggedItem, 'tag')
)
@ -132,7 +132,7 @@ class TagViewSet(ModelViewSet):
# Image attachments
#
class ImageAttachmentViewSet(ModelViewSet):
class ImageAttachmentViewSet(NetBoxModelViewSet):
metadata_class = ContentTypeMetadata
queryset = ImageAttachment.objects.all()
serializer_class = serializers.ImageAttachmentSerializer
@ -143,7 +143,7 @@ class ImageAttachmentViewSet(ModelViewSet):
# Journal entries
#
class JournalEntryViewSet(ModelViewSet):
class JournalEntryViewSet(NetBoxModelViewSet):
metadata_class = ContentTypeMetadata
queryset = JournalEntry.objects.all()
serializer_class = serializers.JournalEntrySerializer
@ -154,7 +154,7 @@ class JournalEntryViewSet(ModelViewSet):
# Config contexts
#
class ConfigContextViewSet(ModelViewSet):
class ConfigContextViewSet(NetBoxModelViewSet):
queryset = ConfigContext.objects.prefetch_related(
'regions', 'site_groups', 'sites', 'roles', 'platforms', 'tenant_groups', 'tenants',
)

View File

@ -13,7 +13,7 @@ from dcim.models import Site
from extras.api.views import CustomFieldModelViewSet
from ipam import filtersets
from ipam.models import *
from netbox.api.views import ModelViewSet, ObjectValidationMixin
from netbox.api.viewsets.mixins import ObjectValidationMixin
from netbox.config import get_config
from utilities.constants import ADVISORY_LOCK_KEYS
from utilities.utils import count_related

View File

@ -9,7 +9,7 @@ from .nested import *
# Base model serializers
#
class NetBoxModelSerializer(TaggableObjectSerializer, CustomFieldModelSerializer, ValidatedModelSerializer):
class NetBoxModelSerializer(TaggableModelSerializer, CustomFieldModelSerializer, ValidatedModelSerializer):
"""
Adds support for custom fields and tags.
"""

View File

@ -8,7 +8,7 @@ from .nested import NestedTagSerializer
__all__ = (
'CustomFieldModelSerializer',
'TaggableObjectSerializer',
'TaggableModelSerializer',
)
@ -44,7 +44,7 @@ class CustomFieldModelSerializer(serializers.Serializer):
instance.custom_fields[field.name] = instance.cf.get(field.name)
class TaggableObjectSerializer(serializers.Serializer):
class TaggableModelSerializer(serializers.Serializer):
"""
Introduces support for Tag assignment. Adds `tags` serialization, and handles tag assignment
on create() and update().

View File

@ -1,292 +1,17 @@
import logging
import platform
from collections import OrderedDict
from django import __version__ as DJANGO_VERSION
from django.apps import apps
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
from django.db import transaction
from django.db.models import ProtectedError
from django.shortcuts import get_object_or_404
from django_rq.queues import get_connection
from rest_framework import status
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet as ModelViewSet_
from rq.worker import Worker
from extras.models import ExportTemplate
from netbox.api import BulkOperationSerializer
from netbox.api.authentication import IsAuthenticatedOrLoginNotRequired
from netbox.api.exceptions import SerializerNotFound
from utilities.api import get_serializer_for_model
HTTP_ACTIONS = {
'GET': 'view',
'OPTIONS': None,
'HEAD': 'view',
'POST': 'add',
'PUT': 'change',
'PATCH': 'change',
'DELETE': 'delete',
}
#
# Mixins
#
class BulkUpdateModelMixin:
"""
Support bulk modification of objects using the list endpoint for a model. Accepts a PATCH action with a list of one
or more JSON objects, each specifying the numeric ID of an object to be updated as well as the attributes to be set.
For example:
PATCH /api/dcim/sites/
[
{
"id": 123,
"name": "New name"
},
{
"id": 456,
"status": "planned"
}
]
"""
def bulk_update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
serializer = BulkOperationSerializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
qs = self.get_queryset().filter(
pk__in=[o['id'] for o in serializer.data]
)
# Map update data by object ID
update_data = {
obj.pop('id'): obj for obj in request.data
}
data = self.perform_bulk_update(qs, update_data, partial=partial)
return Response(data, status=status.HTTP_200_OK)
def perform_bulk_update(self, objects, update_data, partial):
with transaction.atomic():
data_list = []
for obj in objects:
data = update_data.get(obj.id)
if hasattr(obj, 'snapshot'):
obj.snapshot()
serializer = self.get_serializer(obj, data=data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
data_list.append(serializer.data)
return data_list
def bulk_partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
return self.bulk_update(request, *args, **kwargs)
class BulkDestroyModelMixin:
"""
Support bulk deletion of objects using the list endpoint for a model. Accepts a DELETE action with a list of one
or more JSON objects, each specifying the numeric ID of an object to be deleted. For example:
DELETE /api/dcim/sites/
[
{"id": 123},
{"id": 456}
]
"""
def bulk_destroy(self, request, *args, **kwargs):
serializer = BulkOperationSerializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
qs = self.get_queryset().filter(
pk__in=[o['id'] for o in serializer.data]
)
self.perform_bulk_destroy(qs)
return Response(status=status.HTTP_204_NO_CONTENT)
def perform_bulk_destroy(self, objects):
with transaction.atomic():
for obj in objects:
if hasattr(obj, 'snapshot'):
obj.snapshot()
self.perform_destroy(obj)
class ObjectValidationMixin:
def _validate_objects(self, instance):
"""
Check that the provided instance or list of instances are matched by the current queryset. This confirms that
any newly created or modified objects abide by the attributes granted by any applicable ObjectPermissions.
"""
if type(instance) is list:
# Check that all instances are still included in the view's queryset
conforming_count = self.queryset.filter(pk__in=[obj.pk for obj in instance]).count()
if conforming_count != len(instance):
raise ObjectDoesNotExist
else:
# Check that the instance is matched by the view's queryset
self.queryset.get(pk=instance.pk)
#
# Viewsets
#
class ModelViewSet(BulkUpdateModelMixin, BulkDestroyModelMixin, ObjectValidationMixin, ModelViewSet_):
"""
Extend DRF's ModelViewSet to support bulk update and delete functions.
"""
brief = False
brief_prefetch_fields = []
def get_object_with_snapshot(self):
"""
Save a pre-change snapshot of the object immediately after retrieving it. This snapshot will be used to
record the "before" data in the changelog.
"""
obj = super().get_object()
if hasattr(obj, 'snapshot'):
obj.snapshot()
return obj
def get_serializer(self, *args, **kwargs):
# If a list of objects has been provided, initialize the serializer with many=True
if isinstance(kwargs.get('data', {}), list):
kwargs['many'] = True
return super().get_serializer(*args, **kwargs)
def get_serializer_class(self):
logger = logging.getLogger('netbox.api.views.ModelViewSet')
# If using 'brief' mode, find and return the nested serializer for this model, if one exists
if self.brief:
logger.debug("Request is for 'brief' format; initializing nested serializer")
try:
serializer = get_serializer_for_model(self.queryset.model, prefix='Nested')
logger.debug(f"Using serializer {serializer}")
return serializer
except SerializerNotFound:
logger.debug(f"Nested serializer for {self.queryset.model} not found!")
# Fall back to the hard-coded serializer class
logger.debug(f"Using serializer {self.serializer_class}")
return self.serializer_class
def get_queryset(self):
# If using brief mode, clear all prefetches from the queryset and append only brief_prefetch_fields (if any)
if self.brief:
return super().get_queryset().prefetch_related(None).prefetch_related(*self.brief_prefetch_fields)
return super().get_queryset()
def initialize_request(self, request, *args, **kwargs):
# Check if brief=True has been passed
if request.method == 'GET' and request.GET.get('brief'):
self.brief = True
return super().initialize_request(request, *args, **kwargs)
def initial(self, request, *args, **kwargs):
super().initial(request, *args, **kwargs)
if not request.user.is_authenticated:
return
# Restrict the view's QuerySet to allow only the permitted objects
action = HTTP_ACTIONS[request.method]
if action:
self.queryset = self.queryset.restrict(request.user, action)
def dispatch(self, request, *args, **kwargs):
logger = logging.getLogger('netbox.api.views.ModelViewSet')
try:
return super().dispatch(request, *args, **kwargs)
except ProtectedError as e:
protected_objects = list(e.protected_objects)
msg = f'Unable to delete object. {len(protected_objects)} dependent objects were found: '
msg += ', '.join([f'{obj} ({obj.pk})' for obj in protected_objects])
logger.warning(msg)
return self.finalize_response(
request,
Response({'detail': msg}, status=409),
*args,
**kwargs
)
def list(self, request, *args, **kwargs):
"""
Overrides ListModelMixin to allow processing ExportTemplates.
"""
if 'export' in request.GET:
content_type = ContentType.objects.get_for_model(self.get_serializer_class().Meta.model)
et = get_object_or_404(ExportTemplate, content_type=content_type, name=request.GET['export'])
queryset = self.filter_queryset(self.get_queryset())
return et.render_to_response(queryset)
return super().list(request, *args, **kwargs)
def perform_create(self, serializer):
model = self.queryset.model
logger = logging.getLogger('netbox.api.views.ModelViewSet')
logger.info(f"Creating new {model._meta.verbose_name}")
# Enforce object-level permissions on save()
try:
with transaction.atomic():
instance = serializer.save()
self._validate_objects(instance)
except ObjectDoesNotExist:
raise PermissionDenied()
def update(self, request, *args, **kwargs):
# Hotwire get_object() to ensure we save a pre-change snapshot
self.get_object = self.get_object_with_snapshot
return super().update(request, *args, **kwargs)
def perform_update(self, serializer):
model = self.queryset.model
logger = logging.getLogger('netbox.api.views.ModelViewSet')
logger.info(f"Updating {model._meta.verbose_name} {serializer.instance} (PK: {serializer.instance.pk})")
# Enforce object-level permissions on save()
try:
with transaction.atomic():
instance = serializer.save()
self._validate_objects(instance)
except ObjectDoesNotExist:
raise PermissionDenied()
def destroy(self, request, *args, **kwargs):
# Hotwire get_object() to ensure we save a pre-change snapshot
self.get_object = self.get_object_with_snapshot
return super().destroy(request, *args, **kwargs)
def perform_destroy(self, instance):
model = self.queryset.model
logger = logging.getLogger('netbox.api.views.ModelViewSet')
logger.info(f"Deleting {model._meta.verbose_name} {instance} (PK: {instance.pk})")
return super().perform_destroy(instance)
#
# Views
#
class APIRootView(APIView):
"""

View File

@ -0,0 +1,168 @@
import logging
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
from django.db import transaction
from django.db.models import ProtectedError
from django.shortcuts import get_object_or_404
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from extras.models import ExportTemplate
from netbox.api.exceptions import SerializerNotFound
from utilities.api import get_serializer_for_model
from .mixins import *
__all__ = (
'NetBoxModelViewSet',
)
HTTP_ACTIONS = {
'GET': 'view',
'OPTIONS': None,
'HEAD': 'view',
'POST': 'add',
'PUT': 'change',
'PATCH': 'change',
'DELETE': 'delete',
}
class NetBoxModelViewSet(BulkUpdateModelMixin, BulkDestroyModelMixin, ObjectValidationMixin, ModelViewSet):
"""
Extend DRF's ModelViewSet to support bulk update and delete functions.
"""
brief = False
brief_prefetch_fields = []
def get_object_with_snapshot(self):
"""
Save a pre-change snapshot of the object immediately after retrieving it. This snapshot will be used to
record the "before" data in the changelog.
"""
obj = super().get_object()
if hasattr(obj, 'snapshot'):
obj.snapshot()
return obj
def get_serializer(self, *args, **kwargs):
# If a list of objects has been provided, initialize the serializer with many=True
if isinstance(kwargs.get('data', {}), list):
kwargs['many'] = True
return super().get_serializer(*args, **kwargs)
def get_serializer_class(self):
logger = logging.getLogger('netbox.api.views.ModelViewSet')
# If using 'brief' mode, find and return the nested serializer for this model, if one exists
if self.brief:
logger.debug("Request is for 'brief' format; initializing nested serializer")
try:
serializer = get_serializer_for_model(self.queryset.model, prefix='Nested')
logger.debug(f"Using serializer {serializer}")
return serializer
except SerializerNotFound:
logger.debug(f"Nested serializer for {self.queryset.model} not found!")
# Fall back to the hard-coded serializer class
logger.debug(f"Using serializer {self.serializer_class}")
return self.serializer_class
def get_queryset(self):
# If using brief mode, clear all prefetches from the queryset and append only brief_prefetch_fields (if any)
if self.brief:
return super().get_queryset().prefetch_related(None).prefetch_related(*self.brief_prefetch_fields)
return super().get_queryset()
def initialize_request(self, request, *args, **kwargs):
# Check if brief=True has been passed
if request.method == 'GET' and request.GET.get('brief'):
self.brief = True
return super().initialize_request(request, *args, **kwargs)
def initial(self, request, *args, **kwargs):
super().initial(request, *args, **kwargs)
if not request.user.is_authenticated:
return
# Restrict the view's QuerySet to allow only the permitted objects
action = HTTP_ACTIONS[request.method]
if action:
self.queryset = self.queryset.restrict(request.user, action)
def dispatch(self, request, *args, **kwargs):
logger = logging.getLogger('netbox.api.views.ModelViewSet')
try:
return super().dispatch(request, *args, **kwargs)
except ProtectedError as e:
protected_objects = list(e.protected_objects)
msg = f'Unable to delete object. {len(protected_objects)} dependent objects were found: '
msg += ', '.join([f'{obj} ({obj.pk})' for obj in protected_objects])
logger.warning(msg)
return self.finalize_response(
request,
Response({'detail': msg}, status=409),
*args,
**kwargs
)
def list(self, request, *args, **kwargs):
"""
Overrides ListModelMixin to allow processing ExportTemplates.
"""
if 'export' in request.GET:
content_type = ContentType.objects.get_for_model(self.get_serializer_class().Meta.model)
et = get_object_or_404(ExportTemplate, content_type=content_type, name=request.GET['export'])
queryset = self.filter_queryset(self.get_queryset())
return et.render_to_response(queryset)
return super().list(request, *args, **kwargs)
def perform_create(self, serializer):
model = self.queryset.model
logger = logging.getLogger('netbox.api.views.ModelViewSet')
logger.info(f"Creating new {model._meta.verbose_name}")
# Enforce object-level permissions on save()
try:
with transaction.atomic():
instance = serializer.save()
self._validate_objects(instance)
except ObjectDoesNotExist:
raise PermissionDenied()
def update(self, request, *args, **kwargs):
# Hotwire get_object() to ensure we save a pre-change snapshot
self.get_object = self.get_object_with_snapshot
return super().update(request, *args, **kwargs)
def perform_update(self, serializer):
model = self.queryset.model
logger = logging.getLogger('netbox.api.views.ModelViewSet')
logger.info(f"Updating {model._meta.verbose_name} {serializer.instance} (PK: {serializer.instance.pk})")
# Enforce object-level permissions on save()
try:
with transaction.atomic():
instance = serializer.save()
self._validate_objects(instance)
except ObjectDoesNotExist:
raise PermissionDenied()
def destroy(self, request, *args, **kwargs):
# Hotwire get_object() to ensure we save a pre-change snapshot
self.get_object = self.get_object_with_snapshot
return super().destroy(request, *args, **kwargs)
def perform_destroy(self, instance):
model = self.queryset.model
logger = logging.getLogger('netbox.api.views.ModelViewSet')
logger.info(f"Deleting {model._meta.verbose_name} {instance} (PK: {instance.pk})")
return super().perform_destroy(instance)

View File

@ -0,0 +1,113 @@
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
from rest_framework import status
from rest_framework.response import Response
from netbox.api.serializers import BulkOperationSerializer
__all__ = (
'BulkUpdateModelMixin',
'BulkDestroyModelMixin',
'ObjectValidationMixin',
)
class BulkUpdateModelMixin:
"""
Support bulk modification of objects using the list endpoint for a model. Accepts a PATCH action with a list of one
or more JSON objects, each specifying the numeric ID of an object to be updated as well as the attributes to be set.
For example:
PATCH /api/dcim/sites/
[
{
"id": 123,
"name": "New name"
},
{
"id": 456,
"status": "planned"
}
]
"""
def bulk_update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
serializer = BulkOperationSerializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
qs = self.get_queryset().filter(
pk__in=[o['id'] for o in serializer.data]
)
# Map update data by object ID
update_data = {
obj.pop('id'): obj for obj in request.data
}
data = self.perform_bulk_update(qs, update_data, partial=partial)
return Response(data, status=status.HTTP_200_OK)
def perform_bulk_update(self, objects, update_data, partial):
with transaction.atomic():
data_list = []
for obj in objects:
data = update_data.get(obj.id)
if hasattr(obj, 'snapshot'):
obj.snapshot()
serializer = self.get_serializer(obj, data=data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
data_list.append(serializer.data)
return data_list
def bulk_partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
return self.bulk_update(request, *args, **kwargs)
class BulkDestroyModelMixin:
"""
Support bulk deletion of objects using the list endpoint for a model. Accepts a DELETE action with a list of one
or more JSON objects, each specifying the numeric ID of an object to be deleted. For example:
DELETE /api/dcim/sites/
[
{"id": 123},
{"id": 456}
]
"""
def bulk_destroy(self, request, *args, **kwargs):
serializer = BulkOperationSerializer(data=request.data, many=True)
serializer.is_valid(raise_exception=True)
qs = self.get_queryset().filter(
pk__in=[o['id'] for o in serializer.data]
)
self.perform_bulk_destroy(qs)
return Response(status=status.HTTP_204_NO_CONTENT)
def perform_bulk_destroy(self, objects):
with transaction.atomic():
for obj in objects:
if hasattr(obj, 'snapshot'):
obj.snapshot()
self.perform_destroy(obj)
class ObjectValidationMixin:
def _validate_objects(self, instance):
"""
Check that the provided instance or list of instances are matched by the current queryset. This confirms that
any newly created or modified objects abide by the attributes granted by any applicable ObjectPermissions.
"""
if type(instance) is list:
# Check that all instances are still included in the view's queryset
conforming_count = self.queryset.filter(pk__in=[obj.pk for obj in instance]).count()
if conforming_count != len(instance):
raise ObjectDoesNotExist
else:
# Check that the instance is matched by the view's queryset
self.queryset.get(pk=instance.pk)

View File

@ -9,7 +9,7 @@ from rest_framework.status import HTTP_201_CREATED
from rest_framework.views import APIView
from rest_framework.viewsets import ViewSet
from netbox.api.views import ModelViewSet
from netbox.api.viewsets import NetBoxModelViewSet
from users import filtersets
from users.models import ObjectPermission, Token, UserConfig
from utilities.querysets import RestrictedQuerySet
@ -29,13 +29,13 @@ class UsersRootView(APIRootView):
# Users and groups
#
class UserViewSet(ModelViewSet):
class UserViewSet(NetBoxModelViewSet):
queryset = RestrictedQuerySet(model=User).prefetch_related('groups').order_by('username')
serializer_class = serializers.UserSerializer
filterset_class = filtersets.UserFilterSet
class GroupViewSet(ModelViewSet):
class GroupViewSet(NetBoxModelViewSet):
queryset = RestrictedQuerySet(model=Group).annotate(user_count=Count('user')).order_by('name')
serializer_class = serializers.GroupSerializer
filterset_class = filtersets.GroupFilterSet
@ -45,7 +45,7 @@ class GroupViewSet(ModelViewSet):
# REST API tokens
#
class TokenViewSet(ModelViewSet):
class TokenViewSet(NetBoxModelViewSet):
queryset = RestrictedQuerySet(model=Token).prefetch_related('user')
serializer_class = serializers.TokenSerializer
filterset_class = filtersets.TokenFilterSet
@ -94,7 +94,7 @@ class TokenProvisionView(APIView):
# ObjectPermissions
#
class ObjectPermissionViewSet(ModelViewSet):
class ObjectPermissionViewSet(NetBoxModelViewSet):
queryset = ObjectPermission.objects.prefetch_related('object_types', 'groups', 'users')
serializer_class = serializers.ObjectPermissionSerializer
filterset_class = filtersets.ObjectPermissionFilterSet

View File

@ -1,7 +1,7 @@
from rest_framework.routers import APIRootView
from dcim.models import Device
from extras.api.views import ConfigContextQuerySetMixin, CustomFieldModelViewSet, ModelViewSet
from extras.api.views import ConfigContextQuerySetMixin, CustomFieldModelViewSet, NetBoxModelViewSet
from utilities.utils import count_related
from virtualization import filtersets
from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMachine, VMInterface
@ -78,7 +78,7 @@ class VirtualMachineViewSet(ConfigContextQuerySetMixin, CustomFieldModelViewSet)
return serializers.VirtualMachineWithConfigContextSerializer
class VMInterfaceViewSet(ModelViewSet):
class VMInterfaceViewSet(NetBoxModelViewSet):
queryset = VMInterface.objects.prefetch_related(
'virtual_machine', 'parent', 'tags', 'untagged_vlan', 'tagged_vlans', 'vrf', 'ip_addresses',
'fhrp_group_assignments',