mirror of
https://github.com/netbox-community/netbox.git
synced 2024-05-10 07:54:54 +00:00
258 lines
10 KiB
Python
258 lines
10 KiB
Python
import re
|
|
import typing
|
|
from collections import OrderedDict
|
|
|
|
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
|
from drf_spectacular.openapi import AutoSchema
|
|
from drf_spectacular.plumbing import (
|
|
build_basic_type, build_choice_field, build_media_type_object, build_object_type, get_doc,
|
|
)
|
|
from drf_spectacular.types import OpenApiTypes
|
|
from rest_framework import serializers
|
|
from rest_framework.relations import ManyRelatedField
|
|
|
|
from netbox.api.fields import ChoiceField, SerializedPKRelatedField
|
|
from netbox.api.serializers import WritableNestedSerializer
|
|
|
|
# see netbox.api.routers.NetBoxRouter
|
|
BULK_ACTIONS = ("bulk_destroy", "bulk_partial_update", "bulk_update")
|
|
WRITABLE_ACTIONS = ("PATCH", "POST", "PUT")
|
|
|
|
|
|
class FixTimeZoneSerializerField(OpenApiSerializerFieldExtension):
|
|
target_class = 'timezone_field.rest_framework.TimeZoneSerializerField'
|
|
|
|
def map_serializer_field(self, auto_schema, direction):
|
|
return build_basic_type(OpenApiTypes.STR)
|
|
|
|
|
|
class ChoiceFieldFix(OpenApiSerializerFieldExtension):
|
|
target_class = 'netbox.api.fields.ChoiceField'
|
|
|
|
def map_serializer_field(self, auto_schema, direction):
|
|
build_cf = build_choice_field(self.target)
|
|
|
|
if direction == 'request':
|
|
return build_cf
|
|
|
|
elif direction == "response":
|
|
value = build_cf
|
|
label = {**build_basic_type(OpenApiTypes.STR), "enum": list(OrderedDict.fromkeys(self.target.choices.values()))}
|
|
|
|
return build_object_type(
|
|
properties={
|
|
"value": value,
|
|
"label": label
|
|
}
|
|
)
|
|
|
|
|
|
class NetBoxAutoSchema(AutoSchema):
|
|
"""
|
|
Overrides to drf_spectacular.openapi.AutoSchema to fix following issues:
|
|
1. bulk serializers cause operation_id conflicts with non-bulk ones
|
|
2. bulk operations should specify a list
|
|
3. bulk operations don't have filter params
|
|
4. bulk operations don't have pagination
|
|
5. bulk delete should specify input
|
|
"""
|
|
|
|
writable_serializers = {}
|
|
|
|
@property
|
|
def is_bulk_action(self):
|
|
if hasattr(self.view, "action") and self.view.action in BULK_ACTIONS:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def get_operation_id(self):
|
|
"""
|
|
bulk serializers cause operation_id conflicts with non-bulk ones
|
|
bulk operations cause id conflicts in spectacular resulting in numerous:
|
|
Warning: operationId "xxx" has collisions [xxx]. "resolving with numeral suffixes"
|
|
code is modified from drf_spectacular.openapi.AutoSchema.get_operation_id
|
|
"""
|
|
if self.is_bulk_action:
|
|
tokenized_path = self._tokenize_path()
|
|
# replace dashes as they can be problematic later in code generation
|
|
tokenized_path = [t.replace('-', '_') for t in tokenized_path]
|
|
|
|
if self.method == 'GET' and self._is_list_view():
|
|
# this shouldn't happen, but keeping it here to follow base code
|
|
action = 'list'
|
|
else:
|
|
# action = self.method_mapping[self.method.lower()]
|
|
# use bulk name so partial_update -> bulk_partial_update
|
|
action = self.view.action.lower()
|
|
|
|
if not tokenized_path:
|
|
tokenized_path.append('root')
|
|
|
|
if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
|
|
tokenized_path.append('formatted')
|
|
|
|
return '_'.join(tokenized_path + [action])
|
|
|
|
# if not bulk - just return normal id
|
|
return super().get_operation_id()
|
|
|
|
def get_request_serializer(self) -> typing.Any:
|
|
# bulk operations should specify a list
|
|
serializer = super().get_request_serializer()
|
|
|
|
if self.is_bulk_action:
|
|
return type(serializer)(many=True)
|
|
|
|
# handle mapping for Writable serializers - adapted from dansheps original code
|
|
# for drf-yasg
|
|
if serializer is not None and self.method in WRITABLE_ACTIONS:
|
|
writable_class = self.get_writable_class(serializer)
|
|
if writable_class is not None:
|
|
if hasattr(serializer, "child"):
|
|
child_serializer = self.get_writable_class(serializer.child)
|
|
serializer = writable_class(context=serializer.context, child=child_serializer)
|
|
else:
|
|
serializer = writable_class(context=serializer.context)
|
|
|
|
return serializer
|
|
|
|
def get_response_serializers(self) -> typing.Any:
|
|
# bulk operations should specify a list
|
|
response_serializers = super().get_response_serializers()
|
|
|
|
if self.is_bulk_action:
|
|
return type(response_serializers)(many=True)
|
|
|
|
return response_serializers
|
|
|
|
def get_serializer_ref_name(self, serializer):
|
|
# from drf-yasg.utils
|
|
"""Get serializer's ref_name (or None for ModelSerializer if it is named 'NestedSerializer')
|
|
:param serializer: Serializer instance
|
|
:return: Serializer's ``ref_name`` or ``None`` for inline serializer
|
|
:rtype: str or None
|
|
"""
|
|
serializer_meta = getattr(serializer, 'Meta', None)
|
|
serializer_name = type(serializer).__name__
|
|
if hasattr(serializer_meta, 'ref_name'):
|
|
ref_name = serializer_meta.ref_name
|
|
elif serializer_name == 'NestedSerializer' and isinstance(serializer, serializers.ModelSerializer):
|
|
ref_name = None
|
|
else:
|
|
ref_name = serializer_name
|
|
if ref_name.endswith('Serializer'):
|
|
ref_name = ref_name[: -len('Serializer')]
|
|
return ref_name
|
|
|
|
def get_writable_class(self, serializer):
|
|
properties = {}
|
|
fields = {} if hasattr(serializer, 'child') else serializer.fields
|
|
remove_fields = []
|
|
|
|
for child_name, child in fields.items():
|
|
# read_only fields don't need to be in writable (write only) serializers
|
|
if 'read_only' in dir(child) and child.read_only:
|
|
remove_fields.append(child_name)
|
|
if isinstance(child, (ChoiceField, WritableNestedSerializer)):
|
|
properties[child_name] = None
|
|
|
|
if not properties:
|
|
return None
|
|
|
|
if type(serializer) not in self.writable_serializers:
|
|
writable_name = 'Writable' + type(serializer).__name__
|
|
meta_class = getattr(type(serializer), 'Meta', None)
|
|
if meta_class:
|
|
ref_name = 'Writable' + self.get_serializer_ref_name(serializer)
|
|
# remove read_only fields from write-only serializers
|
|
fields = list(meta_class.fields)
|
|
for field in remove_fields:
|
|
fields.remove(field)
|
|
writable_meta = type('Meta', (meta_class,), {'ref_name': ref_name, 'fields': fields})
|
|
|
|
properties['Meta'] = writable_meta
|
|
|
|
self.writable_serializers[type(serializer)] = type(writable_name, (type(serializer),), properties)
|
|
|
|
writable_class = self.writable_serializers[type(serializer)]
|
|
return writable_class
|
|
|
|
def get_filter_backends(self):
|
|
# bulk operations don't have filter params
|
|
if self.is_bulk_action:
|
|
return []
|
|
return super().get_filter_backends()
|
|
|
|
def _get_paginator(self):
|
|
# bulk operations don't have pagination
|
|
if self.is_bulk_action:
|
|
return None
|
|
return super()._get_paginator()
|
|
|
|
def _get_request_body(self, direction='request'):
|
|
# bulk delete should specify input
|
|
if (not self.is_bulk_action) or (self.method != 'DELETE'):
|
|
return super()._get_request_body(direction)
|
|
|
|
# rest from drf_spectacular.openapi.AutoSchema._get_request_body
|
|
# but remove the unsafe method check
|
|
|
|
request_serializer = self.get_request_serializer()
|
|
|
|
if isinstance(request_serializer, dict):
|
|
content = []
|
|
request_body_required = True
|
|
for media_type, serializer in request_serializer.items():
|
|
schema, partial_request_body_required = self._get_request_for_media_type(serializer, direction)
|
|
examples = self._get_examples(serializer, direction, media_type)
|
|
if schema is None:
|
|
continue
|
|
content.append((media_type, schema, examples))
|
|
request_body_required &= partial_request_body_required
|
|
else:
|
|
schema, request_body_required = self._get_request_for_media_type(request_serializer, direction)
|
|
if schema is None:
|
|
return None
|
|
content = [
|
|
(media_type, schema, self._get_examples(request_serializer, direction, media_type))
|
|
for media_type in self.map_parsers()
|
|
]
|
|
|
|
request_body = {
|
|
'content': {
|
|
media_type: build_media_type_object(schema, examples) for media_type, schema, examples in content
|
|
}
|
|
}
|
|
if request_body_required:
|
|
request_body['required'] = request_body_required
|
|
return request_body
|
|
|
|
def get_description(self):
|
|
"""
|
|
Return a string description for the ViewSet.
|
|
"""
|
|
|
|
# If a docstring is provided, use it.
|
|
if self.view.__doc__:
|
|
return get_doc(self.view.__class__)
|
|
|
|
# When the action method is decorated with @action, use the docstring of the method.
|
|
action_or_method = getattr(self.view, getattr(self.view, 'action', self.method.lower()), None)
|
|
if action_or_method and action_or_method.__doc__:
|
|
return get_doc(action_or_method)
|
|
|
|
# Else, generate a description from the class name.
|
|
return self._generate_description()
|
|
|
|
def _generate_description(self):
|
|
"""
|
|
Generate a docstring for the method. It also takes into account whether the method is for list or detail.
|
|
"""
|
|
model_name = self.view.queryset.model._meta.verbose_name
|
|
|
|
# Determine if the method is for list or detail.
|
|
if '{id}' in self.path:
|
|
return f"{self.method.capitalize()} a {model_name} object."
|
|
return f"{self.method.capitalize()} a list of {model_name} objects."
|