mirror of
https://github.com/checktheroads/hyperglass
synced 2024-05-11 05:55:08 +00:00
Add separate hooks for major state objects, add tests
This commit is contained in:
1
.flake8
1
.flake8
@@ -19,6 +19,7 @@ per-file-ignores=
|
||||
# Disable assertion and docstring checks on tests.
|
||||
hyperglass/**/test_*.py:S101,D103
|
||||
hyperglass/api/*.py:B008
|
||||
hyperglass/state/hooks.py:F811
|
||||
ignore=W503,C0330,R504,D202,S403,S301,S404,E731,D402
|
||||
select=B, BLK, C, D, E, F, I, II, N, P, PIE, S, R, W
|
||||
disable-noqa=False
|
||||
|
@@ -4,10 +4,10 @@
|
||||
from hyperglass.state import use_state
|
||||
|
||||
|
||||
def check_redis() -> bool:
|
||||
def check_redis() -> None:
|
||||
"""Ensure Redis is running before starting server."""
|
||||
state = use_state()
|
||||
return state.redis.ping()
|
||||
cache = use_state("cache")
|
||||
cache.check()
|
||||
|
||||
|
||||
on_startup = (check_redis,)
|
||||
|
@@ -3,6 +3,7 @@
|
||||
# Standard Library
|
||||
import json
|
||||
import time
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
# Third Party
|
||||
@@ -16,26 +17,44 @@ from hyperglass.state import HyperglassState, use_state
|
||||
from hyperglass.external import Webhook, bgptools
|
||||
from hyperglass.api.tasks import process_headers
|
||||
from hyperglass.constants import __version__
|
||||
from hyperglass.models.ui import UIParameters
|
||||
from hyperglass.exceptions import HyperglassError
|
||||
from hyperglass.models.api import Query
|
||||
from hyperglass.execution.main import execute
|
||||
from hyperglass.models.config.params import Params
|
||||
from hyperglass.models.config.devices import Devices
|
||||
|
||||
# Local
|
||||
from .fake_output import fake_output
|
||||
|
||||
|
||||
def get_state():
|
||||
def get_state(attr: t.Optional[str] = None):
|
||||
"""Get hyperglass state as a FastAPI dependency."""
|
||||
return use_state()
|
||||
return use_state(attr)
|
||||
|
||||
|
||||
def get_params():
|
||||
"""Get hyperglass params as FastAPI dependency."""
|
||||
return use_state("params")
|
||||
|
||||
|
||||
def get_devices():
|
||||
"""Get hyperglass devices as FastAPI dependency."""
|
||||
return use_state("devices")
|
||||
|
||||
|
||||
def get_ui_params():
|
||||
"""Get hyperglass ui_params as FastAPI dependency."""
|
||||
return use_state("ui_params")
|
||||
|
||||
|
||||
async def send_webhook(
|
||||
query_data: Query, request: Request, timestamp: datetime,
|
||||
):
|
||||
"""If webhooks are enabled, get request info and send a webhook."""
|
||||
state = use_state()
|
||||
params = use_state("params")
|
||||
try:
|
||||
if state.params.logging.http is not None:
|
||||
if params.logging.http is not None:
|
||||
headers = await process_headers(headers=request.headers)
|
||||
|
||||
if headers.get("x-real-ip") is not None:
|
||||
@@ -47,7 +66,7 @@ async def send_webhook(
|
||||
|
||||
network_info = await bgptools.network_info(host)
|
||||
|
||||
async with Webhook(state.params.logging.http) as hook:
|
||||
async with Webhook(params.logging.http) as hook:
|
||||
|
||||
await hook.send(
|
||||
query={
|
||||
@@ -59,7 +78,7 @@ async def send_webhook(
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
log.error("Error sending webhook to {}: {}", state.params.logging.http.provider, str(err))
|
||||
log.error("Error sending webhook to {}: {}", params.logging.http.provider, str(err))
|
||||
|
||||
|
||||
async def query(
|
||||
@@ -83,7 +102,7 @@ async def query(
|
||||
|
||||
log.info("Starting query execution for query {}", query_data.summary)
|
||||
|
||||
cache_response = cache.get_dict(cache_key, "output")
|
||||
cache_response = cache.get_map(cache_key, "output")
|
||||
|
||||
json_output = False
|
||||
|
||||
@@ -104,7 +123,7 @@ async def query(
|
||||
|
||||
cached = True
|
||||
runtime = 0
|
||||
timestamp = cache.get_dict(cache_key, "timestamp")
|
||||
timestamp = cache.get_map(cache_key, "timestamp")
|
||||
|
||||
elif not cache_response:
|
||||
log.debug("No existing cache entry for query {}", cache_key)
|
||||
@@ -133,8 +152,8 @@ async def query(
|
||||
raw_output = json.dumps(cache_output)
|
||||
else:
|
||||
raw_output = str(cache_output)
|
||||
cache.set_dict(cache_key, "output", raw_output)
|
||||
cache.set_dict(cache_key, "timestamp", timestamp)
|
||||
cache.set_map_item(cache_key, "output", raw_output)
|
||||
cache.set_map_item(cache_key, "timestamp", timestamp)
|
||||
cache.expire(cache_key, seconds=state.params.cache.timeout)
|
||||
|
||||
log.debug("Added cache entry for query: {}", cache_key)
|
||||
@@ -164,46 +183,46 @@ async def query(
|
||||
}
|
||||
|
||||
|
||||
async def docs(state: "HyperglassState" = Depends(get_state)):
|
||||
async def docs(params: "Params" = Depends(get_params)):
|
||||
"""Serve custom docs."""
|
||||
if state.params.docs.enable:
|
||||
if params.docs.enable:
|
||||
docs_func_map = {"swagger": get_swagger_ui_html, "redoc": get_redoc_html}
|
||||
docs_func = docs_func_map[state.params.docs.mode]
|
||||
docs_func = docs_func_map[params.docs.mode]
|
||||
return docs_func(
|
||||
openapi_url=state.params.docs.openapi_url, title=state.params.site_title + " - API Docs"
|
||||
openapi_url=params.docs.openapi_url, title=params.site_title + " - API Docs"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(detail="Not found", status_code=404)
|
||||
|
||||
|
||||
async def router(id: str, state: "HyperglassState" = Depends(get_state)):
|
||||
async def router(id: str, devices: "Devices" = Depends(get_devices)):
|
||||
"""Get a device's API-facing attributes."""
|
||||
return state.devices[id].export_api()
|
||||
return devices[id].export_api()
|
||||
|
||||
|
||||
async def routers(state: "HyperglassState" = Depends(get_state)):
|
||||
async def routers(devices: "Devices" = Depends(get_devices)):
|
||||
"""Serve list of configured routers and attributes."""
|
||||
return state.devices.export_api()
|
||||
return devices.export_api()
|
||||
|
||||
|
||||
async def queries(state: "HyperglassState" = Depends(get_state)):
|
||||
async def queries(params: "Params" = Depends(get_params)):
|
||||
"""Serve list of enabled query types."""
|
||||
return state.params.queries.list
|
||||
return params.queries.list
|
||||
|
||||
|
||||
async def info(state: "HyperglassState" = Depends(get_state)):
|
||||
async def info(params: "Params" = Depends(get_params)):
|
||||
"""Serve general information about this instance of hyperglass."""
|
||||
return {
|
||||
"name": state.params.site_title,
|
||||
"organization": state.params.org_name,
|
||||
"primary_asn": int(state.params.primary_asn),
|
||||
"name": params.site_title,
|
||||
"organization": params.org_name,
|
||||
"primary_asn": int(params.primary_asn),
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
async def ui_props(state: "HyperglassState" = Depends(get_state)):
|
||||
async def ui_props(ui_params: "UIParameters" = Depends(get_ui_params)):
|
||||
"""Serve UI configration."""
|
||||
return state.ui_params
|
||||
return ui_params
|
||||
|
||||
|
||||
endpoints = [query, docs, routers, info, ui_props]
|
||||
|
@@ -14,9 +14,8 @@ if TYPE_CHECKING:
|
||||
from hyperglass.models.api.query import Query
|
||||
from hyperglass.models.config.devices import Device
|
||||
|
||||
_state = use_state()
|
||||
MESSAGES = _state.params.messages
|
||||
TEXT = _state.params.web.text
|
||||
(MESSAGES := use_state("params").messages)
|
||||
(TEXT := use_state("params").web.text)
|
||||
|
||||
|
||||
class ScrapeError(
|
||||
|
@@ -38,11 +38,11 @@ class AgentConnection(Connection):
|
||||
async def collect(self) -> Iterable: # noqa: C901
|
||||
"""Connect to a device running hyperglass-agent via HTTP."""
|
||||
log.debug("Query parameters: {}", self.query)
|
||||
state = use_state()
|
||||
params = use_state("params")
|
||||
|
||||
client_params = {
|
||||
"headers": {"Content-Type": "application/json"},
|
||||
"timeout": state.params.request_timeout,
|
||||
"timeout": params.request_timeout,
|
||||
}
|
||||
if self.device.ssl is not None and self.device.ssl.enable:
|
||||
with self.device.ssl.cert.open("r") as file:
|
||||
@@ -77,7 +77,7 @@ class AgentConnection(Connection):
|
||||
encoded_query = await jwt_encode(
|
||||
payload=query,
|
||||
secret=self.device.credential.password.get_secret_value(),
|
||||
duration=state.params.request_timeout,
|
||||
duration=params.request_timeout,
|
||||
)
|
||||
log.debug("Encoded JWT: {}", encoded_query)
|
||||
|
||||
|
@@ -24,7 +24,7 @@ class SSHConnection(Connection):
|
||||
"""Return a preconfigured sshtunnel.SSHTunnelForwarder instance."""
|
||||
|
||||
proxy = self.device.proxy
|
||||
state = use_state()
|
||||
params = use_state("params")
|
||||
|
||||
def opener():
|
||||
"""Set up an SSH tunnel according to a device's configuration."""
|
||||
@@ -33,7 +33,7 @@ class SSHConnection(Connection):
|
||||
"remote_bind_address": (self.device._target, self.device.port),
|
||||
"local_bind_address": ("localhost", 0),
|
||||
"skip_tunnel_checkup": False,
|
||||
"gateway_timeout": state.params.request_timeout - 2,
|
||||
"gateway_timeout": params.request_timeout - 2,
|
||||
}
|
||||
if proxy.credential._method == "password":
|
||||
# Use password auth if no key is defined.
|
||||
|
@@ -46,7 +46,7 @@ class NetmikoConnection(SSHConnection):
|
||||
Directly connects to the router via Netmiko library, returns the
|
||||
command output.
|
||||
"""
|
||||
state = use_state()
|
||||
params = use_state("params")
|
||||
if host is not None:
|
||||
log.debug(
|
||||
"Connecting to {} via proxy {} [{}]",
|
||||
@@ -66,9 +66,9 @@ class NetmikoConnection(SSHConnection):
|
||||
"port": port or self.device.port,
|
||||
"device_type": self.device.type,
|
||||
"username": self.device.credential.username,
|
||||
"global_delay_factor": state.params.netmiko_delay_factor,
|
||||
"timeout": math.floor(state.params.request_timeout * 1.25),
|
||||
"session_timeout": math.ceil(state.params.request_timeout - 1),
|
||||
"global_delay_factor": params.netmiko_delay_factor,
|
||||
"timeout": math.floor(params.request_timeout * 1.25),
|
||||
"session_timeout": math.ceil(params.request_timeout - 1),
|
||||
**global_args,
|
||||
}
|
||||
|
||||
|
@@ -71,7 +71,7 @@ class ScrapliConnection(SSHConnection):
|
||||
Directly connects to the router via Netmiko library, returns the
|
||||
command output.
|
||||
"""
|
||||
state = use_state()
|
||||
params = use_state("params")
|
||||
driver = _map_driver(self.device.type)
|
||||
|
||||
if host is not None:
|
||||
@@ -90,7 +90,7 @@ class ScrapliConnection(SSHConnection):
|
||||
"host": host or self.device._target,
|
||||
"port": port or self.device.port,
|
||||
"auth_username": self.device.credential.username,
|
||||
"timeout_ops": math.floor(state.params.request_timeout * 1.25),
|
||||
"timeout_ops": math.floor(params.request_timeout * 1.25),
|
||||
"transport": "asyncssh",
|
||||
"auth_strict_key": False,
|
||||
"ssh_known_hosts_file": False,
|
||||
|
@@ -47,10 +47,10 @@ def handle_timeout(**exc_args: Any) -> Callable:
|
||||
|
||||
async def execute(query: "Query") -> Union["OutputDataModel", str]:
|
||||
"""Initiate query validation and execution."""
|
||||
state = use_state()
|
||||
output = state.params.messages.general
|
||||
params = use_state("params")
|
||||
output = params.messages.general
|
||||
|
||||
log.debug("Received query for {}", query.json())
|
||||
log.debug("Received query {}", query.json())
|
||||
log.debug("Matched device config: {}", query.device)
|
||||
|
||||
mapped_driver = map_driver(query.device.driver)
|
||||
@@ -60,7 +60,7 @@ async def execute(query: "Query") -> Union["OutputDataModel", str]:
|
||||
signal.SIGALRM,
|
||||
handle_timeout(error=TimeoutError("Connection timed out"), device=query.device),
|
||||
)
|
||||
signal.alarm(state.params.request_timeout - 1)
|
||||
signal.alarm(params.request_timeout - 1)
|
||||
|
||||
if query.device.proxy:
|
||||
proxy = driver.setup_proxy()
|
||||
|
13
hyperglass/external/_base.py
vendored
13
hyperglass/external/_base.py
vendored
@@ -4,7 +4,9 @@
|
||||
import re
|
||||
import json as _json
|
||||
import socket
|
||||
import typing as t
|
||||
from json import JSONDecodeError
|
||||
from types import TracebackType
|
||||
from socket import gaierror
|
||||
|
||||
# Third Party
|
||||
@@ -86,10 +88,15 @@ class BaseExternal:
|
||||
else:
|
||||
raise self._exception(f"Unable to create session to {self.name}")
|
||||
|
||||
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: t.Optional[t.Type[BaseException]] = None,
|
||||
exc_value: t.Optional[BaseException] = None,
|
||||
exc_traceback: t.Optional[TracebackType] = None,
|
||||
):
|
||||
"""Close connection on exit."""
|
||||
if exc_type is not None:
|
||||
log.error(traceback)
|
||||
log.error(str(exc_value))
|
||||
self._session.close()
|
||||
|
||||
def __repr__(self):
|
||||
@@ -232,7 +239,7 @@ class BaseExternal:
|
||||
response = await self._asession.request(**request)
|
||||
|
||||
if response.status_code not in range(200, 300):
|
||||
status = StatusCode(response.status_code)
|
||||
status = httpx.codes(response.status_code)
|
||||
error = self._parse_response(response)
|
||||
raise self._exception(
|
||||
f'{status.name.replace("_", " ")}: {error}', level="danger"
|
||||
|
85
hyperglass/external/bgptools.py
vendored
85
hyperglass/external/bgptools.py
vendored
@@ -6,7 +6,6 @@
|
||||
|
||||
# Standard Library
|
||||
import re
|
||||
import socket
|
||||
import asyncio
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -87,54 +86,24 @@ async def run_whois(targets: List[str]) -> str:
|
||||
return response.decode()
|
||||
|
||||
|
||||
def run_whois_sync(targets: List[str]) -> str:
|
||||
"""Open raw socket to bgp.tools and execute query."""
|
||||
|
||||
# Construct bulk query
|
||||
query = "\n".join(("begin", *targets, "end\n")).encode()
|
||||
|
||||
# Open the socket to bgp.tools
|
||||
log.debug("Opening connection to bgp.tools")
|
||||
sock = socket.socket()
|
||||
sock.connect(("bgp.tools", 43))
|
||||
sock.send(query)
|
||||
|
||||
# Read the response
|
||||
response = b""
|
||||
while True:
|
||||
data = sock.recv(128)
|
||||
if data:
|
||||
response += data
|
||||
|
||||
else:
|
||||
log.debug("Closing connection to bgp.tools")
|
||||
sock.shutdown(1)
|
||||
sock.close()
|
||||
break
|
||||
|
||||
return response.decode()
|
||||
|
||||
|
||||
async def network_info(*targets: str) -> Dict[str, Dict[str, str]]:
|
||||
"""Get ASN, Containing Prefix, and other info about an internet resource."""
|
||||
|
||||
targets = [str(t) for t in targets]
|
||||
(cache := use_state().redis)
|
||||
cache = use_state("cache")
|
||||
|
||||
# Set default data structure.
|
||||
data = {t: {k: "" for k in DEFAULT_KEYS} for t in targets}
|
||||
|
||||
# Get all cached bgp.tools data.
|
||||
cached = cache.hgetall(CACHE_KEY)
|
||||
cached = cache.get_map(CACHE_KEY) or {}
|
||||
|
||||
# Try to use cached data for each of the items in the list of
|
||||
# resources.
|
||||
for t in targets:
|
||||
|
||||
if t in cached:
|
||||
# Reassign the cached network info to the matching resource.
|
||||
data[t] = cached[t]
|
||||
log.debug("Using cached network info for {}", t)
|
||||
for t in (t for t in targets if t in cached):
|
||||
# Reassign the cached network info to the matching resource.
|
||||
data[t] = cached[t]
|
||||
log.debug("Using cached network info for {}", t)
|
||||
|
||||
# Remove cached items from the resource list so they're not queried.
|
||||
targets = [t for t in targets if t not in cached]
|
||||
@@ -149,7 +118,7 @@ async def network_info(*targets: str) -> Dict[str, Dict[str, str]]:
|
||||
|
||||
# Cache the response
|
||||
for t in targets:
|
||||
cache.hset(CACHE_KEY, t, data[t])
|
||||
cache.set_map_item(CACHE_KEY, t, data[t])
|
||||
log.debug("Cached network info for {}", t)
|
||||
|
||||
except Exception as err:
|
||||
@@ -160,42 +129,4 @@ async def network_info(*targets: str) -> Dict[str, Dict[str, str]]:
|
||||
|
||||
def network_info_sync(*targets: str) -> Dict[str, Dict[str, str]]:
|
||||
"""Get ASN, Containing Prefix, and other info about an internet resource."""
|
||||
|
||||
targets = [str(t) for t in targets]
|
||||
(cache := use_state().redis)
|
||||
|
||||
# Set default data structure.
|
||||
data = {t: {k: "" for k in DEFAULT_KEYS} for t in targets}
|
||||
|
||||
# Get all cached bgp.tools data.
|
||||
cached = cache.hgetall(CACHE_KEY)
|
||||
|
||||
# Try to use cached data for each of the items in the list of
|
||||
# resources.
|
||||
for t in targets:
|
||||
|
||||
if t in cached:
|
||||
# Reassign the cached network info to the matching resource.
|
||||
data[t] = cached[t]
|
||||
log.debug("Using cached network info for {}", t)
|
||||
|
||||
# Remove cached items from the resource list so they're not queried.
|
||||
targets = [t for t in targets if t not in cached]
|
||||
|
||||
try:
|
||||
if targets:
|
||||
whoisdata = run_whois_sync(targets)
|
||||
|
||||
if whoisdata:
|
||||
# If the response is not empty, parse it.
|
||||
data.update(parse_whois(whoisdata, targets))
|
||||
|
||||
# Cache the response
|
||||
for t in targets:
|
||||
cache.hset(CACHE_KEY, t, data[t])
|
||||
log.debug("Cached network info for {}", t)
|
||||
|
||||
except Exception as err:
|
||||
log.error(str(err))
|
||||
|
||||
return data
|
||||
return asyncio.run(network_info(*targets))
|
||||
|
29
hyperglass/external/rpki.py
vendored
29
hyperglass/external/rpki.py
vendored
@@ -1,25 +1,32 @@
|
||||
"""Validate RPKI state via Cloudflare GraphQL API."""
|
||||
|
||||
# Standard Library
|
||||
import typing as t
|
||||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.state import use_state
|
||||
from hyperglass.external._base import BaseExternal
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Standard Library
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
|
||||
RPKI_STATE_MAP = {"Invalid": 0, "Valid": 1, "NotFound": 2, "DEFAULT": 3}
|
||||
RPKI_NAME_MAP = {v: k for k, v in RPKI_STATE_MAP.items()}
|
||||
CACHE_KEY = "hyperglass.external.rpki"
|
||||
|
||||
|
||||
def rpki_state(prefix, asn):
|
||||
def rpki_state(prefix: t.Union["IPv4Address", "IPv6Address", str], asn: t.Union[int, str]) -> int:
|
||||
"""Get RPKI state and map to expected integer."""
|
||||
log.debug("Validating RPKI State for {p} via AS{a}", p=prefix, a=asn)
|
||||
|
||||
(cache := use_state().redis)
|
||||
cache = use_state("cache")
|
||||
|
||||
state = 3
|
||||
ro = f"{prefix}@{asn}"
|
||||
ro = f"{prefix!s}@{asn!s}"
|
||||
|
||||
cached = cache.hget(CACHE_KEY, ro)
|
||||
cached = cache.get_map(CACHE_KEY, ro)
|
||||
|
||||
if cached is not None:
|
||||
state = cached
|
||||
@@ -27,17 +34,21 @@ def rpki_state(prefix, asn):
|
||||
|
||||
ql = 'query GetValidation {{ validation(prefix: "{}", asn: {}) {{ state }} }}'
|
||||
query = ql.format(prefix, asn)
|
||||
|
||||
log.debug("Cloudflare RPKI GraphQL Query: {!r}", query)
|
||||
try:
|
||||
with BaseExternal(base_url="https://rpki.cloudflare.com") as client:
|
||||
response = client._post("/api/graphql", data={"query": query})
|
||||
validation_state = (
|
||||
response.get("data", {}).get("validation", {}).get("state", "DEFAULT")
|
||||
)
|
||||
try:
|
||||
validation_state = response["data"]["validation"]["state"]
|
||||
except KeyError as missing:
|
||||
log.error("Response from Cloudflare missing key '{}': {!r}", missing, response)
|
||||
validation_state = 3
|
||||
|
||||
state = RPKI_STATE_MAP[validation_state]
|
||||
cache.hset(CACHE_KEY, ro, state)
|
||||
cache.set_map_item(CACHE_KEY, ro, state)
|
||||
except Exception as err:
|
||||
log.error(str(err))
|
||||
# Don't cache the state when an error produced it.
|
||||
state = 3
|
||||
|
||||
msg = "RPKI Validation State for {} via AS{} is {}".format(prefix, asn, RPKI_NAME_MAP[state])
|
||||
|
1
hyperglass/external/tests/__init__.py
vendored
Normal file
1
hyperglass/external/tests/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
"""External data testing."""
|
44
hyperglass/external/tests/test_bgptools.py
vendored
Normal file
44
hyperglass/external/tests/test_bgptools.py
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Test bgp.tools interactions."""
|
||||
|
||||
# Standard Library
|
||||
import asyncio
|
||||
|
||||
# Third Party
|
||||
import pytest
|
||||
|
||||
# Local
|
||||
from ..bgptools import run_whois, parse_whois, network_info
|
||||
|
||||
WHOIS_OUTPUT = """AS | IP | BGP Prefix | CC | Registry | Allocated | AS Name
|
||||
13335 | 1.1.1.1 | 1.1.1.0/24 | US | ARIN | 2010-07-14 | Cloudflare, Inc."""
|
||||
|
||||
|
||||
# Ignore asyncio deprecation warning about loop
|
||||
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
||||
def test_network_info():
|
||||
addr = "192.0.2.1"
|
||||
info = asyncio.run(network_info(addr))
|
||||
assert isinstance(info, dict)
|
||||
assert "192.0.2.1" in info, "Address missing"
|
||||
assert "asn" in info[addr], "ASN missing"
|
||||
assert info[addr]["asn"] == "0", "Unexpected ASN"
|
||||
assert info[addr]["rir"] == "Unknown", "Unexpected RIR"
|
||||
|
||||
|
||||
# Ignore asyncio deprecation warning about loop
|
||||
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
||||
def test_whois():
|
||||
addr = "192.0.2.1"
|
||||
response = asyncio.run(run_whois([addr]))
|
||||
assert isinstance(response, str)
|
||||
assert response != ""
|
||||
|
||||
|
||||
def test_whois_parser():
|
||||
addr = "1.1.1.1"
|
||||
result = parse_whois(WHOIS_OUTPUT, [addr])
|
||||
assert isinstance(result, dict)
|
||||
assert addr in result, "Address missing"
|
||||
assert result[addr]["asn"] == "13335"
|
||||
assert result[addr]["rir"] == "ARIN"
|
||||
assert result[addr]["org"] == "Cloudflare, Inc."
|
25
hyperglass/external/tests/test_rpki.py
vendored
Normal file
25
hyperglass/external/tests/test_rpki.py
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Test RPKI data fetching."""
|
||||
# Third Party
|
||||
import pytest
|
||||
|
||||
# Local
|
||||
from ..rpki import RPKI_NAME_MAP, rpki_state
|
||||
|
||||
TEST_STATES = (
|
||||
("103.21.244.0/24", 13335, 0),
|
||||
("1.1.1.0/24", 13335, 1),
|
||||
("192.0.2.0/24", 65000, 2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.dependency()
|
||||
def test_rpki():
|
||||
for prefix, asn, expected in TEST_STATES:
|
||||
result = rpki_state(prefix, asn)
|
||||
result_name = RPKI_NAME_MAP.get(result, "No Name")
|
||||
expected_name = RPKI_NAME_MAP.get(expected, "No Name")
|
||||
assert (
|
||||
result == expected
|
||||
), "RPKI State for '{}' via AS{!s} '{}' ({}) instead of '{}' ({})".format(
|
||||
prefix, asn, result, result_name, expected, expected_name
|
||||
)
|
@@ -26,7 +26,7 @@ from hyperglass.exceptions.private import InputValidationError
|
||||
from ..config.devices import Device
|
||||
from ..commands.generic import Directive
|
||||
|
||||
(TEXT := use_state().params.web.text)
|
||||
(TEXT := use_state("params").web.text)
|
||||
|
||||
|
||||
class Query(BaseModel):
|
||||
@@ -154,7 +154,7 @@ class Query(BaseModel):
|
||||
@validator("query_type")
|
||||
def validate_query_type(cls, value):
|
||||
"""Ensure a requested query type exists."""
|
||||
(devices := use_state().devices)
|
||||
devices = use_state("devices")
|
||||
directive_ids = [
|
||||
directive.id for device in devices.objects for directive in device.commands
|
||||
]
|
||||
@@ -167,7 +167,7 @@ class Query(BaseModel):
|
||||
def validate_query_location(cls, value):
|
||||
"""Ensure query_location is defined."""
|
||||
|
||||
(devices := use_state().devices)
|
||||
devices = use_state("devices")
|
||||
valid_id = value in devices.ids
|
||||
valid_hostname = value in devices.hostnames
|
||||
|
||||
@@ -179,7 +179,7 @@ class Query(BaseModel):
|
||||
@validator("query_group")
|
||||
def validate_query_group(cls, value):
|
||||
"""Ensure query_group is defined."""
|
||||
(devices := use_state().devices)
|
||||
devices = use_state("devices")
|
||||
groups = {
|
||||
group
|
||||
for device in devices.objects
|
||||
|
@@ -26,8 +26,8 @@ class QueryError(BaseModel):
|
||||
def validate_output(cls: "QueryError", value):
|
||||
"""If no output is specified, use a customizable generic message."""
|
||||
if value is None:
|
||||
state = use_state()
|
||||
return state.params.messages.general
|
||||
(messages := use_state("params").messages)
|
||||
return messages.general
|
||||
return value
|
||||
|
||||
class Config:
|
||||
|
@@ -46,7 +46,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
|
||||
Returns:
|
||||
Union[IPv4Address, IPv6Address] -- Validated IP address object
|
||||
"""
|
||||
(params := use_state().params)
|
||||
params = use_state("params")
|
||||
query_type_params = getattr(params.queries, query_type)
|
||||
try:
|
||||
|
||||
@@ -149,7 +149,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
|
||||
def validate_community_input(value):
|
||||
"""Validate input communities against configured or default regex pattern."""
|
||||
|
||||
(params := use_state().params)
|
||||
params = use_state("params")
|
||||
|
||||
# RFC4360: Extended Communities (New Format)
|
||||
if re.match(params.queries.bgp_community.pattern.extended_as, value):
|
||||
@@ -174,7 +174,7 @@ def validate_community_input(value):
|
||||
|
||||
def validate_community_select(value):
|
||||
"""Validate selected community against configured communities."""
|
||||
(params := use_state().params)
|
||||
params = use_state("params")
|
||||
communities = tuple(c.community for c in params.queries.bgp_community.communities)
|
||||
if value not in communities:
|
||||
raise InputInvalid(
|
||||
@@ -187,7 +187,7 @@ def validate_community_select(value):
|
||||
|
||||
def validate_aspath(value):
|
||||
"""Validate input AS_PATH against configured or default regext pattern."""
|
||||
(params := use_state().params)
|
||||
params = use_state("params")
|
||||
mode = params.queries.bgp_aspath.pattern.mode
|
||||
pattern = getattr(params.queries.bgp_aspath.pattern, mode)
|
||||
|
||||
|
@@ -44,7 +44,7 @@ class BGPRoute(HyperglassModel):
|
||||
deny: only deny matches
|
||||
"""
|
||||
|
||||
(structured := use_state().params.structured)
|
||||
(structured := use_state("params").structured)
|
||||
|
||||
def _permit(comm):
|
||||
"""Only allow matching patterns."""
|
||||
@@ -73,7 +73,7 @@ class BGPRoute(HyperglassModel):
|
||||
def validate_rpki_state(cls, value, values):
|
||||
"""If external RPKI validation is enabled, get validation state."""
|
||||
|
||||
(structured := use_state().params.structured)
|
||||
(structured := use_state("params").structured)
|
||||
|
||||
if structured.rpki.mode == "router":
|
||||
# If router validation is enabled, return the value as-is.
|
||||
|
@@ -6,7 +6,7 @@ from inspect import isclass
|
||||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.state.redis import use_state
|
||||
from hyperglass.state import use_state
|
||||
from hyperglass.exceptions.private import PluginError
|
||||
|
||||
# Local
|
||||
@@ -16,7 +16,7 @@ from ._output import OutputType, OutputPlugin
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Project
|
||||
from hyperglass.state.redis import HyperglassState
|
||||
from hyperglass.state import HyperglassState
|
||||
from hyperglass.models.api.query import Query
|
||||
from hyperglass.models.config.devices import Device
|
||||
from hyperglass.models.commands.generic import Directive
|
||||
|
@@ -5,7 +5,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
# Third Party
|
||||
import py
|
||||
import pytest
|
||||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
@@ -15,6 +15,11 @@ from hyperglass.models.data.bgp_route import BGPRouteTable
|
||||
# Local
|
||||
from .._builtin.bgp_route_juniper import BGPRoutePluginJuniper
|
||||
|
||||
DEPENDS_KWARGS = {
|
||||
"depends": ["hyperglass/external/tests/test_rpki.py::test_rpki"],
|
||||
"scope": "session",
|
||||
}
|
||||
|
||||
DIRECT = Path(__file__).parent.parent.parent.parent / ".samples" / "juniper_route_direct.xml"
|
||||
INDIRECT = Path(__file__).parent.parent.parent.parent / ".samples" / "juniper_route_indirect.xml"
|
||||
AS_PATH = Path(__file__).parent.parent.parent.parent / ".samples" / "juniper_route_aspath.xml"
|
||||
@@ -42,18 +47,21 @@ def _tester(sample: str):
|
||||
assert result.count > 0, "BGP Table count is 0"
|
||||
|
||||
|
||||
@pytest.mark.dependency(**DEPENDS_KWARGS)
|
||||
def test_juniper_bgp_route_direct():
|
||||
with DIRECT.open("r") as file:
|
||||
sample = file.read()
|
||||
return _tester(sample)
|
||||
|
||||
|
||||
@pytest.mark.dependency(**DEPENDS_KWARGS)
|
||||
def test_juniper_bgp_route_indirect():
|
||||
with INDIRECT.open("r") as file:
|
||||
sample = file.read()
|
||||
return _tester(sample)
|
||||
|
||||
|
||||
@pytest.mark.dependency(**DEPENDS_KWARGS)
|
||||
def test_juniper_bgp_route_aspath():
|
||||
with AS_PATH.open("r") as file:
|
||||
sample = file.read()
|
||||
|
@@ -1,7 +1,8 @@
|
||||
"""hyperglass global state management."""
|
||||
|
||||
# Local
|
||||
from .redis import HyperglassState, use_state
|
||||
from .hooks import use_state
|
||||
from .store import HyperglassState
|
||||
|
||||
__all__ = (
|
||||
"use_state",
|
||||
|
71
hyperglass/state/hooks.py
Normal file
71
hyperglass/state/hooks.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Hooks for accessing hyperglass global state."""
|
||||
|
||||
# Standard Library
|
||||
import typing as t
|
||||
from functools import lru_cache
|
||||
|
||||
# Project
|
||||
from hyperglass.exceptions.private import StateError
|
||||
|
||||
# Local
|
||||
from .store import HyperglassState
|
||||
from ..settings import Settings
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Project
|
||||
from hyperglass.models.ui import UIParameters
|
||||
from hyperglass.models.config.params import Params
|
||||
from hyperglass.models.config.devices import Devices
|
||||
|
||||
# Local
|
||||
from .redis import RedisManager
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _use_state(attr: t.Optional[str] = None) -> "HyperglassState":
|
||||
"""Get hyperglass state by property.
|
||||
|
||||
Implemented separately due to typing issues related to lru_cache described here:
|
||||
https://github.com/python/mypy/issues/8356
|
||||
https://github.com/python/mypy/issues/9112
|
||||
"""
|
||||
if attr is None:
|
||||
return HyperglassState(settings=Settings)
|
||||
if attr in ("cache", "redis"):
|
||||
return HyperglassState(settings=Settings).redis
|
||||
if attr in HyperglassState.properties():
|
||||
return getattr(HyperglassState(settings=Settings), attr)
|
||||
raise StateError("'{attr}' does not exist on HyperglassState", attr=attr)
|
||||
|
||||
|
||||
@t.overload
|
||||
def use_state(attr: t.Literal["params"]) -> "Params":
|
||||
"""Access hyperglass configuration parameters from global state."""
|
||||
|
||||
|
||||
@t.overload
|
||||
def use_state(attr: t.Literal["devices"]) -> "Devices":
|
||||
"""Access hyperglass devices from global state."""
|
||||
|
||||
|
||||
@t.overload
|
||||
def use_state(attr: t.Literal["ui_params"]) -> "UIParameters":
|
||||
"""Access hyperglass UI parameters from global state."""
|
||||
|
||||
|
||||
@t.overload
|
||||
def use_state(attr: t.Literal["cache", "redis"]) -> "RedisManager":
|
||||
"""Directly access hyperglass Redis cache manager."""
|
||||
|
||||
|
||||
@t.overload
|
||||
def use_state(attr=None) -> "HyperglassState":
|
||||
"""Access entire global state.
|
||||
|
||||
This overload needs to be defined last since it's a catchall.
|
||||
"""
|
||||
|
||||
|
||||
def use_state(attr: t.Optional[str] = None) -> "HyperglassState":
|
||||
"""Access global hyperglass state."""
|
||||
return _use_state(attr)
|
50
hyperglass/state/manager.py
Normal file
50
hyperglass/state/manager.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""hyperglass global state."""
|
||||
|
||||
# Standard Library
|
||||
import typing as t
|
||||
|
||||
# Third Party
|
||||
from redis import Redis, ConnectionPool
|
||||
|
||||
# Project
|
||||
from hyperglass.configuration import params, devices, ui_params
|
||||
|
||||
# Local
|
||||
from .redis import RedisManager
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Project
|
||||
from hyperglass.models.system import HyperglassSystem
|
||||
|
||||
|
||||
class StateManager:
|
||||
"""Global State Manager.
|
||||
|
||||
Maintains configuration objects in Redis cache and accesses them as needed.
|
||||
"""
|
||||
|
||||
settings: "HyperglassSystem"
|
||||
redis: RedisManager
|
||||
_namespace: str = "hyperglass.state"
|
||||
|
||||
def __init__(self, *, settings: "HyperglassSystem") -> None:
|
||||
"""Set up Redis connection and add configuration objects."""
|
||||
|
||||
self.settings = settings
|
||||
connection_pool = ConnectionPool.from_url(**self.settings.redis_connection_pool)
|
||||
redis = Redis(connection_pool=connection_pool)
|
||||
self.redis = RedisManager(instance=redis, namespace=self._namespace)
|
||||
|
||||
# Add configuration objects.
|
||||
self.redis.set("params", params)
|
||||
self.redis.set("devices", devices)
|
||||
self.redis.set("ui_params", ui_params)
|
||||
|
||||
@classmethod
|
||||
def properties(cls: "StateManager") -> t.Tuple[str, ...]:
|
||||
"""Get all read-only properties of the state manager."""
|
||||
return tuple(
|
||||
attr
|
||||
for attr in dir(cls)
|
||||
if not attr.startswith("_") and "fget" in dir(getattr(cls, attr))
|
||||
)
|
@@ -1,133 +1,123 @@
|
||||
"""hyperglass global state."""
|
||||
"""Interact with redis for state management."""
|
||||
|
||||
# Standard Library
|
||||
import codecs
|
||||
import pickle
|
||||
import typing as t
|
||||
from functools import lru_cache
|
||||
|
||||
# Third Party
|
||||
from redis import Redis, ConnectionPool
|
||||
from typing import overload
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Project
|
||||
from hyperglass.configuration import params, devices, ui_params
|
||||
from hyperglass.exceptions.private import StateError
|
||||
|
||||
# Local
|
||||
from ..settings import Settings
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Project
|
||||
from hyperglass.models.ui import UIParameters
|
||||
from hyperglass.models.system import HyperglassSystem
|
||||
from hyperglass.plugins._base import HyperglassPlugin
|
||||
from hyperglass.models.config.params import Params
|
||||
from hyperglass.models.config.devices import Devices
|
||||
|
||||
PluginT = t.TypeVar("PluginT", bound="HyperglassPlugin")
|
||||
# Third Party
|
||||
from redis import Redis
|
||||
|
||||
|
||||
class HyperglassState:
|
||||
"""Global State Manager.
|
||||
class RedisManager:
|
||||
"""Convenience wrapper for managing a redis session."""
|
||||
|
||||
Maintains configuration objects in Redis cache and accesses them as needed.
|
||||
"""
|
||||
instance: "Redis"
|
||||
namespace: str
|
||||
|
||||
settings: "HyperglassSystem"
|
||||
redis: Redis
|
||||
_connection_pool: ConnectionPool
|
||||
_namespace: str = "hyperglass.state"
|
||||
|
||||
def __init__(self, *, settings: "HyperglassSystem") -> None:
|
||||
def __init__(self, instance: "Redis", namespace: str) -> None:
|
||||
"""Set up Redis connection and add configuration objects."""
|
||||
self.instance = instance
|
||||
self.namespace = namespace
|
||||
|
||||
self.settings = settings
|
||||
self._connection_pool = ConnectionPool.from_url(**self.settings.redis_connection_pool)
|
||||
self.redis = Redis(connection_pool=self._connection_pool)
|
||||
def __repr__(self) -> str:
|
||||
"""Alias repr to Redis instance's repr."""
|
||||
return repr(self.instance)
|
||||
|
||||
# Add configuration objects.
|
||||
self.set_object("params", params)
|
||||
self.set_object("devices", devices)
|
||||
self.set_object("ui_params", ui_params)
|
||||
|
||||
# Ensure plugins are empty.
|
||||
self.reset_plugins("output")
|
||||
self.reset_plugins("input")
|
||||
|
||||
def key(self, *keys: str) -> str:
|
||||
def _key_join(self, *keys: str) -> str:
|
||||
"""Format keys with state namespace."""
|
||||
return ".".join((*self._namespace.split("."), *keys))
|
||||
key_in_parts = (k for key in keys for k in key.split("."))
|
||||
key_parts = list(dict.fromkeys((*self.namespace.split("."), *key_in_parts)))
|
||||
return ".".join(key_parts)
|
||||
|
||||
def get_object(self, name: str, raise_if_none: bool = False) -> t.Any:
|
||||
"""Get an object (class instance) from the cache."""
|
||||
value = self.redis.get(name)
|
||||
def key(self, key: t.Union[str, t.Sequence[str]]) -> str:
|
||||
"""Format keys with state namespace."""
|
||||
if isinstance(key, (t.List, t.Tuple, t.Generator)):
|
||||
return self._key_join(*key)
|
||||
return self._key_join(key)
|
||||
|
||||
def check(self) -> bool:
|
||||
"""Ensure the redis instance is running and reachable."""
|
||||
result = self.instance.ping()
|
||||
if result is False:
|
||||
raise RuntimeError(
|
||||
"Redis instance {!r} is not running or reachable".format(self.instance)
|
||||
)
|
||||
return result
|
||||
|
||||
def delete(self, key: t.Union[str, t.Sequence[str]]) -> None:
|
||||
"""Delete a key and value from the cache."""
|
||||
self.instance.delete(self.key(key))
|
||||
|
||||
def expire(
|
||||
self,
|
||||
key: t.Union[str, t.Sequence[str]],
|
||||
*,
|
||||
expire_in: t.Optional[t.Union[timedelta, int]] = None,
|
||||
expire_at: t.Optional[t.Union[datetime, int]] = None,
|
||||
) -> None:
|
||||
"""Expire a cache key, either at a time, or in a number of seconds.
|
||||
|
||||
If no at or in time is specified, the key is deleted.
|
||||
"""
|
||||
key = self.key(key)
|
||||
if isinstance(expire_at, (datetime, int)):
|
||||
self.instance.expireat(key, expire_at)
|
||||
return
|
||||
if isinstance(expire_in, (timedelta, int)):
|
||||
self.instance.expire(key, expire_in)
|
||||
return
|
||||
self.instance.delete(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
key: t.Union[str, t.Sequence[str]],
|
||||
*,
|
||||
raise_if_none: bool = False,
|
||||
value_if_none: t.Any = None,
|
||||
) -> t.Union[None, t.Any]:
|
||||
"""Get and decode a value from the cache."""
|
||||
name = self.key(key)
|
||||
value: t.Optional[bytes] = self.instance.get(name)
|
||||
if isinstance(value, bytes):
|
||||
return pickle.loads(value)
|
||||
if raise_if_none is True:
|
||||
raise StateError("'{key}' ('{name}') does not exist in Redis store", key=key, name=name)
|
||||
if value_if_none is not None:
|
||||
return value_if_none
|
||||
return None
|
||||
|
||||
def set(self, key: t.Union[str, t.Sequence[str]], value: t.Any) -> None:
|
||||
"""Add an object to the cache."""
|
||||
name = self.key(key)
|
||||
self.instance.set(name, pickle.dumps(value))
|
||||
|
||||
@overload
|
||||
def get_map(self, key: str, item: str) -> t.Any:
|
||||
"""Get a single value from a Redis hash map (dict)."""
|
||||
|
||||
@overload
|
||||
def get_map(self, key: str, item=None) -> t.Any:
|
||||
"""Get a single value from a Redis hash map (dict)."""
|
||||
|
||||
def get_map(self, key: str, item: t.Optional[str] = None) -> t.Any:
|
||||
"""Get a Redis hash map or hash map value."""
|
||||
name = self.key(key)
|
||||
if isinstance(item, str):
|
||||
value = self.instance.hget(name, item)
|
||||
else:
|
||||
value = self.instance.hgetall(name)
|
||||
|
||||
if isinstance(value, bytes):
|
||||
return pickle.loads(value)
|
||||
elif isinstance(value, str):
|
||||
return pickle.loads(value.encode())
|
||||
if raise_if_none is True:
|
||||
raise StateError("'{key}' does not exist in Redis store", key=name)
|
||||
return None
|
||||
|
||||
def set_object(self, name: str, obj: t.Any) -> None:
|
||||
"""Add an object (class instance) to the cache."""
|
||||
value = pickle.dumps(obj)
|
||||
self.redis.set(self.key(name), value)
|
||||
|
||||
def add_plugin(self, _type: str, plugin: "HyperglassPlugin") -> None:
|
||||
"""Add a plugin to its list by type."""
|
||||
current = self.plugins(_type)
|
||||
plugins = {
|
||||
# Create a base64 representation of a picked plugin.
|
||||
codecs.encode(pickle.dumps(p), "base64").decode()
|
||||
# Merge current plugins with the new plugin.
|
||||
for p in [*current, plugin]
|
||||
}
|
||||
self.set_object(self.key("plugins", _type), list(plugins))
|
||||
|
||||
def remove_plugin(self, _type: str, plugin: "HyperglassPlugin") -> None:
|
||||
"""Remove a plugin from its list by type."""
|
||||
current = self.plugins(_type)
|
||||
plugins = {
|
||||
# Create a base64 representation of a picked plugin.
|
||||
codecs.encode(pickle.dumps(p), "base64").decode()
|
||||
# Merge current plugins with the new plugin.
|
||||
for p in current
|
||||
if p != plugin
|
||||
}
|
||||
self.set_object(self.key("plugins", _type), list(plugins))
|
||||
|
||||
def reset_plugins(self, _type: str) -> None:
|
||||
"""Remove all plugins of `_type`."""
|
||||
self.set_object(self.key("plugins", _type), [])
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete all cache keys."""
|
||||
self.redis.flushdb(asynchronous=True)
|
||||
|
||||
@property
|
||||
def params(self) -> "Params":
|
||||
"""Get hyperglass configuration parameters (`hyperglass.yaml`)."""
|
||||
return self.get_object(self.key("params"), raise_if_none=True)
|
||||
|
||||
@property
|
||||
def devices(self) -> "Devices":
|
||||
"""Get hyperglass devices (`devices.yaml`)."""
|
||||
return self.get_object(self.key("devices"), raise_if_none=True)
|
||||
|
||||
@property
|
||||
def ui_params(self) -> "UIParameters":
|
||||
"""UI parameters, built from params."""
|
||||
return self.get_object(self.key("ui_params"), raise_if_none=True)
|
||||
|
||||
def plugins(self, _type: str) -> t.List[PluginT]:
|
||||
"""Get plugins by type."""
|
||||
current = self.get_object(self.key("plugins", _type), raise_if_none=False) or []
|
||||
return list({pickle.loads(codecs.decode(plugin.encode(), "base64")) for plugin in current})
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def use_state() -> "HyperglassState":
|
||||
"""Access hyperglass global state."""
|
||||
return HyperglassState(settings=Settings)
|
||||
def set_map_item(self, key: str, item: str, value: t.Any) -> None:
|
||||
"""Add a value to a hash map (dict)."""
|
||||
name = self.key(key)
|
||||
self.instance.hset(name, item, pickle.dumps(value))
|
||||
|
82
hyperglass/state/store.py
Normal file
82
hyperglass/state/store.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Primary state container."""
|
||||
|
||||
# Standard Library
|
||||
import codecs
|
||||
import pickle
|
||||
import typing as t
|
||||
|
||||
# Local
|
||||
from .manager import StateManager
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Project
|
||||
from hyperglass.models.ui import UIParameters
|
||||
from hyperglass.models.system import HyperglassSystem
|
||||
from hyperglass.plugins._base import HyperglassPlugin
|
||||
from hyperglass.models.config.params import Params
|
||||
from hyperglass.models.config.devices import Devices
|
||||
|
||||
|
||||
PluginT = t.TypeVar("PluginT", bound="HyperglassPlugin")
|
||||
|
||||
|
||||
class HyperglassState(StateManager):
|
||||
"""Primary hyperglass state container."""
|
||||
|
||||
def __init__(self, *, settings: "HyperglassSystem") -> None:
|
||||
"""Initialize state store and reset plugins."""
|
||||
super().__init__(settings=settings)
|
||||
# Ensure plugins are empty.
|
||||
self.reset_plugins("output")
|
||||
self.reset_plugins("input")
|
||||
|
||||
def add_plugin(self, _type: str, plugin: "HyperglassPlugin") -> None:
|
||||
"""Add a plugin to its list by type."""
|
||||
current = self.plugins(_type)
|
||||
plugins = {
|
||||
# Create a base64 representation of a picked plugin.
|
||||
codecs.encode(pickle.dumps(p), "base64").decode()
|
||||
# Merge current plugins with the new plugin.
|
||||
for p in [*current, plugin]
|
||||
}
|
||||
self.redis.set(("plugins", _type), list(plugins))
|
||||
|
||||
def remove_plugin(self, _type: str, plugin: "HyperglassPlugin") -> None:
|
||||
"""Remove a plugin from its list by type."""
|
||||
current = self.plugins(_type)
|
||||
plugins = {
|
||||
# Create a base64 representation of a picked plugin.
|
||||
codecs.encode(pickle.dumps(p), "base64").decode()
|
||||
# Merge current plugins with the new plugin.
|
||||
for p in current
|
||||
if p != plugin
|
||||
}
|
||||
self.redis.set(("plugins", _type), list(plugins))
|
||||
|
||||
def reset_plugins(self, _type: str) -> None:
|
||||
"""Remove all plugins of `_type`."""
|
||||
self.redis.set(("plugins", _type), [])
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete all cache keys."""
|
||||
self.redis.instance.flushdb(asynchronous=True)
|
||||
|
||||
@property
|
||||
def params(self) -> "Params":
|
||||
"""Get hyperglass configuration parameters (`hyperglass.yaml`)."""
|
||||
return self.redis.get("params", raise_if_none=True)
|
||||
|
||||
@property
|
||||
def devices(self) -> "Devices":
|
||||
"""Get hyperglass devices (`devices.yaml`)."""
|
||||
return self.redis.get("devices", raise_if_none=True)
|
||||
|
||||
@property
|
||||
def ui_params(self) -> "UIParameters":
|
||||
"""UI parameters, built from params."""
|
||||
return self.redis.get("ui_params", raise_if_none=True)
|
||||
|
||||
def plugins(self, _type: str) -> t.List[PluginT]:
|
||||
"""Get plugins by type."""
|
||||
current = self.redis.get(("plugins", _type), raise_if_none=False, value_if_none=[])
|
||||
return list({pickle.loads(codecs.decode(plugin.encode(), "base64")) for plugin in current})
|
1
hyperglass/state/tests/__init__.py
Normal file
1
hyperglass/state/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""State tests."""
|
30
hyperglass/state/tests/test_hooks.py
Normal file
30
hyperglass/state/tests/test_hooks.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Test state hooks."""
|
||||
|
||||
# Project
|
||||
from hyperglass.models.ui import UIParameters
|
||||
from hyperglass.models.config.params import Params
|
||||
from hyperglass.models.config.devices import Devices
|
||||
|
||||
# Local
|
||||
from ..hooks import use_state
|
||||
from ..store import HyperglassState
|
||||
|
||||
STATE_ATTRS = (
|
||||
("params", Params),
|
||||
("devices", Devices),
|
||||
("ui_params", UIParameters),
|
||||
(None, HyperglassState),
|
||||
)
|
||||
|
||||
|
||||
def test_use_state_caching():
|
||||
first = None
|
||||
for attr, model in STATE_ATTRS:
|
||||
for i in range(0, 5):
|
||||
instance = use_state(attr)
|
||||
if i == 0:
|
||||
first = instance
|
||||
assert isinstance(
|
||||
instance, model
|
||||
), f"{instance!r} is not an instance of '{model.__name__}'"
|
||||
assert instance == first, f"{instance!r} is not equal to {first!r}"
|
16
poetry.lock
generated
16
poetry.lock
generated
@@ -976,6 +976,17 @@ toml = "*"
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-dependency"
|
||||
version = "0.5.1"
|
||||
description = "Manage dependencies of tests"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=3.6.0"
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "0.17.0"
|
||||
@@ -1391,7 +1402,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "ad65ca60927ff53c41ce10afc0651eafdc707f4bc9f2b70a797a7cb2fdfb7d87"
|
||||
content-hash = "c439e39b6aee8009b444a98905e88c1d16388c9026cf780ee3ca5ffde07434b1"
|
||||
|
||||
[metadata.files]
|
||||
aiofiles = [
|
||||
@@ -1889,6 +1900,9 @@ pytest = [
|
||||
{file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
|
||||
{file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
|
||||
]
|
||||
pytest-dependency = [
|
||||
{file = "pytest-dependency-0.5.1.tar.gz", hash = "sha256:c2a892906192663f85030a6ab91304e508e546cddfe557d692d61ec57a1d946b"},
|
||||
]
|
||||
python-dotenv = [
|
||||
{file = "python-dotenv-0.17.0.tar.gz", hash = "sha256:471b782da0af10da1a80341e8438fca5fadeba2881c54360d5fd8d03d03a4f4a"},
|
||||
{file = "python_dotenv-0.17.0-py2.py3-none-any.whl", hash = "sha256:49782a97c9d641e8a09ae1d9af0856cc587c8d2474919342d5104d85be9890b2"},
|
||||
|
@@ -81,6 +81,7 @@ pre-commit = "^1.21.0"
|
||||
pytest = "^6.2.5"
|
||||
stackprinter = "^0.2.3"
|
||||
taskipy = "^1.8.2"
|
||||
pytest-dependency = "^0.5.1"
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
|
Reference in New Issue
Block a user