diff --git a/netbox/netbox/views/generic.py b/netbox/netbox/views/generic.py index 0d1734c08..24d459219 100644 --- a/netbox/netbox/views/generic.py +++ b/netbox/netbox/views/generic.py @@ -1,6 +1,5 @@ import logging import re -import csv from copy import deepcopy from django.contrib import messages @@ -8,7 +7,7 @@ from django.contrib.contenttypes.models import ContentType from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist, ValidationError from django.db import transaction, IntegrityError from django.db.models import ManyToManyField, ProtectedError -from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput, Textarea, FileField +from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput, Textarea from django.http import HttpResponse from django.shortcuts import get_object_or_404, redirect, render from django.utils.html import escape @@ -21,7 +20,7 @@ from extras.models import CustomField, ExportTemplate from utilities.error_handlers import handle_protectederror from utilities.exceptions import AbortTransaction from utilities.forms import ( - BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, ImportForm, TableConfigForm, restrict_form_fields, + BootstrapMixin, BulkRenameForm, ConfirmationForm, CSVDataField, ImportForm, TableConfigForm, restrict_form_fields, CSVFileField ) from utilities.permissions import get_permission_for_model from utilities.tables import paginate_table @@ -666,7 +665,8 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): from_form=self.model_form, widget=Textarea(attrs=self.widget_attrs) ) - upload_csv = FileField( + upload_csv = CSVFileField( + from_form=self.model_form, required=False ) return ImportForm(*args, **kwargs) @@ -701,26 +701,9 @@ class BulkImportView(GetReturnURLMixin, ObjectPermissionRequiredMixin, View): # Iterate through CSV data and bind each row to a new model form instance. with transaction.atomic(): if request.FILES: - csv_file = request.FILES["upload_csv"] - csv_file.seek(0) - csv_str = csv_file.read().decode('utf-8') - reader = csv.reader(csv_str.splitlines()) - headers_list = next(reader) - headers = {} - for header in headers_list: - headers[header] = None - records = [] - for row in reader: - row_dict = {} - for i, elt in enumerate(row): - if elt == '': - row_dict[headers_list[i]] = None - else: - row_dict[headers_list[i]] = elt - records.append(row_dict) + headers, records = form.cleaned_data['upload_csv'] else: headers, records = form.cleaned_data['csv'] - print("headers:", headers, "records:", records) for row, data in enumerate(records, start=1): obj_form = self.model_form(data, headers=headers) restrict_form_fields(obj_form, request.user) diff --git a/netbox/utilities/forms/fields.py b/netbox/utilities/forms/fields.py index 93b9e6c44..1d8e299e0 100644 --- a/netbox/utilities/forms/fields.py +++ b/netbox/utilities/forms/fields.py @@ -26,6 +26,7 @@ __all__ = ( 'CSVChoiceField', 'CSVContentTypeField', 'CSVDataField', + 'CSVFileField', 'CSVModelChoiceField', 'CSVTypedChoiceField', 'DynamicModelChoiceField', @@ -221,6 +222,77 @@ class CSVDataField(forms.CharField): return value +class CSVFileField(forms.FileField): + """ + A CharField (rendered as a Textarea) which accepts CSV-formatted data. It returns data as a two-tuple: The first + item is a dictionary of column headers, mapping field names to the attribute by which they match a related object + (where applicable). The second item is a list of dictionaries, each representing a discrete row of CSV data. + + :param from_form: The form from which the field derives its validation rules. + """ + + def __init__(self, from_form, *args, **kwargs): + + form = from_form() + self.model = form.Meta.model + self.fields = form.fields + self.required_fields = [ + name for name, field in form.fields.items() if field.required + ] + + super().__init__(*args, **kwargs) + + def to_python(self, file): + + records = [] + file.seek(0) + csv_str = file.read().decode('utf-8') + reader = csv.reader(csv_str.splitlines()) + + # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional + # "to" field specifying how the related object is being referenced. For example, importing a Device might use a + # `site.slug` header, to indicate the related site is being referenced by its slug. + + headers = {} + for header in next(reader): + if '.' in header: + field, to_field = header.split('.', 1) + headers[field] = to_field + else: + headers[header] = None + + # Parse CSV rows into a list of dictionaries mapped from the column headers. + for i, row in enumerate(reader, start=1): + if len(row) != len(headers): + raise forms.ValidationError( + f"Row {i}: Expected {len(headers)} columns but found {len(row)}" + ) + row = [col.strip() for col in row] + record = dict(zip(headers.keys(), row)) + records.append(record) + + return headers, records + + def validate(self, value): + headers, records = value + + # Validate provided column headers + for field, to_field in headers.items(): + if field not in self.fields: + raise forms.ValidationError(f'Unexpected column header "{field}" found.') + if to_field and not hasattr(self.fields[field], 'to_field_name'): + raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots') + if to_field and not hasattr(self.fields[field].queryset.model, to_field): + raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}') + + # Validate required fields + for f in self.required_fields: + if f not in headers: + raise forms.ValidationError(f'Required column header "{f}" not found.') + + return value + + class CSVChoiceField(forms.ChoiceField): """ Invert the provided set of choices to take the human-friendly label as input, and return the database value.