From 313e63622b019b34757db6369ea684c8e28113b8 Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Mon, 11 Mar 2024 15:35:40 -0400 Subject: [PATCH] Extend logic for validating filter class --- netbox/extras/filtersets.py | 30 ++++--- netbox/ipam/tests/test_filtersets.py | 6 +- netbox/users/filtersets.py | 8 +- netbox/utilities/testing/filtersets.py | 112 +++++++++++++------------ netbox/vpn/filtersets.py | 22 +++-- netbox/vpn/tests/test_filtersets.py | 3 +- 6 files changed, 99 insertions(+), 82 deletions(-) diff --git a/netbox/extras/filtersets.py b/netbox/extras/filtersets.py index 0be2fde28..1ab6679e2 100644 --- a/netbox/extras/filtersets.py +++ b/netbox/extras/filtersets.py @@ -91,8 +91,9 @@ class EventRuleFilterSet(NetBoxModelFilterSet): method='search', label=_('Search'), ) - object_type_id = MultiValueNumberFilter( - field_name='object_types__id' + object_type_id = django_filters.ModelMultipleChoiceFilter( + queryset=ObjectType.objects.all(), + field_name='object_types' ) object_type = ContentTypeFilter( field_name='object_types' @@ -128,14 +129,16 @@ class CustomFieldFilterSet(ChangeLoggedModelFilterSet): type = django_filters.MultipleChoiceFilter( choices=CustomFieldTypeChoices ) - object_type_id = MultiValueNumberFilter( - field_name='object_types__id' + object_type_id = django_filters.ModelMultipleChoiceFilter( + queryset=ObjectType.objects.all(), + field_name='object_types' ) object_type = ContentTypeFilter( field_name='object_types' ) - related_object_type_id = MultiValueNumberFilter( - field_name='related_object_type__id' + related_object_type_id = django_filters.ModelMultipleChoiceFilter( + queryset=ObjectType.objects.all(), + field_name='related_object_type' ) related_object_type = ContentTypeFilter() choice_set_id = django_filters.ModelMultipleChoiceFilter( @@ -199,8 +202,9 @@ class CustomLinkFilterSet(ChangeLoggedModelFilterSet): method='search', label=_('Search'), ) - object_type_id = MultiValueNumberFilter( - field_name='object_types__id' + object_type_id = django_filters.ModelMultipleChoiceFilter( + queryset=ObjectType.objects.all(), + field_name='object_types' ) object_type = ContentTypeFilter( field_name='object_types' @@ -228,8 +232,9 @@ class ExportTemplateFilterSet(ChangeLoggedModelFilterSet): method='search', label=_('Search'), ) - object_type_id = MultiValueNumberFilter( - field_name='object_types__id' + object_type_id = django_filters.ModelMultipleChoiceFilter( + queryset=ObjectType.objects.all(), + field_name='object_types' ) object_type = ContentTypeFilter( field_name='object_types' @@ -264,8 +269,9 @@ class SavedFilterFilterSet(ChangeLoggedModelFilterSet): method='search', label=_('Search'), ) - object_type_id = MultiValueNumberFilter( - field_name='object_types__id' + object_type_id = django_filters.ModelMultipleChoiceFilter( + queryset=ObjectType.objects.all(), + field_name='object_types' ) object_type = ContentTypeFilter( field_name='object_types' diff --git a/netbox/ipam/tests/test_filtersets.py b/netbox/ipam/tests/test_filtersets.py index 52ef460d5..3a46423a5 100644 --- a/netbox/ipam/tests/test_filtersets.py +++ b/netbox/ipam/tests/test_filtersets.py @@ -198,8 +198,7 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = VRF.objects.all() filterset = VRFFilterSet - @staticmethod - def get_m2m_filter_name(field): + def get_m2m_filter_name(self, field): # Override filter names for import & export RouteTargets if field.name == 'import_targets': return 'import_target' @@ -303,8 +302,7 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = RouteTarget.objects.all() filterset = RouteTargetFilterSet - @staticmethod - def get_m2m_filter_name(field): + def get_m2m_filter_name(self, field): # Override filter names for import & export VRFs and L2VPNs if field.name == 'importing_vrfs': return 'importing_vrf' diff --git a/netbox/users/filtersets.py b/netbox/users/filtersets.py index 8a770ef34..6e86528dd 100644 --- a/netbox/users/filtersets.py +++ b/netbox/users/filtersets.py @@ -3,9 +3,10 @@ from django.contrib.auth import get_user_model from django.db.models import Q from django.utils.translation import gettext as _ +from core.models import ObjectType from netbox.filtersets import BaseFilterSet from users.models import Group, ObjectPermission, Token -from utilities.filters import ContentTypeFilter, MultiValueNumberFilter +from utilities.filters import ContentTypeFilter __all__ = ( 'GroupFilterSet', @@ -134,8 +135,9 @@ class ObjectPermissionFilterSet(BaseFilterSet): method='search', label=_('Search'), ) - object_type_id = MultiValueNumberFilter( - field_name='object_types__id' + object_type_id = django_filters.ModelMultipleChoiceFilter( + queryset=ObjectType.objects.all(), + field_name='object_types' ) object_type = ContentTypeFilter( field_name='object_types' diff --git a/netbox/utilities/testing/filtersets.py b/netbox/utilities/testing/filtersets.py index 005630c9c..2cfcb3209 100644 --- a/netbox/utilities/testing/filtersets.py +++ b/netbox/utilities/testing/filtersets.py @@ -8,7 +8,9 @@ from django.contrib.contenttypes.models import ContentType from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel from django.utils.module_loading import import_string from taggit.managers import TaggableManager -from utilities.filters import TreeNodeMultipleChoiceFilter + +from extras.filters import TagFilter +from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter from core.models import ObjectType @@ -46,8 +48,7 @@ class BaseFilterSetTests: filterset = None ignore_fields = tuple() - @staticmethod - def get_m2m_filter_name(field): + def get_m2m_filter_name(self, field): """ Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test cases may override this method to prescribe deviations for specific fields. @@ -55,20 +56,50 @@ class BaseFilterSetTests: related_model_name = field.related_model._meta.verbose_name return related_model_name.lower().replace(' ', '_') - @staticmethod - def get_filter_class_for_field(field): - + def get_filters_for_model_field(self, field): + """ + Given a model field, return an iterable of (name, class) for each filter that should be defined on + the model's FilterSet class. If the appropriate filter class cannot be determined, it will be None. + """ # ForeignKey & OneToOneField if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel: + # Relationships to ContentType (used as part of a GFK) do not need a filter + if field.related_model is ContentType: + return [(None, None)] + + # ForeignKeys to ObjectType need two filters: 'app.model' & PK + if field.related_model is ObjectType: + return [ + (field.name, ContentTypeFilter), + (f'{field.name}_id', django_filters.ModelMultipleChoiceFilter), + ] + # ForeignKey to an MPTT-enabled model if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model: - return TreeNodeMultipleChoiceFilter + return [(f'{field.name}_id', TreeNodeMultipleChoiceFilter)] - return django_filters.ModelMultipleChoiceFilter + return [(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter)] + + # Many-to-many relationships (forward & backward) + elif type(field) in (ManyToManyField, ManyToManyRel): + filter_name = self.get_m2m_filter_name(field) + + # ManyToManyFields to ObjectType need two filters: 'app.model' & PK + if field.related_model is ObjectType: + return [ + (filter_name, ContentTypeFilter), + (f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter), + ] + + return [(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter)] + + # Tag manager + if type(field) is TaggableManager: + return [('tag', TagFilter)] # Unable to determine the correct filter class - return None + return [(field.name, None)] def test_id(self): """ @@ -111,57 +142,32 @@ class BaseFilterSetTests: if type(model_field) is ManyToOneRel: continue - # One-to-one & one-to-many relationships - if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel: - - # Relationships to ContentType (used as part of a GFK) do not need a filter - if model_field.related_model is ContentType: - continue - - # Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix - if model_field.related_model is ObjectType: - filter_name = model_field.name - else: - filter_name = f'{model_field.name}_id' - - self.assertIn( - filter_name, - filters, - f'No filter defined for {filter_name} ({model_field.name})!' - ) - if filter_class := self.get_filter_class_for_field(model_field): - self.assertIs( - type(filters[filter_name]), - filter_class, - f"Invalid filter class for {filter_name}!" - ) - - # Many-to-many relationships (forward & backward) - elif type(model_field) in (ManyToManyField, ManyToManyRel): - filter_name = self.get_m2m_filter_name(model_field) - filter_name = f'{filter_name}_id' - self.assertIn( - filter_name, - filters, - f'No filter defined for {filter_name} ({model_field.name})!' - ) - # TODO: Generic relationships - elif type(model_field) in (GenericForeignKey, GenericRelation): + if type(model_field) in (GenericForeignKey, GenericRelation): continue - # Tags - elif type(model_field) is TaggableManager: - self.assertIn('tag', filters, f'No filter defined for {model_field.name}!') + for filter_name, filter_class in self.get_filters_for_model_field(model_field): - # All other fields - else: + if filter_name is None: + # Field is exempt + continue + + # Check that the filter is defined self.assertIn( - model_field.name, - filters, - f'No defined found for {model_field.name} ({type(model_field)})!' + filter_name, + filters.keys(), + f'No filter defined for {filter_name} ({model_field.name})!' ) + # Check that the filter class is correct + filter = filters[filter_name] + if filter_class is not None: + self.assertIs( + type(filter), + filter_class, + f"Invalid filter class {type(filter)} for {filter_name} (should be {filter_class})!" + ) + class ChangeLoggedFilterSetTests(BaseFilterSetTests): diff --git a/netbox/vpn/filtersets.py b/netbox/vpn/filtersets.py index 327ce8b27..3c23cb478 100644 --- a/netbox/vpn/filtersets.py +++ b/netbox/vpn/filtersets.py @@ -169,11 +169,14 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet): mode = django_filters.MultipleChoiceFilter( choices=IKEModeChoices ) - ike_proposal_id = MultiValueNumberFilter( - field_name='proposals__id' + ike_proposal_id = django_filters.ModelMultipleChoiceFilter( + field_name='proposals', + queryset=IKEProposal.objects.all() ) - ike_proposal = MultiValueCharFilter( - field_name='proposals__name' + ike_proposal = django_filters.ModelMultipleChoiceFilter( + field_name='proposals__name', + queryset=IKEProposal.objects.all(), + to_field_name='name' ) # TODO: Remove in v4.1 @@ -231,11 +234,14 @@ class IPSecPolicyFilterSet(NetBoxModelFilterSet): pfs_group = django_filters.MultipleChoiceFilter( choices=DHGroupChoices ) - ipsec_proposal_id = MultiValueNumberFilter( - field_name='proposals__id' + ipsec_proposal_id = django_filters.ModelMultipleChoiceFilter( + field_name='proposals', + queryset=IPSecProposal.objects.all() ) - ipsec_proposal = MultiValueCharFilter( - field_name='proposals__name' + ipsec_proposal = django_filters.ModelMultipleChoiceFilter( + field_name='proposals__name', + queryset=IPSecProposal.objects.all(), + to_field_name='name' ) # TODO: Remove in v4.1 diff --git a/netbox/vpn/tests/test_filtersets.py b/netbox/vpn/tests/test_filtersets.py index f16db4cb8..d2b893766 100644 --- a/netbox/vpn/tests/test_filtersets.py +++ b/netbox/vpn/tests/test_filtersets.py @@ -743,8 +743,7 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests): queryset = L2VPN.objects.all() filterset = L2VPNFilterSet - @staticmethod - def get_m2m_filter_name(field): + def get_m2m_filter_name(self, field): # Override filter names for import & export RouteTargets if field.name == 'import_targets': return 'import_target'