mirror of
https://github.com/netbox-community/netbox.git
synced 2024-05-10 07:54:54 +00:00
Merge branch 'develop' into 3840-limit-vlan-choices
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from django.core.validators import RegexValidator
|
||||
from django.db import models
|
||||
|
||||
from utilities.ordering import naturalize
|
||||
from .forms import ColorSelect
|
||||
|
||||
ColorValidator = RegexValidator(
|
||||
@@ -35,3 +36,35 @@ class ColorField(models.CharField):
|
||||
def formfield(self, **kwargs):
|
||||
kwargs['widget'] = ColorSelect
|
||||
return super().formfield(**kwargs)
|
||||
|
||||
|
||||
class NaturalOrderingField(models.CharField):
|
||||
"""
|
||||
A field which stores a naturalized representation of its target field, to be used for ordering its parent model.
|
||||
|
||||
:param target_field: Name of the field of the parent model to be naturalized
|
||||
:param naturalize_function: The function used to generate a naturalized value (optional)
|
||||
"""
|
||||
description = "Stores a representation of its target field suitable for natural ordering"
|
||||
|
||||
def __init__(self, target_field, naturalize_function=naturalize, *args, **kwargs):
|
||||
self.target_field = target_field
|
||||
self.naturalize_function = naturalize_function
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
"""
|
||||
Generate a naturalized value from the target field
|
||||
"""
|
||||
value = getattr(model_instance, self.target_field)
|
||||
return self.naturalize_function(value, max_length=self.max_length)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = super().deconstruct()[3] # Pass kwargs from CharField
|
||||
kwargs['naturalize_function'] = self.naturalize_function
|
||||
return (
|
||||
self.name,
|
||||
'utilities.fields.NaturalOrderingField',
|
||||
['target_field'],
|
||||
kwargs,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ import yaml
|
||||
from django import forms
|
||||
from django.conf import settings
|
||||
from django.contrib.postgres.forms.jsonb import JSONField as _JSONField, InvalidJSONInput
|
||||
from django.db.models import Count
|
||||
from mptt.forms import TreeNodeMultipleChoiceField
|
||||
|
||||
from .choices import unpack_grouped_choices
|
||||
@@ -455,12 +456,14 @@ class ExpandableNameField(forms.CharField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if not self.help_text:
|
||||
self.help_text = 'Alphanumeric ranges are supported for bulk creation.<br />' \
|
||||
'Mixed cases and types within a single range are not supported.<br />' \
|
||||
'Examples:<ul><li><code>ge-0/0/[0-23,25,30]</code></li>' \
|
||||
'<li><code>e[0-3][a-d,f]</code></li>' \
|
||||
'<li><code>[xe,ge]-0/0/0</code></li>' \
|
||||
'<li><code>e[0-3,a-d,f]</code></li></ul>'
|
||||
self.help_text = """
|
||||
Alphanumeric ranges are supported for bulk creation. Mixed cases and types within a single range
|
||||
are not supported. Examples:
|
||||
<ul>
|
||||
<li><code>[ge,xe]-0/0/[0-9]</code></li>
|
||||
<li><code>e[0-3][a-d,f]</code></li>
|
||||
</ul>
|
||||
"""
|
||||
|
||||
def to_python(self, value):
|
||||
if re.search(ALPHANUMERIC_EXPANSION_PATTERN, value):
|
||||
@@ -566,6 +569,23 @@ class SlugField(forms.SlugField):
|
||||
self.widget.attrs['slug-source'] = slug_source
|
||||
|
||||
|
||||
class TagFilterField(forms.MultipleChoiceField):
|
||||
"""
|
||||
A filter field for the tags of a model. Only the tags used by a model are displayed.
|
||||
|
||||
:param model: The model of the filter
|
||||
"""
|
||||
widget = StaticSelect2Multiple
|
||||
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
def get_choices():
|
||||
tags = model.tags.annotate(count=Count('extras_taggeditem_items')).order_by('name')
|
||||
return [(str(tag.slug), '{} ({})'.format(tag.name, tag.count)) for tag in tags]
|
||||
|
||||
# Choices are fetched each time the form is initialized
|
||||
super().__init__(label='Tags', choices=get_choices, required=False, *args, **kwargs)
|
||||
|
||||
|
||||
class FilterChoiceIterator(forms.models.ModelChoiceIterator):
|
||||
|
||||
def __iter__(self):
|
||||
@@ -714,26 +734,13 @@ class ConfirmationForm(BootstrapMixin, ReturnURLForm):
|
||||
confirm = forms.BooleanField(required=True, widget=forms.HiddenInput(), initial=True)
|
||||
|
||||
|
||||
class ComponentForm(BootstrapMixin, forms.Form):
|
||||
"""
|
||||
Allow inclusion of the parent Device/VirtualMachine as context for limiting field choices.
|
||||
"""
|
||||
def __init__(self, parent, *args, **kwargs):
|
||||
self.parent = parent
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_iterative_data(self, iteration):
|
||||
return {}
|
||||
|
||||
|
||||
class BulkEditForm(forms.Form):
|
||||
"""
|
||||
Base form for editing multiple objects in bulk
|
||||
"""
|
||||
def __init__(self, model, parent_obj=None, *args, **kwargs):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model = model
|
||||
self.parent_obj = parent_obj
|
||||
self.nullable_fields = []
|
||||
|
||||
# Copy any nullable fields defined in Meta
|
||||
|
||||
@@ -1,7 +1,28 @@
|
||||
# noinspection PyUnresolvedReferences
|
||||
from django.core.management.commands.makemigrations import Command
|
||||
from django.conf import settings
|
||||
from django.core.management.base import CommandError
|
||||
from django.core.management.commands.makemigrations import Command as _Command
|
||||
from django.db import models
|
||||
|
||||
from . import custom_deconstruct
|
||||
|
||||
models.Field.deconstruct = custom_deconstruct
|
||||
|
||||
|
||||
class Command(_Command):
|
||||
|
||||
def handle(self, *args, **kwargs):
|
||||
"""
|
||||
This built-in management command enables the creation of new database schema migration files, which should
|
||||
never be required by and ordinary user. We prevent this command from executing unless the configuration
|
||||
indicates that the user is a developer (i.e. configuration.DEVELOPER == True).
|
||||
"""
|
||||
if not settings.DEVELOPER:
|
||||
raise CommandError(
|
||||
"This command is available for development purposes only. It will\n"
|
||||
"NOT resolve any issues with missing or unapplied migrations. For assistance,\n"
|
||||
"please post to the NetBox mailing list:\n"
|
||||
" https://groups.google.com/forum/#!forum/netbox-discuss"
|
||||
)
|
||||
|
||||
super().handle(*args, **kwargs)
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from django.db.models import Manager
|
||||
from django.db.models.expressions import RawSQL
|
||||
|
||||
NAT1 = r"CAST(SUBSTRING({}.{} FROM '^(\d{{1,9}})') AS integer)"
|
||||
NAT2 = r"SUBSTRING({}.{} FROM '^\d*(.*?)\d*$')"
|
||||
NAT3 = r"CAST(SUBSTRING({}.{} FROM '(\d{{1,9}})$') AS integer)"
|
||||
|
||||
|
||||
class NaturalOrderingManager(Manager):
|
||||
"""
|
||||
Order objects naturally by a designated field (defaults to 'name'). Leading and/or trailing digits of values within
|
||||
this field will be cast as independent integers and sorted accordingly. For example, "Foo2" will be ordered before
|
||||
"Foo10", even though the digit 1 is normally ordered before the digit 2.
|
||||
"""
|
||||
natural_order_field = 'name'
|
||||
|
||||
def get_queryset(self):
|
||||
|
||||
queryset = super().get_queryset()
|
||||
|
||||
db_table = self.model._meta.db_table
|
||||
db_field = self.natural_order_field
|
||||
|
||||
# Append the three subfields derived from the designated natural ordering field
|
||||
queryset = (
|
||||
queryset.annotate(_nat1=RawSQL(NAT1.format(db_table, db_field), ()))
|
||||
.annotate(_nat2=RawSQL(NAT2.format(db_table, db_field), ()))
|
||||
.annotate(_nat3=RawSQL(NAT3.format(db_table, db_field), ()))
|
||||
)
|
||||
|
||||
# Replace any instance of the designated natural ordering field with its three subfields
|
||||
ordering = []
|
||||
for field in self.model._meta.ordering:
|
||||
if field == self.natural_order_field:
|
||||
ordering.append('_nat1')
|
||||
ordering.append('_nat2')
|
||||
ordering.append('_nat3')
|
||||
else:
|
||||
ordering.append(field)
|
||||
|
||||
# Default to using the _nat indexes if Meta.ordering is empty
|
||||
if not ordering:
|
||||
ordering = ('_nat1', '_nat2', '_nat3')
|
||||
|
||||
return queryset.order_by(*ordering)
|
||||
@@ -7,9 +7,6 @@ from django.urls import reverse
|
||||
|
||||
from .views import server_error
|
||||
|
||||
BASE_PATH = getattr(settings, 'BASE_PATH', False)
|
||||
LOGIN_REQUIRED = getattr(settings, 'LOGIN_REQUIRED', False)
|
||||
|
||||
|
||||
class LoginRequiredMiddleware(object):
|
||||
"""
|
||||
@@ -19,7 +16,7 @@ class LoginRequiredMiddleware(object):
|
||||
self.get_response = get_response
|
||||
|
||||
def __call__(self, request):
|
||||
if LOGIN_REQUIRED and not request.user.is_authenticated:
|
||||
if settings.LOGIN_REQUIRED and not request.user.is_authenticated:
|
||||
# Redirect unauthenticated requests to the login page. API requests are exempt from redirection as the API
|
||||
# performs its own authentication. Also metrics can be read without login.
|
||||
api_path = reverse('api-root')
|
||||
|
||||
80
netbox/utilities/ordering.py
Normal file
80
netbox/utilities/ordering.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import re
|
||||
|
||||
INTERFACE_NAME_REGEX = r'(^(?P<type>[^\d\.:]+)?)' \
|
||||
r'((?P<slot>\d+)/)?' \
|
||||
r'((?P<subslot>\d+)/)?' \
|
||||
r'((?P<position>\d+)/)?' \
|
||||
r'((?P<subposition>\d+)/)?' \
|
||||
r'((?P<id>\d+))?' \
|
||||
r'(:(?P<channel>\d+))?' \
|
||||
r'(.(?P<vc>\d+)$)?'
|
||||
|
||||
|
||||
def naturalize(value, max_length=None, integer_places=8):
|
||||
"""
|
||||
Take an alphanumeric string and prepend all integers to `integer_places` places to ensure the strings
|
||||
are ordered naturally. For example:
|
||||
|
||||
site9router21
|
||||
site10router4
|
||||
site10router19
|
||||
|
||||
becomes:
|
||||
|
||||
site00000009router00000021
|
||||
site00000010router00000004
|
||||
site00000010router00000019
|
||||
|
||||
:param value: The value to be naturalized
|
||||
:param max_length: The maximum length of the returned string. Characters beyond this length will be stripped.
|
||||
:param integer_places: The number of places to which each integer will be expanded. (Default: 8)
|
||||
"""
|
||||
if not value:
|
||||
return value
|
||||
output = []
|
||||
for segment in re.split(r'(\d+)', value):
|
||||
if segment.isdigit():
|
||||
output.append(segment.rjust(integer_places, '0'))
|
||||
elif segment:
|
||||
output.append(segment)
|
||||
ret = ''.join(output)
|
||||
|
||||
return ret[:max_length] if max_length else ret
|
||||
|
||||
|
||||
def naturalize_interface(value, max_length=None):
|
||||
"""
|
||||
Similar in nature to naturalize(), but takes into account a particular naming format adapted from the old
|
||||
InterfaceManager.
|
||||
|
||||
:param value: The value to be naturalized
|
||||
:param max_length: The maximum length of the returned string. Characters beyond this length will be stripped.
|
||||
"""
|
||||
output = []
|
||||
match = re.search(INTERFACE_NAME_REGEX, value)
|
||||
if match is None:
|
||||
return value
|
||||
|
||||
# First, we order by slot/position, padding each to four digits. If a field is not present,
|
||||
# set it to 9999 to ensure it is ordered last.
|
||||
for part_name in ('slot', 'subslot', 'position', 'subposition'):
|
||||
part = match.group(part_name)
|
||||
if part is not None:
|
||||
output.append(part.rjust(4, '0'))
|
||||
else:
|
||||
output.append('9999')
|
||||
|
||||
# Append the type, if any.
|
||||
if match.group('type') is not None:
|
||||
output.append(match.group('type'))
|
||||
|
||||
# Finally, append any remaining fields, left-padding to eight digits each.
|
||||
for part_name in ('id', 'channel', 'vc'):
|
||||
part = match.group(part_name)
|
||||
if part is not None:
|
||||
output.append(part.rjust(6, '0'))
|
||||
else:
|
||||
output.append('000000')
|
||||
|
||||
ret = ''.join(output)
|
||||
return ret[:max_length] if max_length else ret
|
||||
@@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import yaml
|
||||
|
||||
from django import template
|
||||
from django.utils.html import strip_tags
|
||||
@@ -76,6 +77,14 @@ def render_json(value):
|
||||
return json.dumps(value, indent=4, sort_keys=True)
|
||||
|
||||
|
||||
@register.filter()
|
||||
def render_yaml(value):
|
||||
"""
|
||||
Render a dictionary as formatted YAML.
|
||||
"""
|
||||
return yaml.dump(dict(value))
|
||||
|
||||
|
||||
@register.filter()
|
||||
def model_name(obj):
|
||||
"""
|
||||
|
||||
2
netbox/utilities/testing/__init__.py
Normal file
2
netbox/utilities/testing/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .testcases import *
|
||||
from .utils import *
|
||||
396
netbox/utilities/testing/testcases.py
Normal file
396
netbox/utilities/testing/testcases.py
Normal file
@@ -0,0 +1,396 @@
|
||||
from django.contrib.auth.models import Permission, User
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.forms.models import model_to_dict
|
||||
from django.test import Client, TestCase as _TestCase, override_settings
|
||||
from django.urls import reverse, NoReverseMatch
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from users.models import Token
|
||||
from .utils import disable_warnings, post_data
|
||||
|
||||
|
||||
class TestCase(_TestCase):
|
||||
user_permissions = ()
|
||||
|
||||
def setUp(self):
|
||||
|
||||
# Create the test user and assign permissions
|
||||
self.user = User.objects.create_user(username='testuser')
|
||||
self.add_permissions(*self.user_permissions)
|
||||
|
||||
# Initialize the test client
|
||||
self.client = Client()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
#
|
||||
# Permissions management
|
||||
#
|
||||
|
||||
def add_permissions(self, *names):
|
||||
"""
|
||||
Assign a set of permissions to the test user. Accepts permission names in the form <app>.<action>_<model>.
|
||||
"""
|
||||
for name in names:
|
||||
app, codename = name.split('.')
|
||||
perm = Permission.objects.get(content_type__app_label=app, codename=codename)
|
||||
self.user.user_permissions.add(perm)
|
||||
|
||||
def remove_permissions(self, *names):
|
||||
"""
|
||||
Remove a set of permissions from the test user, if assigned.
|
||||
"""
|
||||
for name in names:
|
||||
app, codename = name.split('.')
|
||||
perm = Permission.objects.get(content_type__app_label=app, codename=codename)
|
||||
self.user.user_permissions.remove(perm)
|
||||
|
||||
#
|
||||
# Convenience methods
|
||||
#
|
||||
|
||||
def assertHttpStatus(self, response, expected_status):
|
||||
"""
|
||||
TestCase method. Provide more detail in the event of an unexpected HTTP response.
|
||||
"""
|
||||
err_message = "Expected HTTP status {}; received {}: {}"
|
||||
self.assertEqual(response.status_code, expected_status, err_message.format(
|
||||
expected_status, response.status_code, getattr(response, 'data', 'No data')
|
||||
))
|
||||
|
||||
def assertInstanceEqual(self, instance, data):
|
||||
"""
|
||||
Compare a model instance to a dictionary, checking that its attribute values match those specified
|
||||
in the dictionary.
|
||||
"""
|
||||
model_dict = model_to_dict(instance, fields=data.keys())
|
||||
|
||||
for key in list(model_dict.keys()):
|
||||
|
||||
# TODO: Differentiate between tags assigned to the instance and a M2M field for tags (ex: ConfigContext)
|
||||
if key == 'tags':
|
||||
model_dict[key] = ','.join(sorted([tag.name for tag in model_dict['tags']]))
|
||||
|
||||
# Convert ManyToManyField to list of instance PKs
|
||||
elif model_dict[key] and type(model_dict[key]) in (list, tuple) and hasattr(model_dict[key][0], 'pk'):
|
||||
model_dict[key] = [obj.pk for obj in model_dict[key]]
|
||||
|
||||
# Omit any dictionary keys which are not instance attributes
|
||||
relevant_data = {
|
||||
k: v for k, v in data.items() if hasattr(instance, k)
|
||||
}
|
||||
|
||||
self.assertDictEqual(model_dict, relevant_data)
|
||||
|
||||
|
||||
class APITestCase(TestCase):
|
||||
client_class = APIClient
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a superuser and token for API calls.
|
||||
"""
|
||||
self.user = User.objects.create(username='testuser', is_superuser=True)
|
||||
self.token = Token.objects.create(user=self.user)
|
||||
self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
|
||||
|
||||
|
||||
class StandardTestCases:
|
||||
"""
|
||||
We keep any TestCases with test_* methods inside a class to prevent unittest from trying to run them.
|
||||
"""
|
||||
|
||||
class Views(TestCase):
|
||||
"""
|
||||
Stock TestCase suitable for testing all standard View functions:
|
||||
- List objects
|
||||
- View single object
|
||||
- Create new object
|
||||
- Modify existing object
|
||||
- Delete existing object
|
||||
- Import multiple new objects
|
||||
"""
|
||||
model = None
|
||||
|
||||
# Data to be sent when creating/editing individual objects
|
||||
form_data = {}
|
||||
|
||||
# CSV lines used for bulk import of new objects
|
||||
csv_data = ()
|
||||
|
||||
# Form data used when creating multiple objects
|
||||
bulk_create_data = {}
|
||||
|
||||
# Form data to be used when editing multiple objects at once
|
||||
bulk_edit_data = {}
|
||||
|
||||
maxDiff = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.model is None:
|
||||
raise Exception("Test case requires model to be defined")
|
||||
|
||||
#
|
||||
# URL functions
|
||||
#
|
||||
|
||||
def _get_base_url(self):
|
||||
"""
|
||||
Return the base format for a URL for the test's model. Override this to test for a model which belongs
|
||||
to a different app (e.g. testing Interfaces within the virtualization app).
|
||||
"""
|
||||
return '{}:{}_{{}}'.format(
|
||||
self.model._meta.app_label,
|
||||
self.model._meta.model_name
|
||||
)
|
||||
|
||||
def _get_url(self, action, instance=None):
|
||||
"""
|
||||
Return the URL name for a specific action. An instance must be specified for
|
||||
get/edit/delete views.
|
||||
"""
|
||||
url_format = self._get_base_url()
|
||||
|
||||
if action in ('list', 'add', 'import', 'bulk_edit', 'bulk_delete'):
|
||||
return reverse(url_format.format(action))
|
||||
|
||||
elif action in ('get', 'edit', 'delete'):
|
||||
if instance is None:
|
||||
raise Exception("Resolving {} URL requires specifying an instance".format(action))
|
||||
# Attempt to resolve using slug first
|
||||
if hasattr(self.model, 'slug'):
|
||||
try:
|
||||
return reverse(url_format.format(action), kwargs={'slug': instance.slug})
|
||||
except NoReverseMatch:
|
||||
pass
|
||||
return reverse(url_format.format(action), kwargs={'pk': instance.pk})
|
||||
|
||||
else:
|
||||
raise Exception("Invalid action for URL resolution: {}".format(action))
|
||||
|
||||
#
|
||||
# Standard view tests
|
||||
# These methods will run by default. To disable a test, nullify its method on the subclasses TestCase:
|
||||
#
|
||||
# test_list_objects = None
|
||||
#
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
|
||||
def test_list_objects(self):
|
||||
# Attempt to make the request without required permissions
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.get(self._get_url('list')), 403)
|
||||
|
||||
# Assign the required permission and submit again
|
||||
self.add_permissions(
|
||||
'{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
|
||||
)
|
||||
response = self.client.get(self._get_url('list'))
|
||||
self.assertHttpStatus(response, 200)
|
||||
|
||||
# Built-in CSV export
|
||||
if hasattr(self.model, 'csv_headers'):
|
||||
response = self.client.get('{}?export'.format(self._get_url('list')))
|
||||
self.assertHttpStatus(response, 200)
|
||||
self.assertEqual(response.get('Content-Type'), 'text/csv')
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
|
||||
def test_get_object(self):
|
||||
instance = self.model.objects.first()
|
||||
|
||||
# Attempt to make the request without required permissions
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.get(instance.get_absolute_url()), 403)
|
||||
|
||||
# Assign the required permission and submit again
|
||||
self.add_permissions(
|
||||
'{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
|
||||
)
|
||||
response = self.client.get(instance.get_absolute_url())
|
||||
self.assertHttpStatus(response, 200)
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
|
||||
def test_create_object(self):
|
||||
initial_count = self.model.objects.count()
|
||||
request = {
|
||||
'path': self._get_url('add'),
|
||||
'data': post_data(self.form_data),
|
||||
'follow': False, # Do not follow 302 redirects
|
||||
}
|
||||
|
||||
# Attempt to make the request without required permissions
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(**request), 403)
|
||||
|
||||
# Assign the required permission and submit again
|
||||
self.add_permissions(
|
||||
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
|
||||
)
|
||||
response = self.client.post(**request)
|
||||
self.assertHttpStatus(response, 302)
|
||||
|
||||
self.assertEqual(initial_count + 1, self.model.objects.count())
|
||||
instance = self.model.objects.order_by('-pk').first()
|
||||
self.assertInstanceEqual(instance, self.form_data)
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
|
||||
def test_edit_object(self):
|
||||
instance = self.model.objects.first()
|
||||
|
||||
request = {
|
||||
'path': self._get_url('edit', instance),
|
||||
'data': post_data(self.form_data),
|
||||
'follow': False, # Do not follow 302 redirects
|
||||
}
|
||||
|
||||
# Attempt to make the request without required permissions
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(**request), 403)
|
||||
|
||||
# Assign the required permission and submit again
|
||||
self.add_permissions(
|
||||
'{}.change_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
|
||||
)
|
||||
response = self.client.post(**request)
|
||||
self.assertHttpStatus(response, 302)
|
||||
|
||||
instance = self.model.objects.get(pk=instance.pk)
|
||||
self.assertInstanceEqual(instance, self.form_data)
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
|
||||
def test_delete_object(self):
|
||||
instance = self.model.objects.first()
|
||||
|
||||
request = {
|
||||
'path': self._get_url('delete', instance),
|
||||
'data': {'confirm': True},
|
||||
'follow': False, # Do not follow 302 redirects
|
||||
}
|
||||
|
||||
# Attempt to make the request without required permissions
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(**request), 403)
|
||||
|
||||
# Assign the required permission and submit again
|
||||
self.add_permissions(
|
||||
'{}.delete_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
|
||||
)
|
||||
response = self.client.post(**request)
|
||||
self.assertHttpStatus(response, 302)
|
||||
|
||||
with self.assertRaises(ObjectDoesNotExist):
|
||||
self.model.objects.get(pk=instance.pk)
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
|
||||
def test_import_objects(self):
|
||||
initial_count = self.model.objects.count()
|
||||
request = {
|
||||
'path': self._get_url('import'),
|
||||
'data': {
|
||||
'csv': '\n'.join(self.csv_data)
|
||||
}
|
||||
}
|
||||
|
||||
# Attempt to make the request without required permissions
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(**request), 403)
|
||||
|
||||
# Assign the required permission and submit again
|
||||
self.add_permissions(
|
||||
'{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name),
|
||||
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
|
||||
)
|
||||
response = self.client.post(**request)
|
||||
self.assertHttpStatus(response, 200)
|
||||
|
||||
self.assertEqual(self.model.objects.count(), initial_count + len(self.csv_data) - 1)
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
|
||||
def test_bulk_edit_objects(self):
|
||||
# Bulk edit the first three objects only
|
||||
pk_list = self.model.objects.values_list('pk', flat=True)[:3]
|
||||
|
||||
request = {
|
||||
'path': self._get_url('bulk_edit'),
|
||||
'data': {
|
||||
'pk': pk_list,
|
||||
'_apply': True, # Form button
|
||||
},
|
||||
'follow': False, # Do not follow 302 redirects
|
||||
}
|
||||
|
||||
# Append the form data to the request
|
||||
request['data'].update(post_data(self.bulk_edit_data))
|
||||
|
||||
# Attempt to make the request without required permissions
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(**request), 403)
|
||||
|
||||
# Assign the required permission and submit again
|
||||
self.add_permissions(
|
||||
'{}.change_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
|
||||
)
|
||||
response = self.client.post(**request)
|
||||
self.assertHttpStatus(response, 302)
|
||||
|
||||
for i, instance in enumerate(self.model.objects.filter(pk__in=pk_list)):
|
||||
self.assertInstanceEqual(instance, self.bulk_edit_data)
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
|
||||
def test_bulk_delete_objects(self):
|
||||
pk_list = self.model.objects.values_list('pk', flat=True)
|
||||
|
||||
request = {
|
||||
'path': self._get_url('bulk_delete'),
|
||||
'data': {
|
||||
'pk': pk_list,
|
||||
'confirm': True,
|
||||
'_confirm': True, # Form button
|
||||
},
|
||||
'follow': False, # Do not follow 302 redirects
|
||||
}
|
||||
|
||||
# Attempt to make the request without required permissions
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(**request), 403)
|
||||
|
||||
# Assign the required permission and submit again
|
||||
self.add_permissions(
|
||||
'{}.delete_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
|
||||
)
|
||||
response = self.client.post(**request)
|
||||
self.assertHttpStatus(response, 302)
|
||||
|
||||
# Check that all objects were deleted
|
||||
self.assertEqual(self.model.objects.count(), 0)
|
||||
|
||||
#
|
||||
# Optional view tests
|
||||
# These methods will run only if the required data
|
||||
#
|
||||
|
||||
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
|
||||
def _test_bulk_create_objects(self, expected_count):
|
||||
initial_count = self.model.objects.count()
|
||||
request = {
|
||||
'path': self._get_url('add'),
|
||||
'data': post_data(self.bulk_create_data),
|
||||
'follow': False, # Do not follow 302 redirects
|
||||
}
|
||||
|
||||
# Attempt to make the request without required permissions
|
||||
with disable_warnings('django.request'):
|
||||
self.assertHttpStatus(self.client.post(**request), 403)
|
||||
|
||||
# Assign the required permission and submit again
|
||||
self.add_permissions(
|
||||
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
|
||||
)
|
||||
response = self.client.post(**request)
|
||||
self.assertHttpStatus(response, 302)
|
||||
|
||||
self.assertEqual(initial_count + expected_count, self.model.objects.count())
|
||||
for instance in self.model.objects.order_by('-pk')[:expected_count]:
|
||||
self.assertInstanceEqual(instance, self.bulk_create_data)
|
||||
@@ -2,29 +2,23 @@ import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.contrib.auth.models import Permission, User
|
||||
from rest_framework.test import APITestCase as _APITestCase
|
||||
|
||||
from users.models import Token
|
||||
|
||||
|
||||
class APITestCase(_APITestCase):
|
||||
def post_data(data):
|
||||
"""
|
||||
Take a dictionary of test data (suitable for comparison to an instance) and return a dict suitable for POSTing.
|
||||
"""
|
||||
ret = {}
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a superuser and token for API calls.
|
||||
"""
|
||||
self.user = User.objects.create(username='testuser', is_superuser=True)
|
||||
self.token = Token.objects.create(user=self.user)
|
||||
self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
|
||||
for key, value in data.items():
|
||||
if value is None:
|
||||
ret[key] = ''
|
||||
elif type(value) in (list, tuple):
|
||||
ret[key] = value
|
||||
else:
|
||||
ret[key] = str(value)
|
||||
|
||||
def assertHttpStatus(self, response, expected_status):
|
||||
"""
|
||||
Provide more detail in the event of an unexpected HTTP response.
|
||||
"""
|
||||
err_message = "Expected HTTP status {}; received {}: {}"
|
||||
self.assertEqual(response.status_code, expected_status, err_message.format(
|
||||
expected_status, response.status_code, getattr(response, 'data', 'No data')
|
||||
))
|
||||
return ret
|
||||
|
||||
|
||||
def create_test_user(username='testuser', permissions=list()):
|
||||
@@ -4,6 +4,7 @@ from collections import OrderedDict
|
||||
|
||||
from django.core.serializers import serialize
|
||||
from django.db.models import Count, OuterRef, Subquery
|
||||
from django.http import QueryDict
|
||||
from jinja2 import Environment
|
||||
|
||||
from dcim.choices import CableLengthUnitChoices
|
||||
@@ -209,3 +210,15 @@ def prepare_cloned_fields(instance):
|
||||
)
|
||||
|
||||
return param_string
|
||||
|
||||
|
||||
def querydict_to_dict(querydict):
|
||||
"""
|
||||
Convert a django.http.QueryDict object to a regular Python dictionary, preserving lists of multiple values.
|
||||
(QueryDict.dict() will return only the last value in a list for each key.)
|
||||
"""
|
||||
assert isinstance(querydict, QueryDict)
|
||||
return {
|
||||
key: querydict.get(key) if len(value) == 1 and key != 'pk' else querydict.getlist(key)
|
||||
for key, value in querydict.lists()
|
||||
}
|
||||
|
||||
@@ -4,11 +4,10 @@ from copy import deepcopy
|
||||
from django.conf import settings
|
||||
from django.contrib import messages
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.exceptions import FieldDoesNotExist, ValidationError
|
||||
from django.db import transaction, IntegrityError
|
||||
from django.db.models import Count, ProtectedError
|
||||
from django.db.models.query import QuerySet
|
||||
from django.forms import CharField, Form, ModelMultipleChoiceField, MultipleHiddenInput, Textarea
|
||||
from django.db.models import ManyToManyField, ProtectedError
|
||||
from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput, Textarea
|
||||
from django.http import HttpResponse, HttpResponseServerError
|
||||
from django.shortcuts import get_object_or_404, redirect, render
|
||||
from django.template import loader
|
||||
@@ -24,10 +23,9 @@ from django_tables2 import RequestConfig
|
||||
|
||||
from extras.models import CustomField, CustomFieldValue, ExportTemplate
|
||||
from extras.querysets import CustomFieldQueryset
|
||||
from extras.utils import is_taggable
|
||||
from utilities.exceptions import AbortTransaction
|
||||
from utilities.forms import BootstrapMixin, CSVDataField
|
||||
from utilities.utils import csv_format, prepare_cloned_fields
|
||||
from utilities.utils import csv_format, prepare_cloned_fields, querydict_to_dict
|
||||
from .error_handlers import handle_protectederror
|
||||
from .forms import ConfirmationForm, ImportForm
|
||||
from .paginator import EnhancedPaginator
|
||||
@@ -88,15 +86,27 @@ class ObjectListView(View):
|
||||
Export the queryset of objects as comma-separated value (CSV), using the model's to_csv() method.
|
||||
"""
|
||||
csv_data = []
|
||||
custom_fields = []
|
||||
|
||||
# Start with the column headers
|
||||
headers = ','.join(self.queryset.model.csv_headers)
|
||||
csv_data.append(headers)
|
||||
headers = self.queryset.model.csv_headers.copy()
|
||||
|
||||
# Add custom field headers, if any
|
||||
if hasattr(self.queryset.model, 'get_custom_fields'):
|
||||
for custom_field in self.queryset.model().get_custom_fields():
|
||||
headers.append(custom_field.name)
|
||||
custom_fields.append(custom_field.name)
|
||||
|
||||
csv_data.append(','.join(headers))
|
||||
|
||||
# Iterate through the queryset appending each object
|
||||
for obj in self.queryset:
|
||||
data = csv_format(obj.to_csv())
|
||||
csv_data.append(data)
|
||||
data = obj.to_csv()
|
||||
|
||||
for custom_field in custom_fields:
|
||||
data += (obj.cf.get(custom_field, ''),)
|
||||
|
||||
csv_data.append(csv_format(data))
|
||||
|
||||
return '\n'.join(csv_data)
|
||||
|
||||
@@ -155,12 +165,6 @@ class ObjectListView(View):
|
||||
if 'pk' in table.base_columns and (permissions['change'] or permissions['delete']):
|
||||
table.columns.show('pk')
|
||||
|
||||
# Construct queryset for tags list
|
||||
if is_taggable(model):
|
||||
tags = model.tags.annotate(count=Count('extras_taggeditem_items')).order_by('name')
|
||||
else:
|
||||
tags = None
|
||||
|
||||
# Apply the request context
|
||||
paginate = {
|
||||
'paginator_class': EnhancedPaginator,
|
||||
@@ -173,7 +177,6 @@ class ObjectListView(View):
|
||||
'table': table,
|
||||
'permissions': permissions,
|
||||
'filter_form': self.filterset_form(request.GET, label_suffix='') if self.filterset_form else None,
|
||||
'tags': tags,
|
||||
}
|
||||
context.update(self.extra_context())
|
||||
|
||||
@@ -601,14 +604,12 @@ class BulkEditView(GetReturnURLMixin, View):
|
||||
Edit objects in bulk.
|
||||
|
||||
queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
|
||||
parent_model: The model of the parent object (if any)
|
||||
filter: FilterSet to apply when deleting by QuerySet
|
||||
table: The table used to display devices being edited
|
||||
form: The form class used to edit objects in bulk
|
||||
template_name: The name of the template
|
||||
"""
|
||||
queryset = None
|
||||
parent_model = None
|
||||
filterset = None
|
||||
table = None
|
||||
form = None
|
||||
@@ -621,24 +622,21 @@ class BulkEditView(GetReturnURLMixin, View):
|
||||
|
||||
model = self.queryset.model
|
||||
|
||||
# Attempt to derive parent object if a parent class has been given
|
||||
if self.parent_model:
|
||||
parent_obj = get_object_or_404(self.parent_model, **kwargs)
|
||||
else:
|
||||
parent_obj = None
|
||||
# Create a mutable copy of the POST data
|
||||
post_data = request.POST.copy()
|
||||
|
||||
# Are we editing *all* objects in the queryset or just a selected subset?
|
||||
if request.POST.get('_all') and self.filterset is not None:
|
||||
pk_list = [obj.pk for obj in self.filterset(request.GET, model.objects.only('pk')).qs]
|
||||
else:
|
||||
pk_list = [int(pk) for pk in request.POST.getlist('pk')]
|
||||
# If we are editing *all* objects in the queryset, replace the PK list with all matched objects.
|
||||
if post_data.get('_all') and self.filterset is not None:
|
||||
post_data['pk'] = [obj.pk for obj in self.filterset(request.GET, model.objects.only('pk')).qs]
|
||||
|
||||
if '_apply' in request.POST:
|
||||
form = self.form(model, parent_obj, request.POST)
|
||||
form = self.form(model, request.POST, initial=request.GET)
|
||||
if form.is_valid():
|
||||
|
||||
custom_fields = form.custom_fields if hasattr(form, 'custom_fields') else []
|
||||
standard_fields = [field for field in form.fields if field not in custom_fields and field != 'pk']
|
||||
standard_fields = [
|
||||
field for field in form.fields if field not in custom_fields + ['pk']
|
||||
]
|
||||
nullified_fields = request.POST.getlist('_nullify')
|
||||
|
||||
try:
|
||||
@@ -646,18 +644,33 @@ class BulkEditView(GetReturnURLMixin, View):
|
||||
with transaction.atomic():
|
||||
|
||||
updated_count = 0
|
||||
for obj in model.objects.filter(pk__in=pk_list):
|
||||
for obj in model.objects.filter(pk__in=form.cleaned_data['pk']):
|
||||
|
||||
# Update standard fields. If a field is listed in _nullify, delete its value.
|
||||
for name in standard_fields:
|
||||
if name in form.nullable_fields and name in nullified_fields and isinstance(form.cleaned_data[name], QuerySet):
|
||||
getattr(obj, name).set([])
|
||||
elif name in form.nullable_fields and name in nullified_fields:
|
||||
setattr(obj, name, '' if isinstance(form.fields[name], CharField) else None)
|
||||
elif isinstance(form.cleaned_data[name], QuerySet) and form.cleaned_data[name]:
|
||||
|
||||
try:
|
||||
model_field = model._meta.get_field(name)
|
||||
except FieldDoesNotExist:
|
||||
# The form field is used to modify a field rather than set its value directly,
|
||||
# so we skip it.
|
||||
continue
|
||||
|
||||
# Handle nullification
|
||||
if name in form.nullable_fields and name in nullified_fields:
|
||||
if isinstance(model_field, ManyToManyField):
|
||||
getattr(obj, name).set([])
|
||||
else:
|
||||
setattr(obj, name, None if model_field.null else '')
|
||||
|
||||
# ManyToManyFields
|
||||
elif isinstance(model_field, ManyToManyField):
|
||||
getattr(obj, name).set(form.cleaned_data[name])
|
||||
elif form.cleaned_data[name] not in (None, '') and not isinstance(form.cleaned_data[name], QuerySet):
|
||||
|
||||
# Normal fields
|
||||
elif form.cleaned_data[name] not in (None, ''):
|
||||
setattr(obj, name, form.cleaned_data[name])
|
||||
|
||||
obj.full_clean()
|
||||
obj.save()
|
||||
|
||||
@@ -699,12 +712,16 @@ class BulkEditView(GetReturnURLMixin, View):
|
||||
messages.error(self.request, "{} failed validation: {}".format(obj, e))
|
||||
|
||||
else:
|
||||
initial_data = request.POST.copy()
|
||||
initial_data['pk'] = pk_list
|
||||
form = self.form(model, parent_obj, initial=initial_data)
|
||||
# Pass the PK list as initial data to avoid binding the form
|
||||
initial_data = querydict_to_dict(post_data)
|
||||
|
||||
# Append any normal initial data (passed as GET parameters)
|
||||
initial_data.update(request.GET)
|
||||
|
||||
form = self.form(model, initial=initial_data)
|
||||
|
||||
# Retrieve objects being edited
|
||||
table = self.table(self.queryset.filter(pk__in=pk_list), orderable=False)
|
||||
table = self.table(self.queryset.filter(pk__in=post_data.getlist('pk')), orderable=False)
|
||||
if not table.rows:
|
||||
messages.warning(request, "No {} were selected.".format(model._meta.verbose_name_plural))
|
||||
return redirect(self.get_return_url(request))
|
||||
@@ -722,14 +739,12 @@ class BulkDeleteView(GetReturnURLMixin, View):
|
||||
Delete objects in bulk.
|
||||
|
||||
queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
|
||||
parent_model: The model of the parent object (if any)
|
||||
filter: FilterSet to apply when deleting by QuerySet
|
||||
table: The table used to display devices being deleted
|
||||
form: The form class used to delete objects in bulk
|
||||
template_name: The name of the template
|
||||
"""
|
||||
queryset = None
|
||||
parent_model = None
|
||||
filterset = None
|
||||
table = None
|
||||
form = None
|
||||
@@ -742,12 +757,6 @@ class BulkDeleteView(GetReturnURLMixin, View):
|
||||
|
||||
model = self.queryset.model
|
||||
|
||||
# Attempt to derive parent object if a parent class has been given
|
||||
if self.parent_model:
|
||||
parent_obj = get_object_or_404(self.parent_model, **kwargs)
|
||||
else:
|
||||
parent_obj = None
|
||||
|
||||
# Are we deleting *all* objects in the queryset or just a selected subset?
|
||||
if request.POST.get('_all'):
|
||||
if self.filterset is not None:
|
||||
@@ -789,7 +798,6 @@ class BulkDeleteView(GetReturnURLMixin, View):
|
||||
|
||||
return render(request, self.template_name, {
|
||||
'form': form,
|
||||
'parent_obj': parent_obj,
|
||||
'obj_type_plural': model._meta.verbose_name_plural,
|
||||
'table': table,
|
||||
'return_url': self.get_return_url(request),
|
||||
@@ -812,45 +820,40 @@ class BulkDeleteView(GetReturnURLMixin, View):
|
||||
# Device/VirtualMachine components
|
||||
#
|
||||
|
||||
class ComponentCreateView(View):
|
||||
# TODO: Replace with BulkCreateView
|
||||
class ComponentCreateView(GetReturnURLMixin, View):
|
||||
"""
|
||||
Add one or more components (e.g. interfaces, console ports, etc.) to a Device or VirtualMachine.
|
||||
"""
|
||||
parent_model = None
|
||||
parent_field = None
|
||||
model = None
|
||||
form = None
|
||||
model_form = None
|
||||
template_name = None
|
||||
|
||||
def get(self, request, pk):
|
||||
def get(self, request):
|
||||
|
||||
parent = get_object_or_404(self.parent_model, pk=pk)
|
||||
form = self.form(parent, initial=request.GET)
|
||||
form = self.form(initial=request.GET)
|
||||
|
||||
return render(request, self.template_name, {
|
||||
'parent': parent,
|
||||
'component_type': self.model._meta.verbose_name,
|
||||
'form': form,
|
||||
'return_url': parent.get_absolute_url(),
|
||||
'return_url': self.get_return_url(request),
|
||||
})
|
||||
|
||||
def post(self, request, pk):
|
||||
def post(self, request):
|
||||
|
||||
parent = get_object_or_404(self.parent_model, pk=pk)
|
||||
|
||||
form = self.form(parent, request.POST)
|
||||
form = self.form(request.POST, initial=request.GET)
|
||||
if form.is_valid():
|
||||
|
||||
new_components = []
|
||||
data = deepcopy(request.POST)
|
||||
data[self.parent_field] = parent.pk
|
||||
|
||||
for i, name in enumerate(form.cleaned_data['name_pattern']):
|
||||
|
||||
# Initialize the individual component form
|
||||
data['name'] = name
|
||||
data.update(form.get_iterative_data(i))
|
||||
if hasattr(form, 'get_iterative_data'):
|
||||
data.update(form.get_iterative_data(i))
|
||||
component_form = self.model_form(data)
|
||||
|
||||
if component_form.is_valid():
|
||||
@@ -869,19 +872,18 @@ class ComponentCreateView(View):
|
||||
for component_form in new_components:
|
||||
component_form.save()
|
||||
|
||||
messages.success(request, "Added {} {} to {}.".format(
|
||||
len(new_components), self.model._meta.verbose_name_plural, parent
|
||||
messages.success(request, "Added {} {}".format(
|
||||
len(new_components), self.model._meta.verbose_name_plural
|
||||
))
|
||||
if '_addanother' in request.POST:
|
||||
return redirect(request.path)
|
||||
return redirect(request.get_full_path())
|
||||
else:
|
||||
return redirect(parent.get_absolute_url())
|
||||
return redirect(self.get_return_url(request))
|
||||
|
||||
return render(request, self.template_name, {
|
||||
'parent': parent,
|
||||
'component_type': self.model._meta.verbose_name,
|
||||
'form': form,
|
||||
'return_url': parent.get_absolute_url(),
|
||||
'return_url': self.get_return_url(request),
|
||||
})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user