1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00

Refactor devices model

This commit is contained in:
checktheroads
2020-07-30 01:30:01 -07:00
parent e2ae2adcf9
commit e3716784bc
22 changed files with 289 additions and 292 deletions

View File

@@ -242,15 +242,13 @@ app.add_api_route(
# Enable certificate import route only if a device using # Enable certificate import route only if a device using
# hyperglass-agent is defined. # hyperglass-agent is defined.
for device in devices.routers: if [n for n in devices.all_nos if n in TRANSPORT_REST]:
if device.nos in TRANSPORT_REST: app.add_api_route(
app.add_api_route( path="/api/import-agent-certificate/",
path="/api/import-agent-certificate/", endpoint=import_certificate,
endpoint=import_certificate, methods=["POST"],
methods=["POST"], include_in_schema=False,
include_in_schema=False, )
)
break
if params.docs.enable: if params.docs.enable:
app.add_api_route(path=params.docs.uri, endpoint=docs, include_in_schema=False) app.add_api_route(path=params.docs.uri, endpoint=docs, include_in_schema=False)

View File

@@ -124,16 +124,19 @@ class Query(BaseModel):
@property @property
def device(self): def device(self):
"""Get this query's device object by query_location.""" """Get this query's device object by query_location."""
return getattr(devices, self.query_location) return devices[self.query_location]
@property
def query(self):
"""Get this query's configuration object."""
return params.queries[self.query_type]
def export_dict(self, pretty=False): def export_dict(self, pretty=False):
"""Create dictionary representation of instance.""" """Create dictionary representation of instance."""
if pretty: if pretty:
loc = getattr(devices, self.query_location)
query_type = getattr(params.queries, self.query_type)
items = { items = {
"query_location": loc.display_name, "query_location": self.device.display_name,
"query_type": query_type.display_name, "query_type": self.query.display_name,
"query_vrf": self.query_vrf.display_name, "query_vrf": self.query_vrf.display_name,
"query_target": str(self.query_target), "query_target": str(self.query_target),
} }
@@ -163,12 +166,12 @@ class Query(BaseModel):
Returns: Returns:
{str} -- Valid query_type {str} -- Valid query_type
""" """
query_type_obj = getattr(params.queries, value) query = params.queries[value]
if not query_type_obj.enable: if not query.enable:
raise InputInvalid( raise InputInvalid(
params.messages.feature_not_enabled, params.messages.feature_not_enabled,
level="warning", level="warning",
feature=query_type_obj.display_name, feature=query.display_name,
) )
return value return value
@@ -208,7 +211,7 @@ class Query(BaseModel):
{str} -- Valid query_vrf {str} -- Valid query_vrf
""" """
vrf_object = get_vrf_object(value) vrf_object = get_vrf_object(value)
device = getattr(devices, values["query_location"]) device = devices[values["query_location"]]
device_vrf = None device_vrf = None
for vrf in device.vrfs: for vrf in device.vrfs:
if vrf == vrf_object: if vrf == vrf_object:

View File

@@ -13,7 +13,6 @@ from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
# Project # Project
from hyperglass.log import log from hyperglass.log import log
from hyperglass.util import clean_name
from hyperglass.cache import AsyncCache from hyperglass.cache import AsyncCache
from hyperglass.encode import jwt_decode from hyperglass.encode import jwt_decode
from hyperglass.external import Webhook, bgptools from hyperglass.external import Webhook, bgptools
@@ -167,14 +166,9 @@ async def import_certificate(encoded_request: EncodedRequest):
"""Import a certificate from hyperglass-agent.""" """Import a certificate from hyperglass-agent."""
# Try to match the requested device name with configured devices # Try to match the requested device name with configured devices
matched_device = None try:
requested_device_name = clean_name(encoded_request.device) matched_device = devices[encoded_request.device]
for device in devices.routers: except AttributeError:
if device.name == requested_device_name:
matched_device = device
break
if matched_device is None:
raise HTTPException( raise HTTPException(
detail=f"Device {str(encoded_request.device)} not found", status_code=404 detail=f"Device {str(encoded_request.device)} not found", status_code=404
) )
@@ -191,10 +185,12 @@ async def import_certificate(encoded_request: EncodedRequest):
try: try:
# Write certificate to file # Write certificate to file
import_public_key( import_public_key(
app_path=APP_PATH, device_name=device.name, keystring=decoded_request app_path=APP_PATH,
device_name=matched_device.name,
keystring=decoded_request,
) )
except RuntimeError as import_error: except RuntimeError as err:
raise HyperglassError(str(import_error), level="danger") raise HyperglassError(str(err), level="danger")
return { return {
"output": f"Added public key for {encoded_request.device}", "output": f"Added public key for {encoded_request.device}",
@@ -226,7 +222,7 @@ async def routers():
"vrfs": {-1: {"name", "display_name"}}, "vrfs": {-1: {"name", "display_name"}},
} }
) )
for d in devices.routers for d in devices.objects
] ]

View File

@@ -28,9 +28,6 @@ from hyperglass.constants import (
__version__, __version__,
) )
from hyperglass.exceptions import ConfigError, ConfigInvalid, ConfigMissing from hyperglass.exceptions import ConfigError, ConfigInvalid, ConfigMissing
from hyperglass.configuration.models import params as _params
from hyperglass.configuration.models import routers as _routers
from hyperglass.configuration.models import commands as _commands
from hyperglass.configuration.defaults import ( from hyperglass.configuration.defaults import (
CREDIT, CREDIT,
DEFAULT_HELP, DEFAULT_HELP,
@@ -38,6 +35,9 @@ from hyperglass.configuration.defaults import (
DEFAULT_DETAILS, DEFAULT_DETAILS,
) )
from hyperglass.configuration.markdown import get_markdown from hyperglass.configuration.markdown import get_markdown
from hyperglass.configuration.models.params import Params
from hyperglass.configuration.models.devices import Devices
from hyperglass.configuration.models.commands import Commands
set_app_path(required=True) set_app_path(required=True)
@@ -165,7 +165,7 @@ set_log_level(logger=log, debug=user_config.get("debug", True))
# Map imported user configuration to expected schema. # Map imported user configuration to expected schema.
log.debug("Unvalidated configuration from {}: {}", CONFIG_MAIN, user_config) log.debug("Unvalidated configuration from {}: {}", CONFIG_MAIN, user_config)
params = _validate_config(config=user_config, importer=_params.Params) params = _validate_config(config=user_config, importer=Params)
# Re-evaluate debug state after config is validated # Re-evaluate debug state after config is validated
log_level = current_log_level(log) log_level = current_log_level(log)
@@ -178,16 +178,12 @@ elif not params.debug and log_level == "debug":
# Map imported user commands to expected schema. # Map imported user commands to expected schema.
_user_commands = _config_optional(CONFIG_COMMANDS) _user_commands = _config_optional(CONFIG_COMMANDS)
log.debug("Unvalidated commands from {}: {}", CONFIG_COMMANDS, _user_commands) log.debug("Unvalidated commands from {}: {}", CONFIG_COMMANDS, _user_commands)
commands = _validate_config( commands = _validate_config(config=_user_commands, importer=Commands.import_params)
config=_user_commands, importer=_commands.Commands.import_params
)
# Map imported user devices to expected schema. # Map imported user devices to expected schema.
_user_devices = _config_required(CONFIG_DEVICES) _user_devices = _config_required(CONFIG_DEVICES)
log.debug("Unvalidated devices from {}: {}", CONFIG_DEVICES, _user_devices) log.debug("Unvalidated devices from {}: {}", CONFIG_DEVICES, _user_devices)
devices = _validate_config( devices = _validate_config(config=_user_devices.get("routers", []), importer=Devices)
config=_user_devices.get("routers", []), importer=_routers.Routers._import,
)
# Validate commands are both supported and properly mapped. # Validate commands are both supported and properly mapped.
_validate_nos_commands(devices.all_nos, commands) _validate_nos_commands(devices.all_nos, commands)
@@ -226,7 +222,7 @@ try:
# If keywords are unmodified (default), add the org name & # If keywords are unmodified (default), add the org name &
# site_title. # site_title.
if _params.Params().site_keywords == params.site_keywords: if Params().site_keywords == params.site_keywords:
params.site_keywords = sorted( params.site_keywords = sorted(
{*params.site_keywords, params.org_name, params.site_title} {*params.site_keywords, params.org_name, params.site_title}
) )
@@ -258,7 +254,7 @@ def _build_frontend_networks():
{dict} -- Frontend networks {dict} -- Frontend networks
""" """
frontend_dict = {} frontend_dict = {}
for device in devices.routers: for device in devices.objects:
if device.network.display_name in frontend_dict: if device.network.display_name in frontend_dict:
frontend_dict[device.network.display_name].update( frontend_dict[device.network.display_name].update(
{ {
@@ -302,7 +298,7 @@ def _build_frontend_devices():
{dict} -- Frontend devices {dict} -- Frontend devices
""" """
frontend_dict = {} frontend_dict = {}
for device in devices.routers: for device in devices.objects:
if device.name in frontend_dict: if device.name in frontend_dict:
frontend_dict[device.name].update( frontend_dict[device.name].update(
{ {
@@ -348,11 +344,11 @@ def _build_networks():
{dict} -- Networks & devices {dict} -- Networks & devices
""" """
networks = [] networks = []
_networks = list(set({device.network.display_name for device in devices.routers})) _networks = list(set({device.network.display_name for device in devices.objects}))
for _network in _networks: for _network in _networks:
network_def = {"display_name": _network, "locations": []} network_def = {"display_name": _network, "locations": []}
for device in devices.routers: for device in devices.objects:
if device.network.display_name == _network: if device.network.display_name == _network:
network_def["locations"].append( network_def["locations"].append(
{ {
@@ -374,7 +370,7 @@ def _build_networks():
def _build_vrfs(): def _build_vrfs():
vrfs = [] vrfs = []
for device in devices.routers: for device in devices.objects:
for vrf in device.vrfs: for vrf in device.vrfs:
vrf_dict = { vrf_dict = {

View File

@@ -0,0 +1,14 @@
"""Validate credential configuration variables."""
# Third Party
from pydantic import SecretStr, StrictStr
# Project
from hyperglass.models import HyperglassModel
class Credential(HyperglassModel):
"""Model for per-credential config in devices.yaml."""
username: StrictStr
password: SecretStr

View File

@@ -1,35 +0,0 @@
"""Validate credential configuration variables."""
# Third Party
from pydantic import SecretStr, StrictStr
# Project
from hyperglass.util import clean_name
from hyperglass.models import HyperglassModel
class Credential(HyperglassModel):
"""Model for per-credential config in devices.yaml."""
username: StrictStr
password: SecretStr
class Credentials(HyperglassModel):
"""Base model for credentials class."""
@classmethod
def import_params(cls, input_params):
"""Import credentials with corrected field names.
Arguments:
input_params {dict} -- Credential definition
Returns:
{object} -- Validated credential object
"""
obj = Credentials()
for (credname, params) in input_params.items():
cred = clean_name(credname)
setattr(Credentials, cred, Credential(**params))
return obj

View File

@@ -3,23 +3,24 @@
# Standard Library # Standard Library
import os import os
import re import re
from typing import List, Optional from typing import Any, Dict, List, Union, Optional
from pathlib import Path from pathlib import Path
from ipaddress import IPv4Address, IPv6Address
# Third Party # Third Party
from pydantic import StrictInt, StrictStr, StrictBool, validator from pydantic import StrictInt, StrictStr, StrictBool, validator
# Project # Project
from hyperglass.log import log from hyperglass.log import log
from hyperglass.util import clean_name, validate_nos from hyperglass.util import validate_nos, resolve_hostname
from hyperglass.models import HyperglassModel, HyperglassModelExtra from hyperglass.models import HyperglassModel, HyperglassModelExtra
from hyperglass.constants import SCRAPE_HELPERS, SUPPORTED_STRUCTURED_OUTPUT from hyperglass.constants import SCRAPE_HELPERS, SUPPORTED_STRUCTURED_OUTPUT
from hyperglass.exceptions import ConfigError, UnsupportedDevice from hyperglass.exceptions import ConfigError, UnsupportedDevice
from hyperglass.configuration.models.ssl import Ssl from hyperglass.configuration.models.ssl import Ssl
from hyperglass.configuration.models.vrfs import Vrf, Info from hyperglass.configuration.models.vrf import Vrf, Info
from hyperglass.configuration.models.proxies import Proxy from hyperglass.configuration.models.proxy import Proxy
from hyperglass.configuration.models.networks import Network from hyperglass.configuration.models.network import Network
from hyperglass.configuration.models.credentials import Credential from hyperglass.configuration.models.credential import Credential
_default_vrf = { _default_vrf = {
"name": "default", "name": "default",
@@ -38,11 +39,11 @@ _default_vrf = {
} }
class Router(HyperglassModel): class Device(HyperglassModel):
"""Validation model for per-router config in devices.yaml.""" """Validation model for per-router config in devices.yaml."""
name: StrictStr name: StrictStr
address: StrictStr address: Union[IPv4Address, IPv6Address, StrictStr]
network: Network network: Network
credential: Credential credential: Credential
proxy: Optional[Proxy] proxy: Optional[Proxy]
@@ -56,6 +57,35 @@ class Router(HyperglassModel):
vrf_names: List[StrictStr] = [] vrf_names: List[StrictStr] = []
structured_output: Optional[StrictBool] structured_output: Optional[StrictBool]
def __hash__(self) -> int:
"""Make device object hashable so the object can be deduplicated with set()."""
return hash((self.name,))
def __eq__(self, other: Any) -> bool:
"""Make device object comparable so the object can be deduplicated with set()."""
result = False
if isinstance(other, HyperglassModel):
result = self.name == other.name
return result
@property
def _target(self):
return str(self.address)
@validator("address")
def validate_address(cls, value, values):
"""Ensure a hostname is resolvable."""
if not isinstance(value, (IPv4Address, IPv6Address)):
if not any(resolve_hostname(value)):
raise ConfigError(
"Device '{d}' has an address of '{a}', which is not resolvable.",
d=values["name"],
a=value,
)
return value
@validator("structured_output", pre=True, always=True) @validator("structured_output", pre=True, always=True)
def validate_structured_output(cls, value, values): def validate_structured_output(cls, value, values):
"""Validate structured output is supported on the device & set a default. """Validate structured output is supported on the device & set a default.
@@ -101,18 +131,6 @@ class Router(HyperglassModel):
return value return value
@validator("name")
def validate_name(cls, value):
"""Remove or replace unsupported characters from field values.
Arguments:
value {str} -- Raw name/location
Returns:
{} -- Valid name/location
"""
return clean_name(value)
@validator("ssl") @validator("ssl")
def validate_ssl(cls, value, values): def validate_ssl(cls, value, values):
"""Set default cert file location if undefined. """Set default cert file location if undefined.
@@ -219,17 +237,18 @@ class Router(HyperglassModel):
return vrfs return vrfs
class Routers(HyperglassModelExtra): class Devices(HyperglassModelExtra):
"""Validation model for device configurations.""" """Validation model for device configurations."""
hostnames: List[StrictStr] = [] hostnames: List[StrictStr] = []
vrfs: List[StrictStr] = [] vrfs: List[StrictStr] = []
display_vrfs: List[StrictStr] = [] display_vrfs: List[StrictStr] = []
routers: List[Router] = [] vrf_objects: List[Vrf] = []
networks: List[StrictStr] = [] objects: List[Device] = []
all_nos: List[StrictStr] = []
default_vrf: Vrf = Vrf(name="default", display_name="Global")
@classmethod def __init__(self, input_params: List[Dict]) -> None:
def _import(cls, input_params):
"""Import loaded YAML, initialize per-network definitions. """Import loaded YAML, initialize per-network definitions.
Remove unsupported characters from device names, dynamically Remove unsupported characters from device names, dynamically
@@ -243,33 +262,27 @@ class Routers(HyperglassModelExtra):
{object} -- Validated routers object {object} -- Validated routers object
""" """
vrfs = set() vrfs = set()
networks = set()
display_vrfs = set() display_vrfs = set()
vrf_objects = set() vrf_objects = set()
all_nos = set() all_nos = set()
router_objects = [] objects = set()
routers = Routers() hostnames = set()
routers.hostnames = []
routers.vrfs = [] init_kwargs = {}
routers.display_vrfs = []
for definition in input_params: for definition in input_params:
# Validate each router config against Router() model/schema # Validate each router config against Router() model/schema
router = Router(**definition) device = Device(**definition)
# Set a class attribute for each router so each router's
# attributes can be accessed with `devices.router_hostname`
setattr(routers, router.name, router)
# Add router-level attributes (assumed to be unique) to # Add router-level attributes (assumed to be unique) to
# class lists, e.g. so all hostnames can be accessed as a # class lists, e.g. so all hostnames can be accessed as a
# list with `devices.hostnames`, same for all router # list with `devices.hostnames`, same for all router
# classes, for when iteration over all routers is required. # classes, for when iteration over all routers is required.
routers.hostnames.append(router.name) hostnames.add(device.name)
router_objects.append(router) objects.add(device)
all_nos.add(router.nos) all_nos.add(device.nos)
for vrf in router.vrfs: for vrf in device.vrfs:
# For each configured router VRF, add its name and # For each configured router VRF, add its name and
# display_name to a class set (for automatic de-duping). # display_name to a class set (for automatic de-duping).
@@ -278,34 +291,43 @@ class Routers(HyperglassModelExtra):
# Also add the names to a router-level list so each # Also add the names to a router-level list so each
# router's VRFs and display VRFs can be easily accessed. # router's VRFs and display VRFs can be easily accessed.
router.display_vrfs.append(vrf.display_name) device.display_vrfs.append(vrf.display_name)
router.vrf_names.append(vrf.name) device.vrf_names.append(vrf.name)
# Add a 'default_vrf' attribute to the devices class # Add a 'default_vrf' attribute to the devices class
# which contains the configured default VRF display name. # which contains the configured default VRF display name.
if vrf.name == "default" and not hasattr(cls, "default_vrf"): if vrf.name == "default" and not hasattr(self, "default_vrf"):
routers.default_vrf = { init_kwargs["default_vrf"] = Vrf(
"name": vrf.name, name=vrf.name, display_name=vrf.display_name
"display_name": vrf.display_name, )
}
# Add the native VRF objects to a set (for automatic # Add the native VRF objects to a set (for automatic
# de-duping), but exlcude device-specific fields. # de-duping), but exlcude device-specific fields.
_copy_params = { vrf_objects.add(
"deep": True, vrf.copy(
"exclude": {"ipv4": {"source_address"}, "ipv6": {"source_address"}}, deep=True,
} exclude={
vrf_objects.add(vrf.copy(**_copy_params)) "ipv4": {"source_address"},
"ipv6": {"source_address"},
},
)
)
# Convert the de-duplicated sets to a standard list, add lists # Convert the de-duplicated sets to a standard list, add lists
# as class attributes. # as class attributes. Sort router list by router name attribute
routers.vrfs = list(vrfs) init_kwargs["hostnames"] = list(hostnames)
routers.display_vrfs = list(display_vrfs) init_kwargs["all_nos"] = list(all_nos)
routers.vrf_objects = list(vrf_objects) init_kwargs["vrfs"] = list(vrfs)
routers.networks = list(networks) init_kwargs["display_vrfs"] = list(vrfs)
routers.all_nos = list(all_nos) init_kwargs["vrf_objects"] = list(vrf_objects)
init_kwargs["objects"] = sorted(objects, key=lambda x: x.display_name)
# Sort router list by router name attribute super().__init__(**init_kwargs)
routers.routers = sorted(router_objects, key=lambda x: x.display_name)
return routers def __getitem__(self, accessor: str) -> Device:
"""Get a device by its name."""
for device in self.objects:
if device.name == accessor:
return device
raise AttributeError(f"No device named '{accessor}'")

View File

@@ -0,0 +1,22 @@
"""Validate network configuration variables."""
# Third Party
from pydantic import Field, StrictStr
# Project
from hyperglass.models import HyperglassModel
class Network(HyperglassModel):
"""Validation Model for per-network/asn config in devices.yaml."""
name: StrictStr = Field(
...,
title="Network Name",
description="Internal name of the device's primary network.",
)
display_name: StrictStr = Field(
...,
title="Network Display Name",
description="Display name of the device's primary network.",
)

View File

@@ -1,50 +0,0 @@
"""Validate network configuration variables."""
# Third Party
from pydantic import Field, StrictStr
# Project
from hyperglass.util import clean_name
from hyperglass.models import HyperglassModel
class Network(HyperglassModel):
"""Validation Model for per-network/asn config in devices.yaml."""
name: StrictStr = Field(
...,
title="Network Name",
description="Internal name of the device's primary network.",
)
display_name: StrictStr = Field(
...,
title="Network Display Name",
description="Display name of the device's primary network.",
)
class Networks(HyperglassModel):
"""Base model for networks class."""
@classmethod
def import_params(cls, input_params):
"""Import loaded YAML, initialize per-network definitions.
Remove unsupported characters from network names, dynamically
set attributes for the networks class. Add cls.networks
attribute so network objects can be accessed inside a dict.
Arguments:
input_params {dict} -- Unvalidated network definitions
Returns:
{object} -- Validated networks object
"""
obj = Networks()
networks = {}
for (netname, params) in input_params.items():
netname = clean_name(netname)
setattr(Networks, netname, Network(**params))
networks.update({netname: Network(**params).dict()})
Networks.networks = networks
return obj

View File

@@ -1,57 +0,0 @@
"""Validate SSH proxy configuration variables."""
# Third Party
from pydantic import StrictInt, StrictStr, validator
# Project
from hyperglass.util import clean_name
from hyperglass.models import HyperglassModel
from hyperglass.exceptions import UnsupportedDevice
from hyperglass.configuration.models.credentials import Credential
class Proxy(HyperglassModel):
"""Validation model for per-proxy config in devices.yaml."""
name: StrictStr
address: StrictStr
port: StrictInt = 22
credential: Credential
nos: StrictStr = "linux_ssh"
@validator("nos")
def supported_nos(cls, value):
"""Verify NOS is supported by hyperglass.
Raises:
UnsupportedDevice: Raised if NOS is not supported.
Returns:
{str} -- Valid NOS name
"""
if not value == "linux_ssh":
raise UnsupportedDevice(f'"{value}" device type is not supported.')
return value
class Proxies(HyperglassModel):
"""Validation model for SSH proxy configuration."""
@classmethod
def import_params(cls, input_params):
"""Import loaded YAML, initialize per-proxy definitions.
Remove unsupported characters from proxy names, dynamically
set attributes for the proxies class.
Arguments:
input_params {dict} -- Unvalidated proxy definitions
Returns:
{object} -- Validated proxies object
"""
obj = Proxies()
for (devname, params) in input_params.items():
dev = clean_name(devname)
setattr(Proxies, dev, Proxy(**params))
return obj

View File

@@ -0,0 +1,56 @@
"""Validate SSH proxy configuration variables."""
# Standard Library
from typing import Union
from ipaddress import IPv4Address, IPv6Address
# Third Party
from pydantic import StrictInt, StrictStr, validator
# Project
from hyperglass.util import resolve_hostname
from hyperglass.models import HyperglassModel
from hyperglass.exceptions import ConfigError, UnsupportedDevice
from hyperglass.configuration.models.credential import Credential
class Proxy(HyperglassModel):
"""Validation model for per-proxy config in devices.yaml."""
name: StrictStr
address: Union[IPv4Address, IPv6Address, StrictStr]
port: StrictInt = 22
credential: Credential
nos: StrictStr = "linux_ssh"
@property
def _target(self):
return str(self.address)
@validator("address")
def validate_address(cls, value, values):
"""Ensure a hostname is resolvable."""
if not isinstance(value, (IPv4Address, IPv6Address)):
if not any(resolve_hostname(value)):
raise ConfigError(
"Device '{d}' has an address of '{a}', which is not resolvable.",
d=values["name"],
a=value,
)
return value
@validator("nos")
def supported_nos(cls, value, values):
"""Verify NOS is supported by hyperglass.
Raises:
UnsupportedDevice: Raised if NOS is not supported.
Returns:
{str} -- Valid NOS name
"""
if not value == "linux_ssh":
raise UnsupportedDevice(
f"Proxy '{values['name']}' uses NOS '{value}', which is currently unsupported."
)
return value

View File

@@ -197,6 +197,13 @@ class Queries(HyperglassModel):
ping: Ping = Ping() ping: Ping = Ping()
traceroute: Traceroute = Traceroute() traceroute: Traceroute = Traceroute()
def __getitem__(self, query_type: str):
"""Get a query's object by name."""
if hasattr(self, query_type):
return getattr(self, query_type)
raise AttributeError(f"Query '{query_type}' is invalid")
class Config: class Config:
"""Pydantic model configuration.""" """Pydantic model configuration."""

View File

@@ -9,13 +9,13 @@ from hyperglass.parsing.nos import nos_parsers
from hyperglass.parsing.common import parsers from hyperglass.parsing.common import parsers
from hyperglass.api.models.query import Query from hyperglass.api.models.query import Query
from hyperglass.execution.construct import Construct from hyperglass.execution.construct import Construct
from hyperglass.configuration.models.routers import Router from hyperglass.configuration.models.devices import Device
class Connection: class Connection:
"""Base transport driver class.""" """Base transport driver class."""
def __init__(self, device: Router, query_data: Query) -> None: def __init__(self, device: Device, query_data: Query) -> None:
"""Initialize connection to device.""" """Initialize connection to device."""
self.device = device self.device = device
self.query_data = query_data self.query_data = query_data

View File

@@ -23,14 +23,7 @@ from hyperglass.execution.drivers._common import Connection
class AgentConnection(Connection): class AgentConnection(Connection):
"""Connect to target device via specified transport. """Connect to target device via hyperglass-agent."""
scrape_direct() directly connects to devices via SSH
scrape_proxied() connects to devices via an SSH proxy
rest() connects to devices via HTTP for RESTful API communication
"""
async def collect(self) -> Iterable: # noqa: C901 async def collect(self) -> Iterable: # noqa: C901
"""Connect to a device running hyperglass-agent via HTTP.""" """Connect to a device running hyperglass-agent via HTTP."""
@@ -60,7 +53,7 @@ class AgentConnection(Connection):
else: else:
http_protocol = "http" http_protocol = "http"
endpoint = "{protocol}://{address}:{port}/query/".format( endpoint = "{protocol}://{address}:{port}/query/".format(
protocol=http_protocol, address=self.device.address, port=self.device.port protocol=http_protocol, address=self.device._target, port=self.device.port
) )
log.debug(f"URL endpoint: {endpoint}") log.debug(f"URL endpoint: {endpoint}")

View File

@@ -23,11 +23,11 @@ class SSHConnection(Connection):
"""Set up an SSH tunnel according to a device's configuration.""" """Set up an SSH tunnel according to a device's configuration."""
try: try:
return open_tunnel( return open_tunnel(
proxy.address, proxy._target,
proxy.port, proxy.port,
ssh_username=proxy.credential.username, ssh_username=proxy.credential.username,
ssh_password=proxy.credential.password.get_secret_value(), ssh_password=proxy.credential.password.get_secret_value(),
remote_bind_address=(self.device.address, self.device.port), remote_bind_address=(self.device._target, self.device.port),
local_bind_address=("localhost", 0), local_bind_address=("localhost", 0),
skip_tunnel_checkup=False, skip_tunnel_checkup=False,
gateway_timeout=params.request_timeout - 2, gateway_timeout=params.request_timeout - 2,

View File

@@ -43,7 +43,7 @@ class NetmikoConnection(SSHConnection):
log.debug("Connecting directly to {}", self.device.name) log.debug("Connecting directly to {}", self.device.name)
netmiko_args = { netmiko_args = {
"host": host or self.device.address, "host": host or self.device._target,
"port": port or self.device.port, "port": port or self.device.port,
"device_type": self.device.nos, "device_type": self.device.nos,
"username": self.device.credential.username, "username": self.device.credential.username,

View File

@@ -72,7 +72,7 @@ class ScrapliConnection(SSHConnection):
log.debug("Connecting directly to {}", self.device.name) log.debug("Connecting directly to {}", self.device.name)
driver_kwargs = { driver_kwargs = {
"host": host or self.device.address, "host": host or self.device._target,
"port": port or self.device.port, "port": port or self.device.port,
"auth_username": self.device.credential.username, "auth_username": self.device.credential.username,
"auth_password": self.device.credential.password.get_secret_value(), "auth_password": self.device.credential.password.get_secret_value(),

View File

@@ -14,7 +14,7 @@ from typing import Any, Dict, Union, Callable
from hyperglass.log import log from hyperglass.log import log
from hyperglass.util import validate_nos from hyperglass.util import validate_nos
from hyperglass.exceptions import DeviceTimeout, ResponseEmpty from hyperglass.exceptions import DeviceTimeout, ResponseEmpty
from hyperglass.configuration import params, devices from hyperglass.configuration import params
from hyperglass.api.models.query import Query from hyperglass.api.models.query import Query
from hyperglass.execution.drivers import ( from hyperglass.execution.drivers import (
AgentConnection, AgentConnection,
@@ -42,29 +42,28 @@ async def execute(query: Query) -> Union[str, Dict]:
"""Initiate query validation and execution.""" """Initiate query validation and execution."""
output = params.messages.general output = params.messages.general
device = getattr(devices, query.query_location)
log.debug(f"Received query for {query}") log.debug(f"Received query for {query}")
log.debug(f"Matched device config: {device}") log.debug(f"Matched device config: {query.device}")
supported, driver_name = validate_nos(device.nos) supported, driver_name = validate_nos(query.device.nos)
mapped_driver = DRIVER_MAP.get(driver_name, NetmikoConnection) mapped_driver = DRIVER_MAP.get(driver_name, NetmikoConnection)
driver = mapped_driver(device, query) driver = mapped_driver(query.device, query)
timeout_args = { timeout_args = {
"unformatted_msg": params.messages.connection_error, "unformatted_msg": params.messages.connection_error,
"device_name": device.display_name, "device_name": query.device.display_name,
"error": params.messages.request_timeout, "error": params.messages.request_timeout,
} }
if device.proxy: if query.device.proxy:
timeout_args["proxy"] = device.proxy.name timeout_args["proxy"] = query.device.proxy.name
signal.signal(signal.SIGALRM, handle_timeout(**timeout_args)) signal.signal(signal.SIGALRM, handle_timeout(**timeout_args))
signal.alarm(params.request_timeout - 1) signal.alarm(params.request_timeout - 1)
if device.proxy: if query.device.proxy:
proxy = driver.setup_proxy() proxy = driver.setup_proxy()
with proxy() as tunnel: with proxy() as tunnel:
response = await driver.collect( response = await driver.collect(
@@ -76,7 +75,9 @@ async def execute(query: Query) -> Union[str, Dict]:
output = await driver.parsed_response(response) output = await driver.parsed_response(response)
if output == "" or output == "\n": if output == "" or output == "\n":
raise ResponseEmpty(params.messages.no_output, device_name=device.display_name) raise ResponseEmpty(
params.messages.no_output, device_name=query.device.display_name
)
log.debug(f"Output for query: {query.json()}:\n{repr(output)}") log.debug(f"Output for query: {query.json()}:\n{repr(output)}")
signal.alarm(0) signal.alarm(0)

View File

@@ -17,7 +17,6 @@ from pydantic import (
# Project # Project
from hyperglass.log import log from hyperglass.log import log
from hyperglass.util import clean_name
IntFloat = TypeVar("IntFloat", StrictInt, StrictFloat) IntFloat = TypeVar("IntFloat", StrictInt, StrictFloat)
@@ -25,6 +24,18 @@ _WEBHOOK_TITLE = "hyperglass received a valid query with the following data"
_ICON_URL = "https://res.cloudinary.com/hyperglass/image/upload/v1593192484/icon.png" _ICON_URL = "https://res.cloudinary.com/hyperglass/image/upload/v1593192484/icon.png"
def clean_name(_name: str) -> str:
"""Remove unsupported characters from field names.
Converts any "desirable" seperators to underscore, then removes all
characters that are unsupported in Python class variable names.
Also removes leading numbers underscores.
"""
_replaced = re.sub(r"[\-|\.|\@|\~|\:\/|\s]", "_", _name)
_scrubbed = "".join(re.findall(r"([a-zA-Z]\w+|\_+)", _replaced))
return _scrubbed.lower()
class HyperglassModel(BaseModel): class HyperglassModel(BaseModel):
"""Base model for all hyperglass configuration models.""" """Base model for all hyperglass configuration models."""

View File

@@ -8,9 +8,9 @@ import math
import shutil import shutil
import asyncio import asyncio
from queue import Queue from queue import Queue
from typing import Dict, Union, Iterable, Optional from typing import Dict, Union, Iterable, Optional, Generator
from pathlib import Path from pathlib import Path
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address, ip_address
from threading import Thread from threading import Thread
# Third Party # Third Party
@@ -18,6 +18,7 @@ from loguru._logger import Logger as LoguruLogger
# Project # Project
from hyperglass.log import log from hyperglass.log import log
from hyperglass.models import HyperglassModel
def cpu_count(multiplier: int = 0): def cpu_count(multiplier: int = 0):
@@ -33,18 +34,6 @@ def cpu_count(multiplier: int = 0):
return multiprocessing.cpu_count() * multiplier return multiprocessing.cpu_count() * multiplier
def clean_name(_name: str) -> str:
"""Remove unsupported characters from field names.
Converts any "desirable" seperators to underscore, then removes all
characters that are unsupported in Python class variable names.
Also removes leading numbers underscores.
"""
_replaced = re.sub(r"[\-|\.|\@|\~|\:\/|\s]", "_", _name)
_scrubbed = "".join(re.findall(r"([a-zA-Z]\w+|\_+)", _replaced))
return _scrubbed.lower()
def check_path( def check_path(
path: Union[Path, str], mode: str = "r", create: bool = False path: Union[Path, str], mode: str = "r", create: bool = False
) -> Optional[Path]: ) -> Optional[Path]:
@@ -925,3 +914,34 @@ def validation_error_message(*errors: Dict) -> str:
errs += (f'Field: {loc}\n Error: {err["msg"]}\n',) errs += (f'Field: {loc}\n Error: {err["msg"]}\n',)
return "\n".join(errs) return "\n".join(errs)
def resolve_hostname(hostname: str) -> Generator:
"""Resolve a hostname via DNS/hostfile."""
from socket import getaddrinfo, gaierror
log.debug("Ensuring '{}' is resolvable...", hostname)
ip4 = None
ip6 = None
try:
res = getaddrinfo(hostname, None)
if len(res) == 2:
addr = ip_address(res[0][4][0])
if addr.version == 6:
ip6 = addr
else:
ip4 = addr
elif len(res) == 4:
addr1 = ip_address(res[0][4][0])
addr2 = ip_address(res[2][4][0])
for a in (addr1, addr2):
if a.version == 4:
ip4 = a
elif a.version == 6:
ip6 = a
except gaierror:
pass
yield ip4
yield ip6

View File

@@ -50,12 +50,12 @@ def _comment_optional_files():
def _validate_devices(): def _validate_devices():
from hyperglass.configuration.models.routers import Routers from hyperglass.configuration.models.devices import Devices
with DEVICES.open() as raw: with DEVICES.open() as raw:
devices_dict = yaml.safe_load(raw.read()) or {} devices_dict = yaml.safe_load(raw.read()) or {}
try: try:
Routers._import(devices_dict.get("routers", [])) Devices(devices_dict.get("routers", []))
except Exception as e: except Exception as e:
raise ValueError(str(e)) raise ValueError(str(e))
return True return True