1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

Merge branch 'develop-2.8' into 2328-external-authentication

This commit is contained in:
Jeremy Stretch
2020-03-10 15:07:19 -04:00
231 changed files with 4320 additions and 3670 deletions

View File

@@ -1,3 +1,4 @@
import logging
from collections import OrderedDict
import pytz
@@ -6,6 +7,7 @@ from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import FieldError, MultipleObjectsReturned, ObjectDoesNotExist
from django.db.models import ManyToManyField, ProtectedError
from django.http import Http404
from django.urls import reverse
from rest_framework.exceptions import APIException
from rest_framework.permissions import BasePermission
from rest_framework.relations import PrimaryKeyRelatedField, RelatedField
@@ -41,6 +43,14 @@ def get_serializer_for_model(model, prefix=''):
)
def is_api_request(request):
"""
Return True of the request is being made via the REST API.
"""
api_path = reverse('api-root')
return request.path_info.startswith(api_path)
#
# Authentication
#
@@ -294,25 +304,35 @@ class ModelViewSet(_ModelViewSet):
return super().get_serializer(*args, **kwargs)
def get_serializer_class(self):
logger = logging.getLogger('netbox.api.views.ModelViewSet')
# If 'brief' has been passed as a query param, find and return the nested serializer for this model, if one
# exists
request = self.get_serializer_context()['request']
if request.query_params.get('brief', False):
if request.query_params.get('brief'):
logger.debug("Request is for 'brief' format; initializing nested serializer")
try:
return get_serializer_for_model(self.queryset.model, prefix='Nested')
serializer = get_serializer_for_model(self.queryset.model, prefix='Nested')
logger.debug(f"Using serializer {serializer}")
return serializer
except SerializerNotFound:
pass
# Fall back to the hard-coded serializer class
logger.debug(f"Using serializer {self.serializer_class}")
return self.serializer_class
def dispatch(self, request, *args, **kwargs):
logger = logging.getLogger('netbox.api.views.ModelViewSet')
try:
return super().dispatch(request, *args, **kwargs)
except ProtectedError as e:
models = ['{} ({})'.format(o, o._meta) for o in e.protected_objects.all()]
models = [
'{} ({})'.format(o, o._meta) for o in e.protected_objects.all()
]
msg = 'Unable to delete object. The following dependent objects were found: {}'.format(', '.join(models))
logger.warning(msg)
return self.finalize_response(
request,
Response({'detail': msg}, status=409),
@@ -332,6 +352,26 @@ class ModelViewSet(_ModelViewSet):
"""
return super().retrieve(*args, **kwargs)
#
# Logging
#
def perform_create(self, serializer):
model = serializer.child.Meta.model if hasattr(serializer, 'many') else serializer.Meta.model
logger = logging.getLogger('netbox.api.views.ModelViewSet')
logger.info(f"Creating new {model._meta.verbose_name}")
return super().perform_create(serializer)
def perform_update(self, serializer):
logger = logging.getLogger('netbox.api.views.ModelViewSet')
logger.info(f"Updating {serializer.instance} (PK: {serializer.instance.pk})")
return super().perform_update(serializer)
def perform_destroy(self, instance):
logger = logging.getLogger('netbox.api.views.ModelViewSet')
logger.info(f"Deleting {instance} (PK: {instance.pk})")
return super().perform_destroy(instance)
class FieldChoicesViewSet(ViewSet):
"""

View File

@@ -28,12 +28,47 @@ COLOR_CHOICES = (
('ffffff', 'White'),
)
#
# Filter lookup expressions
#
FILTER_CHAR_BASED_LOOKUP_MAP = dict(
n='exact',
ic='icontains',
nic='icontains',
iew='iendswith',
niew='iendswith',
isw='istartswith',
nisw='istartswith',
ie='iexact',
nie='iexact'
)
FILTER_NUMERIC_BASED_LOOKUP_MAP = dict(
n='exact',
lte='lte',
lt='lt',
gte='gte',
gt='gt'
)
FILTER_NEGATION_LOOKUP_MAP = dict(
n='exact'
)
FILTER_TREENODE_NEGATION_LOOKUP_MAP = dict(
n='in'
)
# Keys for PostgreSQL advisory locks. These are arbitrary bigints used by
# the advisory_lock contextmanager. When a lock is acquired,
# one of these keys will be used to identify said lock.
#
# When adding a new key, pick something arbitrary and unique so
# that it is easily searchable in query logs.
ADVISORY_LOCK_KEYS = {
'available-prefixes': 100100,
'available-ips': 100200,

View File

@@ -1,3 +1,4 @@
from django.contrib.postgres.fields import JSONField
from drf_yasg import openapi
from drf_yasg.inspectors import FieldInspector, NotHandled, PaginatorInspector, FilterInspector, SwaggerAutoSchema
from drf_yasg.utils import get_serializer_ref_name
@@ -75,22 +76,28 @@ class CustomChoiceFieldInspector(FieldInspector):
SwaggerType, _ = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
if isinstance(field, ChoiceField):
value_schema = openapi.Schema(type=openapi.TYPE_STRING)
choices = field._choices
choice_value = list(choices.keys())
choice_label = list(choices.values())
value_schema = openapi.Schema(type=openapi.TYPE_STRING, enum=choice_value)
choices = list(field._choices.keys())
if set([None] + choices) == {None, True, False}:
if set([None] + choice_value) == {None, True, False}:
# DeviceType.subdevice_role, Device.face and InterfaceConnection.connection_status all need to be
# differentiated since they each have subtly different values in their choice keys.
# - subdevice_role and connection_status are booleans, although subdevice_role includes None
# - face is an integer set {0, 1} which is easily confused with {False, True}
schema_type = openapi.TYPE_STRING
if all(type(x) == bool for x in [c for c in choices if c is not None]):
if all(type(x) == bool for x in [c for c in choice_value if c is not None]):
schema_type = openapi.TYPE_BOOLEAN
value_schema = openapi.Schema(type=schema_type)
value_schema = openapi.Schema(type=schema_type, enum=choice_value)
value_schema['x-nullable'] = True
if isinstance(choice_value[0], int):
# Change value_schema for IPAddressFamilyChoices, RackWidthChoices
value_schema = openapi.Schema(type=openapi.TYPE_INTEGER, enum=choice_value)
schema = SwaggerType(type=openapi.TYPE_OBJECT, required=["label", "value"], properties={
"label": openapi.Schema(type=openapi.TYPE_STRING),
"label": openapi.Schema(type=openapi.TYPE_STRING, enum=choice_label),
"value": value_schema
})
@@ -115,13 +122,12 @@ class NullableBooleanFieldInspector(FieldInspector):
return result
class IdInFilterInspector(FilterInspector):
class JSONFieldInspector(FieldInspector):
"""Required because by default, Swagger sees a JSONField as a string and not dict
"""
def process_result(self, result, method_name, obj, **kwargs):
if isinstance(result, list):
params = [p for p in result if isinstance(p, openapi.Parameter) and p.name == 'id__in']
for p in params:
p.type = 'string'
if isinstance(result, openapi.Schema) and isinstance(obj, JSONField):
result.type = 'dict'
return result

View File

@@ -1,9 +1,16 @@
import django_filters
from copy import deepcopy
from dcim.forms import MACAddressField
from django import forms
from django.conf import settings
from django.db import models
from django_filters.utils import get_model_field, resolve_field
from extras.models import Tag
from utilities.constants import (
FILTER_CHAR_BASED_LOOKUP_MAP, FILTER_NEGATION_LOOKUP_MAP, FILTER_TREENODE_NEGATION_LOOKUP_MAP,
FILTER_NUMERIC_BASED_LOOKUP_MAP
)
def multivalue_field_factory(field_class):
@@ -73,13 +80,6 @@ class TreeNodeMultipleChoiceFilter(django_filters.ModelMultipleChoiceFilter):
return super().filter(qs, value)
class NumericInFilter(django_filters.BaseInFilter, django_filters.NumberFilter):
"""
Filters for a set of numeric values. Example: id__in=100,200,300
"""
pass
class NullableCharFieldFilter(django_filters.CharFilter):
"""
Allow matching on null field values by passing a special string used to signify NULL.
@@ -111,6 +111,165 @@ class TagFilter(django_filters.ModelMultipleChoiceFilter):
# FilterSets
#
class BaseFilterSet(django_filters.FilterSet):
"""
A base filterset which provides common functionaly to all NetBox filtersets
"""
FILTER_DEFAULTS = deepcopy(django_filters.filterset.FILTER_FOR_DBFIELD_DEFAULTS)
FILTER_DEFAULTS.update({
models.AutoField: {
'filter_class': MultiValueNumberFilter
},
models.CharField: {
'filter_class': MultiValueCharFilter
},
models.DateField: {
'filter_class': MultiValueDateFilter
},
models.DateTimeField: {
'filter_class': MultiValueDateTimeFilter
},
models.DecimalField: {
'filter_class': MultiValueNumberFilter
},
models.EmailField: {
'filter_class': MultiValueCharFilter
},
models.FloatField: {
'filter_class': MultiValueNumberFilter
},
models.IntegerField: {
'filter_class': MultiValueNumberFilter
},
models.PositiveIntegerField: {
'filter_class': MultiValueNumberFilter
},
models.PositiveSmallIntegerField: {
'filter_class': MultiValueNumberFilter
},
models.SlugField: {
'filter_class': MultiValueCharFilter
},
models.SmallIntegerField: {
'filter_class': MultiValueNumberFilter
},
models.TimeField: {
'filter_class': MultiValueTimeFilter
},
models.URLField: {
'filter_class': MultiValueCharFilter
},
MACAddressField: {
'filter_class': MultiValueMACAddressFilter
},
})
@staticmethod
def _get_filter_lookup_dict(existing_filter):
# Choose the lookup expression map based on the filter type
if isinstance(existing_filter, (
MultiValueDateFilter,
MultiValueDateTimeFilter,
MultiValueNumberFilter,
MultiValueTimeFilter
)):
lookup_map = FILTER_NUMERIC_BASED_LOOKUP_MAP
elif isinstance(existing_filter, (
TreeNodeMultipleChoiceFilter,
)):
# TreeNodeMultipleChoiceFilter only support negation but must maintain the `in` lookup expression
lookup_map = FILTER_TREENODE_NEGATION_LOOKUP_MAP
elif isinstance(existing_filter, (
django_filters.ModelChoiceFilter,
django_filters.ModelMultipleChoiceFilter,
TagFilter
)) or existing_filter.extra.get('choices'):
# These filter types support only negation
lookup_map = FILTER_NEGATION_LOOKUP_MAP
elif isinstance(existing_filter, (
django_filters.filters.CharFilter,
django_filters.MultipleChoiceFilter,
MultiValueCharFilter,
MultiValueMACAddressFilter
)):
lookup_map = FILTER_CHAR_BASED_LOOKUP_MAP
else:
lookup_map = None
return lookup_map
@classmethod
def get_filters(cls):
"""
Override filter generation to support dynamic lookup expressions for certain filter types.
For specific filter types, new filters are created based on defined lookup expressions in
the form `<field_name>__<lookup_expr>`
"""
# TODO: once 3.6 is the minimum required version of python, change this to a bare super() call
# We have to do it this way in py3.5 becuase of django_filters.FilterSet's use of a metaclass
filters = super(django_filters.FilterSet, cls).get_filters()
new_filters = {}
for existing_filter_name, existing_filter in filters.items():
# Loop over existing filters to extract metadata by which to create new filters
# If the filter makes use of a custom filter method or lookup expression skip it
# as we cannot sanely handle these cases in a generic mannor
if existing_filter.method is not None or existing_filter.lookup_expr not in ['exact', 'in']:
continue
# Choose the lookup expression map based on the filter type
lookup_map = cls._get_filter_lookup_dict(existing_filter)
if lookup_map is None:
# Do not augment this filter type with more lookup expressions
continue
# Get properties of the existing filter for later use
field_name = existing_filter.field_name
field = get_model_field(cls._meta.model, field_name)
# Create new filters for each lookup expression in the map
for lookup_name, lookup_expr in lookup_map.items():
new_filter_name = '{}__{}'.format(existing_filter_name, lookup_name)
try:
if existing_filter_name in cls.declared_filters:
# The filter field has been explicity defined on the filterset class so we must manually
# create the new filter with the same type because there is no guarantee the defined type
# is the same as the default type for the field
resolve_field(field, lookup_expr) # Will raise FieldLookupError if the lookup is invalid
new_filter = type(existing_filter)(
field_name=field_name,
lookup_expr=lookup_expr,
label=existing_filter.label,
exclude=existing_filter.exclude,
distinct=existing_filter.distinct,
**existing_filter.extra
)
else:
# The filter field is listed in Meta.fields so we can safely rely on default behaviour
# Will raise FieldLookupError if the lookup is invalid
new_filter = cls.filter_for_field(field, field_name, lookup_expr)
except django_filters.exceptions.FieldLookupError:
# The filter could not be created because the lookup expression is not supported on the field
continue
if lookup_name.startswith('n'):
# This is a negation filter which requires a queryset.exclude() clause
# Of course setting the negation of the existing filter's exclude attribute handles both cases
new_filter.exclude = not existing_filter.exclude
new_filters[new_filter_name] = new_filter
filters.update(new_filters)
return filters
class NameSlugSearchFilterSet(django_filters.FilterSet):
"""
A base class for adding the search method to models which only expose the `name` and `slug` fields
@@ -127,54 +286,3 @@ class NameSlugSearchFilterSet(django_filters.FilterSet):
models.Q(name__icontains=value) |
models.Q(slug__icontains=value)
)
#
# Update default filters
#
FILTER_DEFAULTS = django_filters.filterset.FILTER_FOR_DBFIELD_DEFAULTS
FILTER_DEFAULTS.update({
models.AutoField: {
'filter_class': MultiValueNumberFilter
},
models.CharField: {
'filter_class': MultiValueCharFilter
},
models.DateField: {
'filter_class': MultiValueDateFilter
},
models.DateTimeField: {
'filter_class': MultiValueDateTimeFilter
},
models.DecimalField: {
'filter_class': MultiValueNumberFilter
},
models.EmailField: {
'filter_class': MultiValueCharFilter
},
models.FloatField: {
'filter_class': MultiValueNumberFilter
},
models.IntegerField: {
'filter_class': MultiValueNumberFilter
},
models.PositiveIntegerField: {
'filter_class': MultiValueNumberFilter
},
models.PositiveSmallIntegerField: {
'filter_class': MultiValueNumberFilter
},
models.SlugField: {
'filter_class': MultiValueCharFilter
},
models.SmallIntegerField: {
'filter_class': MultiValueNumberFilter
},
models.TimeField: {
'filter_class': MultiValueTimeFilter
},
models.URLField: {
'filter_class': MultiValueCharFilter
},
})

View File

@@ -2,8 +2,9 @@ import csv
import json
import re
from io import StringIO
import yaml
import django_filters
import yaml
from django import forms
from django.conf import settings
from django.contrib.postgres.forms.jsonb import JSONField as _JSONField, InvalidJSONInput
@@ -497,14 +498,14 @@ class ExpandableIPAddressField(forms.CharField):
class CommentField(forms.CharField):
"""
A textarea with support for GitHub-Flavored Markdown. Exists mostly just to add a standard help_text.
A textarea with support for Markdown rendering. Exists mostly just to add a standard help_text.
"""
widget = forms.Textarea
default_label = ''
# TODO: Port GFM syntax cheat sheet to internal documentation
# TODO: Port Markdown cheat sheet to internal documentation
default_helptext = '<i class="fa fa-info-circle"></i> '\
'<a href="https://github.com/adam-p/markdown-here/wiki/Markdown-Cheatsheet" target="_blank">'\
'GitHub-Flavored Markdown</a> syntax is supported'
'Markdown</a> syntax is supported'
def __init__(self, *args, **kwargs):
required = kwargs.pop('required', False)
@@ -564,18 +565,17 @@ class TagFilterField(forms.MultipleChoiceField):
class DynamicModelChoiceMixin:
field_modifier = ''
filter = django_filters.ModelChoiceFilter
def get_bound_field(self, form, field_name):
bound_field = BoundField(form, self, field_name)
# Modify the QuerySet of the field before we return it. Limit choices to any data already bound: Options
# will be populated on-demand via the APISelect widget.
field_name = '{}{}'.format(self.to_field_name or 'pk', self.field_modifier)
if bound_field.data:
self.queryset = self.queryset.filter(**{field_name: self.prepare_value(bound_field.data)})
elif bound_field.initial:
self.queryset = self.queryset.filter(**{field_name: self.prepare_value(bound_field.initial)})
data = self.prepare_value(bound_field.data or bound_field.initial)
if data:
filter = self.filter(field_name=self.to_field_name or 'pk', queryset=self.queryset)
self.queryset = filter.filter(self.queryset, data)
else:
self.queryset = self.queryset.none()
@@ -594,7 +594,7 @@ class DynamicModelMultipleChoiceField(DynamicModelChoiceMixin, forms.ModelMultip
"""
A multiple-choice version of DynamicModelChoiceField.
"""
field_modifier = '__in'
filter = django_filters.ModelMultipleChoiceFilter
class LaxURLField(forms.URLField):

View File

@@ -6,6 +6,7 @@ from django.db import ProgrammingError
from django.http import Http404, HttpResponseRedirect
from django.urls import reverse
from .api import is_api_request
from .views import server_error
@@ -47,9 +48,8 @@ class APIVersionMiddleware(object):
self.get_response = get_response
def __call__(self, request):
api_path = reverse('api-root')
response = self.get_response(request)
if request.path_info.startswith(api_path):
if is_api_request(request):
response['API-Version'] = settings.REST_FRAMEWORK_VERSION
return response

View File

@@ -4,8 +4,8 @@
<span class="fa fa-upload" aria-hidden="true"></span>
Export <span class="caret"></span>
</button>
<ul class="dropdown-menu">
<li><a href="?{% if url_params %}{{ url_params.urlencode }}&{% endif %}export">CSV (default)</a></li>
<ul class="dropdown-menu dropdown-menu-right">
<li><a href="?{% if url_params %}{{ url_params.urlencode }}&{% endif %}export">Default format</a></li>
<li class="divider"></li>
{% for et in export_templates %}
<li><a href="?{% if url_params %}{{ url_params.urlencode }}&{% endif %}export={{ et.name }}"{% if et.description %} title="{{ et.description }}"{% endif %}>{{ et.name }}</a></li>

View File

@@ -4,6 +4,7 @@ import re
import yaml
from django import template
from django.conf import settings
from django.urls import NoReverseMatch, reverse
from django.utils.html import strip_tags
from django.utils.safestring import mark_safe
@@ -19,15 +20,6 @@ register = template.Library()
# Filters
#
@register.filter()
def oneline(value):
"""
Replace each line break with a single space
"""
value = value.replace('\r', '')
return value.replace('\n', ' ')
@register.filter()
def placeholder(value):
"""
@@ -39,32 +31,16 @@ def placeholder(value):
return mark_safe(placeholder)
@register.filter()
def getlist(value, arg):
"""
Return all values of a QueryDict key
"""
return value.getlist(arg)
@register.filter
def getkey(value, key):
"""
Return a dictionary item specified by key
"""
return value[key]
@register.filter(is_safe=True)
def gfm(value):
def render_markdown(value):
"""
Render text as GitHub-Flavored Markdown
Render text as Markdown
"""
# Strip HTML tags
value = strip_tags(value)
# Render Markdown with GFM extension
html = markdown(value, extensions=['mdx_gfm'])
# Render Markdown
html = markdown(value, extensions=['fenced_code'])
return mark_safe(html)
@@ -86,19 +62,12 @@ def render_yaml(value):
@register.filter()
def model_name(obj):
def meta(obj, attr):
"""
Return the name of the model of the given object
Return the specified Meta attribute of a model. This is needed because Django does not permit templates
to access attributes which begin with an underscore (e.g. _meta).
"""
return obj._meta.verbose_name
@register.filter()
def model_name_plural(obj):
"""
Return the plural name of the model of the given object
"""
return obj._meta.verbose_name_plural
return getattr(obj._meta, attr, '')
@register.filter()
@@ -116,14 +85,6 @@ def url_name(model, action):
return None
@register.filter()
def contains(value, arg):
"""
Test whether a value contains any of a given set of strings. `arg` should be a comma-separated list of strings.
"""
return any(s in value for s in arg.split(','))
@register.filter()
def bettertitle(value):
"""
@@ -216,6 +177,30 @@ def percentage(x, y):
return round(x / y * 100)
@register.filter()
def get_docs(model):
"""
Render and return documentation for the specified model.
"""
path = '{}/models/{}/{}.md'.format(
settings.DOCS_ROOT,
model._meta.app_label,
model._meta.model_name
)
try:
with open(path) as docfile:
content = docfile.read()
except FileNotFoundError:
return "Unable to load documentation, file not found: {}".format(path)
except IOError:
return "Unable to load documentation, error reading file: {}".format(path)
# Render Markdown with the admonition extension
content = markdown(content, extensions=['admonition', 'fenced_code'])
return mark_safe(content)
#
# Tags
#

View File

@@ -172,24 +172,29 @@ class ViewTestCases:
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_create_object(self):
# Try GET without permission
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(self._get_url('add')), 403)
# Try GET with permission
self.add_permissions(
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.get(path=self._get_url('add'))
self.assertHttpStatus(response, 200)
# Try POST with permission
initial_count = self.model.objects.count()
request = {
'path': self._get_url('add'),
'data': post_data(self.form_data),
'follow': False, # Do not follow 302 redirects
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 302)
# Validate object creation
self.assertEqual(initial_count + 1, self.model.objects.count())
instance = self.model.objects.order_by('-pk').first()
self.assertInstanceEqual(instance, self.form_data)
@@ -204,23 +209,27 @@ class ViewTestCases:
def test_edit_object(self):
instance = self.model.objects.first()
# Try GET without permission
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(self._get_url('edit', instance)), 403)
# Try GET with permission
self.add_permissions(
'{}.change_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.get(path=self._get_url('edit', instance))
self.assertHttpStatus(response, 200)
# Try POST with permission
request = {
'path': self._get_url('edit', instance),
'data': post_data(self.form_data),
'follow': False, # Do not follow 302 redirects
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.change_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 302)
# Validate object modifications
instance = self.model.objects.get(pk=instance.pk)
self.assertInstanceEqual(instance, self.form_data)
@@ -232,23 +241,26 @@ class ViewTestCases:
def test_delete_object(self):
instance = self.model.objects.first()
# Try GET without permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(self._get_url('delete', instance)), 403)
# Try GET with permission
self.add_permissions(
'{}.delete_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.get(path=self._get_url('delete', instance))
self.assertHttpStatus(response, 200)
request = {
'path': self._get_url('delete', instance),
'data': {'confirm': True},
'follow': False, # Do not follow 302 redirects
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.delete_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 302)
# Validate object deletion
with self.assertRaises(ObjectDoesNotExist):
self.model.objects.get(pk=instance.pk)
@@ -314,6 +326,20 @@ class ViewTestCases:
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_import_objects(self):
# Test GET without permission
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.get(self._get_url('import')), 403)
# Test GET with permission
self.add_permissions(
'{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name),
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.get(self._get_url('import'))
self.assertHttpStatus(response, 200)
# Test POST with permission
initial_count = self.model.objects.count()
request = {
'path': self._get_url('import'),
@@ -321,19 +347,10 @@ class ViewTestCases:
'csv': '\n'.join(self.csv_data)
}
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name),
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 200)
# Validate import of new objects
self.assertEqual(self.model.objects.count(), initial_count + len(self.csv_data) - 1)
class BulkEditObjectsViewTestCase(ModelViewTestCase):

View File

@@ -1,9 +1,21 @@
from django.conf import settings
from django.test import TestCase
import django_filters
from django.conf import settings
from django.db import models
from django.test import TestCase
from mptt.fields import TreeForeignKey
from taggit.managers import TaggableManager
from dcim.models import Region, Site
from utilities.filters import TreeNodeMultipleChoiceFilter
from dcim.choices import *
from dcim.fields import MACAddressField
from dcim.filters import DeviceFilterSet, SiteFilterSet
from dcim.models import (
Device, DeviceRole, DeviceType, Interface, Manufacturer, Platform, Rack, Region, Site
)
from extras.models import TaggedItem
from utilities.filters import (
BaseFilterSet, MACAddressFilter, MultiValueCharFilter, MultiValueDateFilter, MultiValueDateTimeFilter,
MultiValueNumberFilter, MultiValueTimeFilter, TagFilter, TreeNodeMultipleChoiceFilter,
)
class TreeNodeMultipleChoiceFilterTest(TestCase):
@@ -60,3 +72,447 @@ class TreeNodeMultipleChoiceFilterTest(TestCase):
self.assertEqual(qs.count(), 2)
self.assertEqual(qs[0], self.site1)
self.assertEqual(qs[1], self.site3)
class DummyModel(models.Model):
"""
Dummy model used by BaseFilterSetTest for filter validation. Should never appear in a schema migration.
"""
charfield = models.CharField(
max_length=10
)
choicefield = models.IntegerField(
choices=(('A', 1), ('B', 2), ('C', 3))
)
datefield = models.DateField()
datetimefield = models.DateTimeField()
integerfield = models.IntegerField()
macaddressfield = MACAddressField()
timefield = models.TimeField()
treeforeignkeyfield = TreeForeignKey(
to='self',
on_delete=models.CASCADE
)
tags = TaggableManager(through=TaggedItem)
class BaseFilterSetTest(TestCase):
"""
Ensure that a BaseFilterSet automatically creates the expected set of filters for each filter type.
"""
class DummyFilterSet(BaseFilterSet):
charfield = django_filters.CharFilter()
macaddressfield = MACAddressFilter()
modelchoicefield = django_filters.ModelChoiceFilter(
field_name='integerfield', # We're pretending this is a ForeignKey field
queryset=Site.objects.all()
)
modelmultiplechoicefield = django_filters.ModelMultipleChoiceFilter(
field_name='integerfield', # We're pretending this is a ForeignKey field
queryset=Site.objects.all()
)
multiplechoicefield = django_filters.MultipleChoiceFilter(
field_name='choicefield'
)
multivaluecharfield = MultiValueCharFilter(
field_name='charfield'
)
tagfield = TagFilter()
treeforeignkeyfield = TreeNodeMultipleChoiceFilter(
queryset=DummyModel.objects.all()
)
class Meta:
model = DummyModel
fields = (
'charfield',
'choicefield',
'datefield',
'datetimefield',
'integerfield',
'macaddressfield',
'modelchoicefield',
'modelmultiplechoicefield',
'multiplechoicefield',
'tagfield',
'timefield',
'treeforeignkeyfield',
)
@classmethod
def setUpTestData(cls):
cls.filters = cls.DummyFilterSet().filters
def test_char_filter(self):
self.assertIsInstance(self.filters['charfield'], django_filters.CharFilter)
self.assertEqual(self.filters['charfield'].lookup_expr, 'exact')
self.assertEqual(self.filters['charfield'].exclude, False)
self.assertEqual(self.filters['charfield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['charfield__n'].exclude, True)
self.assertEqual(self.filters['charfield__ie'].lookup_expr, 'iexact')
self.assertEqual(self.filters['charfield__ie'].exclude, False)
self.assertEqual(self.filters['charfield__nie'].lookup_expr, 'iexact')
self.assertEqual(self.filters['charfield__nie'].exclude, True)
self.assertEqual(self.filters['charfield__ic'].lookup_expr, 'icontains')
self.assertEqual(self.filters['charfield__ic'].exclude, False)
self.assertEqual(self.filters['charfield__nic'].lookup_expr, 'icontains')
self.assertEqual(self.filters['charfield__nic'].exclude, True)
self.assertEqual(self.filters['charfield__isw'].lookup_expr, 'istartswith')
self.assertEqual(self.filters['charfield__isw'].exclude, False)
self.assertEqual(self.filters['charfield__nisw'].lookup_expr, 'istartswith')
self.assertEqual(self.filters['charfield__nisw'].exclude, True)
self.assertEqual(self.filters['charfield__iew'].lookup_expr, 'iendswith')
self.assertEqual(self.filters['charfield__iew'].exclude, False)
self.assertEqual(self.filters['charfield__niew'].lookup_expr, 'iendswith')
self.assertEqual(self.filters['charfield__niew'].exclude, True)
def test_mac_address_filter(self):
self.assertIsInstance(self.filters['macaddressfield'], MACAddressFilter)
self.assertEqual(self.filters['macaddressfield'].lookup_expr, 'exact')
self.assertEqual(self.filters['macaddressfield'].exclude, False)
self.assertEqual(self.filters['macaddressfield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['macaddressfield__n'].exclude, True)
self.assertEqual(self.filters['macaddressfield__ie'].lookup_expr, 'iexact')
self.assertEqual(self.filters['macaddressfield__ie'].exclude, False)
self.assertEqual(self.filters['macaddressfield__nie'].lookup_expr, 'iexact')
self.assertEqual(self.filters['macaddressfield__nie'].exclude, True)
self.assertEqual(self.filters['macaddressfield__ic'].lookup_expr, 'icontains')
self.assertEqual(self.filters['macaddressfield__ic'].exclude, False)
self.assertEqual(self.filters['macaddressfield__nic'].lookup_expr, 'icontains')
self.assertEqual(self.filters['macaddressfield__nic'].exclude, True)
self.assertEqual(self.filters['macaddressfield__isw'].lookup_expr, 'istartswith')
self.assertEqual(self.filters['macaddressfield__isw'].exclude, False)
self.assertEqual(self.filters['macaddressfield__nisw'].lookup_expr, 'istartswith')
self.assertEqual(self.filters['macaddressfield__nisw'].exclude, True)
self.assertEqual(self.filters['macaddressfield__iew'].lookup_expr, 'iendswith')
self.assertEqual(self.filters['macaddressfield__iew'].exclude, False)
self.assertEqual(self.filters['macaddressfield__niew'].lookup_expr, 'iendswith')
self.assertEqual(self.filters['macaddressfield__niew'].exclude, True)
def test_model_choice_filter(self):
self.assertIsInstance(self.filters['modelchoicefield'], django_filters.ModelChoiceFilter)
self.assertEqual(self.filters['modelchoicefield'].lookup_expr, 'exact')
self.assertEqual(self.filters['modelchoicefield'].exclude, False)
self.assertEqual(self.filters['modelchoicefield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['modelchoicefield__n'].exclude, True)
def test_model_multiple_choice_filter(self):
self.assertIsInstance(self.filters['modelmultiplechoicefield'], django_filters.ModelMultipleChoiceFilter)
self.assertEqual(self.filters['modelmultiplechoicefield'].lookup_expr, 'exact')
self.assertEqual(self.filters['modelmultiplechoicefield'].exclude, False)
self.assertEqual(self.filters['modelmultiplechoicefield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['modelmultiplechoicefield__n'].exclude, True)
def test_multi_value_char_filter(self):
self.assertIsInstance(self.filters['multivaluecharfield'], MultiValueCharFilter)
self.assertEqual(self.filters['multivaluecharfield'].lookup_expr, 'exact')
self.assertEqual(self.filters['multivaluecharfield'].exclude, False)
self.assertEqual(self.filters['multivaluecharfield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['multivaluecharfield__n'].exclude, True)
self.assertEqual(self.filters['multivaluecharfield__ie'].lookup_expr, 'iexact')
self.assertEqual(self.filters['multivaluecharfield__ie'].exclude, False)
self.assertEqual(self.filters['multivaluecharfield__nie'].lookup_expr, 'iexact')
self.assertEqual(self.filters['multivaluecharfield__nie'].exclude, True)
self.assertEqual(self.filters['multivaluecharfield__ic'].lookup_expr, 'icontains')
self.assertEqual(self.filters['multivaluecharfield__ic'].exclude, False)
self.assertEqual(self.filters['multivaluecharfield__nic'].lookup_expr, 'icontains')
self.assertEqual(self.filters['multivaluecharfield__nic'].exclude, True)
self.assertEqual(self.filters['multivaluecharfield__isw'].lookup_expr, 'istartswith')
self.assertEqual(self.filters['multivaluecharfield__isw'].exclude, False)
self.assertEqual(self.filters['multivaluecharfield__nisw'].lookup_expr, 'istartswith')
self.assertEqual(self.filters['multivaluecharfield__nisw'].exclude, True)
self.assertEqual(self.filters['multivaluecharfield__iew'].lookup_expr, 'iendswith')
self.assertEqual(self.filters['multivaluecharfield__iew'].exclude, False)
self.assertEqual(self.filters['multivaluecharfield__niew'].lookup_expr, 'iendswith')
self.assertEqual(self.filters['multivaluecharfield__niew'].exclude, True)
def test_multi_value_date_filter(self):
self.assertIsInstance(self.filters['datefield'], MultiValueDateFilter)
self.assertEqual(self.filters['datefield'].lookup_expr, 'exact')
self.assertEqual(self.filters['datefield'].exclude, False)
self.assertEqual(self.filters['datefield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['datefield__n'].exclude, True)
self.assertEqual(self.filters['datefield__lt'].lookup_expr, 'lt')
self.assertEqual(self.filters['datefield__lt'].exclude, False)
self.assertEqual(self.filters['datefield__lte'].lookup_expr, 'lte')
self.assertEqual(self.filters['datefield__lte'].exclude, False)
self.assertEqual(self.filters['datefield__gt'].lookup_expr, 'gt')
self.assertEqual(self.filters['datefield__gt'].exclude, False)
self.assertEqual(self.filters['datefield__gte'].lookup_expr, 'gte')
self.assertEqual(self.filters['datefield__gte'].exclude, False)
def test_multi_value_datetime_filter(self):
self.assertIsInstance(self.filters['datetimefield'], MultiValueDateTimeFilter)
self.assertEqual(self.filters['datetimefield'].lookup_expr, 'exact')
self.assertEqual(self.filters['datetimefield'].exclude, False)
self.assertEqual(self.filters['datetimefield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['datetimefield__n'].exclude, True)
self.assertEqual(self.filters['datetimefield__lt'].lookup_expr, 'lt')
self.assertEqual(self.filters['datetimefield__lt'].exclude, False)
self.assertEqual(self.filters['datetimefield__lte'].lookup_expr, 'lte')
self.assertEqual(self.filters['datetimefield__lte'].exclude, False)
self.assertEqual(self.filters['datetimefield__gt'].lookup_expr, 'gt')
self.assertEqual(self.filters['datetimefield__gt'].exclude, False)
self.assertEqual(self.filters['datetimefield__gte'].lookup_expr, 'gte')
self.assertEqual(self.filters['datetimefield__gte'].exclude, False)
def test_multi_value_number_filter(self):
self.assertIsInstance(self.filters['integerfield'], MultiValueNumberFilter)
self.assertEqual(self.filters['integerfield'].lookup_expr, 'exact')
self.assertEqual(self.filters['integerfield'].exclude, False)
self.assertEqual(self.filters['integerfield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['integerfield__n'].exclude, True)
self.assertEqual(self.filters['integerfield__lt'].lookup_expr, 'lt')
self.assertEqual(self.filters['integerfield__lt'].exclude, False)
self.assertEqual(self.filters['integerfield__lte'].lookup_expr, 'lte')
self.assertEqual(self.filters['integerfield__lte'].exclude, False)
self.assertEqual(self.filters['integerfield__gt'].lookup_expr, 'gt')
self.assertEqual(self.filters['integerfield__gt'].exclude, False)
self.assertEqual(self.filters['integerfield__gte'].lookup_expr, 'gte')
self.assertEqual(self.filters['integerfield__gte'].exclude, False)
def test_multi_value_time_filter(self):
self.assertIsInstance(self.filters['timefield'], MultiValueTimeFilter)
self.assertEqual(self.filters['timefield'].lookup_expr, 'exact')
self.assertEqual(self.filters['timefield'].exclude, False)
self.assertEqual(self.filters['timefield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['timefield__n'].exclude, True)
self.assertEqual(self.filters['timefield__lt'].lookup_expr, 'lt')
self.assertEqual(self.filters['timefield__lt'].exclude, False)
self.assertEqual(self.filters['timefield__lte'].lookup_expr, 'lte')
self.assertEqual(self.filters['timefield__lte'].exclude, False)
self.assertEqual(self.filters['timefield__gt'].lookup_expr, 'gt')
self.assertEqual(self.filters['timefield__gt'].exclude, False)
self.assertEqual(self.filters['timefield__gte'].lookup_expr, 'gte')
self.assertEqual(self.filters['timefield__gte'].exclude, False)
def test_multiple_choice_filter(self):
self.assertIsInstance(self.filters['multiplechoicefield'], django_filters.MultipleChoiceFilter)
self.assertEqual(self.filters['multiplechoicefield'].lookup_expr, 'exact')
self.assertEqual(self.filters['multiplechoicefield'].exclude, False)
self.assertEqual(self.filters['multiplechoicefield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['multiplechoicefield__n'].exclude, True)
self.assertEqual(self.filters['multiplechoicefield__ie'].lookup_expr, 'iexact')
self.assertEqual(self.filters['multiplechoicefield__ie'].exclude, False)
self.assertEqual(self.filters['multiplechoicefield__nie'].lookup_expr, 'iexact')
self.assertEqual(self.filters['multiplechoicefield__nie'].exclude, True)
self.assertEqual(self.filters['multiplechoicefield__ic'].lookup_expr, 'icontains')
self.assertEqual(self.filters['multiplechoicefield__ic'].exclude, False)
self.assertEqual(self.filters['multiplechoicefield__nic'].lookup_expr, 'icontains')
self.assertEqual(self.filters['multiplechoicefield__nic'].exclude, True)
self.assertEqual(self.filters['multiplechoicefield__isw'].lookup_expr, 'istartswith')
self.assertEqual(self.filters['multiplechoicefield__isw'].exclude, False)
self.assertEqual(self.filters['multiplechoicefield__nisw'].lookup_expr, 'istartswith')
self.assertEqual(self.filters['multiplechoicefield__nisw'].exclude, True)
self.assertEqual(self.filters['multiplechoicefield__iew'].lookup_expr, 'iendswith')
self.assertEqual(self.filters['multiplechoicefield__iew'].exclude, False)
self.assertEqual(self.filters['multiplechoicefield__niew'].lookup_expr, 'iendswith')
self.assertEqual(self.filters['multiplechoicefield__niew'].exclude, True)
def test_tag_filter(self):
self.assertIsInstance(self.filters['tagfield'], TagFilter)
self.assertEqual(self.filters['tagfield'].lookup_expr, 'exact')
self.assertEqual(self.filters['tagfield'].exclude, False)
self.assertEqual(self.filters['tagfield__n'].lookup_expr, 'exact')
self.assertEqual(self.filters['tagfield__n'].exclude, True)
def test_tree_node_multiple_choice_filter(self):
self.assertIsInstance(self.filters['treeforeignkeyfield'], TreeNodeMultipleChoiceFilter)
# TODO: lookup_expr different for negation?
self.assertEqual(self.filters['treeforeignkeyfield'].lookup_expr, 'exact')
self.assertEqual(self.filters['treeforeignkeyfield'].exclude, False)
self.assertEqual(self.filters['treeforeignkeyfield__n'].lookup_expr, 'in')
self.assertEqual(self.filters['treeforeignkeyfield__n'].exclude, True)
class DynamicFilterLookupExpressionTest(TestCase):
"""
Validate function of automatically generated filters using the Device model as an example.
"""
device_queryset = Device.objects.all()
device_filterset = DeviceFilterSet
site_queryset = Site.objects.all()
site_filterset = SiteFilterSet
@classmethod
def setUpTestData(cls):
manufacturers = (
Manufacturer(name='Manufacturer 1', slug='manufacturer-1'),
Manufacturer(name='Manufacturer 2', slug='manufacturer-2'),
Manufacturer(name='Manufacturer 3', slug='manufacturer-3'),
)
Manufacturer.objects.bulk_create(manufacturers)
device_types = (
DeviceType(manufacturer=manufacturers[0], model='Model 1', slug='model-1', is_full_depth=True),
DeviceType(manufacturer=manufacturers[1], model='Model 2', slug='model-2', is_full_depth=True),
DeviceType(manufacturer=manufacturers[2], model='Model 3', slug='model-3', is_full_depth=False),
)
DeviceType.objects.bulk_create(device_types)
device_roles = (
DeviceRole(name='Device Role 1', slug='device-role-1'),
DeviceRole(name='Device Role 2', slug='device-role-2'),
DeviceRole(name='Device Role 3', slug='device-role-3'),
)
DeviceRole.objects.bulk_create(device_roles)
platforms = (
Platform(name='Platform 1', slug='platform-1'),
Platform(name='Platform 2', slug='platform-2'),
Platform(name='Platform 3', slug='platform-3'),
)
Platform.objects.bulk_create(platforms)
regions = (
Region(name='Region 1', slug='region-1'),
Region(name='Region 2', slug='region-2'),
Region(name='Region 3', slug='region-3'),
)
for region in regions:
region.save()
sites = (
Site(name='Site 1', slug='abc-site-1', region=regions[0], asn=65001),
Site(name='Site 2', slug='def-site-2', region=regions[1], asn=65101),
Site(name='Site 3', slug='ghi-site-3', region=regions[2], asn=65201),
)
Site.objects.bulk_create(sites)
racks = (
Rack(name='Rack 1', site=sites[0]),
Rack(name='Rack 2', site=sites[1]),
Rack(name='Rack 3', site=sites[2]),
)
Rack.objects.bulk_create(racks)
devices = (
Device(name='Device 1', device_type=device_types[0], device_role=device_roles[0], platform=platforms[0], serial='ABC', asset_tag='1001', site=sites[0], rack=racks[0], position=1, face=DeviceFaceChoices.FACE_FRONT, status=DeviceStatusChoices.STATUS_ACTIVE, local_context_data={"foo": 123}),
Device(name='Device 2', device_type=device_types[1], device_role=device_roles[1], platform=platforms[1], serial='DEF', asset_tag='1002', site=sites[1], rack=racks[1], position=2, face=DeviceFaceChoices.FACE_FRONT, status=DeviceStatusChoices.STATUS_STAGED),
Device(name='Device 3', device_type=device_types[2], device_role=device_roles[2], platform=platforms[2], serial='GHI', asset_tag='1003', site=sites[2], rack=racks[2], position=3, face=DeviceFaceChoices.FACE_REAR, status=DeviceStatusChoices.STATUS_FAILED),
)
Device.objects.bulk_create(devices)
interfaces = (
Interface(device=devices[0], name='Interface 1', mac_address='00-00-00-00-00-01'),
Interface(device=devices[0], name='Interface 2', mac_address='aa-00-00-00-00-01'),
Interface(device=devices[1], name='Interface 3', mac_address='00-00-00-00-00-02'),
Interface(device=devices[1], name='Interface 4', mac_address='bb-00-00-00-00-02'),
Interface(device=devices[2], name='Interface 5', mac_address='00-00-00-00-00-03'),
Interface(device=devices[2], name='Interface 6', mac_address='cc-00-00-00-00-03'),
)
Interface.objects.bulk_create(interfaces)
def test_site_name_negation(self):
params = {'name__n': ['Site 1']}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 2)
def test_site_slug_icontains(self):
params = {'slug__ic': ['-1']}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 1)
def test_site_slug_icontains_negation(self):
params = {'slug__nic': ['-1']}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 2)
def test_site_slug_startswith(self):
params = {'slug__isw': ['abc']}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 1)
def test_site_slug_startswith_negation(self):
params = {'slug__nisw': ['abc']}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 2)
def test_site_slug_endswith(self):
params = {'slug__iew': ['-1']}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 1)
def test_site_slug_endswith_negation(self):
params = {'slug__niew': ['-1']}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 2)
def test_site_asn_lt(self):
params = {'asn__lt': [65101]}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 1)
def test_site_asn_lte(self):
params = {'asn__lte': [65101]}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 2)
def test_site_asn_gt(self):
params = {'asn__lt': [65101]}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 1)
def test_site_asn_gte(self):
params = {'asn__gte': [65101]}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 2)
def test_site_region_negation(self):
params = {'region__n': ['region-1']}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 2)
def test_site_region_id_negation(self):
params = {'region_id__n': [Region.objects.first().pk]}
self.assertEqual(SiteFilterSet(params, self.site_queryset).qs.count(), 2)
def test_device_name_eq(self):
params = {'name': ['Device 1']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 1)
def test_device_name_negation(self):
params = {'name__n': ['Device 1']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 2)
def test_device_name_startswith(self):
params = {'name__isw': ['Device']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 3)
def test_device_name_startswith_negation(self):
params = {'name__nisw': ['Device 1']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 2)
def test_device_name_endswith(self):
params = {'name__iew': [' 1']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 1)
def test_device_name_endswith_negation(self):
params = {'name__niew': [' 1']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 2)
def test_device_name_icontains(self):
params = {'name__ic': [' 2']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 1)
def test_device_name_icontains_negation(self):
params = {'name__nic': [' ']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 0)
def test_device_mac_address_negation(self):
params = {'mac_address__n': ['00-00-00-00-00-01', 'aa-00-00-00-00-01']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 2)
def test_device_mac_address_startswith(self):
params = {'mac_address__isw': ['aa:']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 1)
def test_device_mac_address_startswith_negation(self):
params = {'mac_address__nisw': ['aa:']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 2)
def test_device_mac_address_endswith(self):
params = {'mac_address__iew': [':02']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 1)
def test_device_mac_address_endswith_negation(self):
params = {'mac_address__niew': [':02']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 2)
def test_device_mac_address_icontains(self):
params = {'mac_address__ic': ['aa:', 'bb']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 2)
def test_device_mac_address_icontains_negation(self):
params = {'mac_address__nic': ['aa:', 'bb']}
self.assertEqual(DeviceFilterSet(params, self.device_queryset).qs.count(), 1)

View File

@@ -33,14 +33,20 @@ class NaturalizationTestCase(TestCase):
# IOS/JunOS-style
('Gi', '9999999999999999Gi000000000000000000'),
('Gi1', '9999999999999999Gi000001000000000000'),
('Gi1.0', '9999999999999999Gi000001000000000000'),
('Gi1.1', '9999999999999999Gi000001000000000001'),
('Gi1:0', '9999999999999999Gi000001000000000000'),
('Gi1:0.0', '9999999999999999Gi000001000000000000'),
('Gi1:0.1', '9999999999999999Gi000001000000000001'),
('Gi1:1', '9999999999999999Gi000001000001000000'),
('Gi1:1.0', '9999999999999999Gi000001000001000000'),
('Gi1:1.1', '9999999999999999Gi000001000001000001'),
('Gi1/2', '0001999999999999Gi000002000000000000'),
('Gi1/2/3', '0001000299999999Gi000003000000000000'),
('Gi1/2/3/4', '0001000200039999Gi000004000000000000'),
('Gi1/2/3/4/5', '0001000200030004Gi000005000000000000'),
('Gi1/2/3/4/5:6', '0001000200030004Gi000005000006000000'),
('Gi1/2/3/4/5:6.7', '0001000200030004Gi000005000006000007'),
('Gi1:2', '9999999999999999Gi000001000002000000'),
('Gi1:2.3', '9999999999999999Gi000001000002000003'),
# Generic
('Interface 1', '9999999999999999Interface 000001000000000000'),
('Interface 1 (other)', '9999999999999999Interface 000001000000000000 (other)'),

View File

@@ -31,8 +31,9 @@ def csv_format(data):
if not isinstance(value, str):
value = '{}'.format(value)
# Double-quote the value if it contains a comma
# Double-quote the value if it contains a comma or line break
if ',' in value or '\n' in value:
value = value.replace('"', '""') # Escape double-quotes
csv.append('"{}"'.format(value))
else:
csv.append('{}'.format(value))
@@ -80,10 +81,12 @@ def get_subquery(model, field):
return subquery
def serialize_object(obj, extra=None):
def serialize_object(obj, extra=None, exclude=None):
"""
Return a generic JSON representation of an object using Django's built-in serializer. (This is used for things like
change logging, not the REST API.) Optionally include a dictionary to supplement the object data.
change logging, not the REST API.) Optionally include a dictionary to supplement the object data. A list of keys
can be provided to exclude them from the returned dictionary. Private fields (prefaced with an underscore) are
implicitly excluded.
"""
json_str = serialize('json', [obj])
data = json.loads(json_str)[0]['fields']
@@ -102,6 +105,16 @@ def serialize_object(obj, extra=None):
if extra is not None:
data.update(extra)
# Copy keys to list to avoid 'dictionary changed size during iteration' exception
for key in list(data):
# Private fields shouldn't be logged in the object change
if isinstance(key, str) and key.startswith('_'):
data.pop(key)
# Explicitly excluded keys
if isinstance(exclude, (list, tuple)) and key in exclude:
data.pop(key)
return data
@@ -212,18 +225,6 @@ def prepare_cloned_fields(instance):
return param_string
def querydict_to_dict(querydict):
"""
Convert a django.http.QueryDict object to a regular Python dictionary, preserving lists of multiple values.
(QueryDict.dict() will return only the last value in a list for each key.)
"""
assert isinstance(querydict, QueryDict)
return {
key: querydict.get(key) if len(value) == 1 and key != 'pk' else querydict.getlist(key)
for key, value in querydict.lists()
}
def shallow_compare_dict(source_dict, destination_dict, exclude=None):
"""
Return a new dictionary of the different keys. The values of `destination_dict` are returned. Only the equality of

View File

@@ -1,6 +1,6 @@
import re
from django.core.validators import _lazy_re_compile, URLValidator
from django.core.validators import _lazy_re_compile, BaseValidator, URLValidator
class EnhancedURLValidator(URLValidator):
@@ -26,3 +26,13 @@ class EnhancedURLValidator(URLValidator):
r'(?:[/?#][^\s]*)?' # Path
r'\Z', re.IGNORECASE)
schemes = AnyURLScheme()
class ExclusionValidator(BaseValidator):
"""
Ensure that a field's value is not equal to any of the specified values.
"""
message = 'This value may not be %(show_value)s.'
def compare(self, a, b):
return a in b

View File

@@ -1,3 +1,4 @@
import logging
import sys
from copy import deepcopy
@@ -25,7 +26,7 @@ from extras.models import CustomField, CustomFieldValue, ExportTemplate
from extras.querysets import CustomFieldQueryset
from utilities.exceptions import AbortTransaction
from utilities.forms import BootstrapMixin, CSVDataField
from utilities.utils import csv_format, prepare_cloned_fields, querydict_to_dict
from utilities.utils import csv_format, prepare_cloned_fields
from .error_handlers import handle_protectederror
from .forms import ConfirmationForm, ImportForm
from .paginator import EnhancedPaginator
@@ -219,35 +220,36 @@ class ObjectEditView(GetReturnURLMixin, View):
# given some parameter from the request URL.
return obj
def get(self, request, *args, **kwargs):
def dispatch(self, request, *args, **kwargs):
self.obj = self.alter_obj(self.get_object(kwargs), request, args, kwargs)
obj = self.get_object(kwargs)
obj = self.alter_obj(obj, request, args, kwargs)
return super().dispatch(request, *args, **kwargs)
def get(self, request, *args, **kwargs):
# Parse initial data manually to avoid setting field values as lists
initial_data = {k: request.GET[k] for k in request.GET}
form = self.model_form(instance=obj, initial=initial_data)
form = self.model_form(instance=self.obj, initial=initial_data)
return render(request, self.template_name, {
'obj': obj,
'obj': self.obj,
'obj_type': self.model._meta.verbose_name,
'form': form,
'return_url': self.get_return_url(request, obj),
'return_url': self.get_return_url(request, self.obj),
})
def post(self, request, *args, **kwargs):
obj = self.get_object(kwargs)
obj = self.alter_obj(obj, request, args, kwargs)
form = self.model_form(request.POST, request.FILES, instance=obj)
logger = logging.getLogger('netbox.views.ObjectEditView')
form = self.model_form(request.POST, request.FILES, instance=self.obj)
if form.is_valid():
obj_created = not form.instance.pk
obj = form.save()
logger.debug("Form validation was successful")
obj = form.save()
msg = '{} {}'.format(
'Created' if obj_created else 'Modified',
'Created' if not form.instance.pk else 'Modified',
self.model._meta.verbose_name
)
logger.info(f"{msg} {obj} (PK: {obj.pk})")
if hasattr(obj, 'get_absolute_url'):
msg = '{} <a href="{}">{}</a>'.format(msg, obj.get_absolute_url(), escape(obj))
else:
@@ -269,11 +271,14 @@ class ObjectEditView(GetReturnURLMixin, View):
else:
return redirect(self.get_return_url(request, obj))
else:
logger.debug("Form validation failed")
return render(request, self.template_name, {
'obj': obj,
'obj': self.obj,
'obj_type': self.model._meta.verbose_name,
'form': form,
'return_url': self.get_return_url(request, obj),
'return_url': self.get_return_url(request, self.obj),
})
@@ -295,7 +300,6 @@ class ObjectDeleteView(GetReturnURLMixin, View):
return get_object_or_404(self.model, pk=kwargs['pk'])
def get(self, request, **kwargs):
obj = self.get_object(kwargs)
form = ConfirmationForm(initial=request.GET)
@@ -307,18 +311,22 @@ class ObjectDeleteView(GetReturnURLMixin, View):
})
def post(self, request, **kwargs):
logger = logging.getLogger('netbox.views.ObjectDeleteView')
obj = self.get_object(kwargs)
form = ConfirmationForm(request.POST)
if form.is_valid():
logger.debug("Form validation was successful")
try:
obj.delete()
except ProtectedError as e:
logger.info("Caught ProtectedError while attempting to delete object")
handle_protectederror(obj, request, e)
return redirect(obj.get_absolute_url())
msg = 'Deleted {} {}'.format(self.model._meta.verbose_name, obj)
logger.info(msg)
messages.success(request, msg)
return_url = form.cleaned_data.get('return_url')
@@ -327,6 +335,9 @@ class ObjectDeleteView(GetReturnURLMixin, View):
else:
return redirect(self.get_return_url(request, obj))
else:
logger.debug("Form validation failed")
return render(request, self.template_name, {
'obj': obj,
'form': form,
@@ -350,7 +361,6 @@ class BulkCreateView(GetReturnURLMixin, View):
template_name = None
def get(self, request):
# Set initial values for visible form fields from query args
initial = {}
for field in getattr(self.model_form._meta, 'fields', []):
@@ -368,13 +378,13 @@ class BulkCreateView(GetReturnURLMixin, View):
})
def post(self, request):
logger = logging.getLogger('netbox.views.BulkCreateView')
model = self.model_form._meta.model
form = self.form(request.POST)
model_form = self.model_form(request.POST)
if form.is_valid():
logger.debug("Form validation was successful")
pattern = form.cleaned_data['pattern']
new_objs = []
@@ -392,6 +402,7 @@ class BulkCreateView(GetReturnURLMixin, View):
# Validate each new object independently.
if model_form.is_valid():
obj = model_form.save()
logger.debug(f"Created {obj} (PK: {obj.pk})")
new_objs.append(obj)
else:
# Copy any errors on the pattern target field to the pattern form.
@@ -403,6 +414,7 @@ class BulkCreateView(GetReturnURLMixin, View):
# If we make it to this point, validation has succeeded on all new objects.
msg = "Added {} {}".format(len(new_objs), model._meta.verbose_name_plural)
logger.info(msg)
messages.success(request, msg)
if '_addanother' in request.POST:
@@ -412,6 +424,9 @@ class BulkCreateView(GetReturnURLMixin, View):
except IntegrityError:
pass
else:
logger.debug("Form validation failed")
return render(request, self.template_name, {
'form': form,
'model_form': model_form,
@@ -430,7 +445,6 @@ class ObjectImportView(GetReturnURLMixin, View):
template_name = 'utilities/obj_import.html'
def get(self, request):
form = ImportForm()
return render(request, self.template_name, {
@@ -440,9 +454,11 @@ class ObjectImportView(GetReturnURLMixin, View):
})
def post(self, request):
logger = logging.getLogger('netbox.views.ObjectImportView')
form = ImportForm(request.POST)
if form.is_valid():
logger.debug("Import form validation was successful")
# Initialize model form
data = form.cleaned_data['data']
@@ -463,9 +479,11 @@ class ObjectImportView(GetReturnURLMixin, View):
# Save the primary object
obj = model_form.save()
logger.debug(f"Created {obj} (PK: {obj.pk})")
# Iterate through the related object forms (if any), validating and saving each instance.
for field_name, related_object_form in self.related_object_forms.items():
logger.debug("Processing form for related objects: {related_object_form}")
for i, rel_obj_data in enumerate(data.get(field_name, list())):
@@ -489,7 +507,7 @@ class ObjectImportView(GetReturnURLMixin, View):
pass
if not model_form.errors:
logger.info(f"Import object {obj} (PK: {obj.pk})")
messages.success(request, mark_safe('Imported object: <a href="{}">{}</a>'.format(
obj.get_absolute_url(), obj
)))
@@ -504,6 +522,7 @@ class ObjectImportView(GetReturnURLMixin, View):
return redirect(self.get_return_url(request, obj))
else:
logger.debug("Model form validation failed")
# Replicate model form errors for display
for field, errors in model_form.errors.items():
@@ -513,6 +532,9 @@ class ObjectImportView(GetReturnURLMixin, View):
else:
form.add_error(None, "{}: {}".format(field, err))
else:
logger.debug("Import form validation failed")
return render(request, self.template_name, {
'form': form,
'obj_type': self.model._meta.verbose_name,
@@ -544,7 +566,7 @@ class BulkImportView(GetReturnURLMixin, View):
return ImportForm(*args, **kwargs)
def _save_obj(self, obj_form):
def _save_obj(self, obj_form, request):
"""
Provide a hook to modify the object immediately before saving it (e.g. to encrypt secret data).
"""
@@ -560,20 +582,20 @@ class BulkImportView(GetReturnURLMixin, View):
})
def post(self, request):
logger = logging.getLogger('netbox.views.BulkImportView')
new_objs = []
form = self._import_form(request.POST)
if form.is_valid():
logger.debug("Form validation was successful")
try:
# Iterate through CSV data and bind each row to a new model form instance.
with transaction.atomic():
for row, data in enumerate(form.cleaned_data['csv'], start=1):
obj_form = self.model_form(data)
if obj_form.is_valid():
obj = self._save_obj(obj_form)
obj = self._save_obj(obj_form, request)
new_objs.append(obj)
else:
for field, err in obj_form.errors.items():
@@ -585,6 +607,7 @@ class BulkImportView(GetReturnURLMixin, View):
if new_objs:
msg = 'Imported {} {}'.format(len(new_objs), new_objs[0]._meta.verbose_name_plural)
logger.info(msg)
messages.success(request, msg)
return render(request, "import_success.html", {
@@ -595,6 +618,9 @@ class BulkImportView(GetReturnURLMixin, View):
except ValidationError:
pass
else:
logger.debug("Form validation failed")
return render(request, self.template_name, {
'form': form,
'fields': self.model_form().fields,
@@ -623,20 +649,22 @@ class BulkEditView(GetReturnURLMixin, View):
return redirect(self.get_return_url(request))
def post(self, request, **kwargs):
logger = logging.getLogger('netbox.views.BulkEditView')
model = self.queryset.model
# Create a mutable copy of the POST data
post_data = request.POST.copy()
# If we are editing *all* objects in the queryset, replace the PK list with all matched objects.
if post_data.get('_all') and self.filterset is not None:
post_data['pk'] = [obj.pk for obj in self.filterset(request.GET, model.objects.only('pk')).qs]
if request.POST.get('_all') and self.filterset is not None:
pk_list = [
obj.pk for obj in self.filterset(request.GET, model.objects.only('pk')).qs
]
else:
pk_list = request.POST.getlist('pk')
if '_apply' in request.POST:
form = self.form(model, request.POST)
if form.is_valid():
if form.is_valid():
logger.debug("Form validation was successful")
custom_fields = form.custom_fields if hasattr(form, 'custom_fields') else []
standard_fields = [
field for field in form.fields if field not in custom_fields + ['pk']
@@ -676,6 +704,7 @@ class BulkEditView(GetReturnURLMixin, View):
obj.full_clean()
obj.save()
logger.debug(f"Saved {obj} (PK: {obj.pk})")
# Update custom fields
obj_type = ContentType.objects.get_for_model(model)
@@ -696,6 +725,7 @@ class BulkEditView(GetReturnURLMixin, View):
)
cfv.value = form.cleaned_data[name]
cfv.save()
logger.debug(f"Saved custom fields for {obj} (PK: {obj.pk})")
# Add/remove tags
if form.cleaned_data.get('add_tags', None):
@@ -707,6 +737,7 @@ class BulkEditView(GetReturnURLMixin, View):
if updated_count:
msg = 'Updated {} {}'.format(updated_count, model._meta.verbose_name_plural)
logger.info(msg)
messages.success(self.request, msg)
return redirect(self.get_return_url(request))
@@ -714,13 +745,23 @@ class BulkEditView(GetReturnURLMixin, View):
except ValidationError as e:
messages.error(self.request, "{} failed validation: {}".format(obj, e))
else:
logger.debug("Form validation failed")
else:
# Pass the PK list as initial data to avoid binding the form
initial_data = querydict_to_dict(post_data)
# Include the PK list as initial data for the form
initial_data = {'pk': pk_list}
# Check for other contextual data needed for the form. We avoid passing all of request.GET because the
# filter values will conflict with the bulk edit form fields.
# TODO: Find a better way to accomplish this
if 'device' in request.GET:
initial_data['device'] = request.GET.get('device')
form = self.form(model, initial=initial_data)
# Retrieve objects being edited
table = self.table(self.queryset.filter(pk__in=post_data.getlist('pk')), orderable=False)
table = self.table(self.queryset.filter(pk__in=pk_list), orderable=False)
if not table.rows:
messages.warning(request, "No {} were selected.".format(model._meta.verbose_name_plural))
return redirect(self.get_return_url(request))
@@ -753,7 +794,7 @@ class BulkDeleteView(GetReturnURLMixin, View):
return redirect(self.get_return_url(request))
def post(self, request, **kwargs):
logger = logging.getLogger('netbox.views.BulkDeleteView')
model = self.queryset.model
# Are we deleting *all* objects in the queryset or just a selected subset?
@@ -770,19 +811,25 @@ class BulkDeleteView(GetReturnURLMixin, View):
if '_confirm' in request.POST:
form = form_cls(request.POST)
if form.is_valid():
logger.debug("Form validation was successful")
# Delete objects
queryset = model.objects.filter(pk__in=pk_list)
try:
deleted_count = queryset.delete()[1][model._meta.label]
except ProtectedError as e:
logger.info("Caught ProtectedError while attempting to delete objects")
handle_protectederror(list(queryset), request, e)
return redirect(self.get_return_url(request))
msg = 'Deleted {} {}'.format(deleted_count, model._meta.verbose_name_plural)
logger.info(msg)
messages.success(request, msg)
return redirect(self.get_return_url(request))
else:
logger.debug("Form validation failed")
else:
form = form_cls(initial={
'pk': pk_list,
@@ -806,12 +853,12 @@ class BulkDeleteView(GetReturnURLMixin, View):
"""
Provide a standard bulk delete form if none has been specified for the view
"""
class BulkDeleteForm(ConfirmationForm):
pk = ModelMultipleChoiceField(queryset=self.queryset, widget=MultipleHiddenInput)
if self.form:
return self.form
return BulkDeleteForm
@@ -900,7 +947,7 @@ class BulkComponentCreateView(GetReturnURLMixin, View):
template_name = 'utilities/obj_bulk_add_component.html'
def post(self, request):
logger = logging.getLogger('netbox.views.BulkComponentCreateView')
parent_model_name = self.parent_model._meta.verbose_name_plural
model_name = self.model._meta.verbose_name_plural
@@ -918,10 +965,13 @@ class BulkComponentCreateView(GetReturnURLMixin, View):
if '_create' in request.POST:
form = self.form(request.POST)
if form.is_valid():
logger.debug("Form validation was successful")
new_components = []
data = deepcopy(form.cleaned_data)
for obj in data['pk']:
names = data['name_pattern']
@@ -941,15 +991,20 @@ class BulkComponentCreateView(GetReturnURLMixin, View):
if not form.errors:
self.model.objects.bulk_create(new_components)
messages.success(request, "Added {} {} to {} {}.".format(
msg = "Added {} {} to {} {}.".format(
len(new_components),
model_name,
len(form.cleaned_data['pk']),
parent_model_name
))
)
logger.info(msg)
messages.success(request, msg)
return redirect(self.get_return_url(request))
else:
logger.debug("Form validation failed")
else:
form = self.form(initial={'pk': pk_list})