1
0
mirror of https://github.com/netbox-community/netbox.git synced 2024-05-10 07:54:54 +00:00
Jeremy Stretch d470848b29 Closes #12246: General cleanup of utilities modules
* Clean up base modules

* Clean up forms modules

* Clean up templatetags modules

* Replace custom simplify_decimal filter with floatformat

* Misc cleanup

* Merge ReturnURLForm into ConfirmationForm

* Clean up import statements for utilities.forms

* Fix field class references in docs
2023-04-14 10:33:53 -04:00

253 lines
10 KiB
Python

import functools
from django.core.exceptions import FieldDoesNotExist
from django.db.models import ForeignKey
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.reverse_related import ManyToOneRel
from graphene import InputObjectType
from graphene.types.generic import GenericScalar
from graphene.types.resolver import default_resolver
from graphene_django import DjangoObjectType
from graphql import GraphQLResolveInfo, GraphQLSchema
from graphql.execution.execute import get_field_def
from graphql.language.ast import FragmentSpreadNode, InlineFragmentNode, VariableNode
from graphql.pyutils import Path
from graphql.type.definition import GraphQLInterfaceType, GraphQLUnionType
__all__ = (
'gql_query_optimizer',
)
def gql_query_optimizer(queryset, info, **options):
return QueryOptimizer(info).optimize(queryset)
class QueryOptimizer(object):
def __init__(self, info, **options):
self.root_info = info
def optimize(self, queryset):
info = self.root_info
field_def = get_field_def(info.schema, info.parent_type, info.field_nodes[0])
field_names = self._optimize_gql_selections(
self._get_type(field_def),
info.field_nodes[0],
)
qs = queryset.prefetch_related(*field_names)
return qs
def _get_type(self, field_def):
a_type = field_def.type
while hasattr(a_type, "of_type"):
a_type = a_type.of_type
return a_type
def _get_graphql_schema(self, schema):
if isinstance(schema, GraphQLSchema):
return schema
else:
return schema.graphql_schema
def _get_possible_types(self, graphql_type):
if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLUnionType)):
graphql_schema = self._get_graphql_schema(self.root_info.schema)
return graphql_schema.get_possible_types(graphql_type)
else:
return (graphql_type,)
def _get_base_model(self, graphql_types):
models = tuple(t.graphene_type._meta.model for t in graphql_types)
for model in models:
if all(issubclass(m, model) for m in models):
return model
return None
def handle_inline_fragment(self, selection, schema, possible_types, field_names):
fragment_type_name = selection.type_condition.name.value
graphql_schema = self._get_graphql_schema(schema)
fragment_type = graphql_schema.get_type(fragment_type_name)
fragment_possible_types = self._get_possible_types(fragment_type)
for fragment_possible_type in fragment_possible_types:
fragment_model = fragment_possible_type.graphene_type._meta.model
parent_model = self._get_base_model(possible_types)
if not parent_model:
continue
path_from_parent = fragment_model._meta.get_path_from_parent(parent_model)
select_related_name = LOOKUP_SEP.join(p.join_field.name for p in path_from_parent)
if not select_related_name:
continue
sub_field_names = self._optimize_gql_selections(
fragment_possible_type,
selection,
)
field_names.append(select_related_name)
return
def handle_fragment_spread(self, field_names, name, field_type):
fragment = self.root_info.fragments[name]
sub_field_names = self._optimize_gql_selections(
field_type,
fragment,
)
def _optimize_gql_selections(self, field_type, field_ast):
field_names = []
selection_set = field_ast.selection_set
if not selection_set:
return field_names
optimized_fields_by_model = {}
schema = self.root_info.schema
graphql_schema = self._get_graphql_schema(schema)
graphql_type = graphql_schema.get_type(field_type.name)
possible_types = self._get_possible_types(graphql_type)
for selection in selection_set.selections:
if isinstance(selection, InlineFragmentNode):
self.handle_inline_fragment(selection, schema, possible_types, field_names)
else:
name = selection.name.value
if isinstance(selection, FragmentSpreadNode):
self.handle_fragment_spread(field_names, name, field_type)
else:
for possible_type in possible_types:
selection_field_def = possible_type.fields.get(name)
if not selection_field_def:
continue
graphene_type = possible_type.graphene_type
model = getattr(graphene_type._meta, "model", None)
if model and name not in optimized_fields_by_model:
field_model = optimized_fields_by_model[name] = model
if field_model == model:
self._optimize_field(
field_names,
model,
selection,
selection_field_def,
possible_type,
)
return field_names
def _get_field_info(self, field_names, model, selection, field_def):
name = None
model_field = None
name = self._get_name_from_resolver(field_def.resolve)
if not name and callable(field_def.resolve) and not isinstance(field_def.resolve, functools.partial):
name = selection.name.value
if name:
model_field = self._get_model_field_from_name(model, name)
return (name, model_field)
def _optimize_field(self, field_names, model, selection, field_def, parent_type):
name, model_field = self._get_field_info(field_names, model, selection, field_def)
if model_field:
self._optimize_field_by_name(field_names, model, selection, field_def, name, model_field)
return
def _optimize_field_by_name(self, field_names, model, selection, field_def, name, model_field):
if model_field.many_to_one or model_field.one_to_one:
sub_field_names = self._optimize_gql_selections(
self._get_type(field_def),
selection,
)
if name not in field_names:
field_names.append(name)
for field in sub_field_names:
prefetch_key = f"{name}__{field}"
if prefetch_key not in field_names:
field_names.append(prefetch_key)
if model_field.one_to_many or model_field.many_to_many:
sub_field_names = self._optimize_gql_selections(
self._get_type(field_def),
selection,
)
if isinstance(model_field, ManyToOneRel):
sub_field_names.append(model_field.field.name)
field_names.append(name)
for field in sub_field_names:
prefetch_key = f"{name}__{field}"
if prefetch_key not in field_names:
field_names.append(prefetch_key)
return
def _get_optimization_hints(self, resolver):
return getattr(resolver, "optimization_hints", None)
def _get_value(self, info, value):
if isinstance(value, VariableNode):
var_name = value.name.value
value = info.variable_values.get(var_name)
return value
elif isinstance(value, InputObjectType):
return value.__dict__
else:
return GenericScalar.parse_literal(value)
def _get_name_from_resolver(self, resolver):
optimization_hints = self._get_optimization_hints(resolver)
if optimization_hints:
name_fn = optimization_hints.model_field
if name_fn:
return name_fn()
if self._is_resolver_for_id_field(resolver):
return "id"
elif isinstance(resolver, functools.partial):
resolver_fn = resolver
if resolver_fn.func != default_resolver:
# Some resolvers have the partial function as the second
# argument.
for arg in resolver_fn.args:
if isinstance(arg, (str, functools.partial)):
break
else:
# No suitable instances found, default to first arg
arg = resolver_fn.args[0]
resolver_fn = arg
if isinstance(resolver_fn, functools.partial) and resolver_fn.func == default_resolver:
return resolver_fn.args[0]
if self._is_resolver_for_id_field(resolver_fn):
return "id"
return resolver_fn
def _is_resolver_for_id_field(self, resolver):
resolve_id = DjangoObjectType.resolve_id
return resolver == resolve_id
def _get_model_field_from_name(self, model, name):
try:
return model._meta.get_field(name)
except FieldDoesNotExist:
descriptor = model.__dict__.get(name)
if not descriptor:
return None
return getattr(descriptor, "rel", None) or getattr(descriptor, "related", None) # Django < 1.9
def _is_foreign_key_id(self, model_field, name):
return isinstance(model_field, ForeignKey) and model_field.name != name and model_field.get_attname() == name
def _create_resolve_info(self, field_name, field_asts, return_type, parent_type):
return GraphQLResolveInfo(
field_name,
field_asts,
return_type,
parent_type,
Path(None, 0, None),
schema=self.root_info.schema,
fragments=self.root_info.fragments,
root_value=self.root_info.root_value,
operation=self.root_info.operation,
variable_values=self.root_info.variable_values,
context=self.root_info.context,
is_awaitable=self.root_info.is_awaitable,
)