From e91a76c936dc88ade43e541af0cf6e5e09181d05 Mon Sep 17 00:00:00 2001 From: jeremystretch Date: Thu, 16 Dec 2021 16:28:23 -0500 Subject: [PATCH] Refactor bulk generic views --- netbox/netbox/views/generic.py | 362 +++++++++++++++++---------------- 1 file changed, 191 insertions(+), 171 deletions(-) diff --git a/netbox/netbox/views/generic.py b/netbox/netbox/views/generic.py index c6f6305cb..3096b86fc 100644 --- a/netbox/netbox/views/generic.py +++ b/netbox/netbox/views/generic.py @@ -539,6 +539,31 @@ class BulkCreateView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): def get_required_permission(self): return get_permission_for_model(self.queryset.model, 'add') + def _create_objects(self, form, request): + new_objects = [] + + # Create objects from the expanded. Abort the transaction on the first validation error. + for value in form.cleaned_data['pattern']: + + # Reinstantiate the model form each time to avoid overwriting the same instance. Use a mutable + # copy of the POST QueryDict so that we can update the target field value. + model_form = self.model_form(request.POST.copy()) + model_form.data[self.pattern_target] = value + + # Validate each new object independently. + if model_form.is_valid(): + obj = model_form.save() + new_objects.append(obj) + else: + # Copy any errors on the pattern target field to the pattern form. + errors = model_form.errors.as_data() + if errors.get(self.pattern_target): + form.add_error('pattern', errors[self.pattern_target]) + # Raise an IntegrityError to break the for loop and abort the transaction. + raise IntegrityError() + + return new_objects + def get(self, request): # Set initial values for visible form fields from query args initial = {} @@ -564,45 +589,23 @@ class BulkCreateView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): if form.is_valid(): logger.debug("Form validation was successful") - pattern = form.cleaned_data['pattern'] - new_objs = [] try: with transaction.atomic(): - - # Create objects from the expanded. Abort the transaction on the first validation error. - for value in pattern: - - # Reinstantiate the model form each time to avoid overwriting the same instance. Use a mutable - # copy of the POST QueryDict so that we can update the target field value. - model_form = self.model_form(request.POST.copy()) - model_form.data[self.pattern_target] = value - - # Validate each new object independently. - if model_form.is_valid(): - obj = model_form.save() - logger.debug(f"Created {obj} (PK: {obj.pk})") - new_objs.append(obj) - else: - # Copy any errors on the pattern target field to the pattern form. - errors = model_form.errors.as_data() - if errors.get(self.pattern_target): - form.add_error('pattern', errors[self.pattern_target]) - # Raise an IntegrityError to break the for loop and abort the transaction. - raise IntegrityError() + new_objs = self._create_objects(form, request) # Enforce object-level permissions if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs): raise PermissionsViolation - # If we make it to this point, validation has succeeded on all new objects. - msg = "Added {} {}".format(len(new_objs), model._meta.verbose_name_plural) - logger.info(msg) - messages.success(request, msg) + # If we make it to this point, validation has succeeded on all new objects. + msg = f"Added {len(new_objs)} {model._meta.verbose_name_plural}" + logger.info(msg) + messages.success(request, msg) - if '_addanother' in request.POST: - return redirect(request.path) - return redirect(self.get_return_url(request)) + if '_addanother' in request.POST: + return redirect(request.path) + return redirect(self.get_return_url(request)) except IntegrityError: pass @@ -640,6 +643,45 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): def get_required_permission(self): return get_permission_for_model(self.queryset.model, 'add') + def _create_object(self, model_form): + + # Save the primary object + obj = model_form.save() + + # Enforce object-level permissions + if not self.queryset.filter(pk=obj.pk).first(): + raise PermissionsViolation() + + # Iterate through the related object forms (if any), validating and saving each instance. + for field_name, related_object_form in self.related_object_forms.items(): + + related_obj_pks = [] + for i, rel_obj_data in enumerate(model_form.data.get(field_name, list())): + + f = related_object_form(obj, rel_obj_data) + + for subfield_name, field in f.fields.items(): + if subfield_name not in rel_obj_data and hasattr(field, 'initial'): + f.data[subfield_name] = field.initial + + if f.is_valid(): + related_obj = f.save() + related_obj_pks.append(related_obj.pk) + else: + # Replicate errors on the related object form to the primary form for display + for subfield_name, errors in f.errors.items(): + for err in errors: + err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err) + model_form.add_error(None, err_msg) + raise AbortTransaction() + + # Enforce object-level permissions on related objects + model = related_object_form.Meta.model + if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks): + raise ObjectDoesNotExist + + return obj + def get(self, request): form = ImportForm() @@ -673,44 +715,7 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): try: with transaction.atomic(): - - # Save the primary object - obj = model_form.save() - - # Enforce object-level permissions - if not self.queryset.filter(pk=obj.pk).first(): - raise PermissionsViolation() - - logger.debug(f"Created {obj} (PK: {obj.pk})") - - # Iterate through the related object forms (if any), validating and saving each instance. - for field_name, related_object_form in self.related_object_forms.items(): - logger.debug("Processing form for related objects: {related_object_form}") - - related_obj_pks = [] - for i, rel_obj_data in enumerate(data.get(field_name, list())): - - f = related_object_form(obj, rel_obj_data) - - for subfield_name, field in f.fields.items(): - if subfield_name not in rel_obj_data and hasattr(field, 'initial'): - f.data[subfield_name] = field.initial - - if f.is_valid(): - related_obj = f.save() - related_obj_pks.append(related_obj.pk) - else: - # Replicate errors on the related object form to the primary form for display - for subfield_name, errors in f.errors.items(): - for err in errors: - err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err) - model_form.add_error(None, err_msg) - raise AbortTransaction() - - # Enforce object-level permissions on related objects - model = related_object_form.Meta.model - if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks): - raise ObjectDoesNotExist + obj = self._create_object(model_form) except AbortTransaction: clear_webhooks.send(sender=self) @@ -723,9 +728,8 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): if not model_form.errors: logger.info(f"Import object {obj} (PK: {obj.pk})") - messages.success(request, mark_safe('Imported object: {}'.format( - obj.get_absolute_url(), obj - ))) + msg = f'Imported object: {obj}' + messages.success(request, mark_safe(msg)) if '_addanother' in request.POST: return redirect(request.get_full_path()) @@ -733,8 +737,7 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): return_url = form.cleaned_data.get('return_url') if return_url is not None and is_safe_url(url=return_url, allowed_hosts=request.get_host()): return redirect(return_url) - else: - return redirect(self.get_return_url(request, obj)) + return redirect(self.get_return_url(request, obj)) else: logger.debug("Model form validation failed") @@ -799,6 +802,27 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): return ImportForm(*args, **kwargs) + def _create_objects(self, form, request): + new_objs = [] + if request.FILES: + headers, records = form.cleaned_data['csv_file'] + else: + headers, records = form.cleaned_data['csv'] + + for row, data in enumerate(records, start=1): + obj_form = self.model_form(data, headers=headers) + restrict_form_fields(obj_form, request.user) + + if obj_form.is_valid(): + obj = self._save_obj(obj_form, request) + new_objs.append(obj) + else: + for field, err in obj_form.errors.items(): + form.add_error('csv', f'Row {row} {field}: {err[0]}') + raise ValidationError("") + + return new_objs + def _save_obj(self, obj_form, request): """ Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data). @@ -819,7 +843,6 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): def post(self, request): logger = logging.getLogger('netbox.views.BulkImportView') - new_objs = [] form = self._import_form(request.POST, request.FILES) if form.is_valid(): @@ -828,21 +851,7 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): try: # Iterate through CSV data and bind each row to a new model form instance. with transaction.atomic(): - if request.FILES: - headers, records = form.cleaned_data['csv_file'] - else: - headers, records = form.cleaned_data['csv'] - for row, data in enumerate(records, start=1): - obj_form = self.model_form(data, headers=headers) - restrict_form_fields(obj_form, request.user) - - if obj_form.is_valid(): - obj = self._save_obj(obj_form, request) - new_objs.append(obj) - else: - for field, err in obj_form.errors.items(): - form.add_error('csv', "Row {} {}: {}".format(row, field, err[0])) - raise ValidationError("") + new_objs = self._create_objects(form, request) # Enforce object-level permissions if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs): @@ -886,7 +895,7 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): Edit objects in bulk. queryset: Custom queryset to use when retrieving objects (e.g. to select related objects) - filter: FilterSet to apply when deleting by QuerySet + filterset: FilterSet to apply when deleting by QuerySet table: The table used to display devices being edited form: The form class used to edit objects in bulk template_name: The name of the template @@ -900,6 +909,63 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): def get_required_permission(self): return get_permission_for_model(self.queryset.model, 'change') + def _update_objects(self, form, request): + custom_fields = form.custom_fields if hasattr(form, 'custom_fields') else [] + standard_fields = [ + field for field in form.fields if field not in custom_fields + ['pk'] + ] + nullified_fields = request.POST.getlist('_nullify') + updated_objects = [] + + for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']): + + # Take a snapshot of change-logged models + if hasattr(obj, 'snapshot'): + obj.snapshot() + + # Update standard fields. If a field is listed in _nullify, delete its value. + for name in standard_fields: + + try: + model_field = self.queryset.model._meta.get_field(name) + except FieldDoesNotExist: + # This form field is used to modify a field rather than set its value directly + model_field = None + + # Handle nullification + if name in form.nullable_fields and name in nullified_fields: + if isinstance(model_field, ManyToManyField): + getattr(obj, name).set([]) + else: + setattr(obj, name, None if model_field.null else '') + + # ManyToManyFields + elif isinstance(model_field, ManyToManyField): + if form.cleaned_data[name]: + getattr(obj, name).set(form.cleaned_data[name]) + # Normal fields + elif name in form.changed_data: + setattr(obj, name, form.cleaned_data[name]) + + # Update custom fields + for name in custom_fields: + if name in form.nullable_fields and name in nullified_fields: + obj.custom_field_data[name] = None + elif name in form.changed_data: + obj.custom_field_data[name] = form.cleaned_data[name] + + obj.full_clean() + obj.save() + updated_objects.append(obj) + + # Add/remove tags + if form.cleaned_data.get('add_tags', None): + obj.tags.add(*form.cleaned_data['add_tags']) + if form.cleaned_data.get('remove_tags', None): + obj.tags.remove(*form.cleaned_data['remove_tags']) + + return updated_objects + def get(self, request): return redirect(self.get_return_url(request)) @@ -932,78 +998,26 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): if form.is_valid(): logger.debug("Form validation was successful") - custom_fields = form.custom_fields if hasattr(form, 'custom_fields') else [] - standard_fields = [ - field for field in form.fields if field not in custom_fields + ['pk'] - ] - nullified_fields = request.POST.getlist('_nullify') try: with transaction.atomic(): - - updated_objects = [] - for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']): - - # Take a snapshot of change-logged models - if hasattr(obj, 'snapshot'): - obj.snapshot() - - # Update standard fields. If a field is listed in _nullify, delete its value. - for name in standard_fields: - - try: - model_field = model._meta.get_field(name) - except FieldDoesNotExist: - # This form field is used to modify a field rather than set its value directly - model_field = None - - # Handle nullification - if name in form.nullable_fields and name in nullified_fields: - if isinstance(model_field, ManyToManyField): - getattr(obj, name).set([]) - else: - setattr(obj, name, None if model_field.null else '') - - # ManyToManyFields - elif isinstance(model_field, ManyToManyField): - if form.cleaned_data[name]: - getattr(obj, name).set(form.cleaned_data[name]) - # Normal fields - elif name in form.changed_data: - setattr(obj, name, form.cleaned_data[name]) - - # Update custom fields - for name in custom_fields: - if name in form.nullable_fields and name in nullified_fields: - obj.custom_field_data[name] = None - elif name in form.changed_data: - obj.custom_field_data[name] = form.cleaned_data[name] - - obj.full_clean() - obj.save() - updated_objects.append(obj) - logger.debug(f"Saved {obj} (PK: {obj.pk})") - - # Add/remove tags - if form.cleaned_data.get('add_tags', None): - obj.tags.add(*form.cleaned_data['add_tags']) - if form.cleaned_data.get('remove_tags', None): - obj.tags.remove(*form.cleaned_data['remove_tags']) + updated_objects = self._update_objects(form, request) # Enforce object-level permissions - if self.queryset.filter(pk__in=[obj.pk for obj in updated_objects]).count() != len(updated_objects): + object_count = self.queryset.filter(pk__in=[obj.pk for obj in updated_objects]).count() + if object_count != len(updated_objects): raise PermissionsViolation if updated_objects: - msg = 'Updated {} {}'.format(len(updated_objects), model._meta.verbose_name_plural) + msg = f'Updated {len(updated_objects)} {model._meta.verbose_name_plural}' logger.info(msg) messages.success(self.request, msg) return redirect(self.get_return_url(request)) except ValidationError as e: - messages.error(self.request, "{} failed validation: {}".format(obj, ", ".join(e.messages))) + messages.error(self.request, ", ".join(e.messages)) clear_webhooks.send(sender=self) except PermissionsViolation: @@ -1016,7 +1030,6 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): logger.debug("Form validation failed") else: - form = self.form(model, initial=initial_data) restrict_form_fields(form, request.user) @@ -1037,6 +1050,9 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): """ An extendable view for renaming objects in bulk. + + queryset: QuerySet of objects being renamed + template_name: The name of the template """ queryset = None template_name = 'generic/object_bulk_rename.html' @@ -1056,6 +1072,29 @@ class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): def get_required_permission(self): return get_permission_for_model(self.queryset.model, 'change') + def _rename_objects(self, form, selected_objects): + renamed_pks = [] + + for obj in selected_objects: + + # Take a snapshot of change-logged models + if hasattr(obj, 'snapshot'): + obj.snapshot() + + find = form.cleaned_data['find'] + replace = form.cleaned_data['replace'] + if form.cleaned_data['use_regex']: + try: + obj.new_name = re.sub(find, replace, obj.name) + # Catch regex group reference errors + except re.error: + obj.new_name = obj.name + else: + obj.new_name = obj.name.replace(find, replace) + renamed_pks.append(obj.pk) + + return renamed_pks + def post(self, request): logger = logging.getLogger('netbox.views.BulkRenameView') @@ -1066,24 +1105,7 @@ class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): if form.is_valid(): try: with transaction.atomic(): - renamed_pks = [] - for obj in selected_objects: - - # Take a snapshot of change-logged models - if hasattr(obj, 'snapshot'): - obj.snapshot() - - find = form.cleaned_data['find'] - replace = form.cleaned_data['replace'] - if form.cleaned_data['use_regex']: - try: - obj.new_name = re.sub(find, replace, obj.name) - # Catch regex group reference errors - except re.error: - obj.new_name = obj.name - else: - obj.new_name = obj.name.replace(find, replace) - renamed_pks.append(obj.pk) + renamed_pks = self._rename_objects(form, selected_objects) if '_apply' in request.POST: for obj in selected_objects: @@ -1094,10 +1116,8 @@ class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): if self.queryset.filter(pk__in=renamed_pks).count() != len(selected_objects): raise PermissionsViolation - messages.success(request, "Renamed {} {}".format( - len(selected_objects), - self.queryset.model._meta.verbose_name_plural - )) + model_name = self.queryset.model._meta.verbose_name_plural + messages.success(request, f"Renamed {len(selected_objects)} {model_name}") return redirect(self.get_return_url(request)) except PermissionsViolation: @@ -1123,7 +1143,7 @@ class BulkDeleteView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): Delete objects in bulk. queryset: Custom queryset to use when retrieving objects (e.g. to select related objects) - filter: FilterSet to apply when deleting by QuerySet + filterset: FilterSet to apply when deleting by QuerySet table: The table used to display devices being deleted form: The form class used to delete objects in bulk template_name: The name of the template