1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

Merge branch 'develop' into 3840-limit-vlan-choices

This commit is contained in:
Saria Hajjar
2020-02-08 16:14:10 +00:00
committed by GitHub
141 changed files with 5394 additions and 9239 deletions

View File

@@ -1,6 +1,7 @@
from django.core.validators import RegexValidator
from django.db import models
from utilities.ordering import naturalize
from .forms import ColorSelect
ColorValidator = RegexValidator(
@@ -35,3 +36,35 @@ class ColorField(models.CharField):
def formfield(self, **kwargs):
kwargs['widget'] = ColorSelect
return super().formfield(**kwargs)
class NaturalOrderingField(models.CharField):
"""
A field which stores a naturalized representation of its target field, to be used for ordering its parent model.
:param target_field: Name of the field of the parent model to be naturalized
:param naturalize_function: The function used to generate a naturalized value (optional)
"""
description = "Stores a representation of its target field suitable for natural ordering"
def __init__(self, target_field, naturalize_function=naturalize, *args, **kwargs):
self.target_field = target_field
self.naturalize_function = naturalize_function
super().__init__(*args, **kwargs)
def pre_save(self, model_instance, add):
"""
Generate a naturalized value from the target field
"""
value = getattr(model_instance, self.target_field)
return self.naturalize_function(value, max_length=self.max_length)
def deconstruct(self):
kwargs = super().deconstruct()[3] # Pass kwargs from CharField
kwargs['naturalize_function'] = self.naturalize_function
return (
self.name,
'utilities.fields.NaturalOrderingField',
['target_field'],
kwargs,
)

View File

@@ -7,6 +7,7 @@ import yaml
from django import forms
from django.conf import settings
from django.contrib.postgres.forms.jsonb import JSONField as _JSONField, InvalidJSONInput
from django.db.models import Count
from mptt.forms import TreeNodeMultipleChoiceField
from .choices import unpack_grouped_choices
@@ -455,12 +456,14 @@ class ExpandableNameField(forms.CharField):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.help_text:
self.help_text = 'Alphanumeric ranges are supported for bulk creation.<br />' \
'Mixed cases and types within a single range are not supported.<br />' \
'Examples:<ul><li><code>ge-0/0/[0-23,25,30]</code></li>' \
'<li><code>e[0-3][a-d,f]</code></li>' \
'<li><code>[xe,ge]-0/0/0</code></li>' \
'<li><code>e[0-3,a-d,f]</code></li></ul>'
self.help_text = """
Alphanumeric ranges are supported for bulk creation. Mixed cases and types within a single range
are not supported. Examples:
<ul>
<li><code>[ge,xe]-0/0/[0-9]</code></li>
<li><code>e[0-3][a-d,f]</code></li>
</ul>
"""
def to_python(self, value):
if re.search(ALPHANUMERIC_EXPANSION_PATTERN, value):
@@ -566,6 +569,23 @@ class SlugField(forms.SlugField):
self.widget.attrs['slug-source'] = slug_source
class TagFilterField(forms.MultipleChoiceField):
"""
A filter field for the tags of a model. Only the tags used by a model are displayed.
:param model: The model of the filter
"""
widget = StaticSelect2Multiple
def __init__(self, model, *args, **kwargs):
def get_choices():
tags = model.tags.annotate(count=Count('extras_taggeditem_items')).order_by('name')
return [(str(tag.slug), '{} ({})'.format(tag.name, tag.count)) for tag in tags]
# Choices are fetched each time the form is initialized
super().__init__(label='Tags', choices=get_choices, required=False, *args, **kwargs)
class FilterChoiceIterator(forms.models.ModelChoiceIterator):
def __iter__(self):
@@ -714,26 +734,13 @@ class ConfirmationForm(BootstrapMixin, ReturnURLForm):
confirm = forms.BooleanField(required=True, widget=forms.HiddenInput(), initial=True)
class ComponentForm(BootstrapMixin, forms.Form):
"""
Allow inclusion of the parent Device/VirtualMachine as context for limiting field choices.
"""
def __init__(self, parent, *args, **kwargs):
self.parent = parent
super().__init__(*args, **kwargs)
def get_iterative_data(self, iteration):
return {}
class BulkEditForm(forms.Form):
"""
Base form for editing multiple objects in bulk
"""
def __init__(self, model, parent_obj=None, *args, **kwargs):
def __init__(self, model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = model
self.parent_obj = parent_obj
self.nullable_fields = []
# Copy any nullable fields defined in Meta

View File

@@ -1,7 +1,28 @@
# noinspection PyUnresolvedReferences
from django.core.management.commands.makemigrations import Command
from django.conf import settings
from django.core.management.base import CommandError
from django.core.management.commands.makemigrations import Command as _Command
from django.db import models
from . import custom_deconstruct
models.Field.deconstruct = custom_deconstruct
class Command(_Command):
def handle(self, *args, **kwargs):
"""
This built-in management command enables the creation of new database schema migration files, which should
never be required by and ordinary user. We prevent this command from executing unless the configuration
indicates that the user is a developer (i.e. configuration.DEVELOPER == True).
"""
if not settings.DEVELOPER:
raise CommandError(
"This command is available for development purposes only. It will\n"
"NOT resolve any issues with missing or unapplied migrations. For assistance,\n"
"please post to the NetBox mailing list:\n"
" https://groups.google.com/forum/#!forum/netbox-discuss"
)
super().handle(*args, **kwargs)

View File

@@ -1,45 +0,0 @@
from django.db.models import Manager
from django.db.models.expressions import RawSQL
NAT1 = r"CAST(SUBSTRING({}.{} FROM '^(\d{{1,9}})') AS integer)"
NAT2 = r"SUBSTRING({}.{} FROM '^\d*(.*?)\d*$')"
NAT3 = r"CAST(SUBSTRING({}.{} FROM '(\d{{1,9}})$') AS integer)"
class NaturalOrderingManager(Manager):
"""
Order objects naturally by a designated field (defaults to 'name'). Leading and/or trailing digits of values within
this field will be cast as independent integers and sorted accordingly. For example, "Foo2" will be ordered before
"Foo10", even though the digit 1 is normally ordered before the digit 2.
"""
natural_order_field = 'name'
def get_queryset(self):
queryset = super().get_queryset()
db_table = self.model._meta.db_table
db_field = self.natural_order_field
# Append the three subfields derived from the designated natural ordering field
queryset = (
queryset.annotate(_nat1=RawSQL(NAT1.format(db_table, db_field), ()))
.annotate(_nat2=RawSQL(NAT2.format(db_table, db_field), ()))
.annotate(_nat3=RawSQL(NAT3.format(db_table, db_field), ()))
)
# Replace any instance of the designated natural ordering field with its three subfields
ordering = []
for field in self.model._meta.ordering:
if field == self.natural_order_field:
ordering.append('_nat1')
ordering.append('_nat2')
ordering.append('_nat3')
else:
ordering.append(field)
# Default to using the _nat indexes if Meta.ordering is empty
if not ordering:
ordering = ('_nat1', '_nat2', '_nat3')
return queryset.order_by(*ordering)

View File

@@ -7,9 +7,6 @@ from django.urls import reverse
from .views import server_error
BASE_PATH = getattr(settings, 'BASE_PATH', False)
LOGIN_REQUIRED = getattr(settings, 'LOGIN_REQUIRED', False)
class LoginRequiredMiddleware(object):
"""
@@ -19,7 +16,7 @@ class LoginRequiredMiddleware(object):
self.get_response = get_response
def __call__(self, request):
if LOGIN_REQUIRED and not request.user.is_authenticated:
if settings.LOGIN_REQUIRED and not request.user.is_authenticated:
# Redirect unauthenticated requests to the login page. API requests are exempt from redirection as the API
# performs its own authentication. Also metrics can be read without login.
api_path = reverse('api-root')

View File

@@ -0,0 +1,80 @@
import re
INTERFACE_NAME_REGEX = r'(^(?P<type>[^\d\.:]+)?)' \
r'((?P<slot>\d+)/)?' \
r'((?P<subslot>\d+)/)?' \
r'((?P<position>\d+)/)?' \
r'((?P<subposition>\d+)/)?' \
r'((?P<id>\d+))?' \
r'(:(?P<channel>\d+))?' \
r'(.(?P<vc>\d+)$)?'
def naturalize(value, max_length=None, integer_places=8):
"""
Take an alphanumeric string and prepend all integers to `integer_places` places to ensure the strings
are ordered naturally. For example:
site9router21
site10router4
site10router19
becomes:
site00000009router00000021
site00000010router00000004
site00000010router00000019
:param value: The value to be naturalized
:param max_length: The maximum length of the returned string. Characters beyond this length will be stripped.
:param integer_places: The number of places to which each integer will be expanded. (Default: 8)
"""
if not value:
return value
output = []
for segment in re.split(r'(\d+)', value):
if segment.isdigit():
output.append(segment.rjust(integer_places, '0'))
elif segment:
output.append(segment)
ret = ''.join(output)
return ret[:max_length] if max_length else ret
def naturalize_interface(value, max_length=None):
"""
Similar in nature to naturalize(), but takes into account a particular naming format adapted from the old
InterfaceManager.
:param value: The value to be naturalized
:param max_length: The maximum length of the returned string. Characters beyond this length will be stripped.
"""
output = []
match = re.search(INTERFACE_NAME_REGEX, value)
if match is None:
return value
# First, we order by slot/position, padding each to four digits. If a field is not present,
# set it to 9999 to ensure it is ordered last.
for part_name in ('slot', 'subslot', 'position', 'subposition'):
part = match.group(part_name)
if part is not None:
output.append(part.rjust(4, '0'))
else:
output.append('9999')
# Append the type, if any.
if match.group('type') is not None:
output.append(match.group('type'))
# Finally, append any remaining fields, left-padding to eight digits each.
for part_name in ('id', 'channel', 'vc'):
part = match.group(part_name)
if part is not None:
output.append(part.rjust(6, '0'))
else:
output.append('000000')
ret = ''.join(output)
return ret[:max_length] if max_length else ret

View File

@@ -1,6 +1,7 @@
import datetime
import json
import re
import yaml
from django import template
from django.utils.html import strip_tags
@@ -76,6 +77,14 @@ def render_json(value):
return json.dumps(value, indent=4, sort_keys=True)
@register.filter()
def render_yaml(value):
"""
Render a dictionary as formatted YAML.
"""
return yaml.dump(dict(value))
@register.filter()
def model_name(obj):
"""

View File

@@ -0,0 +1,2 @@
from .testcases import *
from .utils import *

View File

@@ -0,0 +1,396 @@
from django.contrib.auth.models import Permission, User
from django.core.exceptions import ObjectDoesNotExist
from django.forms.models import model_to_dict
from django.test import Client, TestCase as _TestCase, override_settings
from django.urls import reverse, NoReverseMatch
from rest_framework.test import APIClient
from users.models import Token
from .utils import disable_warnings, post_data
class TestCase(_TestCase):
user_permissions = ()
def setUp(self):
# Create the test user and assign permissions
self.user = User.objects.create_user(username='testuser')
self.add_permissions(*self.user_permissions)
# Initialize the test client
self.client = Client()
self.client.force_login(self.user)
#
# Permissions management
#
def add_permissions(self, *names):
"""
Assign a set of permissions to the test user. Accepts permission names in the form <app>.<action>_<model>.
"""
for name in names:
app, codename = name.split('.')
perm = Permission.objects.get(content_type__app_label=app, codename=codename)
self.user.user_permissions.add(perm)
def remove_permissions(self, *names):
"""
Remove a set of permissions from the test user, if assigned.
"""
for name in names:
app, codename = name.split('.')
perm = Permission.objects.get(content_type__app_label=app, codename=codename)
self.user.user_permissions.remove(perm)
#
# Convenience methods
#
def assertHttpStatus(self, response, expected_status):
"""
TestCase method. Provide more detail in the event of an unexpected HTTP response.
"""
err_message = "Expected HTTP status {}; received {}: {}"
self.assertEqual(response.status_code, expected_status, err_message.format(
expected_status, response.status_code, getattr(response, 'data', 'No data')
))
def assertInstanceEqual(self, instance, data):
"""
Compare a model instance to a dictionary, checking that its attribute values match those specified
in the dictionary.
"""
model_dict = model_to_dict(instance, fields=data.keys())
for key in list(model_dict.keys()):
# TODO: Differentiate between tags assigned to the instance and a M2M field for tags (ex: ConfigContext)
if key == 'tags':
model_dict[key] = ','.join(sorted([tag.name for tag in model_dict['tags']]))
# Convert ManyToManyField to list of instance PKs
elif model_dict[key] and type(model_dict[key]) in (list, tuple) and hasattr(model_dict[key][0], 'pk'):
model_dict[key] = [obj.pk for obj in model_dict[key]]
# Omit any dictionary keys which are not instance attributes
relevant_data = {
k: v for k, v in data.items() if hasattr(instance, k)
}
self.assertDictEqual(model_dict, relevant_data)
class APITestCase(TestCase):
client_class = APIClient
def setUp(self):
"""
Create a superuser and token for API calls.
"""
self.user = User.objects.create(username='testuser', is_superuser=True)
self.token = Token.objects.create(user=self.user)
self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
class StandardTestCases:
"""
We keep any TestCases with test_* methods inside a class to prevent unittest from trying to run them.
"""
class Views(TestCase):
"""
Stock TestCase suitable for testing all standard View functions:
- List objects
- View single object
- Create new object
- Modify existing object
- Delete existing object
- Import multiple new objects
"""
model = None
# Data to be sent when creating/editing individual objects
form_data = {}
# CSV lines used for bulk import of new objects
csv_data = ()
# Form data used when creating multiple objects
bulk_create_data = {}
# Form data to be used when editing multiple objects at once
bulk_edit_data = {}
maxDiff = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.model is None:
raise Exception("Test case requires model to be defined")
#
# URL functions
#
def _get_base_url(self):
"""
Return the base format for a URL for the test's model. Override this to test for a model which belongs
to a different app (e.g. testing Interfaces within the virtualization app).
"""
return '{}:{}_{{}}'.format(
self.model._meta.app_label,
self.model._meta.model_name
)
def _get_url(self, action, instance=None):
"""
Return the URL name for a specific action. An instance must be specified for
get/edit/delete views.
"""
url_format = self._get_base_url()
if action in ('list', 'add', 'import', 'bulk_edit', 'bulk_delete'):
return reverse(url_format.format(action))
elif action in ('get', 'edit', 'delete'):
if instance is None:
raise Exception("Resolving {} URL requires specifying an instance".format(action))
# Attempt to resolve using slug first
if hasattr(self.model, 'slug'):
try:
return reverse(url_format.format(action), kwargs={'slug': instance.slug})
except NoReverseMatch:
pass
return reverse(url_format.format(action), kwargs={'pk': instance.pk})
else:
raise Exception("Invalid action for URL resolution: {}".format(action))
#
# Standard view tests
# These methods will run by default. To disable a test, nullify its method on the subclasses TestCase:
#
# test_list_objects = None
#
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_list_objects(self):
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.get(self._get_url('list')), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.get(self._get_url('list'))
self.assertHttpStatus(response, 200)
# Built-in CSV export
if hasattr(self.model, 'csv_headers'):
response = self.client.get('{}?export'.format(self._get_url('list')))
self.assertHttpStatus(response, 200)
self.assertEqual(response.get('Content-Type'), 'text/csv')
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_get_object(self):
instance = self.model.objects.first()
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.get(instance.get_absolute_url()), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.get(instance.get_absolute_url())
self.assertHttpStatus(response, 200)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_create_object(self):
initial_count = self.model.objects.count()
request = {
'path': self._get_url('add'),
'data': post_data(self.form_data),
'follow': False, # Do not follow 302 redirects
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 302)
self.assertEqual(initial_count + 1, self.model.objects.count())
instance = self.model.objects.order_by('-pk').first()
self.assertInstanceEqual(instance, self.form_data)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_edit_object(self):
instance = self.model.objects.first()
request = {
'path': self._get_url('edit', instance),
'data': post_data(self.form_data),
'follow': False, # Do not follow 302 redirects
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.change_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 302)
instance = self.model.objects.get(pk=instance.pk)
self.assertInstanceEqual(instance, self.form_data)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_delete_object(self):
instance = self.model.objects.first()
request = {
'path': self._get_url('delete', instance),
'data': {'confirm': True},
'follow': False, # Do not follow 302 redirects
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.delete_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 302)
with self.assertRaises(ObjectDoesNotExist):
self.model.objects.get(pk=instance.pk)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_import_objects(self):
initial_count = self.model.objects.count()
request = {
'path': self._get_url('import'),
'data': {
'csv': '\n'.join(self.csv_data)
}
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.view_{}'.format(self.model._meta.app_label, self.model._meta.model_name),
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 200)
self.assertEqual(self.model.objects.count(), initial_count + len(self.csv_data) - 1)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_bulk_edit_objects(self):
# Bulk edit the first three objects only
pk_list = self.model.objects.values_list('pk', flat=True)[:3]
request = {
'path': self._get_url('bulk_edit'),
'data': {
'pk': pk_list,
'_apply': True, # Form button
},
'follow': False, # Do not follow 302 redirects
}
# Append the form data to the request
request['data'].update(post_data(self.bulk_edit_data))
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.change_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 302)
for i, instance in enumerate(self.model.objects.filter(pk__in=pk_list)):
self.assertInstanceEqual(instance, self.bulk_edit_data)
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def test_bulk_delete_objects(self):
pk_list = self.model.objects.values_list('pk', flat=True)
request = {
'path': self._get_url('bulk_delete'),
'data': {
'pk': pk_list,
'confirm': True,
'_confirm': True, # Form button
},
'follow': False, # Do not follow 302 redirects
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.delete_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 302)
# Check that all objects were deleted
self.assertEqual(self.model.objects.count(), 0)
#
# Optional view tests
# These methods will run only if the required data
#
@override_settings(EXEMPT_VIEW_PERMISSIONS=[])
def _test_bulk_create_objects(self, expected_count):
initial_count = self.model.objects.count()
request = {
'path': self._get_url('add'),
'data': post_data(self.bulk_create_data),
'follow': False, # Do not follow 302 redirects
}
# Attempt to make the request without required permissions
with disable_warnings('django.request'):
self.assertHttpStatus(self.client.post(**request), 403)
# Assign the required permission and submit again
self.add_permissions(
'{}.add_{}'.format(self.model._meta.app_label, self.model._meta.model_name)
)
response = self.client.post(**request)
self.assertHttpStatus(response, 302)
self.assertEqual(initial_count + expected_count, self.model.objects.count())
for instance in self.model.objects.order_by('-pk')[:expected_count]:
self.assertInstanceEqual(instance, self.bulk_create_data)

View File

@@ -2,29 +2,23 @@ import logging
from contextlib import contextmanager
from django.contrib.auth.models import Permission, User
from rest_framework.test import APITestCase as _APITestCase
from users.models import Token
class APITestCase(_APITestCase):
def post_data(data):
"""
Take a dictionary of test data (suitable for comparison to an instance) and return a dict suitable for POSTing.
"""
ret = {}
def setUp(self):
"""
Create a superuser and token for API calls.
"""
self.user = User.objects.create(username='testuser', is_superuser=True)
self.token = Token.objects.create(user=self.user)
self.header = {'HTTP_AUTHORIZATION': 'Token {}'.format(self.token.key)}
for key, value in data.items():
if value is None:
ret[key] = ''
elif type(value) in (list, tuple):
ret[key] = value
else:
ret[key] = str(value)
def assertHttpStatus(self, response, expected_status):
"""
Provide more detail in the event of an unexpected HTTP response.
"""
err_message = "Expected HTTP status {}; received {}: {}"
self.assertEqual(response.status_code, expected_status, err_message.format(
expected_status, response.status_code, getattr(response, 'data', 'No data')
))
return ret
def create_test_user(username='testuser', permissions=list()):

View File

@@ -4,6 +4,7 @@ from collections import OrderedDict
from django.core.serializers import serialize
from django.db.models import Count, OuterRef, Subquery
from django.http import QueryDict
from jinja2 import Environment
from dcim.choices import CableLengthUnitChoices
@@ -209,3 +210,15 @@ def prepare_cloned_fields(instance):
)
return param_string
def querydict_to_dict(querydict):
"""
Convert a django.http.QueryDict object to a regular Python dictionary, preserving lists of multiple values.
(QueryDict.dict() will return only the last value in a list for each key.)
"""
assert isinstance(querydict, QueryDict)
return {
key: querydict.get(key) if len(value) == 1 and key != 'pk' else querydict.getlist(key)
for key, value in querydict.lists()
}

View File

@@ -4,11 +4,10 @@ from copy import deepcopy
from django.conf import settings
from django.contrib import messages
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.core.exceptions import FieldDoesNotExist, ValidationError
from django.db import transaction, IntegrityError
from django.db.models import Count, ProtectedError
from django.db.models.query import QuerySet
from django.forms import CharField, Form, ModelMultipleChoiceField, MultipleHiddenInput, Textarea
from django.db.models import ManyToManyField, ProtectedError
from django.forms import Form, ModelMultipleChoiceField, MultipleHiddenInput, Textarea
from django.http import HttpResponse, HttpResponseServerError
from django.shortcuts import get_object_or_404, redirect, render
from django.template import loader
@@ -24,10 +23,9 @@ from django_tables2 import RequestConfig
from extras.models import CustomField, CustomFieldValue, ExportTemplate
from extras.querysets import CustomFieldQueryset
from extras.utils import is_taggable
from utilities.exceptions import AbortTransaction
from utilities.forms import BootstrapMixin, CSVDataField
from utilities.utils import csv_format, prepare_cloned_fields
from utilities.utils import csv_format, prepare_cloned_fields, querydict_to_dict
from .error_handlers import handle_protectederror
from .forms import ConfirmationForm, ImportForm
from .paginator import EnhancedPaginator
@@ -88,15 +86,27 @@ class ObjectListView(View):
Export the queryset of objects as comma-separated value (CSV), using the model's to_csv() method.
"""
csv_data = []
custom_fields = []
# Start with the column headers
headers = ','.join(self.queryset.model.csv_headers)
csv_data.append(headers)
headers = self.queryset.model.csv_headers.copy()
# Add custom field headers, if any
if hasattr(self.queryset.model, 'get_custom_fields'):
for custom_field in self.queryset.model().get_custom_fields():
headers.append(custom_field.name)
custom_fields.append(custom_field.name)
csv_data.append(','.join(headers))
# Iterate through the queryset appending each object
for obj in self.queryset:
data = csv_format(obj.to_csv())
csv_data.append(data)
data = obj.to_csv()
for custom_field in custom_fields:
data += (obj.cf.get(custom_field, ''),)
csv_data.append(csv_format(data))
return '\n'.join(csv_data)
@@ -155,12 +165,6 @@ class ObjectListView(View):
if 'pk' in table.base_columns and (permissions['change'] or permissions['delete']):
table.columns.show('pk')
# Construct queryset for tags list
if is_taggable(model):
tags = model.tags.annotate(count=Count('extras_taggeditem_items')).order_by('name')
else:
tags = None
# Apply the request context
paginate = {
'paginator_class': EnhancedPaginator,
@@ -173,7 +177,6 @@ class ObjectListView(View):
'table': table,
'permissions': permissions,
'filter_form': self.filterset_form(request.GET, label_suffix='') if self.filterset_form else None,
'tags': tags,
}
context.update(self.extra_context())
@@ -601,14 +604,12 @@ class BulkEditView(GetReturnURLMixin, View):
Edit objects in bulk.
queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
parent_model: The model of the parent object (if any)
filter: FilterSet to apply when deleting by QuerySet
table: The table used to display devices being edited
form: The form class used to edit objects in bulk
template_name: The name of the template
"""
queryset = None
parent_model = None
filterset = None
table = None
form = None
@@ -621,24 +622,21 @@ class BulkEditView(GetReturnURLMixin, View):
model = self.queryset.model
# Attempt to derive parent object if a parent class has been given
if self.parent_model:
parent_obj = get_object_or_404(self.parent_model, **kwargs)
else:
parent_obj = None
# Create a mutable copy of the POST data
post_data = request.POST.copy()
# Are we editing *all* objects in the queryset or just a selected subset?
if request.POST.get('_all') and self.filterset is not None:
pk_list = [obj.pk for obj in self.filterset(request.GET, model.objects.only('pk')).qs]
else:
pk_list = [int(pk) for pk in request.POST.getlist('pk')]
# If we are editing *all* objects in the queryset, replace the PK list with all matched objects.
if post_data.get('_all') and self.filterset is not None:
post_data['pk'] = [obj.pk for obj in self.filterset(request.GET, model.objects.only('pk')).qs]
if '_apply' in request.POST:
form = self.form(model, parent_obj, request.POST)
form = self.form(model, request.POST, initial=request.GET)
if form.is_valid():
custom_fields = form.custom_fields if hasattr(form, 'custom_fields') else []
standard_fields = [field for field in form.fields if field not in custom_fields and field != 'pk']
standard_fields = [
field for field in form.fields if field not in custom_fields + ['pk']
]
nullified_fields = request.POST.getlist('_nullify')
try:
@@ -646,18 +644,33 @@ class BulkEditView(GetReturnURLMixin, View):
with transaction.atomic():
updated_count = 0
for obj in model.objects.filter(pk__in=pk_list):
for obj in model.objects.filter(pk__in=form.cleaned_data['pk']):
# Update standard fields. If a field is listed in _nullify, delete its value.
for name in standard_fields:
if name in form.nullable_fields and name in nullified_fields and isinstance(form.cleaned_data[name], QuerySet):
getattr(obj, name).set([])
elif name in form.nullable_fields and name in nullified_fields:
setattr(obj, name, '' if isinstance(form.fields[name], CharField) else None)
elif isinstance(form.cleaned_data[name], QuerySet) and form.cleaned_data[name]:
try:
model_field = model._meta.get_field(name)
except FieldDoesNotExist:
# The form field is used to modify a field rather than set its value directly,
# so we skip it.
continue
# Handle nullification
if name in form.nullable_fields and name in nullified_fields:
if isinstance(model_field, ManyToManyField):
getattr(obj, name).set([])
else:
setattr(obj, name, None if model_field.null else '')
# ManyToManyFields
elif isinstance(model_field, ManyToManyField):
getattr(obj, name).set(form.cleaned_data[name])
elif form.cleaned_data[name] not in (None, '') and not isinstance(form.cleaned_data[name], QuerySet):
# Normal fields
elif form.cleaned_data[name] not in (None, ''):
setattr(obj, name, form.cleaned_data[name])
obj.full_clean()
obj.save()
@@ -699,12 +712,16 @@ class BulkEditView(GetReturnURLMixin, View):
messages.error(self.request, "{} failed validation: {}".format(obj, e))
else:
initial_data = request.POST.copy()
initial_data['pk'] = pk_list
form = self.form(model, parent_obj, initial=initial_data)
# Pass the PK list as initial data to avoid binding the form
initial_data = querydict_to_dict(post_data)
# Append any normal initial data (passed as GET parameters)
initial_data.update(request.GET)
form = self.form(model, initial=initial_data)
# Retrieve objects being edited
table = self.table(self.queryset.filter(pk__in=pk_list), orderable=False)
table = self.table(self.queryset.filter(pk__in=post_data.getlist('pk')), orderable=False)
if not table.rows:
messages.warning(request, "No {} were selected.".format(model._meta.verbose_name_plural))
return redirect(self.get_return_url(request))
@@ -722,14 +739,12 @@ class BulkDeleteView(GetReturnURLMixin, View):
Delete objects in bulk.
queryset: Custom queryset to use when retrieving objects (e.g. to select related objects)
parent_model: The model of the parent object (if any)
filter: FilterSet to apply when deleting by QuerySet
table: The table used to display devices being deleted
form: The form class used to delete objects in bulk
template_name: The name of the template
"""
queryset = None
parent_model = None
filterset = None
table = None
form = None
@@ -742,12 +757,6 @@ class BulkDeleteView(GetReturnURLMixin, View):
model = self.queryset.model
# Attempt to derive parent object if a parent class has been given
if self.parent_model:
parent_obj = get_object_or_404(self.parent_model, **kwargs)
else:
parent_obj = None
# Are we deleting *all* objects in the queryset or just a selected subset?
if request.POST.get('_all'):
if self.filterset is not None:
@@ -789,7 +798,6 @@ class BulkDeleteView(GetReturnURLMixin, View):
return render(request, self.template_name, {
'form': form,
'parent_obj': parent_obj,
'obj_type_plural': model._meta.verbose_name_plural,
'table': table,
'return_url': self.get_return_url(request),
@@ -812,45 +820,40 @@ class BulkDeleteView(GetReturnURLMixin, View):
# Device/VirtualMachine components
#
class ComponentCreateView(View):
# TODO: Replace with BulkCreateView
class ComponentCreateView(GetReturnURLMixin, View):
"""
Add one or more components (e.g. interfaces, console ports, etc.) to a Device or VirtualMachine.
"""
parent_model = None
parent_field = None
model = None
form = None
model_form = None
template_name = None
def get(self, request, pk):
def get(self, request):
parent = get_object_or_404(self.parent_model, pk=pk)
form = self.form(parent, initial=request.GET)
form = self.form(initial=request.GET)
return render(request, self.template_name, {
'parent': parent,
'component_type': self.model._meta.verbose_name,
'form': form,
'return_url': parent.get_absolute_url(),
'return_url': self.get_return_url(request),
})
def post(self, request, pk):
def post(self, request):
parent = get_object_or_404(self.parent_model, pk=pk)
form = self.form(parent, request.POST)
form = self.form(request.POST, initial=request.GET)
if form.is_valid():
new_components = []
data = deepcopy(request.POST)
data[self.parent_field] = parent.pk
for i, name in enumerate(form.cleaned_data['name_pattern']):
# Initialize the individual component form
data['name'] = name
data.update(form.get_iterative_data(i))
if hasattr(form, 'get_iterative_data'):
data.update(form.get_iterative_data(i))
component_form = self.model_form(data)
if component_form.is_valid():
@@ -869,19 +872,18 @@ class ComponentCreateView(View):
for component_form in new_components:
component_form.save()
messages.success(request, "Added {} {} to {}.".format(
len(new_components), self.model._meta.verbose_name_plural, parent
messages.success(request, "Added {} {}".format(
len(new_components), self.model._meta.verbose_name_plural
))
if '_addanother' in request.POST:
return redirect(request.path)
return redirect(request.get_full_path())
else:
return redirect(parent.get_absolute_url())
return redirect(self.get_return_url(request))
return render(request, self.template_name, {
'parent': parent,
'component_type': self.model._meta.verbose_name,
'form': form,
'return_url': parent.get_absolute_url(),
'return_url': self.get_return_url(request),
})