1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

Extend logic for validating filter class

This commit is contained in:
Jeremy Stretch
2024-03-11 15:35:40 -04:00
parent a136030094
commit 313e63622b
6 changed files with 99 additions and 82 deletions

View File

@ -91,8 +91,9 @@ class EventRuleFilterSet(NetBoxModelFilterSet):
method='search',
label=_('Search'),
)
object_type_id = MultiValueNumberFilter(
field_name='object_types__id'
object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
field_name='object_types'
@ -128,14 +129,16 @@ class CustomFieldFilterSet(ChangeLoggedModelFilterSet):
type = django_filters.MultipleChoiceFilter(
choices=CustomFieldTypeChoices
)
object_type_id = MultiValueNumberFilter(
field_name='object_types__id'
object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
field_name='object_types'
)
related_object_type_id = MultiValueNumberFilter(
field_name='related_object_type__id'
related_object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ObjectType.objects.all(),
field_name='related_object_type'
)
related_object_type = ContentTypeFilter()
choice_set_id = django_filters.ModelMultipleChoiceFilter(
@ -199,8 +202,9 @@ class CustomLinkFilterSet(ChangeLoggedModelFilterSet):
method='search',
label=_('Search'),
)
object_type_id = MultiValueNumberFilter(
field_name='object_types__id'
object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
field_name='object_types'
@ -228,8 +232,9 @@ class ExportTemplateFilterSet(ChangeLoggedModelFilterSet):
method='search',
label=_('Search'),
)
object_type_id = MultiValueNumberFilter(
field_name='object_types__id'
object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
field_name='object_types'
@ -264,8 +269,9 @@ class SavedFilterFilterSet(ChangeLoggedModelFilterSet):
method='search',
label=_('Search'),
)
object_type_id = MultiValueNumberFilter(
field_name='object_types__id'
object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
field_name='object_types'

View File

@ -198,8 +198,7 @@ class VRFTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = VRF.objects.all()
filterset = VRFFilterSet
@staticmethod
def get_m2m_filter_name(field):
def get_m2m_filter_name(self, field):
# Override filter names for import & export RouteTargets
if field.name == 'import_targets':
return 'import_target'
@ -303,8 +302,7 @@ class RouteTargetTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = RouteTarget.objects.all()
filterset = RouteTargetFilterSet
@staticmethod
def get_m2m_filter_name(field):
def get_m2m_filter_name(self, field):
# Override filter names for import & export VRFs and L2VPNs
if field.name == 'importing_vrfs':
return 'importing_vrf'

View File

@ -3,9 +3,10 @@ from django.contrib.auth import get_user_model
from django.db.models import Q
from django.utils.translation import gettext as _
from core.models import ObjectType
from netbox.filtersets import BaseFilterSet
from users.models import Group, ObjectPermission, Token
from utilities.filters import ContentTypeFilter, MultiValueNumberFilter
from utilities.filters import ContentTypeFilter
__all__ = (
'GroupFilterSet',
@ -134,8 +135,9 @@ class ObjectPermissionFilterSet(BaseFilterSet):
method='search',
label=_('Search'),
)
object_type_id = MultiValueNumberFilter(
field_name='object_types__id'
object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
field_name='object_types'

View File

@ -8,7 +8,9 @@ from django.contrib.contenttypes.models import ContentType
from django.db.models import ForeignKey, ManyToManyField, ManyToManyRel, ManyToOneRel, OneToOneRel
from django.utils.module_loading import import_string
from taggit.managers import TaggableManager
from utilities.filters import TreeNodeMultipleChoiceFilter
from extras.filters import TagFilter
from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter
from core.models import ObjectType
@ -46,8 +48,7 @@ class BaseFilterSetTests:
filterset = None
ignore_fields = tuple()
@staticmethod
def get_m2m_filter_name(field):
def get_m2m_filter_name(self, field):
"""
Given a ManyToManyField, determine the correct name for its corresponding Filter. Individual test
cases may override this method to prescribe deviations for specific fields.
@ -55,20 +56,50 @@ class BaseFilterSetTests:
related_model_name = field.related_model._meta.verbose_name
return related_model_name.lower().replace(' ', '_')
@staticmethod
def get_filter_class_for_field(field):
def get_filters_for_model_field(self, field):
"""
Given a model field, return an iterable of (name, class) for each filter that should be defined on
the model's FilterSet class. If the appropriate filter class cannot be determined, it will be None.
"""
# ForeignKey & OneToOneField
if issubclass(field.__class__, ForeignKey) or type(field) is OneToOneRel:
# Relationships to ContentType (used as part of a GFK) do not need a filter
if field.related_model is ContentType:
return [(None, None)]
# ForeignKeys to ObjectType need two filters: 'app.model' & PK
if field.related_model is ObjectType:
return [
(field.name, ContentTypeFilter),
(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter),
]
# ForeignKey to an MPTT-enabled model
if issubclass(field.related_model, MPTTModel) and field.model is not field.related_model:
return TreeNodeMultipleChoiceFilter
return [(f'{field.name}_id', TreeNodeMultipleChoiceFilter)]
return django_filters.ModelMultipleChoiceFilter
return [(f'{field.name}_id', django_filters.ModelMultipleChoiceFilter)]
# Many-to-many relationships (forward & backward)
elif type(field) in (ManyToManyField, ManyToManyRel):
filter_name = self.get_m2m_filter_name(field)
# ManyToManyFields to ObjectType need two filters: 'app.model' & PK
if field.related_model is ObjectType:
return [
(filter_name, ContentTypeFilter),
(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter),
]
return [(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter)]
# Tag manager
if type(field) is TaggableManager:
return [('tag', TagFilter)]
# Unable to determine the correct filter class
return None
return [(field.name, None)]
def test_id(self):
"""
@ -111,57 +142,32 @@ class BaseFilterSetTests:
if type(model_field) is ManyToOneRel:
continue
# One-to-one & one-to-many relationships
if issubclass(model_field.__class__, ForeignKey) or type(model_field) is OneToOneRel:
# Relationships to ContentType (used as part of a GFK) do not need a filter
if model_field.related_model is ContentType:
continue
# Filters to ObjectType use 'app.model' rather than numeric PK, so we omit the _id suffix
if model_field.related_model is ObjectType:
filter_name = model_field.name
else:
filter_name = f'{model_field.name}_id'
self.assertIn(
filter_name,
filters,
f'No filter defined for {filter_name} ({model_field.name})!'
)
if filter_class := self.get_filter_class_for_field(model_field):
self.assertIs(
type(filters[filter_name]),
filter_class,
f"Invalid filter class for {filter_name}!"
)
# Many-to-many relationships (forward & backward)
elif type(model_field) in (ManyToManyField, ManyToManyRel):
filter_name = self.get_m2m_filter_name(model_field)
filter_name = f'{filter_name}_id'
self.assertIn(
filter_name,
filters,
f'No filter defined for {filter_name} ({model_field.name})!'
)
# TODO: Generic relationships
elif type(model_field) in (GenericForeignKey, GenericRelation):
if type(model_field) in (GenericForeignKey, GenericRelation):
continue
# Tags
elif type(model_field) is TaggableManager:
self.assertIn('tag', filters, f'No filter defined for {model_field.name}!')
for filter_name, filter_class in self.get_filters_for_model_field(model_field):
# All other fields
else:
if filter_name is None:
# Field is exempt
continue
# Check that the filter is defined
self.assertIn(
model_field.name,
filters,
f'No defined found for {model_field.name} ({type(model_field)})!'
filter_name,
filters.keys(),
f'No filter defined for {filter_name} ({model_field.name})!'
)
# Check that the filter class is correct
filter = filters[filter_name]
if filter_class is not None:
self.assertIs(
type(filter),
filter_class,
f"Invalid filter class {type(filter)} for {filter_name} (should be {filter_class})!"
)
class ChangeLoggedFilterSetTests(BaseFilterSetTests):

View File

@ -169,11 +169,14 @@ class IKEPolicyFilterSet(NetBoxModelFilterSet):
mode = django_filters.MultipleChoiceFilter(
choices=IKEModeChoices
)
ike_proposal_id = MultiValueNumberFilter(
field_name='proposals__id'
ike_proposal_id = django_filters.ModelMultipleChoiceFilter(
field_name='proposals',
queryset=IKEProposal.objects.all()
)
ike_proposal = MultiValueCharFilter(
field_name='proposals__name'
ike_proposal = django_filters.ModelMultipleChoiceFilter(
field_name='proposals__name',
queryset=IKEProposal.objects.all(),
to_field_name='name'
)
# TODO: Remove in v4.1
@ -231,11 +234,14 @@ class IPSecPolicyFilterSet(NetBoxModelFilterSet):
pfs_group = django_filters.MultipleChoiceFilter(
choices=DHGroupChoices
)
ipsec_proposal_id = MultiValueNumberFilter(
field_name='proposals__id'
ipsec_proposal_id = django_filters.ModelMultipleChoiceFilter(
field_name='proposals',
queryset=IPSecProposal.objects.all()
)
ipsec_proposal = MultiValueCharFilter(
field_name='proposals__name'
ipsec_proposal = django_filters.ModelMultipleChoiceFilter(
field_name='proposals__name',
queryset=IPSecProposal.objects.all(),
to_field_name='name'
)
# TODO: Remove in v4.1

View File

@ -743,8 +743,7 @@ class L2VPNTestCase(TestCase, ChangeLoggedFilterSetTests):
queryset = L2VPN.objects.all()
filterset = L2VPNFilterSet
@staticmethod
def get_m2m_filter_name(field):
def get_m2m_filter_name(self, field):
# Override filter names for import & export RouteTargets
if field.name == 'import_targets':
return 'import_target'