1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00
Files
checktheroads-hyperglass/hyperglass/models/api/query.py
2022-12-11 17:30:20 -05:00

119 lines
3.8 KiB
Python

"""Input query validation model."""
# Standard Library
import typing as t
import hashlib
import secrets
from datetime import datetime
# Third Party
from pydantic import BaseModel, constr, validator
# Project
from hyperglass.log import log
from hyperglass.util import snake_to_camel, repr_from_attrs
from hyperglass.state import use_state
from hyperglass.plugins import InputPluginManager
from hyperglass.exceptions.public import InputInvalid, QueryTypeNotFound, QueryLocationNotFound
from hyperglass.exceptions.private import InputValidationError
# Local
from ..config.devices import Device
(TEXT := use_state("params").web.text)
QueryLocation = constr(strip_whitespace=True, strict=True, min_length=1)
QueryTarget = constr(strip_whitespace=True, min_length=1)
QueryType = constr(strip_whitespace=True, strict=True, min_length=1)
class Query(BaseModel):
"""Validation model for input query parameters."""
query_location: QueryLocation # Device `name` field
query_target: t.Union[t.List[QueryTarget], QueryTarget]
query_type: QueryType # Directive `id` field
class Config:
"""Pydantic model configuration."""
extra = "allow"
alias_generator = snake_to_camel
allow_population_by_field_name = True
def __init__(self, **data) -> None:
"""Initialize the query with a UTC timestamp at initialization time."""
super().__init__(**data)
self.timestamp = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
state = use_state()
self._state = state
query_directives = self.device.directives.matching(self.query_type)
if len(query_directives) < 1:
raise QueryTypeNotFound(query_type=self.query_type)
self.directive = query_directives[0]
try:
self.validate_query_target()
except InputValidationError as err:
raise InputInvalid(**err.kwargs)
def __repr__(self) -> str:
"""Represent only the query fields."""
return repr_from_attrs(self, self.__config__.fields.keys())
def __str__(self) -> str:
"""Alias __str__ to __repr__."""
return repr(self)
def digest(self) -> str:
"""Create SHA256 hash digest of model representation."""
return hashlib.sha256(repr(self).encode()).hexdigest()
def random(self) -> str:
"""Create a random string to prevent client or proxy caching."""
return hashlib.sha256(
secrets.token_bytes(8) + repr(self).encode() + secrets.token_bytes(8)
).hexdigest()
def validate_query_target(self) -> None:
"""Validate a query target after all fields/relationships havebeen initialized."""
# Run config/rule-based validations.
self.directive.validate_target(self.query_target)
# Run plugin-based validations.
manager = InputPluginManager()
manager.execute(query=self)
log.debug("Validation passed for query {!r}", self)
def dict(self) -> t.Dict[str, t.Union[t.List[str], str]]:
"""Include only public fields."""
return super().dict(include={"query_location", "query_target", "query_type"})
@property
def device(self) -> Device:
"""Get this query's device object by query_location."""
return self._state.devices[self.query_location]
@validator("query_location")
def validate_query_location(cls, value):
"""Ensure query_location is defined."""
devices = use_state("devices")
if not devices.valid_id_or_name(value):
raise QueryLocationNotFound(location=value)
return value
@validator("query_type")
def validate_query_type(cls, value: t.Any):
"""Ensure a requested query type exists."""
devices = use_state("devices")
if any((device.has_directives(value) for device in devices)):
return value
raise QueryTypeNotFound(query_type=value)