1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00

Add separate hooks for major state objects, add tests

This commit is contained in:
thatmattlove
2021-09-16 13:46:50 -07:00
parent c99f98a6f0
commit e06ea5ecb9
30 changed files with 549 additions and 263 deletions

View File

@@ -19,6 +19,7 @@ per-file-ignores=
# Disable assertion and docstring checks on tests.
hyperglass/**/test_*.py:S101,D103
hyperglass/api/*.py:B008
hyperglass/state/hooks.py:F811
ignore=W503,C0330,R504,D202,S403,S301,S404,E731,D402
select=B, BLK, C, D, E, F, I, II, N, P, PIE, S, R, W
disable-noqa=False

View File

@@ -4,10 +4,10 @@
from hyperglass.state import use_state
def check_redis() -> bool:
def check_redis() -> None:
"""Ensure Redis is running before starting server."""
state = use_state()
return state.redis.ping()
cache = use_state("cache")
cache.check()
on_startup = (check_redis,)

View File

@@ -3,6 +3,7 @@
# Standard Library
import json
import time
import typing as t
from datetime import datetime
# Third Party
@@ -16,26 +17,44 @@ from hyperglass.state import HyperglassState, use_state
from hyperglass.external import Webhook, bgptools
from hyperglass.api.tasks import process_headers
from hyperglass.constants import __version__
from hyperglass.models.ui import UIParameters
from hyperglass.exceptions import HyperglassError
from hyperglass.models.api import Query
from hyperglass.execution.main import execute
from hyperglass.models.config.params import Params
from hyperglass.models.config.devices import Devices
# Local
from .fake_output import fake_output
def get_state():
def get_state(attr: t.Optional[str] = None):
"""Get hyperglass state as a FastAPI dependency."""
return use_state()
return use_state(attr)
def get_params():
"""Get hyperglass params as FastAPI dependency."""
return use_state("params")
def get_devices():
"""Get hyperglass devices as FastAPI dependency."""
return use_state("devices")
def get_ui_params():
"""Get hyperglass ui_params as FastAPI dependency."""
return use_state("ui_params")
async def send_webhook(
query_data: Query, request: Request, timestamp: datetime,
):
"""If webhooks are enabled, get request info and send a webhook."""
state = use_state()
params = use_state("params")
try:
if state.params.logging.http is not None:
if params.logging.http is not None:
headers = await process_headers(headers=request.headers)
if headers.get("x-real-ip") is not None:
@@ -47,7 +66,7 @@ async def send_webhook(
network_info = await bgptools.network_info(host)
async with Webhook(state.params.logging.http) as hook:
async with Webhook(params.logging.http) as hook:
await hook.send(
query={
@@ -59,7 +78,7 @@ async def send_webhook(
}
)
except Exception as err:
log.error("Error sending webhook to {}: {}", state.params.logging.http.provider, str(err))
log.error("Error sending webhook to {}: {}", params.logging.http.provider, str(err))
async def query(
@@ -83,7 +102,7 @@ async def query(
log.info("Starting query execution for query {}", query_data.summary)
cache_response = cache.get_dict(cache_key, "output")
cache_response = cache.get_map(cache_key, "output")
json_output = False
@@ -104,7 +123,7 @@ async def query(
cached = True
runtime = 0
timestamp = cache.get_dict(cache_key, "timestamp")
timestamp = cache.get_map(cache_key, "timestamp")
elif not cache_response:
log.debug("No existing cache entry for query {}", cache_key)
@@ -133,8 +152,8 @@ async def query(
raw_output = json.dumps(cache_output)
else:
raw_output = str(cache_output)
cache.set_dict(cache_key, "output", raw_output)
cache.set_dict(cache_key, "timestamp", timestamp)
cache.set_map_item(cache_key, "output", raw_output)
cache.set_map_item(cache_key, "timestamp", timestamp)
cache.expire(cache_key, seconds=state.params.cache.timeout)
log.debug("Added cache entry for query: {}", cache_key)
@@ -164,46 +183,46 @@ async def query(
}
async def docs(state: "HyperglassState" = Depends(get_state)):
async def docs(params: "Params" = Depends(get_params)):
"""Serve custom docs."""
if state.params.docs.enable:
if params.docs.enable:
docs_func_map = {"swagger": get_swagger_ui_html, "redoc": get_redoc_html}
docs_func = docs_func_map[state.params.docs.mode]
docs_func = docs_func_map[params.docs.mode]
return docs_func(
openapi_url=state.params.docs.openapi_url, title=state.params.site_title + " - API Docs"
openapi_url=params.docs.openapi_url, title=params.site_title + " - API Docs"
)
else:
raise HTTPException(detail="Not found", status_code=404)
async def router(id: str, state: "HyperglassState" = Depends(get_state)):
async def router(id: str, devices: "Devices" = Depends(get_devices)):
"""Get a device's API-facing attributes."""
return state.devices[id].export_api()
return devices[id].export_api()
async def routers(state: "HyperglassState" = Depends(get_state)):
async def routers(devices: "Devices" = Depends(get_devices)):
"""Serve list of configured routers and attributes."""
return state.devices.export_api()
return devices.export_api()
async def queries(state: "HyperglassState" = Depends(get_state)):
async def queries(params: "Params" = Depends(get_params)):
"""Serve list of enabled query types."""
return state.params.queries.list
return params.queries.list
async def info(state: "HyperglassState" = Depends(get_state)):
async def info(params: "Params" = Depends(get_params)):
"""Serve general information about this instance of hyperglass."""
return {
"name": state.params.site_title,
"organization": state.params.org_name,
"primary_asn": int(state.params.primary_asn),
"name": params.site_title,
"organization": params.org_name,
"primary_asn": int(params.primary_asn),
"version": __version__,
}
async def ui_props(state: "HyperglassState" = Depends(get_state)):
async def ui_props(ui_params: "UIParameters" = Depends(get_ui_params)):
"""Serve UI configration."""
return state.ui_params
return ui_params
endpoints = [query, docs, routers, info, ui_props]

View File

@@ -14,9 +14,8 @@ if TYPE_CHECKING:
from hyperglass.models.api.query import Query
from hyperglass.models.config.devices import Device
_state = use_state()
MESSAGES = _state.params.messages
TEXT = _state.params.web.text
(MESSAGES := use_state("params").messages)
(TEXT := use_state("params").web.text)
class ScrapeError(

View File

@@ -38,11 +38,11 @@ class AgentConnection(Connection):
async def collect(self) -> Iterable: # noqa: C901
"""Connect to a device running hyperglass-agent via HTTP."""
log.debug("Query parameters: {}", self.query)
state = use_state()
params = use_state("params")
client_params = {
"headers": {"Content-Type": "application/json"},
"timeout": state.params.request_timeout,
"timeout": params.request_timeout,
}
if self.device.ssl is not None and self.device.ssl.enable:
with self.device.ssl.cert.open("r") as file:
@@ -77,7 +77,7 @@ class AgentConnection(Connection):
encoded_query = await jwt_encode(
payload=query,
secret=self.device.credential.password.get_secret_value(),
duration=state.params.request_timeout,
duration=params.request_timeout,
)
log.debug("Encoded JWT: {}", encoded_query)

View File

@@ -24,7 +24,7 @@ class SSHConnection(Connection):
"""Return a preconfigured sshtunnel.SSHTunnelForwarder instance."""
proxy = self.device.proxy
state = use_state()
params = use_state("params")
def opener():
"""Set up an SSH tunnel according to a device's configuration."""
@@ -33,7 +33,7 @@ class SSHConnection(Connection):
"remote_bind_address": (self.device._target, self.device.port),
"local_bind_address": ("localhost", 0),
"skip_tunnel_checkup": False,
"gateway_timeout": state.params.request_timeout - 2,
"gateway_timeout": params.request_timeout - 2,
}
if proxy.credential._method == "password":
# Use password auth if no key is defined.

View File

@@ -46,7 +46,7 @@ class NetmikoConnection(SSHConnection):
Directly connects to the router via Netmiko library, returns the
command output.
"""
state = use_state()
params = use_state("params")
if host is not None:
log.debug(
"Connecting to {} via proxy {} [{}]",
@@ -66,9 +66,9 @@ class NetmikoConnection(SSHConnection):
"port": port or self.device.port,
"device_type": self.device.type,
"username": self.device.credential.username,
"global_delay_factor": state.params.netmiko_delay_factor,
"timeout": math.floor(state.params.request_timeout * 1.25),
"session_timeout": math.ceil(state.params.request_timeout - 1),
"global_delay_factor": params.netmiko_delay_factor,
"timeout": math.floor(params.request_timeout * 1.25),
"session_timeout": math.ceil(params.request_timeout - 1),
**global_args,
}

View File

@@ -71,7 +71,7 @@ class ScrapliConnection(SSHConnection):
Directly connects to the router via Netmiko library, returns the
command output.
"""
state = use_state()
params = use_state("params")
driver = _map_driver(self.device.type)
if host is not None:
@@ -90,7 +90,7 @@ class ScrapliConnection(SSHConnection):
"host": host or self.device._target,
"port": port or self.device.port,
"auth_username": self.device.credential.username,
"timeout_ops": math.floor(state.params.request_timeout * 1.25),
"timeout_ops": math.floor(params.request_timeout * 1.25),
"transport": "asyncssh",
"auth_strict_key": False,
"ssh_known_hosts_file": False,

View File

@@ -47,10 +47,10 @@ def handle_timeout(**exc_args: Any) -> Callable:
async def execute(query: "Query") -> Union["OutputDataModel", str]:
"""Initiate query validation and execution."""
state = use_state()
output = state.params.messages.general
params = use_state("params")
output = params.messages.general
log.debug("Received query for {}", query.json())
log.debug("Received query {}", query.json())
log.debug("Matched device config: {}", query.device)
mapped_driver = map_driver(query.device.driver)
@@ -60,7 +60,7 @@ async def execute(query: "Query") -> Union["OutputDataModel", str]:
signal.SIGALRM,
handle_timeout(error=TimeoutError("Connection timed out"), device=query.device),
)
signal.alarm(state.params.request_timeout - 1)
signal.alarm(params.request_timeout - 1)
if query.device.proxy:
proxy = driver.setup_proxy()

View File

@@ -4,7 +4,9 @@
import re
import json as _json
import socket
import typing as t
from json import JSONDecodeError
from types import TracebackType
from socket import gaierror
# Third Party
@@ -86,10 +88,15 @@ class BaseExternal:
else:
raise self._exception(f"Unable to create session to {self.name}")
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
def __exit__(
self,
exc_type: t.Optional[t.Type[BaseException]] = None,
exc_value: t.Optional[BaseException] = None,
exc_traceback: t.Optional[TracebackType] = None,
):
"""Close connection on exit."""
if exc_type is not None:
log.error(traceback)
log.error(str(exc_value))
self._session.close()
def __repr__(self):
@@ -232,7 +239,7 @@ class BaseExternal:
response = await self._asession.request(**request)
if response.status_code not in range(200, 300):
status = StatusCode(response.status_code)
status = httpx.codes(response.status_code)
error = self._parse_response(response)
raise self._exception(
f'{status.name.replace("_", " ")}: {error}', level="danger"

View File

@@ -6,7 +6,6 @@
# Standard Library
import re
import socket
import asyncio
from typing import Dict, List
@@ -87,54 +86,24 @@ async def run_whois(targets: List[str]) -> str:
return response.decode()
def run_whois_sync(targets: List[str]) -> str:
"""Open raw socket to bgp.tools and execute query."""
# Construct bulk query
query = "\n".join(("begin", *targets, "end\n")).encode()
# Open the socket to bgp.tools
log.debug("Opening connection to bgp.tools")
sock = socket.socket()
sock.connect(("bgp.tools", 43))
sock.send(query)
# Read the response
response = b""
while True:
data = sock.recv(128)
if data:
response += data
else:
log.debug("Closing connection to bgp.tools")
sock.shutdown(1)
sock.close()
break
return response.decode()
async def network_info(*targets: str) -> Dict[str, Dict[str, str]]:
"""Get ASN, Containing Prefix, and other info about an internet resource."""
targets = [str(t) for t in targets]
(cache := use_state().redis)
cache = use_state("cache")
# Set default data structure.
data = {t: {k: "" for k in DEFAULT_KEYS} for t in targets}
# Get all cached bgp.tools data.
cached = cache.hgetall(CACHE_KEY)
cached = cache.get_map(CACHE_KEY) or {}
# Try to use cached data for each of the items in the list of
# resources.
for t in targets:
if t in cached:
# Reassign the cached network info to the matching resource.
data[t] = cached[t]
log.debug("Using cached network info for {}", t)
for t in (t for t in targets if t in cached):
# Reassign the cached network info to the matching resource.
data[t] = cached[t]
log.debug("Using cached network info for {}", t)
# Remove cached items from the resource list so they're not queried.
targets = [t for t in targets if t not in cached]
@@ -149,7 +118,7 @@ async def network_info(*targets: str) -> Dict[str, Dict[str, str]]:
# Cache the response
for t in targets:
cache.hset(CACHE_KEY, t, data[t])
cache.set_map_item(CACHE_KEY, t, data[t])
log.debug("Cached network info for {}", t)
except Exception as err:
@@ -160,42 +129,4 @@ async def network_info(*targets: str) -> Dict[str, Dict[str, str]]:
def network_info_sync(*targets: str) -> Dict[str, Dict[str, str]]:
"""Get ASN, Containing Prefix, and other info about an internet resource."""
targets = [str(t) for t in targets]
(cache := use_state().redis)
# Set default data structure.
data = {t: {k: "" for k in DEFAULT_KEYS} for t in targets}
# Get all cached bgp.tools data.
cached = cache.hgetall(CACHE_KEY)
# Try to use cached data for each of the items in the list of
# resources.
for t in targets:
if t in cached:
# Reassign the cached network info to the matching resource.
data[t] = cached[t]
log.debug("Using cached network info for {}", t)
# Remove cached items from the resource list so they're not queried.
targets = [t for t in targets if t not in cached]
try:
if targets:
whoisdata = run_whois_sync(targets)
if whoisdata:
# If the response is not empty, parse it.
data.update(parse_whois(whoisdata, targets))
# Cache the response
for t in targets:
cache.hset(CACHE_KEY, t, data[t])
log.debug("Cached network info for {}", t)
except Exception as err:
log.error(str(err))
return data
return asyncio.run(network_info(*targets))

View File

@@ -1,25 +1,32 @@
"""Validate RPKI state via Cloudflare GraphQL API."""
# Standard Library
import typing as t
# Project
from hyperglass.log import log
from hyperglass.state import use_state
from hyperglass.external._base import BaseExternal
if t.TYPE_CHECKING:
# Standard Library
from ipaddress import IPv4Address, IPv6Address
RPKI_STATE_MAP = {"Invalid": 0, "Valid": 1, "NotFound": 2, "DEFAULT": 3}
RPKI_NAME_MAP = {v: k for k, v in RPKI_STATE_MAP.items()}
CACHE_KEY = "hyperglass.external.rpki"
def rpki_state(prefix, asn):
def rpki_state(prefix: t.Union["IPv4Address", "IPv6Address", str], asn: t.Union[int, str]) -> int:
"""Get RPKI state and map to expected integer."""
log.debug("Validating RPKI State for {p} via AS{a}", p=prefix, a=asn)
(cache := use_state().redis)
cache = use_state("cache")
state = 3
ro = f"{prefix}@{asn}"
ro = f"{prefix!s}@{asn!s}"
cached = cache.hget(CACHE_KEY, ro)
cached = cache.get_map(CACHE_KEY, ro)
if cached is not None:
state = cached
@@ -27,17 +34,21 @@ def rpki_state(prefix, asn):
ql = 'query GetValidation {{ validation(prefix: "{}", asn: {}) {{ state }} }}'
query = ql.format(prefix, asn)
log.debug("Cloudflare RPKI GraphQL Query: {!r}", query)
try:
with BaseExternal(base_url="https://rpki.cloudflare.com") as client:
response = client._post("/api/graphql", data={"query": query})
validation_state = (
response.get("data", {}).get("validation", {}).get("state", "DEFAULT")
)
try:
validation_state = response["data"]["validation"]["state"]
except KeyError as missing:
log.error("Response from Cloudflare missing key '{}': {!r}", missing, response)
validation_state = 3
state = RPKI_STATE_MAP[validation_state]
cache.hset(CACHE_KEY, ro, state)
cache.set_map_item(CACHE_KEY, ro, state)
except Exception as err:
log.error(str(err))
# Don't cache the state when an error produced it.
state = 3
msg = "RPKI Validation State for {} via AS{} is {}".format(prefix, asn, RPKI_NAME_MAP[state])

1
hyperglass/external/tests/__init__.py vendored Normal file
View File

@@ -0,0 +1 @@
"""External data testing."""

View File

@@ -0,0 +1,44 @@
"""Test bgp.tools interactions."""
# Standard Library
import asyncio
# Third Party
import pytest
# Local
from ..bgptools import run_whois, parse_whois, network_info
WHOIS_OUTPUT = """AS | IP | BGP Prefix | CC | Registry | Allocated | AS Name
13335 | 1.1.1.1 | 1.1.1.0/24 | US | ARIN | 2010-07-14 | Cloudflare, Inc."""
# Ignore asyncio deprecation warning about loop
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_network_info():
addr = "192.0.2.1"
info = asyncio.run(network_info(addr))
assert isinstance(info, dict)
assert "192.0.2.1" in info, "Address missing"
assert "asn" in info[addr], "ASN missing"
assert info[addr]["asn"] == "0", "Unexpected ASN"
assert info[addr]["rir"] == "Unknown", "Unexpected RIR"
# Ignore asyncio deprecation warning about loop
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def test_whois():
addr = "192.0.2.1"
response = asyncio.run(run_whois([addr]))
assert isinstance(response, str)
assert response != ""
def test_whois_parser():
addr = "1.1.1.1"
result = parse_whois(WHOIS_OUTPUT, [addr])
assert isinstance(result, dict)
assert addr in result, "Address missing"
assert result[addr]["asn"] == "13335"
assert result[addr]["rir"] == "ARIN"
assert result[addr]["org"] == "Cloudflare, Inc."

25
hyperglass/external/tests/test_rpki.py vendored Normal file
View File

@@ -0,0 +1,25 @@
"""Test RPKI data fetching."""
# Third Party
import pytest
# Local
from ..rpki import RPKI_NAME_MAP, rpki_state
TEST_STATES = (
("103.21.244.0/24", 13335, 0),
("1.1.1.0/24", 13335, 1),
("192.0.2.0/24", 65000, 2),
)
@pytest.mark.dependency()
def test_rpki():
for prefix, asn, expected in TEST_STATES:
result = rpki_state(prefix, asn)
result_name = RPKI_NAME_MAP.get(result, "No Name")
expected_name = RPKI_NAME_MAP.get(expected, "No Name")
assert (
result == expected
), "RPKI State for '{}' via AS{!s} '{}' ({}) instead of '{}' ({})".format(
prefix, asn, result, result_name, expected, expected_name
)

View File

@@ -26,7 +26,7 @@ from hyperglass.exceptions.private import InputValidationError
from ..config.devices import Device
from ..commands.generic import Directive
(TEXT := use_state().params.web.text)
(TEXT := use_state("params").web.text)
class Query(BaseModel):
@@ -154,7 +154,7 @@ class Query(BaseModel):
@validator("query_type")
def validate_query_type(cls, value):
"""Ensure a requested query type exists."""
(devices := use_state().devices)
devices = use_state("devices")
directive_ids = [
directive.id for device in devices.objects for directive in device.commands
]
@@ -167,7 +167,7 @@ class Query(BaseModel):
def validate_query_location(cls, value):
"""Ensure query_location is defined."""
(devices := use_state().devices)
devices = use_state("devices")
valid_id = value in devices.ids
valid_hostname = value in devices.hostnames
@@ -179,7 +179,7 @@ class Query(BaseModel):
@validator("query_group")
def validate_query_group(cls, value):
"""Ensure query_group is defined."""
(devices := use_state().devices)
devices = use_state("devices")
groups = {
group
for device in devices.objects

View File

@@ -26,8 +26,8 @@ class QueryError(BaseModel):
def validate_output(cls: "QueryError", value):
"""If no output is specified, use a customizable generic message."""
if value is None:
state = use_state()
return state.params.messages.general
(messages := use_state("params").messages)
return messages.general
return value
class Config:

View File

@@ -46,7 +46,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
Returns:
Union[IPv4Address, IPv6Address] -- Validated IP address object
"""
(params := use_state().params)
params = use_state("params")
query_type_params = getattr(params.queries, query_type)
try:
@@ -149,7 +149,7 @@ def validate_ip(value, query_type, query_vrf): # noqa: C901
def validate_community_input(value):
"""Validate input communities against configured or default regex pattern."""
(params := use_state().params)
params = use_state("params")
# RFC4360: Extended Communities (New Format)
if re.match(params.queries.bgp_community.pattern.extended_as, value):
@@ -174,7 +174,7 @@ def validate_community_input(value):
def validate_community_select(value):
"""Validate selected community against configured communities."""
(params := use_state().params)
params = use_state("params")
communities = tuple(c.community for c in params.queries.bgp_community.communities)
if value not in communities:
raise InputInvalid(
@@ -187,7 +187,7 @@ def validate_community_select(value):
def validate_aspath(value):
"""Validate input AS_PATH against configured or default regext pattern."""
(params := use_state().params)
params = use_state("params")
mode = params.queries.bgp_aspath.pattern.mode
pattern = getattr(params.queries.bgp_aspath.pattern, mode)

View File

@@ -44,7 +44,7 @@ class BGPRoute(HyperglassModel):
deny: only deny matches
"""
(structured := use_state().params.structured)
(structured := use_state("params").structured)
def _permit(comm):
"""Only allow matching patterns."""
@@ -73,7 +73,7 @@ class BGPRoute(HyperglassModel):
def validate_rpki_state(cls, value, values):
"""If external RPKI validation is enabled, get validation state."""
(structured := use_state().params.structured)
(structured := use_state("params").structured)
if structured.rpki.mode == "router":
# If router validation is enabled, return the value as-is.

View File

@@ -6,7 +6,7 @@ from inspect import isclass
# Project
from hyperglass.log import log
from hyperglass.state.redis import use_state
from hyperglass.state import use_state
from hyperglass.exceptions.private import PluginError
# Local
@@ -16,7 +16,7 @@ from ._output import OutputType, OutputPlugin
if t.TYPE_CHECKING:
# Project
from hyperglass.state.redis import HyperglassState
from hyperglass.state import HyperglassState
from hyperglass.models.api.query import Query
from hyperglass.models.config.devices import Device
from hyperglass.models.commands.generic import Directive

View File

@@ -5,7 +5,7 @@
from pathlib import Path
# Third Party
import py
import pytest
# Project
from hyperglass.log import log
@@ -15,6 +15,11 @@ from hyperglass.models.data.bgp_route import BGPRouteTable
# Local
from .._builtin.bgp_route_juniper import BGPRoutePluginJuniper
DEPENDS_KWARGS = {
"depends": ["hyperglass/external/tests/test_rpki.py::test_rpki"],
"scope": "session",
}
DIRECT = Path(__file__).parent.parent.parent.parent / ".samples" / "juniper_route_direct.xml"
INDIRECT = Path(__file__).parent.parent.parent.parent / ".samples" / "juniper_route_indirect.xml"
AS_PATH = Path(__file__).parent.parent.parent.parent / ".samples" / "juniper_route_aspath.xml"
@@ -42,18 +47,21 @@ def _tester(sample: str):
assert result.count > 0, "BGP Table count is 0"
@pytest.mark.dependency(**DEPENDS_KWARGS)
def test_juniper_bgp_route_direct():
with DIRECT.open("r") as file:
sample = file.read()
return _tester(sample)
@pytest.mark.dependency(**DEPENDS_KWARGS)
def test_juniper_bgp_route_indirect():
with INDIRECT.open("r") as file:
sample = file.read()
return _tester(sample)
@pytest.mark.dependency(**DEPENDS_KWARGS)
def test_juniper_bgp_route_aspath():
with AS_PATH.open("r") as file:
sample = file.read()

View File

@@ -1,7 +1,8 @@
"""hyperglass global state management."""
# Local
from .redis import HyperglassState, use_state
from .hooks import use_state
from .store import HyperglassState
__all__ = (
"use_state",

71
hyperglass/state/hooks.py Normal file
View File

@@ -0,0 +1,71 @@
"""Hooks for accessing hyperglass global state."""
# Standard Library
import typing as t
from functools import lru_cache
# Project
from hyperglass.exceptions.private import StateError
# Local
from .store import HyperglassState
from ..settings import Settings
if t.TYPE_CHECKING:
# Project
from hyperglass.models.ui import UIParameters
from hyperglass.models.config.params import Params
from hyperglass.models.config.devices import Devices
# Local
from .redis import RedisManager
@lru_cache
def _use_state(attr: t.Optional[str] = None) -> "HyperglassState":
"""Get hyperglass state by property.
Implemented separately due to typing issues related to lru_cache described here:
https://github.com/python/mypy/issues/8356
https://github.com/python/mypy/issues/9112
"""
if attr is None:
return HyperglassState(settings=Settings)
if attr in ("cache", "redis"):
return HyperglassState(settings=Settings).redis
if attr in HyperglassState.properties():
return getattr(HyperglassState(settings=Settings), attr)
raise StateError("'{attr}' does not exist on HyperglassState", attr=attr)
@t.overload
def use_state(attr: t.Literal["params"]) -> "Params":
"""Access hyperglass configuration parameters from global state."""
@t.overload
def use_state(attr: t.Literal["devices"]) -> "Devices":
"""Access hyperglass devices from global state."""
@t.overload
def use_state(attr: t.Literal["ui_params"]) -> "UIParameters":
"""Access hyperglass UI parameters from global state."""
@t.overload
def use_state(attr: t.Literal["cache", "redis"]) -> "RedisManager":
"""Directly access hyperglass Redis cache manager."""
@t.overload
def use_state(attr=None) -> "HyperglassState":
"""Access entire global state.
This overload needs to be defined last since it's a catchall.
"""
def use_state(attr: t.Optional[str] = None) -> "HyperglassState":
"""Access global hyperglass state."""
return _use_state(attr)

View File

@@ -0,0 +1,50 @@
"""hyperglass global state."""
# Standard Library
import typing as t
# Third Party
from redis import Redis, ConnectionPool
# Project
from hyperglass.configuration import params, devices, ui_params
# Local
from .redis import RedisManager
if t.TYPE_CHECKING:
# Project
from hyperglass.models.system import HyperglassSystem
class StateManager:
"""Global State Manager.
Maintains configuration objects in Redis cache and accesses them as needed.
"""
settings: "HyperglassSystem"
redis: RedisManager
_namespace: str = "hyperglass.state"
def __init__(self, *, settings: "HyperglassSystem") -> None:
"""Set up Redis connection and add configuration objects."""
self.settings = settings
connection_pool = ConnectionPool.from_url(**self.settings.redis_connection_pool)
redis = Redis(connection_pool=connection_pool)
self.redis = RedisManager(instance=redis, namespace=self._namespace)
# Add configuration objects.
self.redis.set("params", params)
self.redis.set("devices", devices)
self.redis.set("ui_params", ui_params)
@classmethod
def properties(cls: "StateManager") -> t.Tuple[str, ...]:
"""Get all read-only properties of the state manager."""
return tuple(
attr
for attr in dir(cls)
if not attr.startswith("_") and "fget" in dir(getattr(cls, attr))
)

View File

@@ -1,133 +1,123 @@
"""hyperglass global state."""
"""Interact with redis for state management."""
# Standard Library
import codecs
import pickle
import typing as t
from functools import lru_cache
# Third Party
from redis import Redis, ConnectionPool
from typing import overload
from datetime import datetime, timedelta
# Project
from hyperglass.configuration import params, devices, ui_params
from hyperglass.exceptions.private import StateError
# Local
from ..settings import Settings
if t.TYPE_CHECKING:
# Project
from hyperglass.models.ui import UIParameters
from hyperglass.models.system import HyperglassSystem
from hyperglass.plugins._base import HyperglassPlugin
from hyperglass.models.config.params import Params
from hyperglass.models.config.devices import Devices
PluginT = t.TypeVar("PluginT", bound="HyperglassPlugin")
# Third Party
from redis import Redis
class HyperglassState:
"""Global State Manager.
class RedisManager:
"""Convenience wrapper for managing a redis session."""
Maintains configuration objects in Redis cache and accesses them as needed.
"""
instance: "Redis"
namespace: str
settings: "HyperglassSystem"
redis: Redis
_connection_pool: ConnectionPool
_namespace: str = "hyperglass.state"
def __init__(self, *, settings: "HyperglassSystem") -> None:
def __init__(self, instance: "Redis", namespace: str) -> None:
"""Set up Redis connection and add configuration objects."""
self.instance = instance
self.namespace = namespace
self.settings = settings
self._connection_pool = ConnectionPool.from_url(**self.settings.redis_connection_pool)
self.redis = Redis(connection_pool=self._connection_pool)
def __repr__(self) -> str:
"""Alias repr to Redis instance's repr."""
return repr(self.instance)
# Add configuration objects.
self.set_object("params", params)
self.set_object("devices", devices)
self.set_object("ui_params", ui_params)
# Ensure plugins are empty.
self.reset_plugins("output")
self.reset_plugins("input")
def key(self, *keys: str) -> str:
def _key_join(self, *keys: str) -> str:
"""Format keys with state namespace."""
return ".".join((*self._namespace.split("."), *keys))
key_in_parts = (k for key in keys for k in key.split("."))
key_parts = list(dict.fromkeys((*self.namespace.split("."), *key_in_parts)))
return ".".join(key_parts)
def get_object(self, name: str, raise_if_none: bool = False) -> t.Any:
"""Get an object (class instance) from the cache."""
value = self.redis.get(name)
def key(self, key: t.Union[str, t.Sequence[str]]) -> str:
"""Format keys with state namespace."""
if isinstance(key, (t.List, t.Tuple, t.Generator)):
return self._key_join(*key)
return self._key_join(key)
def check(self) -> bool:
"""Ensure the redis instance is running and reachable."""
result = self.instance.ping()
if result is False:
raise RuntimeError(
"Redis instance {!r} is not running or reachable".format(self.instance)
)
return result
def delete(self, key: t.Union[str, t.Sequence[str]]) -> None:
"""Delete a key and value from the cache."""
self.instance.delete(self.key(key))
def expire(
self,
key: t.Union[str, t.Sequence[str]],
*,
expire_in: t.Optional[t.Union[timedelta, int]] = None,
expire_at: t.Optional[t.Union[datetime, int]] = None,
) -> None:
"""Expire a cache key, either at a time, or in a number of seconds.
If no at or in time is specified, the key is deleted.
"""
key = self.key(key)
if isinstance(expire_at, (datetime, int)):
self.instance.expireat(key, expire_at)
return
if isinstance(expire_in, (timedelta, int)):
self.instance.expire(key, expire_in)
return
self.instance.delete(key)
def get(
self,
key: t.Union[str, t.Sequence[str]],
*,
raise_if_none: bool = False,
value_if_none: t.Any = None,
) -> t.Union[None, t.Any]:
"""Get and decode a value from the cache."""
name = self.key(key)
value: t.Optional[bytes] = self.instance.get(name)
if isinstance(value, bytes):
return pickle.loads(value)
if raise_if_none is True:
raise StateError("'{key}' ('{name}') does not exist in Redis store", key=key, name=name)
if value_if_none is not None:
return value_if_none
return None
def set(self, key: t.Union[str, t.Sequence[str]], value: t.Any) -> None:
"""Add an object to the cache."""
name = self.key(key)
self.instance.set(name, pickle.dumps(value))
@overload
def get_map(self, key: str, item: str) -> t.Any:
"""Get a single value from a Redis hash map (dict)."""
@overload
def get_map(self, key: str, item=None) -> t.Any:
"""Get a single value from a Redis hash map (dict)."""
def get_map(self, key: str, item: t.Optional[str] = None) -> t.Any:
"""Get a Redis hash map or hash map value."""
name = self.key(key)
if isinstance(item, str):
value = self.instance.hget(name, item)
else:
value = self.instance.hgetall(name)
if isinstance(value, bytes):
return pickle.loads(value)
elif isinstance(value, str):
return pickle.loads(value.encode())
if raise_if_none is True:
raise StateError("'{key}' does not exist in Redis store", key=name)
return None
def set_object(self, name: str, obj: t.Any) -> None:
"""Add an object (class instance) to the cache."""
value = pickle.dumps(obj)
self.redis.set(self.key(name), value)
def add_plugin(self, _type: str, plugin: "HyperglassPlugin") -> None:
"""Add a plugin to its list by type."""
current = self.plugins(_type)
plugins = {
# Create a base64 representation of a picked plugin.
codecs.encode(pickle.dumps(p), "base64").decode()
# Merge current plugins with the new plugin.
for p in [*current, plugin]
}
self.set_object(self.key("plugins", _type), list(plugins))
def remove_plugin(self, _type: str, plugin: "HyperglassPlugin") -> None:
"""Remove a plugin from its list by type."""
current = self.plugins(_type)
plugins = {
# Create a base64 representation of a picked plugin.
codecs.encode(pickle.dumps(p), "base64").decode()
# Merge current plugins with the new plugin.
for p in current
if p != plugin
}
self.set_object(self.key("plugins", _type), list(plugins))
def reset_plugins(self, _type: str) -> None:
"""Remove all plugins of `_type`."""
self.set_object(self.key("plugins", _type), [])
def clear(self) -> None:
"""Delete all cache keys."""
self.redis.flushdb(asynchronous=True)
@property
def params(self) -> "Params":
"""Get hyperglass configuration parameters (`hyperglass.yaml`)."""
return self.get_object(self.key("params"), raise_if_none=True)
@property
def devices(self) -> "Devices":
"""Get hyperglass devices (`devices.yaml`)."""
return self.get_object(self.key("devices"), raise_if_none=True)
@property
def ui_params(self) -> "UIParameters":
"""UI parameters, built from params."""
return self.get_object(self.key("ui_params"), raise_if_none=True)
def plugins(self, _type: str) -> t.List[PluginT]:
"""Get plugins by type."""
current = self.get_object(self.key("plugins", _type), raise_if_none=False) or []
return list({pickle.loads(codecs.decode(plugin.encode(), "base64")) for plugin in current})
@lru_cache(maxsize=None)
def use_state() -> "HyperglassState":
"""Access hyperglass global state."""
return HyperglassState(settings=Settings)
def set_map_item(self, key: str, item: str, value: t.Any) -> None:
"""Add a value to a hash map (dict)."""
name = self.key(key)
self.instance.hset(name, item, pickle.dumps(value))

82
hyperglass/state/store.py Normal file
View File

@@ -0,0 +1,82 @@
"""Primary state container."""
# Standard Library
import codecs
import pickle
import typing as t
# Local
from .manager import StateManager
if t.TYPE_CHECKING:
# Project
from hyperglass.models.ui import UIParameters
from hyperglass.models.system import HyperglassSystem
from hyperglass.plugins._base import HyperglassPlugin
from hyperglass.models.config.params import Params
from hyperglass.models.config.devices import Devices
PluginT = t.TypeVar("PluginT", bound="HyperglassPlugin")
class HyperglassState(StateManager):
"""Primary hyperglass state container."""
def __init__(self, *, settings: "HyperglassSystem") -> None:
"""Initialize state store and reset plugins."""
super().__init__(settings=settings)
# Ensure plugins are empty.
self.reset_plugins("output")
self.reset_plugins("input")
def add_plugin(self, _type: str, plugin: "HyperglassPlugin") -> None:
"""Add a plugin to its list by type."""
current = self.plugins(_type)
plugins = {
# Create a base64 representation of a picked plugin.
codecs.encode(pickle.dumps(p), "base64").decode()
# Merge current plugins with the new plugin.
for p in [*current, plugin]
}
self.redis.set(("plugins", _type), list(plugins))
def remove_plugin(self, _type: str, plugin: "HyperglassPlugin") -> None:
"""Remove a plugin from its list by type."""
current = self.plugins(_type)
plugins = {
# Create a base64 representation of a picked plugin.
codecs.encode(pickle.dumps(p), "base64").decode()
# Merge current plugins with the new plugin.
for p in current
if p != plugin
}
self.redis.set(("plugins", _type), list(plugins))
def reset_plugins(self, _type: str) -> None:
"""Remove all plugins of `_type`."""
self.redis.set(("plugins", _type), [])
def clear(self) -> None:
"""Delete all cache keys."""
self.redis.instance.flushdb(asynchronous=True)
@property
def params(self) -> "Params":
"""Get hyperglass configuration parameters (`hyperglass.yaml`)."""
return self.redis.get("params", raise_if_none=True)
@property
def devices(self) -> "Devices":
"""Get hyperglass devices (`devices.yaml`)."""
return self.redis.get("devices", raise_if_none=True)
@property
def ui_params(self) -> "UIParameters":
"""UI parameters, built from params."""
return self.redis.get("ui_params", raise_if_none=True)
def plugins(self, _type: str) -> t.List[PluginT]:
"""Get plugins by type."""
current = self.redis.get(("plugins", _type), raise_if_none=False, value_if_none=[])
return list({pickle.loads(codecs.decode(plugin.encode(), "base64")) for plugin in current})

View File

@@ -0,0 +1 @@
"""State tests."""

View File

@@ -0,0 +1,30 @@
"""Test state hooks."""
# Project
from hyperglass.models.ui import UIParameters
from hyperglass.models.config.params import Params
from hyperglass.models.config.devices import Devices
# Local
from ..hooks import use_state
from ..store import HyperglassState
STATE_ATTRS = (
("params", Params),
("devices", Devices),
("ui_params", UIParameters),
(None, HyperglassState),
)
def test_use_state_caching():
first = None
for attr, model in STATE_ATTRS:
for i in range(0, 5):
instance = use_state(attr)
if i == 0:
first = instance
assert isinstance(
instance, model
), f"{instance!r} is not an instance of '{model.__name__}'"
assert instance == first, f"{instance!r} is not equal to {first!r}"

16
poetry.lock generated
View File

@@ -976,6 +976,17 @@ toml = "*"
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
[[package]]
name = "pytest-dependency"
version = "0.5.1"
description = "Manage dependencies of tests"
category = "dev"
optional = false
python-versions = "*"
[package.dependencies]
pytest = ">=3.6.0"
[[package]]
name = "python-dotenv"
version = "0.17.0"
@@ -1391,7 +1402,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[metadata]
lock-version = "1.1"
python-versions = ">=3.8.1,<4.0"
content-hash = "ad65ca60927ff53c41ce10afc0651eafdc707f4bc9f2b70a797a7cb2fdfb7d87"
content-hash = "c439e39b6aee8009b444a98905e88c1d16388c9026cf780ee3ca5ffde07434b1"
[metadata.files]
aiofiles = [
@@ -1889,6 +1900,9 @@ pytest = [
{file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
{file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
]
pytest-dependency = [
{file = "pytest-dependency-0.5.1.tar.gz", hash = "sha256:c2a892906192663f85030a6ab91304e508e546cddfe557d692d61ec57a1d946b"},
]
python-dotenv = [
{file = "python-dotenv-0.17.0.tar.gz", hash = "sha256:471b782da0af10da1a80341e8438fca5fadeba2881c54360d5fd8d03d03a4f4a"},
{file = "python_dotenv-0.17.0-py2.py3-none-any.whl", hash = "sha256:49782a97c9d641e8a09ae1d9af0856cc587c8d2474919342d5104d85be9890b2"},

View File

@@ -81,6 +81,7 @@ pre-commit = "^1.21.0"
pytest = "^6.2.5"
stackprinter = "^0.2.3"
taskipy = "^1.8.2"
pytest-dependency = "^0.5.1"
[tool.black]
line-length = 100