1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

Merge pull request #4564 from netbox-community/3147-csv-import-fields

Closes #3147: Allow dynamic access to related objects during CSV import
This commit is contained in:
Jeremy Stretch
2020-05-06 10:15:00 -04:00
committed by GitHub
17 changed files with 706 additions and 879 deletions

View File

@@ -8,6 +8,7 @@ import yaml
from django import forms
from django.conf import settings
from django.contrib.postgres.forms.jsonb import JSONField as _JSONField, InvalidJSONInput
from django.core.exceptions import MultipleObjectsReturned
from django.db.models import Count
from django.forms import BoundField
from django.forms.models import fields_for_model
@@ -400,15 +401,22 @@ class TimePicker(forms.TextInput):
class CSVDataField(forms.CharField):
"""
A CharField (rendered as a Textarea) which accepts CSV-formatted data. It returns a list of dictionaries mapping
column headers to values. Each dictionary represents an individual record.
A CharField (rendered as a Textarea) which accepts CSV-formatted data. It returns data as a two-tuple: The first
item is a dictionary of column headers, mapping field names to the attribute by which they match a related object
(where applicable). The second item is a list of dictionaries, each representing a discrete row of CSV data.
:param from_form: The form from which the field derives its validation rules.
"""
widget = forms.Textarea
def __init__(self, fields, required_fields=[], *args, **kwargs):
def __init__(self, from_form, *args, **kwargs):
self.fields = fields
self.required_fields = required_fields
form = from_form()
self.model = form.Meta.model
self.fields = form.fields
self.required_fields = [
name for name, field in form.fields.items() if field.required
]
super().__init__(*args, **kwargs)
@@ -416,7 +424,7 @@ class CSVDataField(forms.CharField):
if not self.label:
self.label = ''
if not self.initial:
self.initial = ','.join(required_fields) + '\n'
self.initial = ','.join(self.required_fields) + '\n'
if not self.help_text:
self.help_text = 'Enter the list of column headers followed by one line per record to be imported, using ' \
'commas to separate values. Multi-line data and values containing commas may be wrapped ' \
@@ -425,36 +433,55 @@ class CSVDataField(forms.CharField):
def to_python(self, value):
records = []
reader = csv.reader(StringIO(value))
reader = csv.reader(StringIO(value.strip()))
# Consume and validate the first line of CSV data as column headers
headers = next(reader)
# Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
# "to" field specifying how the related object is being referenced. For example, importing a Device might use a
# `site.slug` header, to indicate the related site is being referenced by its slug.
headers = {}
for header in next(reader):
if '.' in header:
field, to_field = header.split('.', 1)
headers[field] = to_field
else:
headers[header] = None
# Parse CSV rows into a list of dictionaries mapped from the column headers.
for i, row in enumerate(reader, start=1):
if len(row) != len(headers):
raise forms.ValidationError(
f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
)
row = [col.strip() for col in row]
record = dict(zip(headers.keys(), row))
records.append(record)
return headers, records
def validate(self, value):
headers, records = value
# Validate provided column headers
for field, to_field in headers.items():
if field not in self.fields:
raise forms.ValidationError(f'Unexpected column header "{field}" found.')
if to_field and not hasattr(self.fields[field], 'to_field_name'):
raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
if to_field and not hasattr(self.fields[field].queryset.model, to_field):
raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
# Validate required fields
for f in self.required_fields:
if f not in headers:
raise forms.ValidationError('Required column header "{}" not found.'.format(f))
for f in headers:
if f not in self.fields:
raise forms.ValidationError('Unexpected column header "{}" found.'.format(f))
raise forms.ValidationError(f'Required column header "{f}" not found.')
# Parse CSV data
for i, row in enumerate(reader, start=1):
if row:
if len(row) != len(headers):
raise forms.ValidationError(
"Row {}: Expected {} columns but found {}".format(i, len(headers), len(row))
)
row = [col.strip() for col in row]
record = dict(zip(headers, row))
records.append(record)
return records
return value
class CSVChoiceField(forms.ChoiceField):
"""
Invert the provided set of choices to take the human-friendly label as input, and return the database value.
"""
def __init__(self, choices, *args, **kwargs):
super().__init__(choices=choices, *args, **kwargs)
self.choices = [(label, label) for value, label in unpack_grouped_choices(choices)]
@@ -469,6 +496,23 @@ class CSVChoiceField(forms.ChoiceField):
return self.choice_values[value]
class CSVModelChoiceField(forms.ModelChoiceField):
"""
Provides additional validation for model choices entered as CSV data.
"""
default_error_messages = {
'invalid_choice': 'Object not found.',
}
def to_python(self, value):
try:
return super().to_python(value)
except MultipleObjectsReturned as e:
raise forms.ValidationError(
f'"{value}" is not a unique value for this field; multiple objects were found'
)
class ExpandableNameField(forms.CharField):
"""
A field which allows for numeric range expansion
@@ -530,27 +574,6 @@ class CommentField(forms.CharField):
super().__init__(required=required, label=label, help_text=help_text, *args, **kwargs)
class FlexibleModelChoiceField(forms.ModelChoiceField):
"""
Allow a model to be reference by either '{ID}' or the field specified by `to_field_name`.
"""
def to_python(self, value):
if value in self.empty_values:
return None
try:
if not self.to_field_name:
key = 'pk'
elif re.match(r'^\{\d+\}$', value):
key = 'pk'
value = value.strip('{}')
else:
key = self.to_field_name
value = self.queryset.get(**{key: value})
except (ValueError, TypeError, self.queryset.model.DoesNotExist):
raise forms.ValidationError(self.error_messages['invalid_choice'], code='invalid_choice')
return value
class SlugField(forms.SlugField):
"""
Extend the built-in SlugField to automatically populate from a field called `name` unless otherwise specified.
@@ -709,6 +732,20 @@ class BulkEditForm(forms.Form):
self.nullable_fields = self.Meta.nullable_fields
class CSVModelForm(forms.ModelForm):
"""
ModelForm used for the import of objects in CSV format.
"""
def __init__(self, *args, headers=None, **kwargs):
super().__init__(*args, **kwargs)
# Modify the model form to accommodate any customized to_field_name properties
if headers:
for field, to_field in headers.items():
if to_field is not None:
self.fields[field].to_field_name = to_field
class ImportForm(BootstrapMixin, forms.Form):
"""
Generic form for creating an object from JSON/YAML data

View File

@@ -116,28 +116,6 @@ def humanize_speed(speed):
return '{} Kbps'.format(speed)
@register.filter()
def example_choices(field, arg=3):
"""
Returns a number (default: 3) of example choices for a ChoiceFiled (useful for CSV import forms).
"""
examples = []
if hasattr(field, 'queryset'):
choices = [
(obj.pk, getattr(obj, field.to_field_name)) for obj in field.queryset[:arg + 1]
]
else:
choices = field.choices
for value, label in unpack_grouped_choices(choices):
if len(examples) == arg:
examples.append('etc.')
break
if not value or not label:
continue
examples.append(label)
return ', '.join(examples) or 'None'
@register.filter()
def tzoffset(value):
"""

View File

@@ -1,6 +1,8 @@
from django import forms
from django.test import TestCase
from ipam.forms import IPAddressCSVForm
from ipam.models import VRF
from utilities.forms import *
@@ -281,3 +283,85 @@ class ExpandAlphanumeric(TestCase):
with self.assertRaises(ValueError):
sorted(expand_alphanumeric_pattern('r[a,,b]a'))
class CSVDataFieldTest(TestCase):
def setUp(self):
self.field = CSVDataField(from_form=IPAddressCSVForm)
def test_clean(self):
input = """
address,status,vrf
192.0.2.1/32,Active,Test VRF
"""
output = (
{'address': None, 'status': None, 'vrf': None},
[{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': 'Test VRF'}]
)
self.assertEqual(self.field.clean(input), output)
def test_clean_invalid_header(self):
input = """
address,status,vrf,xxx
192.0.2.1/32,Active,Test VRF,123
"""
with self.assertRaises(forms.ValidationError):
self.field.clean(input)
def test_clean_missing_required_header(self):
input = """
status,vrf
Active,Test VRF
"""
with self.assertRaises(forms.ValidationError):
self.field.clean(input)
def test_clean_default_to_field(self):
input = """
address,status,vrf.name
192.0.2.1/32,Active,Test VRF
"""
output = (
{'address': None, 'status': None, 'vrf': 'name'},
[{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': 'Test VRF'}]
)
self.assertEqual(self.field.clean(input), output)
def test_clean_pk_to_field(self):
input = """
address,status,vrf.pk
192.0.2.1/32,Active,123
"""
output = (
{'address': None, 'status': None, 'vrf': 'pk'},
[{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': '123'}]
)
self.assertEqual(self.field.clean(input), output)
def test_clean_custom_to_field(self):
input = """
address,status,vrf.rd
192.0.2.1/32,Active,123:456
"""
output = (
{'address': None, 'status': None, 'vrf': 'rd'},
[{'address': '192.0.2.1/32', 'status': 'Active', 'vrf': '123:456'}]
)
self.assertEqual(self.field.clean(input), output)
def test_clean_invalid_to_field(self):
input = """
address,status,vrf.xxx
192.0.2.1/32,Active,123:456
"""
with self.assertRaises(forms.ValidationError):
self.field.clean(input)
def test_clean_to_field_on_non_object(self):
input = """
address,status.foo,vrf
192.0.2.1/32,Bar,Test VRF
"""
with self.assertRaises(forms.ValidationError):
self.field.clean(input)

View File

@@ -575,11 +575,11 @@ class BulkImportView(GetReturnURLMixin, View):
def _import_form(self, *args, **kwargs):
fields = self.model_form().fields.keys()
required_fields = [name for name, field in self.model_form().fields.items() if field.required]
class ImportForm(BootstrapMixin, Form):
csv = CSVDataField(fields=fields, required_fields=required_fields, widget=Textarea(attrs=self.widget_attrs))
csv = CSVDataField(
from_form=self.model_form,
widget=Textarea(attrs=self.widget_attrs)
)
return ImportForm(*args, **kwargs)
@@ -609,8 +609,10 @@ class BulkImportView(GetReturnURLMixin, View):
try:
# Iterate through CSV data and bind each row to a new model form instance.
with transaction.atomic():
for row, data in enumerate(form.cleaned_data['csv'], start=1):
obj_form = self.model_form(data)
headers, records = form.cleaned_data['csv']
for row, data in enumerate(records, start=1):
obj_form = self.model_form(data, headers=headers)
if obj_form.is_valid():
obj = self._save_obj(obj_form, request)
new_objs.append(obj)