1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

Refactor bulk generic views

This commit is contained in:
jeremystretch
2021-12-16 16:28:23 -05:00
parent 1dd3d2ec48
commit e91a76c936

View File

@ -539,6 +539,31 @@ class BulkCreateView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
def get_required_permission(self):
return get_permission_for_model(self.queryset.model, 'add')
def _create_objects(self, form, request):
new_objects = []
# Create objects from the expanded. Abort the transaction on the first validation error.
for value in form.cleaned_data['pattern']:
# Reinstantiate the model form each time to avoid overwriting the same instance. Use a mutable
# copy of the POST QueryDict so that we can update the target field value.
model_form = self.model_form(request.POST.copy())
model_form.data[self.pattern_target] = value
# Validate each new object independently.
if model_form.is_valid():
obj = model_form.save()
new_objects.append(obj)
else:
# Copy any errors on the pattern target field to the pattern form.
errors = model_form.errors.as_data()
if errors.get(self.pattern_target):
form.add_error('pattern', errors[self.pattern_target])
# Raise an IntegrityError to break the for loop and abort the transaction.
raise IntegrityError()
return new_objects
def get(self, request):
# Set initial values for visible form fields from query args
initial = {}
@ -564,45 +589,23 @@ class BulkCreateView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
if form.is_valid():
logger.debug("Form validation was successful")
pattern = form.cleaned_data['pattern']
new_objs = []
try:
with transaction.atomic():
# Create objects from the expanded. Abort the transaction on the first validation error.
for value in pattern:
# Reinstantiate the model form each time to avoid overwriting the same instance. Use a mutable
# copy of the POST QueryDict so that we can update the target field value.
model_form = self.model_form(request.POST.copy())
model_form.data[self.pattern_target] = value
# Validate each new object independently.
if model_form.is_valid():
obj = model_form.save()
logger.debug(f"Created {obj} (PK: {obj.pk})")
new_objs.append(obj)
else:
# Copy any errors on the pattern target field to the pattern form.
errors = model_form.errors.as_data()
if errors.get(self.pattern_target):
form.add_error('pattern', errors[self.pattern_target])
# Raise an IntegrityError to break the for loop and abort the transaction.
raise IntegrityError()
new_objs = self._create_objects(form, request)
# Enforce object-level permissions
if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):
raise PermissionsViolation
# If we make it to this point, validation has succeeded on all new objects.
msg = "Added {} {}".format(len(new_objs), model._meta.verbose_name_plural)
logger.info(msg)
messages.success(request, msg)
# If we make it to this point, validation has succeeded on all new objects.
msg = f"Added {len(new_objs)} {model._meta.verbose_name_plural}"
logger.info(msg)
messages.success(request, msg)
if '_addanother' in request.POST:
return redirect(request.path)
return redirect(self.get_return_url(request))
if '_addanother' in request.POST:
return redirect(request.path)
return redirect(self.get_return_url(request))
except IntegrityError:
pass
@ -640,6 +643,45 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
def get_required_permission(self):
return get_permission_for_model(self.queryset.model, 'add')
def _create_object(self, model_form):
# Save the primary object
obj = model_form.save()
# Enforce object-level permissions
if not self.queryset.filter(pk=obj.pk).first():
raise PermissionsViolation()
# Iterate through the related object forms (if any), validating and saving each instance.
for field_name, related_object_form in self.related_object_forms.items():
related_obj_pks = []
for i, rel_obj_data in enumerate(model_form.data.get(field_name, list())):
f = related_object_form(obj, rel_obj_data)
for subfield_name, field in f.fields.items():
if subfield_name not in rel_obj_data and hasattr(field, 'initial'):
f.data[subfield_name] = field.initial
if f.is_valid():
related_obj = f.save()
related_obj_pks.append(related_obj.pk)
else:
# Replicate errors on the related object form to the primary form for display
for subfield_name, errors in f.errors.items():
for err in errors:
err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err)
model_form.add_error(None, err_msg)
raise AbortTransaction()
# Enforce object-level permissions on related objects
model = related_object_form.Meta.model
if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks):
raise ObjectDoesNotExist
return obj
def get(self, request):
form = ImportForm()
@ -673,44 +715,7 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
try:
with transaction.atomic():
# Save the primary object
obj = model_form.save()
# Enforce object-level permissions
if not self.queryset.filter(pk=obj.pk).first():
raise PermissionsViolation()
logger.debug(f"Created {obj} (PK: {obj.pk})")
# Iterate through the related object forms (if any), validating and saving each instance.
for field_name, related_object_form in self.related_object_forms.items():
logger.debug("Processing form for related objects: {related_object_form}")
related_obj_pks = []
for i, rel_obj_data in enumerate(data.get(field_name, list())):
f = related_object_form(obj, rel_obj_data)
for subfield_name, field in f.fields.items():
if subfield_name not in rel_obj_data and hasattr(field, 'initial'):
f.data[subfield_name] = field.initial
if f.is_valid():
related_obj = f.save()
related_obj_pks.append(related_obj.pk)
else:
# Replicate errors on the related object form to the primary form for display
for subfield_name, errors in f.errors.items():
for err in errors:
err_msg = "{}[{}] {}: {}".format(field_name, i, subfield_name, err)
model_form.add_error(None, err_msg)
raise AbortTransaction()
# Enforce object-level permissions on related objects
model = related_object_form.Meta.model
if model.objects.filter(pk__in=related_obj_pks).count() != len(related_obj_pks):
raise ObjectDoesNotExist
obj = self._create_object(model_form)
except AbortTransaction:
clear_webhooks.send(sender=self)
@ -723,9 +728,8 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
if not model_form.errors:
logger.info(f"Import object {obj} (PK: {obj.pk})")
messages.success(request, mark_safe('Imported object: <a href="{}">{}</a>'.format(
obj.get_absolute_url(), obj
)))
msg = f'Imported object: <a href="{obj.get_absolute_url()}">{obj}</a>'
messages.success(request, mark_safe(msg))
if '_addanother' in request.POST:
return redirect(request.get_full_path())
@ -733,8 +737,7 @@ class ObjectImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
return_url = form.cleaned_data.get('return_url')
if return_url is not None and is_safe_url(url=return_url, allowed_hosts=request.get_host()):
return redirect(return_url)
else:
return redirect(self.get_return_url(request, obj))
return redirect(self.get_return_url(request, obj))
else:
logger.debug("Model form validation failed")
@ -799,6 +802,27 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
return ImportForm(*args, **kwargs)
def _create_objects(self, form, request):
new_objs = []
if request.FILES:
headers, records = form.cleaned_data['csv_file']
else:
headers, records = form.cleaned_data['csv']
for row, data in enumerate(records, start=1):
obj_form = self.model_form(data, headers=headers)
restrict_form_fields(obj_form, request.user)
if obj_form.is_valid():
obj = self._save_obj(obj_form, request)
new_objs.append(obj)
else:
for field, err in obj_form.errors.items():
form.add_error('csv', f'Row {row} {field}: {err[0]}')
raise ValidationError("")
return new_objs
def _save_obj(self, obj_form, request):
"""
Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
@ -819,7 +843,6 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
def post(self, request):
logger = logging.getLogger('netbox.views.BulkImportView')
new_objs = []
form = self._import_form(request.POST, request.FILES)
if form.is_valid():
@ -828,21 +851,7 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
try:
# Iterate through CSV data and bind each row to a new model form instance.
with transaction.atomic():
if request.FILES:
headers, records = form.cleaned_data['csv_file']
else:
headers, records = form.cleaned_data['csv']
for row, data in enumerate(records, start=1):
obj_form = self.model_form(data, headers=headers)
restrict_form_fields(obj_form, request.user)
if obj_form.is_valid():
obj = self._save_obj(obj_form, request)
new_objs.append(obj)
else:
for field, err in obj_form.errors.items():
form.add_error('csv', "Row {} {}: {}".format(row, field, err[0]))
raise ValidationError("")
new_objs = self._create_objects(form, request)
# Enforce object-level permissions
if self.queryset.filter(pk__in=[obj.pk for obj in new_objs]).count() != len(new_objs):
@ -886,7 +895,7 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
Edit objects in bulk.
queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
filter: FilterSet to apply when deleting by QuerySet
filterset: FilterSet to apply when deleting by QuerySet
table: The table used to display devices being edited
form: The form class used to edit objects in bulk
template_name: The name of the template
@ -900,6 +909,63 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
def get_required_permission(self):
return get_permission_for_model(self.queryset.model, 'change')
def _update_objects(self, form, request):
custom_fields = form.custom_fields if hasattr(form, 'custom_fields') else []
standard_fields = [
field for field in form.fields if field not in custom_fields + ['pk']
]
nullified_fields = request.POST.getlist('_nullify')
updated_objects = []
for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']):
# Take a snapshot of change-logged models
if hasattr(obj, 'snapshot'):
obj.snapshot()
# Update standard fields. If a field is listed in _nullify, delete its value.
for name in standard_fields:
try:
model_field = self.queryset.model._meta.get_field(name)
except FieldDoesNotExist:
# This form field is used to modify a field rather than set its value directly
model_field = None
# Handle nullification
if name in form.nullable_fields and name in nullified_fields:
if isinstance(model_field, ManyToManyField):
getattr(obj, name).set([])
else:
setattr(obj, name, None if model_field.null else '')
# ManyToManyFields
elif isinstance(model_field, ManyToManyField):
if form.cleaned_data[name]:
getattr(obj, name).set(form.cleaned_data[name])
# Normal fields
elif name in form.changed_data:
setattr(obj, name, form.cleaned_data[name])
# Update custom fields
for name in custom_fields:
if name in form.nullable_fields and name in nullified_fields:
obj.custom_field_data[name] = None
elif name in form.changed_data:
obj.custom_field_data[name] = form.cleaned_data[name]
obj.full_clean()
obj.save()
updated_objects.append(obj)
# Add/remove tags
if form.cleaned_data.get('add_tags', None):
obj.tags.add(*form.cleaned_data['add_tags'])
if form.cleaned_data.get('remove_tags', None):
obj.tags.remove(*form.cleaned_data['remove_tags'])
return updated_objects
def get(self, request):
return redirect(self.get_return_url(request))
@ -932,78 +998,26 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
if form.is_valid():
logger.debug("Form validation was successful")
custom_fields = form.custom_fields if hasattr(form, 'custom_fields') else []
standard_fields = [
field for field in form.fields if field not in custom_fields + ['pk']
]
nullified_fields = request.POST.getlist('_nullify')
try:
with transaction.atomic():
updated_objects = []
for obj in self.queryset.filter(pk__in=form.cleaned_data['pk']):
# Take a snapshot of change-logged models
if hasattr(obj, 'snapshot'):
obj.snapshot()
# Update standard fields. If a field is listed in _nullify, delete its value.
for name in standard_fields:
try:
model_field = model._meta.get_field(name)
except FieldDoesNotExist:
# This form field is used to modify a field rather than set its value directly
model_field = None
# Handle nullification
if name in form.nullable_fields and name in nullified_fields:
if isinstance(model_field, ManyToManyField):
getattr(obj, name).set([])
else:
setattr(obj, name, None if model_field.null else '')
# ManyToManyFields
elif isinstance(model_field, ManyToManyField):
if form.cleaned_data[name]:
getattr(obj, name).set(form.cleaned_data[name])
# Normal fields
elif name in form.changed_data:
setattr(obj, name, form.cleaned_data[name])
# Update custom fields
for name in custom_fields:
if name in form.nullable_fields and name in nullified_fields:
obj.custom_field_data[name] = None
elif name in form.changed_data:
obj.custom_field_data[name] = form.cleaned_data[name]
obj.full_clean()
obj.save()
updated_objects.append(obj)
logger.debug(f"Saved {obj} (PK: {obj.pk})")
# Add/remove tags
if form.cleaned_data.get('add_tags', None):
obj.tags.add(*form.cleaned_data['add_tags'])
if form.cleaned_data.get('remove_tags', None):
obj.tags.remove(*form.cleaned_data['remove_tags'])
updated_objects = self._update_objects(form, request)
# Enforce object-level permissions
if self.queryset.filter(pk__in=[obj.pk for obj in updated_objects]).count() != len(updated_objects):
object_count = self.queryset.filter(pk__in=[obj.pk for obj in updated_objects]).count()
if object_count != len(updated_objects):
raise PermissionsViolation
if updated_objects:
msg = 'Updated {} {}'.format(len(updated_objects), model._meta.verbose_name_plural)
msg = f'Updated {len(updated_objects)} {model._meta.verbose_name_plural}'
logger.info(msg)
messages.success(self.request, msg)
return redirect(self.get_return_url(request))
except ValidationError as e:
messages.error(self.request, "{} failed validation: {}".format(obj, ", ".join(e.messages)))
messages.error(self.request, ", ".join(e.messages))
clear_webhooks.send(sender=self)
except PermissionsViolation:
@ -1016,7 +1030,6 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
logger.debug("Form validation failed")
else:
form = self.form(model, initial=initial_data)
restrict_form_fields(form, request.user)
@ -1037,6 +1050,9 @@ class BulkEditView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
"""
An extendable view for renaming objects in bulk.
queryset: QuerySet of objects being renamed
template_name: The name of the template
"""
queryset = None
template_name = 'generic/object_bulk_rename.html'
@ -1056,6 +1072,29 @@ class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
def get_required_permission(self):
return get_permission_for_model(self.queryset.model, 'change')
def _rename_objects(self, form, selected_objects):
renamed_pks = []
for obj in selected_objects:
# Take a snapshot of change-logged models
if hasattr(obj, 'snapshot'):
obj.snapshot()
find = form.cleaned_data['find']
replace = form.cleaned_data['replace']
if form.cleaned_data['use_regex']:
try:
obj.new_name = re.sub(find, replace, obj.name)
# Catch regex group reference errors
except re.error:
obj.new_name = obj.name
else:
obj.new_name = obj.name.replace(find, replace)
renamed_pks.append(obj.pk)
return renamed_pks
def post(self, request):
logger = logging.getLogger('netbox.views.BulkRenameView')
@ -1066,24 +1105,7 @@ class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
if form.is_valid():
try:
with transaction.atomic():
renamed_pks = []
for obj in selected_objects:
# Take a snapshot of change-logged models
if hasattr(obj, 'snapshot'):
obj.snapshot()
find = form.cleaned_data['find']
replace = form.cleaned_data['replace']
if form.cleaned_data['use_regex']:
try:
obj.new_name = re.sub(find, replace, obj.name)
# Catch regex group reference errors
except re.error:
obj.new_name = obj.name
else:
obj.new_name = obj.name.replace(find, replace)
renamed_pks.append(obj.pk)
renamed_pks = self._rename_objects(form, selected_objects)
if '_apply' in request.POST:
for obj in selected_objects:
@ -1094,10 +1116,8 @@ class BulkRenameView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
if self.queryset.filter(pk__in=renamed_pks).count() != len(selected_objects):
raise PermissionsViolation
messages.success(request, "Renamed {} {}".format(
len(selected_objects),
self.queryset.model._meta.verbose_name_plural
))
model_name = self.queryset.model._meta.verbose_name_plural
messages.success(request, f"Renamed {len(selected_objects)} {model_name}")
return redirect(self.get_return_url(request))
except PermissionsViolation:
@ -1123,7 +1143,7 @@ class BulkDeleteView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View):
Delete objects in bulk.
queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
filter: FilterSet to apply when deleting by QuerySet
filterset: FilterSet to apply when deleting by QuerySet
table: The table used to display devices being deleted
form: The form class used to delete objects in bulk
template_name: The name of the template