1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00

Continue output plugin implementation

This commit is contained in:
thatmattlove
2021-09-12 18:27:33 -07:00
parent 560663601d
commit 74fcb5dba4
16 changed files with 124 additions and 129 deletions

View File

@@ -58,7 +58,6 @@ async def send_webhook(query_data: Query, request: Request, timestamp: datetime)
log.error("Error sending webhook to {}: {}", params.logging.http.provider, str(err)) log.error("Error sending webhook to {}: {}", params.logging.http.provider, str(err))
@log.catch
async def query(query_data: Query, request: Request, background_tasks: BackgroundTasks): async def query(query_data: Query, request: Request, background_tasks: BackgroundTasks):
"""Ingest request data pass it to the backend application to perform the query.""" """Ingest request data pass it to the backend application to perform the query."""

View File

@@ -120,7 +120,7 @@ class PublicHyperglassError(HyperglassError):
"""Format error message with keyword arguments.""" """Format error message with keyword arguments."""
if "error" in kwargs: if "error" in kwargs:
error = kwargs.pop("error") error = kwargs.pop("error")
error = self._safe_format(error, **kwargs) error = self._safe_format(str(error), **kwargs)
kwargs["error"] = error kwargs["error"] = error
self._message = self._safe_format(self._message_template, **kwargs) self._message = self._safe_format(self._message_template, **kwargs)
self._keywords = list(kwargs.values()) self._keywords = list(kwargs.values())
@@ -150,7 +150,7 @@ class PrivateHyperglassError(HyperglassError):
"""Format error message with keyword arguments.""" """Format error message with keyword arguments."""
if "error" in kwargs: if "error" in kwargs:
error = kwargs.pop("error") error = kwargs.pop("error")
error = self._safe_format(error, **kwargs) error = self._safe_format(str(error), **kwargs)
kwargs["error"] = error kwargs["error"] = error
self._message = self._safe_format(message, **kwargs) self._message = self._safe_format(message, **kwargs)
self._keywords = list(kwargs.values()) self._keywords = list(kwargs.values())

View File

@@ -1,7 +1,7 @@
"""User-facing/Public exceptions.""" """User-facing/Public exceptions."""
# Standard Library # Standard Library
from typing import Any, Dict, Optional, ForwardRef from typing import TYPE_CHECKING, Any, Dict, Optional
# Project # Project
from hyperglass.configuration import params from hyperglass.configuration import params
@@ -9,8 +9,10 @@ from hyperglass.configuration import params
# Local # Local
from ._common import PublicHyperglassError from ._common import PublicHyperglassError
Query = ForwardRef("Query") if TYPE_CHECKING:
Device = ForwardRef("Device") # Project
from hyperglass.models.api.query import Query
from hyperglass.models.config.devices import Device
class ScrapeError( class ScrapeError(
@@ -18,7 +20,7 @@ class ScrapeError(
): ):
"""Raised when an SSH driver error occurs.""" """Raised when an SSH driver error occurs."""
def __init__(self, error: BaseException, *, device: Device): def __init__(self, *, error: BaseException, device: "Device"):
"""Initialize parent error.""" """Initialize parent error."""
super().__init__(error=str(error), device=device.name, proxy=device.proxy) super().__init__(error=str(error), device=device.name, proxy=device.proxy)
@@ -28,7 +30,7 @@ class AuthError(
): ):
"""Raised when authentication to a device fails.""" """Raised when authentication to a device fails."""
def __init__(self, error: BaseException, *, device: Device): def __init__(self, *, error: BaseException, device: "Device"):
"""Initialize parent error.""" """Initialize parent error."""
super().__init__(error=str(error), device=device.name, proxy=device.proxy) super().__init__(error=str(error), device=device.name, proxy=device.proxy)
@@ -36,7 +38,7 @@ class AuthError(
class RestError(PublicHyperglassError, template=params.messages.connection_error, level="danger"): class RestError(PublicHyperglassError, template=params.messages.connection_error, level="danger"):
"""Raised upon a rest API client error.""" """Raised upon a rest API client error."""
def __init__(self, error: BaseException, *, device: Device): def __init__(self, *, error: BaseException, device: "Device"):
"""Initialize parent error.""" """Initialize parent error."""
super().__init__(error=str(error), device=device.name) super().__init__(error=str(error), device=device.name)
@@ -46,7 +48,7 @@ class DeviceTimeout(
): ):
"""Raised when the connection to a device times out.""" """Raised when the connection to a device times out."""
def __init__(self, error: BaseException, *, device: Device): def __init__(self, *, error: BaseException, device: "Device"):
"""Initialize parent error.""" """Initialize parent error."""
super().__init__(error=str(error), device=device.name, proxy=device.proxy) super().__init__(error=str(error), device=device.name, proxy=device.proxy)
@@ -55,7 +57,7 @@ class InvalidQuery(PublicHyperglassError, template=params.messages.invalid_query
"""Raised when input validation fails.""" """Raised when input validation fails."""
def __init__( def __init__(
self, error: Optional[str] = None, *, query: "Query", **kwargs: Dict[str, Any] self, *, error: Optional[str] = None, query: "Query", **kwargs: Dict[str, Any]
) -> None: ) -> None:
"""Initialize parent error.""" """Initialize parent error."""
@@ -107,7 +109,7 @@ class InputInvalid(PublicHyperglassError, template=params.messages.invalid_input
"""Raised when input validation fails.""" """Raised when input validation fails."""
def __init__( def __init__(
self, error: Optional[Any] = None, *, target: str, **kwargs: Dict[str, Any] self, *, error: Optional[Any] = None, target: str, **kwargs: Dict[str, Any]
) -> None: ) -> None:
"""Initialize parent error.""" """Initialize parent error."""
@@ -123,7 +125,7 @@ class InputNotAllowed(PublicHyperglassError, template=params.messages.acl_not_al
"""Raised when input validation fails due to a configured check.""" """Raised when input validation fails due to a configured check."""
def __init__( def __init__(
self, error: Optional[str] = None, *, query: Query, **kwargs: Dict[str, Any] self, *, error: Optional[str] = None, query: "Query", **kwargs: Dict[str, Any]
) -> None: ) -> None:
"""Initialize parent error.""" """Initialize parent error."""
@@ -143,7 +145,7 @@ class ResponseEmpty(PublicHyperglassError, template=params.messages.no_output):
"""Raised when hyperglass can connect to the device but the response is empty.""" """Raised when hyperglass can connect to the device but the response is empty."""
def __init__( def __init__(
self, error: Optional[str] = None, *, query: Query, **kwargs: Dict[str, Any] self, *, error: Optional[str] = None, query: "Query", **kwargs: Dict[str, Any]
) -> None: ) -> None:
"""Initialize parent error.""" """Initialize parent error."""

View File

@@ -7,23 +7,21 @@ from typing import TYPE_CHECKING, Dict, Union, Sequence
# Project # Project
from hyperglass.log import log from hyperglass.log import log
from hyperglass.plugins import OutputPluginManager from hyperglass.plugins import OutputPluginManager
from hyperglass.models.api import Query
from hyperglass.parsing.nos import scrape_parsers, structured_parsers
from hyperglass.parsing.common import parsers
from hyperglass.models.config.devices import Device
# Local # Local
from ._construct import Construct from ._construct import Construct
if TYPE_CHECKING: if TYPE_CHECKING:
# Project # Project
from hyperglass.models.api import Query
from hyperglass.compat._sshtunnel import SSHTunnelForwarder from hyperglass.compat._sshtunnel import SSHTunnelForwarder
from hyperglass.models.config.devices import Device
class Connection(ABC): class Connection(ABC):
"""Base transport driver class.""" """Base transport driver class."""
def __init__(self, device: Device, query_data: Query) -> None: def __init__(self, device: "Device", query_data: "Query") -> None:
"""Initialize connection to device.""" """Initialize connection to device."""
self.device = device self.device = device
self.query_data = query_data self.query_data = query_data
@@ -38,53 +36,14 @@ class Connection(ABC):
"""Return a preconfigured sshtunnel.SSHTunnelForwarder instance.""" """Return a preconfigured sshtunnel.SSHTunnelForwarder instance."""
pass pass
async def parsed_response( # noqa: C901 ("too complex") async def parsed_response(self, output: Sequence[str]) -> Union[str, Sequence[Dict]]:
self, output: Sequence[str]
) -> Union[str, Sequence[Dict]]:
"""Send output through common parsers.""" """Send output through common parsers."""
log.debug("Pre-parsed responses:\n{}", output) log.debug("Pre-parsed responses:\n{}", output)
parsed = ()
response = None
structured_nos = structured_parsers.keys() response = self.plugin_manager.execute(
structured_query_types = structured_parsers.get(self.device.nos, {}).keys() directive=self.query_data.directive, output=output, device=self.device
)
scrape_nos = scrape_parsers.keys()
scrape_query_types = scrape_parsers.get(self.device.nos, {}).keys()
if not self.device.structured_output:
_parsed = ()
for func in parsers:
for response in output:
_output = func(commands=self.query, output=response)
_parsed += (_output,)
if self.device.nos in scrape_nos and self.query_type in scrape_query_types:
func = scrape_parsers[self.device.nos][self.query_type]
for response in _parsed:
_output = func(response)
parsed += (_output,)
else:
parsed += _parsed
response = "\n\n".join(parsed)
elif (
self.device.structured_output
and self.device.nos in structured_nos
and self.query_type not in structured_query_types
):
for func in parsers:
for response in output:
_output = func(commands=self.query, output=response)
parsed += (_output,)
response = "\n\n".join(parsed)
elif (
self.device.structured_output
and self.device.nos in structured_nos
and self.query_type in structured_query_types
):
func = structured_parsers[self.device.nos][self.query_type]
response = func(output)
if response is None: if response is None:
response = "\n\n".join(output) response = "\n\n".join(output)

View File

@@ -55,7 +55,10 @@ async def execute(query: "Query") -> Union[str, Sequence[Dict]]:
mapped_driver = map_driver(query.device.driver) mapped_driver = map_driver(query.device.driver)
driver: "Connection" = mapped_driver(query.device, query) driver: "Connection" = mapped_driver(query.device, query)
signal.signal(signal.SIGALRM, handle_timeout(error=TimeoutError(), device=query.device)) signal.signal(
signal.SIGALRM,
handle_timeout(error=TimeoutError("Connection timed out"), device=query.device),
)
signal.alarm(params.request_timeout - 1) signal.alarm(params.request_timeout - 1)
if query.device.proxy: if query.device.proxy:

View File

@@ -7,7 +7,6 @@ import shutil
import logging import logging
import platform import platform
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from pathlib import Path
# Third Party # Third Party
from gunicorn.app.base import BaseApplication # type: ignore from gunicorn.app.base import BaseApplication # type: ignore
@@ -124,10 +123,8 @@ def cache_config() -> bool:
def register_all_plugins(devices: "Devices") -> None: def register_all_plugins(devices: "Devices") -> None:
"""Validate and register configured plugins.""" """Validate and register configured plugins."""
for plugin_file in { for plugin_file, directives in devices.directive_plugins().items():
Path(p) for p in (p for d in devices.objects for c in d.commands for p in c.plugins) failures = register_plugin(plugin_file, directives=directives)
}:
failures = register_plugin(plugin_file)
for failure in failures: for failure in failures:
log.warning( log.warning(
"Plugin '{}' is not a valid hyperglass plugin, and was not registered", failure, "Plugin '{}' is not a valid hyperglass plugin, and was not registered", failure,
@@ -203,11 +200,9 @@ class HyperglassWSGI(BaseApplication):
def start(**kwargs): def start(**kwargs):
"""Start hyperglass via gunicorn.""" """Start hyperglass via gunicorn."""
# Project
from hyperglass.api import app
HyperglassWSGI( HyperglassWSGI(
app=app, app="hyperglass.api:app",
options={ options={
"worker_class": "uvicorn.workers.UvicornWorker", "worker_class": "uvicorn.workers.UvicornWorker",
"preload": True, "preload": True,

View File

@@ -1,6 +1,9 @@
"""All Data Models used by hyperglass.""" """All Data Models used by hyperglass."""
# Local # Local
from .main import HyperglassModel from .main import HyperglassModel, HyperglassModelWithId
__all__ = ("HyperglassModel",) __all__ = (
"HyperglassModel",
"HyperglassModelWithId",
)

View File

@@ -168,7 +168,7 @@ class Query(BaseModel):
def validate_query_location(cls, value): def validate_query_location(cls, value):
"""Ensure query_location is defined.""" """Ensure query_location is defined."""
valid_id = value in devices._ids valid_id = value in devices.ids
valid_hostname = value in devices.hostnames valid_hostname = value in devices.hostnames
if not any((valid_id, valid_hostname)): if not any((valid_id, valid_hostname)):

View File

@@ -23,7 +23,7 @@ from hyperglass.log import log
from hyperglass.exceptions.private import InputValidationError from hyperglass.exceptions.private import InputValidationError
# Local # Local
from ..main import HyperglassModel from ..main import HyperglassModel, HyperglassModelWithId
from ..fields import Action from ..fields import Action
from ..config.params import Params from ..config.params import Params
@@ -224,7 +224,7 @@ class RuleWithoutValidation(Rule):
Rules = Union[RuleWithIPv4, RuleWithIPv6, RuleWithPattern, RuleWithoutValidation] Rules = Union[RuleWithIPv4, RuleWithIPv6, RuleWithPattern, RuleWithoutValidation]
class Directive(HyperglassModel): class Directive(HyperglassModelWithId):
"""A directive contains commands that can be run on a device, as long as defined rules are met.""" """A directive contains commands that can be run on a device, as long as defined rules are met."""
id: StrictStr id: StrictStr

View File

@@ -3,19 +3,12 @@
# Standard Library # Standard Library
import os import os
import re import re
from typing import Any, Dict, List, Tuple, Union, Optional from typing import Any, Set, Dict, List, Tuple, Union, Optional
from pathlib import Path from pathlib import Path
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
# Third Party # Third Party
from pydantic import ( from pydantic import StrictInt, StrictStr, StrictBool, validator, root_validator
StrictInt,
StrictStr,
StrictBool,
PrivateAttr,
validator,
root_validator,
)
# Project # Project
from hyperglass.log import log from hyperglass.log import log
@@ -26,7 +19,7 @@ from hyperglass.models.commands.generic import Directive
# Local # Local
from .ssl import Ssl from .ssl import Ssl
from ..main import HyperglassModel from ..main import HyperglassModel, HyperglassModelWithId
from .proxy import Proxy from .proxy import Proxy
from .params import Params from .params import Params
from ..fields import SupportedDriver from ..fields import SupportedDriver
@@ -34,10 +27,10 @@ from .network import Network
from .credential import Credential from .credential import Credential
class Device(HyperglassModel, extra="allow"): class Device(HyperglassModelWithId, extra="allow"):
"""Validation model for per-router config in devices.yaml.""" """Validation model for per-router config in devices.yaml."""
_id: StrictStr = PrivateAttr() id: StrictStr
name: StrictStr name: StrictStr
address: Union[IPv4Address, IPv6Address, StrictStr] address: Union[IPv4Address, IPv6Address, StrictStr]
network: Network network: Network
@@ -55,23 +48,9 @@ class Device(HyperglassModel, extra="allow"):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
"""Set the device ID.""" """Set the device ID."""
_id, values = self._generate_id(kwargs) _id, values = self._generate_id(kwargs)
super().__init__(**values) super().__init__(id=_id, **values)
self._id = _id
self._validate_directive_attrs() self._validate_directive_attrs()
def __hash__(self) -> int:
"""Make device object hashable so the object can be deduplicated with set()."""
return hash((self.name,))
def __eq__(self, other: Any) -> bool:
"""Make device object comparable so the object can be deduplicated with set()."""
result = False
if isinstance(other, HyperglassModel):
result = self.name == other.name
return result
@property @property
def _target(self): def _target(self):
return str(self.address) return str(self.address)
@@ -104,7 +83,7 @@ class Device(HyperglassModel, extra="allow"):
def export_api(self) -> Dict[str, Any]: def export_api(self) -> Dict[str, Any]:
"""Export API-facing device fields.""" """Export API-facing device fields."""
return { return {
"id": self._id, "id": self.id,
"name": self.name, "name": self.name,
"network": self.network.display_name, "network": self.network.display_name,
} }
@@ -233,7 +212,7 @@ class Device(HyperglassModel, extra="allow"):
class Devices(HyperglassModel, extra="allow"): class Devices(HyperglassModel, extra="allow"):
"""Validation model for device configurations.""" """Validation model for device configurations."""
_ids: List[StrictStr] = [] ids: List[StrictStr] = []
hostnames: List[StrictStr] = [] hostnames: List[StrictStr] = []
objects: List[Device] = [] objects: List[Device] = []
all_nos: List[StrictStr] = [] all_nos: List[StrictStr] = []
@@ -248,7 +227,7 @@ class Devices(HyperglassModel, extra="allow"):
all_nos = set() all_nos = set()
objects = set() objects = set()
hostnames = set() hostnames = set()
_ids = set() ids = set()
init_kwargs = {} init_kwargs = {}
@@ -261,13 +240,13 @@ class Devices(HyperglassModel, extra="allow"):
# list with `devices.hostnames`, same for all router # list with `devices.hostnames`, same for all router
# classes, for when iteration over all routers is required. # classes, for when iteration over all routers is required.
hostnames.add(device.name) hostnames.add(device.name)
_ids.add(device._id) ids.add(device.id)
objects.add(device) objects.add(device)
all_nos.add(device.nos) all_nos.add(device.nos)
# Convert the de-duplicated sets to a standard list, add lists # Convert the de-duplicated sets to a standard list, add lists
# as class attributes. Sort router list by router name attribute # as class attributes. Sort router list by router name attribute
init_kwargs["_ids"] = list(_ids) init_kwargs["ids"] = list(ids)
init_kwargs["hostnames"] = list(hostnames) init_kwargs["hostnames"] = list(hostnames)
init_kwargs["all_nos"] = list(all_nos) init_kwargs["all_nos"] = list(all_nos)
init_kwargs["objects"] = sorted(objects, key=lambda x: x.name) init_kwargs["objects"] = sorted(objects, key=lambda x: x.name)
@@ -277,7 +256,7 @@ class Devices(HyperglassModel, extra="allow"):
def __getitem__(self, accessor: str) -> Device: def __getitem__(self, accessor: str) -> Device:
"""Get a device by its name.""" """Get a device by its name."""
for device in self.objects: for device in self.objects:
if device._id == accessor: if device.id == accessor:
return device return device
elif device.name == accessor: elif device.name == accessor:
return device return device
@@ -296,7 +275,7 @@ class Devices(HyperglassModel, extra="allow"):
"display_name": name, "display_name": name,
"locations": [ "locations": [
{ {
"id": device._id, "id": device.id,
"name": device.name, "name": device.name,
"network": device.network.display_name, "network": device.network.display_name,
"directives": [c.frontend(params) for c in device.commands], "directives": [c.frontend(params) for c in device.commands],
@@ -307,3 +286,20 @@ class Devices(HyperglassModel, extra="allow"):
} }
for name in names for name in names
] ]
def directive_plugins(self) -> Dict[Path, Tuple[StrictStr]]:
"""Get a mapping of plugin paths to associated directive IDs."""
result: Dict[Path, Set[StrictStr]] = {}
# Unique set of all directives.
directives = {directive for device in self.objects for directive in device.commands}
# Unique set of all plugin file names.
plugin_names = {plugin for directive in directives for plugin in directive.plugins}
for directive in directives:
# Convert each plugin file name to a `Path` object.
for plugin in (Path(p) for p in directive.plugins if p in plugin_names):
if plugin not in result:
result[plugin] = set()
result[plugin].add(directive.id)
# Convert the directive set to a tuple.
return {k: tuple(v) for k, v in result.items()}

View File

@@ -80,3 +80,25 @@ class HyperglassModel(BaseModel):
} }
return yaml.safe_dump(json.loads(self.export_json(**export_kwargs)), *args, **kwargs) return yaml.safe_dump(json.loads(self.export_json(**export_kwargs)), *args, **kwargs)
class HyperglassModelWithId(HyperglassModel):
"""hyperglass model that is unique by its `id` field."""
id: str
def __eq__(self: "HyperglassModelWithId", other: "HyperglassModelWithId") -> bool:
"""Other model is equal to this model."""
if not isinstance(other, self.__class__):
return False
if hasattr(other, "id"):
return other and self.id == other.id
return False
def __ne__(self: "HyperglassModelWithId", other: "HyperglassModelWithId") -> bool:
"""Other model is not equal to this model."""
return not self.__eq__(other)
def __hash__(self: "HyperglassModelWithId") -> int:
"""Create a hashed representation of this model's name."""
return hash(self.id)

View File

@@ -2,7 +2,7 @@
# Standard Library # Standard Library
from abc import ABC from abc import ABC
from typing import Any, Union, Literal, TypeVar from typing import Any, Union, Literal, TypeVar, Sequence
from inspect import Signature from inspect import Signature
# Third Party # Third Party
@@ -52,3 +52,9 @@ class HyperglassPlugin(BaseModel, ABC):
"""Initialize plugin instance.""" """Initialize plugin instance."""
name = kwargs.pop("name", None) or self.__class__.__name__ name = kwargs.pop("name", None) or self.__class__.__name__
super().__init__(name=name, **kwargs) super().__init__(name=name, **kwargs)
class DirectivePlugin(HyperglassPlugin):
"""Plugin associated with directives."""
directives: Sequence[str] = ()

View File

@@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
# Local # Local
from ._base import HyperglassPlugin from ._base import DirectivePlugin
if TYPE_CHECKING: if TYPE_CHECKING:
# Project # Project
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
InputPluginReturn = Union[None, bool] InputPluginReturn = Union[None, bool]
class InputPlugin(HyperglassPlugin): class InputPlugin(DirectivePlugin):
"""Plugin to validate user input prior to running commands.""" """Plugin to validate user input prior to running commands."""
def validate(self, query: "Query") -> InputPluginReturn: def validate(self, query: "Query") -> InputPluginReturn:

View File

@@ -4,7 +4,7 @@
import json import json
import codecs import codecs
import pickle import pickle
from typing import TYPE_CHECKING, List, Generic, TypeVar, Callable, Generator from typing import TYPE_CHECKING, Any, List, Generic, TypeVar, Callable, Generator
from inspect import isclass from inspect import isclass
# Project # Project
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
# Project # Project
from hyperglass.models.api.query import Query from hyperglass.models.api.query import Query
from hyperglass.models.config.devices import Device from hyperglass.models.config.devices import Device
from hyperglass.models.commands.generic import Directive
PluginT = TypeVar("PluginT") PluginT = TypeVar("PluginT")
@@ -73,7 +74,7 @@ class PluginManager(Generic[PluginT]):
def plugins(self: "PluginManager") -> List[PluginT]: def plugins(self: "PluginManager") -> List[PluginT]:
"""Get all plugins, with built-in plugins last.""" """Get all plugins, with built-in plugins last."""
return sorted( return sorted(
self.plugins, self._get_plugins(),
key=lambda p: -1 if p.__hyperglass_builtin__ else 1, # flake8: noqa IF100 key=lambda p: -1 if p.__hyperglass_builtin__ else 1, # flake8: noqa IF100
reverse=True, reverse=True,
) )
@@ -117,12 +118,12 @@ class PluginManager(Generic[PluginT]):
return return
raise PluginError("Plugin '{}' is not a valid hyperglass plugin", repr(plugin)) raise PluginError("Plugin '{}' is not a valid hyperglass plugin", repr(plugin))
def register(self: "PluginManager", plugin: PluginT) -> None: def register(self: "PluginManager", plugin: PluginT, *args: Any, **kwargs: Any) -> None:
"""Add a plugin to currently active plugins.""" """Add a plugin to currently active plugins."""
# Create a set of plugins so duplicate plugins are not mistakenly added. # Create a set of plugins so duplicate plugins are not mistakenly added.
try: try:
if issubclass(plugin, HyperglassPlugin): if issubclass(plugin, HyperglassPlugin):
instance = plugin() instance = plugin(*args, **kwargs)
plugins = { plugins = {
# Create a base64 representation of a picked plugin. # Create a base64 representation of a picked plugin.
codecs.encode(pickle.dumps(p), "base64").decode() codecs.encode(pickle.dumps(p), "base64").decode()
@@ -131,7 +132,10 @@ class PluginManager(Generic[PluginT]):
} }
# Add plugins from cache. # Add plugins from cache.
self._cache.set(f"hyperglass.plugins.{self._type}", json.dumps(list(plugins))) self._cache.set(f"hyperglass.plugins.{self._type}", json.dumps(list(plugins)))
log.success("Registered plugin '{}'", instance.name) if instance.__hyperglass_builtin__ is True:
log.debug("Registered built-in plugin '{}'", instance.name)
else:
log.success("Registered plugin '{}'", instance.name)
return return
except TypeError: except TypeError:
raise PluginError( raise PluginError(
@@ -145,13 +149,15 @@ class PluginManager(Generic[PluginT]):
class InputPluginManager(PluginManager[InputPlugin], type="input"): class InputPluginManager(PluginManager[InputPlugin], type="input"):
"""Manage Input Validation Plugins.""" """Manage Input Validation Plugins."""
def execute(self: "InputPluginManager", query: "Query") -> InputPluginReturn: def execute(
self: "InputPluginManager", *, directive: "Directive", query: "Query"
) -> InputPluginReturn:
"""Execute all input validation plugins. """Execute all input validation plugins.
If any plugin returns `False`, execution is halted. If any plugin returns `False`, execution is halted.
""" """
result = None result = None
for plugin in self.plugins: for plugin in (plugin for plugin in self.plugins if directive.id in plugin.directives):
if result is False: if result is False:
return result return result
result = plugin.validate(query) result = plugin.validate(query)
@@ -161,13 +167,15 @@ class InputPluginManager(PluginManager[InputPlugin], type="input"):
class OutputPluginManager(PluginManager[OutputPlugin], type="output"): class OutputPluginManager(PluginManager[OutputPlugin], type="output"):
"""Manage Output Processing Plugins.""" """Manage Output Processing Plugins."""
def execute(self: "OutputPluginManager", output: str, device: "Device") -> OutputPluginReturn: def execute(
self: "OutputPluginManager", *, directive: "Directive", output: str, device: "Device"
) -> OutputPluginReturn:
"""Execute all output parsing plugins. """Execute all output parsing plugins.
The result of each plugin is passed to the next plugin. The result of each plugin is passed to the next plugin.
""" """
result = output result = output
for plugin in self.plugins: for plugin in (plugin for plugin in self.plugins if directive.id in plugin.directives):
if result is False: if result is False:
return result return result
# Pass the result of each plugin to the next plugin. # Pass the result of each plugin to the next plugin.

View File

@@ -1,10 +1,10 @@
"""Device output plugins.""" """Device output plugins."""
# Standard Library # Standard Library
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union, Sequence
# Local # Local
from ._base import HyperglassPlugin from ._base import DirectivePlugin
if TYPE_CHECKING: if TYPE_CHECKING:
# Project # Project
@@ -14,9 +14,11 @@ if TYPE_CHECKING:
OutputPluginReturn = Union[None, "ParsedRoutes", str] OutputPluginReturn = Union[None, "ParsedRoutes", str]
class OutputPlugin(HyperglassPlugin): class OutputPlugin(DirectivePlugin):
"""Plugin to interact with device command output.""" """Plugin to interact with device command output."""
directive_ids: Sequence[str] = ()
def process(self, output: Union["ParsedRoutes", str], device: "Device") -> OutputPluginReturn: def process(self, output: Union["ParsedRoutes", str], device: "Device") -> OutputPluginReturn:
"""Process or manipulate output from a device.""" """Process or manipulate output from a device."""
return None return None

View File

@@ -23,7 +23,7 @@ def _is_class(module: Any, obj: object) -> bool:
return isclass(obj) and obj.__module__ == module.__name__ return isclass(obj) and obj.__module__ == module.__name__
def _register_from_module(module: Any) -> Tuple[str, ...]: def _register_from_module(module: Any, **kwargs: Any) -> Tuple[str, ...]:
"""Register defined classes from the module.""" """Register defined classes from the module."""
failures = () failures = ()
defs = getmembers(module, lambda o: _is_class(module, o)) defs = getmembers(module, lambda o: _is_class(module, o))
@@ -35,7 +35,7 @@ def _register_from_module(module: Any) -> Tuple[str, ...]:
else: else:
failures += (name,) failures += (name,)
continue continue
manager.register(plugin) manager.register(plugin, **kwargs)
return failures return failures
return failures return failures
@@ -57,10 +57,10 @@ def init_plugins() -> None:
_register_from_module(_builtin) _register_from_module(_builtin)
def register_plugin(plugin_file: Path) -> Tuple[str, ...]: def register_plugin(plugin_file: Path, **kwargs) -> Tuple[str, ...]:
"""Register an external plugin by file path.""" """Register an external plugin by file path."""
if plugin_file.exists(): if plugin_file.exists():
module = _module_from_file(plugin_file) module = _module_from_file(plugin_file)
results = _register_from_module(module) results = _register_from_module(module, **kwargs)
return results return results
raise FileNotFoundError(str(plugin_file)) raise FileNotFoundError(str(plugin_file))