1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00

351 lines
11 KiB
Python
Raw Normal View History

from collections import OrderedDict
2017-05-24 11:33:11 -04:00
2018-11-02 15:20:08 -04:00
import pytz
2017-03-07 17:17:39 -05:00
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import FieldError, MultipleObjectsReturned, ObjectDoesNotExist
from django.db.models import ManyToManyField, ProtectedError
from django.http import Http404
2016-03-01 11:23:03 -05:00
from rest_framework.exceptions import APIException
from rest_framework.permissions import BasePermission
from rest_framework.relations import PrimaryKeyRelatedField, RelatedField
from rest_framework.response import Response
from rest_framework.serializers import Field, ModelSerializer, ValidationError
2018-07-30 12:49:08 -04:00
from rest_framework.viewsets import ModelViewSet as _ModelViewSet, ViewSet
2017-03-07 17:17:39 -05:00
from .utils import dict_to_filter_params, dynamic_import
2018-06-19 14:57:03 -04:00
2016-03-01 11:23:03 -05:00
class ServiceUnavailable(APIException):
status_code = 503
default_detail = "Service temporarily unavailable, please try again later."
2017-01-27 14:36:13 -05:00
class SerializerNotFound(Exception):
pass
2018-06-19 14:57:03 -04:00
def get_serializer_for_model(model, prefix=''):
"""
Dynamically resolve and return the appropriate serializer for a model.
"""
app_name, model_name = model._meta.label.split('.')
serializer_name = '{}.api.serializers.{}{}Serializer'.format(
app_name, prefix, model_name
)
try:
return dynamic_import(serializer_name)
except AttributeError:
raise SerializerNotFound(
"Could not determine serializer for {}.{} with prefix '{}'".format(app_name, model_name, prefix)
)
2018-06-19 14:57:03 -04:00
#
# Authentication
#
class IsAuthenticatedOrLoginNotRequired(BasePermission):
"""
Returns True if the user is authenticated or LOGIN_REQUIRED is False.
"""
def has_permission(self, request, view):
if not settings.LOGIN_REQUIRED:
return True
return request.user.is_authenticated
#
2018-04-04 15:39:14 -04:00
# Fields
#
class ChoiceField(Field):
"""
Represent a ChoiceField as {'value': <DB value>, 'label': <string>}.
"""
def __init__(self, choices, **kwargs):
2017-03-16 16:50:18 -04:00
self._choices = dict()
for k, v in choices:
# Unpack grouped choices
if type(v) in [list, tuple]:
for k2, v2 in v:
self._choices[k2] = v2
else:
self._choices[k] = v
super().__init__(**kwargs)
def to_representation(self, obj):
2018-10-26 12:25:11 -04:00
if obj is '':
return None
data = OrderedDict([
('value', obj),
('label', self._choices[obj])
])
return data
def to_internal_value(self, data):
# Provide an explicit error message if the request is trying to write a dict or list
if isinstance(data, (dict, list)):
raise ValidationError('Value must be passed directly (e.g. "foo": 123); do not use a dictionary or list.')
# Check for string representations of boolean/integer values
if hasattr(data, 'lower'):
if data.lower() == 'true':
data = True
elif data.lower() == 'false':
data = False
else:
try:
data = int(data)
except ValueError:
pass
try:
if data in self._choices:
return data
except TypeError: # Input is an unhashable type
pass
raise ValidationError("{} is not a valid choice.".format(data))
@property
def choices(self):
return self._choices
class ContentTypeField(RelatedField):
"""
Represent a ContentType as '<app_label>.<model>'
"""
default_error_messages = {
"does_not_exist": "Invalid content type: {content_type}",
"invalid": "Invalid value. Specify a content type as '<app_label>.<model_name>'.",
}
def to_internal_value(self, data):
try:
app_label, model = data.split('.')
return ContentType.objects.get_by_natural_key(app_label=app_label, model=model)
except ObjectDoesNotExist:
self.fail('does_not_exist', content_type=data)
except (TypeError, ValueError):
self.fail('invalid')
def to_representation(self, obj):
return "{}.{}".format(obj.app_label, obj.model)
class TimeZoneField(Field):
"""
Represent a pytz time zone.
"""
def to_representation(self, obj):
return obj.zone if obj else None
def to_internal_value(self, data):
if not data:
return ""
if data not in pytz.common_timezones:
raise ValidationError('Unknown time zone "{}" (see pytz.common_timezones for all options)'.format(data))
return pytz.timezone(data)
class SerializedPKRelatedField(PrimaryKeyRelatedField):
"""
Extends PrimaryKeyRelatedField to return a serialized object on read. This is useful for representing related
objects in a ManyToManyField while still allowing a set of primary keys to be written.
"""
def __init__(self, serializer, **kwargs):
self.serializer = serializer
self.pk_field = kwargs.pop('pk_field', None)
super().__init__(**kwargs)
def to_representation(self, value):
return self.serializer(value, context={'request': self.context['request']}).data
2018-04-04 15:39:14 -04:00
#
# Serializers
#
# TODO: We should probably take a fresh look at exactly what we're doing with this. There might be a more elegant
# way to enforce model validation on the serializer.
2018-04-04 15:39:14 -04:00
class ValidatedModelSerializer(ModelSerializer):
"""
Extends the built-in ModelSerializer to enforce calling clean() on the associated model during validation.
"""
def validate(self, data):
# Remove custom fields data and tags (if any) prior to model validation
2018-04-04 15:39:14 -04:00
attrs = data.copy()
attrs.pop('custom_fields', None)
attrs.pop('tags', None)
2018-04-04 15:39:14 -04:00
# Skip ManyToManyFields
for field in self.Meta.model._meta.get_fields():
if isinstance(field, ManyToManyField):
attrs.pop(field.name, None)
2018-04-04 15:39:14 -04:00
# Run clean() on an instance of the model
if self.instance is None:
instance = self.Meta.model(**attrs)
else:
instance = self.instance
for k, v in attrs.items():
setattr(instance, k, v)
instance.clean()
return data
class WritableNestedSerializer(ModelSerializer):
"""
Returns a nested representation of an object on read, but accepts only a primary key on write.
"""
2018-04-04 15:39:14 -04:00
def to_internal_value(self, data):
2018-04-05 14:12:43 -04:00
if data is None:
return None
# Dictionary of related object attributes
if isinstance(data, dict):
params = dict_to_filter_params(data)
try:
return self.Meta.model.objects.get(**params)
except ObjectDoesNotExist:
raise ValidationError(
"Related object not found using the provided attributes: {}".format(params)
)
except MultipleObjectsReturned:
raise ValidationError(
"Multiple objects match the provided attributes: {}".format(params)
)
except FieldError as e:
raise ValidationError(e)
# Integer PK of related object
if isinstance(data, int):
pk = data
else:
try:
# PK might have been mistakenly passed as a string
pk = int(data)
except (TypeError, ValueError):
raise ValidationError(
"Related objects must be referenced by numeric ID or by dictionary of attributes. Received an "
"unrecognized value: {}".format(data)
)
# Look up object by PK
2018-04-04 15:39:14 -04:00
try:
return self.Meta.model.objects.get(pk=int(data))
2018-04-04 15:39:14 -04:00
except ObjectDoesNotExist:
raise ValidationError(
"Related object not found using the provided numeric ID: {}".format(pk)
)
2018-04-04 15:39:14 -04:00
#
# Viewsets
#
2018-07-30 12:49:08 -04:00
class ModelViewSet(_ModelViewSet):
"""
2018-04-05 14:12:43 -04:00
Accept either a single object or a list of objects to create.
"""
def get_serializer(self, *args, **kwargs):
2018-07-30 12:49:08 -04:00
# If a list of objects has been provided, initialize the serializer with many=True
if isinstance(kwargs.get('data', {}), list):
kwargs['many'] = True
2018-07-30 12:49:08 -04:00
return super().get_serializer(*args, **kwargs)
def get_serializer_class(self):
# If 'brief' has been passed as a query param, find and return the nested serializer for this model, if one
# exists
request = self.get_serializer_context()['request']
if request.query_params.get('brief', False):
try:
return get_serializer_for_model(self.queryset.model, prefix='Nested')
except SerializerNotFound:
pass
# Fall back to the hard-coded serializer class
return self.serializer_class
def dispatch(self, request, *args, **kwargs):
try:
return super().dispatch(request, *args, **kwargs)
except ProtectedError as e:
2019-05-29 10:33:29 -04:00
models = ['{} ({})'.format(o, o._meta) for o in e.protected_objects.all()]
msg = 'Unable to delete object. The following dependent objects were found: {}'.format(', '.join(models))
return self.finalize_response(
request,
Response({'detail': msg}, status=409),
*args,
**kwargs
)
2019-04-15 03:55:33 -04:00
def list(self, *args, **kwargs):
"""
Call to super to allow for caching
"""
return super().list(*args, **kwargs)
def retrieve(self, *args, **kwargs):
"""
Call to super to allow for caching
"""
return super().retrieve(*args, **kwargs)
class FieldChoicesViewSet(ViewSet):
"""
Expose the built-in numeric values which represent static choices for a model's field.
"""
permission_classes = [IsAuthenticatedOrLoginNotRequired]
fields = []
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Compile a dict of all fields in this view
self._fields = OrderedDict()
for cls, field_list in self.fields:
for field_name in field_list:
model_name = cls._meta.verbose_name.lower().replace(' ', '-')
key = ':'.join([model_name, field_name])
serializer = get_serializer_for_model(cls)()
choices = []
for k, v in serializer.get_fields()[field_name].choices.items():
if type(v) in [list, tuple]:
for k2, v2 in v:
choices.append({
2017-10-10 17:47:53 -04:00
'value': k2,
'label': v2,
})
else:
choices.append({
2017-10-10 17:47:53 -04:00
'value': k,
'label': v,
})
self._fields[key] = choices
def list(self, request):
return Response(self._fields)
def retrieve(self, request, pk):
if pk not in self._fields:
raise Http404
return Response(self._fields[pk])
def get_view_name(self):
return "Field Choices"