From aeb32104a46c32797380c80e2549e4583377d58d Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Thu, 14 May 2020 17:44:46 -0400 Subject: [PATCH] Enforce object-level permissions for API views --- netbox/utilities/api.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/netbox/utilities/api.py b/netbox/utilities/api.py index 205055669..405c26878 100644 --- a/netbox/utilities/api.py +++ b/netbox/utilities/api.py @@ -6,15 +6,15 @@ from django.conf import settings from django.contrib.contenttypes.models import ContentType from django.core.exceptions import FieldError, MultipleObjectsReturned, ObjectDoesNotExist from django.db.models import ManyToManyField, ProtectedError -from django.http import Http404 from django.urls import reverse from rest_framework.exceptions import APIException from rest_framework.permissions import BasePermission from rest_framework.relations import PrimaryKeyRelatedField, RelatedField from rest_framework.response import Response from rest_framework.serializers import Field, ModelSerializer, ValidationError -from rest_framework.viewsets import ModelViewSet as _ModelViewSet, ViewSet +from rest_framework.viewsets import ModelViewSet as _ModelViewSet +from users.models import ObjectPermission from .utils import dict_to_filter_params, dynamic_import @@ -323,6 +323,22 @@ class ModelViewSet(_ModelViewSet): logger.debug(f"Using serializer {self.serializer_class}") return self.serializer_class + def initial(self, request, *args, **kwargs): + super().initial(request, *args, **kwargs) + + if not request.user.is_authenticated or request.user.is_superuser: + return + + permission_required = 'dcim.view_site' + + # Enforce object-level permissions + if permission_required not in self.request.user._perm_cache: + attrs = ObjectPermission.objects.get_attr_constraints(self.request.user, permission_required) + if attrs: + # Update the view's QuerySet to filter only the permitted objects + self.queryset = self.queryset.filter(attrs) + return True + def dispatch(self, request, *args, **kwargs): logger = logging.getLogger('netbox.api.views.ModelViewSet')