From 99467e8f66f29338f8c46b4113b42ddcb4c34718 Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Fri, 22 Dec 2023 10:01:05 -0500 Subject: [PATCH] Fixes #12731: Support custom validation for many-to-many fields (#14516) * WIP * Enforce custom validators during bulk edit * Add bulk edit M2M validation test * Clean up tests * Add custom validation test for bulk import * Misc cleanup --- netbox/extras/tests/test_custom_validation.py | 265 ++++++++++++++++++ netbox/extras/validators.py | 26 +- netbox/netbox/api/serializers/base.py | 13 +- netbox/netbox/forms/base.py | 11 + netbox/netbox/views/generic/bulk_views.py | 8 + 5 files changed, 314 insertions(+), 9 deletions(-) create mode 100644 netbox/extras/tests/test_custom_validation.py diff --git a/netbox/extras/tests/test_custom_validation.py b/netbox/extras/tests/test_custom_validation.py new file mode 100644 index 000000000..e375b49f5 --- /dev/null +++ b/netbox/extras/tests/test_custom_validation.py @@ -0,0 +1,265 @@ +from django.test import TestCase +from django.test import override_settings + +from circuits.api.serializers import ProviderSerializer +from circuits.forms import ProviderForm +from circuits.models import Provider +from ipam.models import ASN, RIR +from utilities.choices import CSVDelimiterChoices, ImportFormatChoices +from utilities.testing import APITestCase, ModelViewTestCase, create_tags, post_data + + +class ModelFormCustomValidationTest(TestCase): + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'tags': {'required': True}} + ] + }) + def test_tags_validation(self): + """ + Check that custom validation rules work for tag assignment. + """ + data = { + 'name': 'Provider 1', + 'slug': 'provider-1', + } + form = ProviderForm(data) + self.assertFalse(form.is_valid()) + + tags = create_tags('Tag1', 'Tag2', 'Tag3') + data['tags'] = [tag.pk for tag in tags] + form = ProviderForm(data) + self.assertTrue(form.is_valid()) + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'asns': {'required': True}} + ] + }) + def test_m2m_validation(self): + """ + Check that custom validation rules work for many-to-many fields. + """ + data = { + 'name': 'Provider 1', + 'slug': 'provider-1', + } + form = ProviderForm(data) + self.assertFalse(form.is_valid()) + + rir = RIR.objects.create(name='RIR 1', slug='rir-1') + asns = ASN.objects.bulk_create(( + ASN(rir=rir, asn=65001), + ASN(rir=rir, asn=65002), + ASN(rir=rir, asn=65003), + )) + data['asns'] = [asn.pk for asn in asns] + form = ProviderForm(data) + self.assertTrue(form.is_valid()) + + +class BulkEditCustomValidationTest(ModelViewTestCase): + model = Provider + + @classmethod + def setUpTestData(cls): + rir = RIR.objects.create(name='RIR 1', slug='rir-1') + asns = ASN.objects.bulk_create(( + ASN(rir=rir, asn=65001), + ASN(rir=rir, asn=65002), + ASN(rir=rir, asn=65003), + )) + + providers = ( + Provider(name='Provider 1', slug='provider-1'), + Provider(name='Provider 2', slug='provider-2'), + Provider(name='Provider 3', slug='provider-3'), + ) + Provider.objects.bulk_create(providers) + for provider in providers: + provider.asns.set(asns) + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'asns': {'required': True}} + ] + }) + def test_bulk_edit_without_m2m(self): + """ + Check that custom validation rules do not interfere with bulk editing. + """ + data = { + 'pk': list(Provider.objects.values_list('pk', flat=True)), + '_apply': '', + 'description': 'New description', + } + self.add_permissions( + 'circuits.view_provider', + 'circuits.change_provider', + ) + + # Bulk edit the description without changing ASN assignments + request = { + 'path': self._get_url('bulk_edit'), + 'data': post_data(data), + } + response = self.client.post(**request) + self.assertHttpStatus(response, 302) + self.assertEqual( + Provider.objects.filter(description=data['description']).count(), + len(data['pk']) + ) + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'asns': {'required': True}} + ] + }) + def test_bulk_edit_m2m(self): + """ + Test that custom validation rules are enforced during bulk editing. + """ + data = { + 'pk': list(Provider.objects.values_list('pk', flat=True)), + '_apply': '', + 'description': 'New description', + } + self.add_permissions( + 'circuits.view_provider', + 'circuits.change_provider', + 'ipam.view_asn', + ) + + # Change the ASN assignments + asn = ASN.objects.first() + data['asns'] = [asn.pk] + request = { + 'path': self._get_url('bulk_edit'), + 'data': post_data(data), + } + response = self.client.post(**request) + self.assertHttpStatus(response, 302) + for provider in Provider.objects.all(): + self.assertEqual(len(provider.asns.all()), 1) + + # Attempt to remove the ASN assignments + data.pop('asns') + data['_nullify'] = 'asns' + request = { + 'path': self._get_url('bulk_edit'), + 'data': post_data(data), + } + response = self.client.post(**request) + self.assertHttpStatus(response, 200) + for provider in Provider.objects.all(): + self.assertTrue(provider.asns.exists()) + + +class BulkImportCustomValidationTest(ModelViewTestCase): + model = Provider + + @classmethod + def setUpTestData(cls): + create_tags('Tag1', 'Tag2', 'Tag3') + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'tags': {'required': True}} + ] + }) + def test_bulk_import_invalid(self): + """ + Test that custom validation rules are enforced during bulk import. + """ + csv_data = ( + "name,slug", + "Provider 1,provider-1", + "Provider 2,provider-2", + "Provider 3,provider-3", + ) + data = { + 'data': '\n'.join(csv_data), + 'format': ImportFormatChoices.CSV, + 'csv_delimiter': CSVDelimiterChoices.COMMA, + } + self.add_permissions( + 'circuits.view_provider', + 'circuits.add_provider', + 'extras.view_tag', + ) + + # Attempt to import providers without tags + request = { + 'path': self._get_url('import'), + 'data': post_data(data), + } + response = self.client.post(**request) + self.assertHttpStatus(response, 200) + self.assertFalse(Provider.objects.exists()) + + # Import providers successfully with tag assignments + csv_data = ( + "name,slug,tags", + "Provider 1,provider-1,tag1", + "Provider 2,provider-2,tag2", + "Provider 3,provider-3,tag3", + ) + data['data'] = '\n'.join(csv_data) + request = { + 'path': self._get_url('import'), + 'data': post_data(data), + } + response = self.client.post(**request) + self.assertHttpStatus(response, 302) + self.assertTrue(Provider.objects.exists()) + + +class APISerializerCustomValidationTest(APITestCase): + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'tags': {'required': True}} + ] + }) + def test_tags_validation(self): + """ + Check that custom validation rules work for tag assignment. + """ + data = { + 'name': 'Provider 1', + 'slug': 'provider-1', + } + serializer = ProviderSerializer(data=data) + self.assertFalse(serializer.is_valid()) + + tags = create_tags('Tag1', 'Tag2', 'Tag3') + data['tags'] = [tag.pk for tag in tags] + serializer = ProviderSerializer(data=data) + self.assertTrue(serializer.is_valid()) + + @override_settings(CUSTOM_VALIDATORS={ + 'circuits.provider': [ + {'asns': {'required': True}} + ] + }) + def test_m2m_validation(self): + """ + Check that custom validation rules work for many-to-many fields. + """ + data = { + 'name': 'Provider 1', + 'slug': 'provider-1', + } + serializer = ProviderSerializer(data=data) + self.assertFalse(serializer.is_valid()) + + rir = RIR.objects.create(name='RIR 1', slug='rir-1') + asns = ASN.objects.bulk_create(( + ASN(rir=rir, asn=65001), + ASN(rir=rir, asn=65002), + ASN(rir=rir, asn=65003), + )) + data['asns'] = [asn.pk for asn in asns] + serializer = ProviderSerializer(data=data) + self.assertTrue(serializer.is_valid()) diff --git a/netbox/extras/validators.py b/netbox/extras/validators.py index 686c9b032..366d3a426 100644 --- a/netbox/extras/validators.py +++ b/netbox/extras/validators.py @@ -1,5 +1,6 @@ -from django.core.exceptions import ValidationError from django.core import validators +from django.core.exceptions import ValidationError +from django.utils.translation import gettext_lazy as _ # NOTE: As this module may be imported by configuration.py, we cannot import # anything from NetBox itself. @@ -66,8 +67,7 @@ class CustomValidator: def __call__(self, instance): # Validate instance attributes per validation rules for attr_name, rules in self.validation_rules.items(): - assert hasattr(instance, attr_name), f"Invalid attribute '{attr_name}' for {instance.__class__.__name__}" - attr = getattr(instance, attr_name) + attr = self._getattr(instance, attr_name) for descriptor, value in rules.items(): validator = self.get_validator(descriptor, value) try: @@ -79,6 +79,26 @@ class CustomValidator: # Execute custom validation logic (if any) self.validate(instance) + @staticmethod + def _getattr(instance, name): + # Attempt to resolve many-to-many fields to their stored values + m2m_fields = [f.name for f in instance._meta.local_many_to_many] + if name in m2m_fields: + if name in getattr(instance, '_m2m_values', []): + return instance._m2m_values[name] + if instance.pk: + return list(getattr(instance, name).all()) + return [] + + # Raise a ValidationError for unknown attributes + if not hasattr(instance, name): + raise ValidationError(_('Invalid attribute "{name}" for {model}').format( + name=name, + model=instance.__class__.__name__ + )) + + return getattr(instance, name) + def get_validator(self, descriptor, value): """ Instantiate and return the appropriate validator based on the descriptor given. For diff --git a/netbox/netbox/api/serializers/base.py b/netbox/netbox/api/serializers/base.py index 5ee74bf8c..d513c8000 100644 --- a/netbox/netbox/api/serializers/base.py +++ b/netbox/netbox/api/serializers/base.py @@ -23,16 +23,16 @@ class ValidatedModelSerializer(BaseModelSerializer): validation. (DRF does not do this by default; see https://github.com/encode/django-rest-framework/issues/3144) """ def validate(self, data): - - # Remove custom fields data and tags (if any) prior to model validation attrs = data.copy() + + # Remove custom field data (if any) prior to model validation attrs.pop('custom_fields', None) - attrs.pop('tags', None) # Skip ManyToManyFields - for field in self.Meta.model._meta.get_fields(): - if isinstance(field, ManyToManyField): - attrs.pop(field.name, None) + m2m_values = {} + for field in self.Meta.model._meta.local_many_to_many: + if field.name in attrs: + m2m_values[field.name] = attrs.pop(field.name) # Run clean() on an instance of the model if self.instance is None: @@ -41,6 +41,7 @@ class ValidatedModelSerializer(BaseModelSerializer): instance = self.instance for k, v in attrs.items(): setattr(instance, k, v) + instance._m2m_values = m2m_values instance.full_clean() return data diff --git a/netbox/netbox/forms/base.py b/netbox/netbox/forms/base.py index 51e664a39..070a5d26c 100644 --- a/netbox/netbox/forms/base.py +++ b/netbox/netbox/forms/base.py @@ -57,6 +57,17 @@ class NetBoxModelForm(BootstrapMixin, CheckLastUpdatedMixin, CustomFieldsMixin, return super().clean() + def _post_clean(self): + """ + Override BaseModelForm's _post_clean() to store many-to-many field values on the model instance. + """ + self.instance._m2m_values = {} + for field in self.instance._meta.local_many_to_many: + if field.name in self.cleaned_data: + self.instance._m2m_values[field.name] = list(self.cleaned_data[field.name]) + + return super()._post_clean() + class NetBoxModelImportForm(CSVModelForm, NetBoxModelForm): """ diff --git a/netbox/netbox/views/generic/bulk_views.py b/netbox/netbox/views/generic/bulk_views.py index c5a08c80a..69bb85c41 100644 --- a/netbox/netbox/views/generic/bulk_views.py +++ b/netbox/netbox/views/generic/bulk_views.py @@ -557,6 +557,14 @@ class BulkEditView(GetReturnURLMixin, BaseMultiObjectView): elif name in form.changed_data: obj.custom_field_data[cf_name] = customfield.serialize(form.cleaned_data[name]) + # Store M2M values for validation + obj._m2m_values = {} + for field in obj._meta.local_many_to_many: + if value := form.cleaned_data.get(field.name): + obj._m2m_values[field.name] = list(value) + elif field.name in nullified_fields: + obj._m2m_values[field.name] = [] + obj.full_clean() obj.save() updated_objects.append(obj)