From 117da337c79c0da9e7c1272c7c2889acf6e6c53d Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Thu, 27 Apr 2017 12:46:04 -0400 Subject: [PATCH] Corrected tests and improved validation --- netbox/extras/api/customfields.py | 42 ++++++++++++++---------- netbox/extras/tests/test_customfields.py | 6 ++-- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/netbox/extras/api/customfields.py b/netbox/extras/api/customfields.py index 117c6a5fd..dafed750b 100644 --- a/netbox/extras/api/customfields.py +++ b/netbox/extras/api/customfields.py @@ -18,25 +18,29 @@ class CustomFieldsSerializer(serializers.BaseSerializer): def to_internal_value(self, data): - parent_content_type = ContentType.objects.get_for_model(self.parent.Meta.model) + content_type = ContentType.objects.get_for_model(self.parent.Meta.model) + custom_fields = {field.name: field for field in CustomField.objects.filter(obj_type=content_type)} - for custom_field, value in data.items(): + for field_name, value in data.items(): # Validate custom field name - try: - cf = CustomField.objects.get(name=custom_field) - except CustomField.DoesNotExist: - raise ValidationError(u"Unknown custom field: {}".format(custom_field)) - - # Validate custom field content type - if parent_content_type not in cf.obj_type.all(): - raise ValidationError(u"Invalid custom field for {} objects".format(parent_content_type)) + if field_name not in custom_fields: + raise ValidationError(u"Invalid custom field for {} objects: {}".format(content_type, field_name)) # Validate selected choice + cf = custom_fields[field_name] if cf.type == CF_TYPE_SELECT: valid_choices = [c.pk for c in cf.choices.all()] if value not in valid_choices: - raise ValidationError(u"Invalid choice ({}) for field {}".format(value, custom_field)) + raise ValidationError(u"Invalid choice ({}) for field {}".format(value, field_name)) + + # Check for missing required fields + missing_fields = [] + for field_name, field in custom_fields.items(): + if field.required and field_name not in data: + missing_fields.append(field_name) + if missing_fields: + raise ValidationError(u"Missing required fields: {}".format(u", ".join(missing_fields))) return data @@ -45,7 +49,7 @@ class CustomFieldModelSerializer(serializers.ModelSerializer): """ Extends ModelSerializer to render any CustomFields and their values associated with an object. """ - custom_fields = CustomFieldsSerializer() + custom_fields = CustomFieldsSerializer(required=False) def __init__(self, *args, **kwargs): @@ -86,29 +90,31 @@ class CustomFieldModelSerializer(serializers.ModelSerializer): def create(self, validated_data): - custom_fields = validated_data.pop('custom_fields') + custom_fields = validated_data.pop('custom_fields', None) with transaction.atomic(): instance = super(CustomFieldModelSerializer, self).create(validated_data) # Save custom fields - self._save_custom_fields(instance, custom_fields) - instance.custom_fields = custom_fields + if custom_fields is not None: + self._save_custom_fields(instance, custom_fields) + instance.custom_fields = custom_fields return instance def update(self, instance, validated_data): - custom_fields = validated_data.pop('custom_fields') + custom_fields = validated_data.pop('custom_fields', None) with transaction.atomic(): instance = super(CustomFieldModelSerializer, self).update(instance, validated_data) # Save custom fields - self._save_custom_fields(instance, custom_fields) - instance.custom_fields = custom_fields + if custom_fields is not None: + self._save_custom_fields(instance, custom_fields) + instance.custom_fields = custom_fields return instance diff --git a/netbox/extras/tests/test_customfields.py b/netbox/extras/tests/test_customfields.py index 7986431bf..9e475fde8 100644 --- a/netbox/extras/tests/test_customfields.py +++ b/netbox/extras/tests/test_customfields.py @@ -243,7 +243,7 @@ class CustomFieldAPITest(HttpStatusMixin, APITestCase): 'name': 'Test Site 1', 'slug': 'test-site-1', 'custom_fields': { - 'is_magic': False, + 'is_magic': 0, } } @@ -261,7 +261,7 @@ class CustomFieldAPITest(HttpStatusMixin, APITestCase): 'name': 'Test Site 1', 'slug': 'test-site-1', 'custom_fields': { - 'magic_date': date(2017, 4, 25), + 'magic_date': '2017-04-25', } } @@ -271,7 +271,7 @@ class CustomFieldAPITest(HttpStatusMixin, APITestCase): self.assertHttpStatus(response, status.HTTP_200_OK) self.assertEqual(response.data['custom_fields'].get('magic_date'), data['custom_fields']['magic_date']) cfv = self.site.custom_field_values.get(field=self.cf_date) - self.assertEqual(cfv.value, data['custom_fields']['magic_date']) + self.assertEqual(cfv.value.isoformat(), data['custom_fields']['magic_date']) def test_set_custom_field_url(self):