mirror of
https://github.com/checktheroads/hyperglass
synced 2024-05-11 05:55:08 +00:00
Restructure utilities, add tests
This commit is contained in:
4
hyperglass/external/_base.py
vendored
4
hyperglass/external/_base.py
vendored
@@ -13,7 +13,7 @@ import httpx
|
||||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
from hyperglass.util import make_repr, parse_exception
|
||||
from hyperglass.util import parse_exception, repr_from_attrs
|
||||
from hyperglass.constants import __version__
|
||||
from hyperglass.models.fields import JsonValue, HttpMethod, Primitives
|
||||
from hyperglass.exceptions.private import ExternalError
|
||||
@@ -124,7 +124,7 @@ class BaseExternal:
|
||||
|
||||
def __repr__(self: "BaseExternal") -> str:
|
||||
"""Return user friendly representation of instance."""
|
||||
return make_repr(self)
|
||||
return repr_from_attrs(self, ("name", "base_url", "config", "parse"))
|
||||
|
||||
def _exception(
|
||||
self: "BaseExternal",
|
||||
|
||||
2
hyperglass/external/msteams.py
vendored
2
hyperglass/external/msteams.py
vendored
@@ -1,5 +1,6 @@
|
||||
"""Session handler for Microsoft Teams API."""
|
||||
|
||||
# Standard Library
|
||||
import typing as t
|
||||
|
||||
# Project
|
||||
@@ -8,6 +9,7 @@ from hyperglass.external._base import BaseExternal
|
||||
from hyperglass.models.webhook import Webhook
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Project
|
||||
from hyperglass.models.config.logging import Http
|
||||
|
||||
|
||||
|
||||
@@ -1,311 +1,48 @@
|
||||
"""Utility functions."""
|
||||
|
||||
# Standard Library
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import string
|
||||
import typing as t
|
||||
import platform
|
||||
from asyncio import iscoroutine
|
||||
from pathlib import Path
|
||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||
|
||||
# Third Party
|
||||
from loguru._logger import Logger as LoguruLogger
|
||||
from netmiko.ssh_dispatcher import CLASS_MAPPER # type: ignore
|
||||
|
||||
# Project
|
||||
from hyperglass.types import Series
|
||||
from hyperglass.constants import DRIVER_MAP
|
||||
|
||||
ALL_DEVICE_TYPES = {*DRIVER_MAP.keys(), *CLASS_MAPPER.keys()}
|
||||
ALL_DRIVERS = {*DRIVER_MAP.values(), "netmiko"}
|
||||
|
||||
DeepConvert = t.TypeVar("DeepConvert", bound=t.Dict[str, t.Any])
|
||||
|
||||
|
||||
def cpu_count(multiplier: int = 0) -> int:
|
||||
"""Get server's CPU core count.
|
||||
|
||||
Used to determine the number of web server workers.
|
||||
"""
|
||||
# Standard Library
|
||||
import multiprocessing
|
||||
|
||||
return multiprocessing.cpu_count() * multiplier
|
||||
|
||||
|
||||
def check_python() -> str:
|
||||
"""Verify Python Version."""
|
||||
# Project
|
||||
from hyperglass.constants import MIN_PYTHON_VERSION
|
||||
|
||||
pretty_version = ".".join(tuple(str(v) for v in MIN_PYTHON_VERSION))
|
||||
if sys.version_info < MIN_PYTHON_VERSION:
|
||||
raise RuntimeError(f"Python {pretty_version}+ is required.")
|
||||
return platform.python_version()
|
||||
|
||||
|
||||
def split_on_uppercase(s: str) -> t.List[str]:
|
||||
"""Split characters by uppercase letters.
|
||||
|
||||
From: https://stackoverflow.com/a/40382663
|
||||
"""
|
||||
string_length = len(s)
|
||||
is_lower_around = lambda: s[i - 1].islower() or string_length > (i + 1) and s[i + 1].islower()
|
||||
|
||||
start = 0
|
||||
parts = []
|
||||
for i in range(1, string_length):
|
||||
if s[i].isupper() and is_lower_around():
|
||||
parts.append(s[start:i])
|
||||
start = i
|
||||
parts.append(s[start:])
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def parse_exception(exc: BaseException) -> str:
|
||||
"""Parse an exception and its direct cause."""
|
||||
|
||||
if not isinstance(exc, BaseException):
|
||||
raise TypeError(f"'{repr(exc)}' is not an exception.")
|
||||
|
||||
def get_exc_name(exc):
|
||||
return " ".join(split_on_uppercase(exc.__class__.__name__))
|
||||
|
||||
def get_doc_summary(doc):
|
||||
return doc.strip().split("\n")[0].strip(".")
|
||||
|
||||
name = get_exc_name(exc)
|
||||
parsed = []
|
||||
if exc.__doc__:
|
||||
detail = get_doc_summary(exc.__doc__)
|
||||
parsed.append(f"{name} ({detail})")
|
||||
else:
|
||||
parsed.append(name)
|
||||
|
||||
if exc.__cause__:
|
||||
cause = get_exc_name(exc.__cause__)
|
||||
if exc.__cause__.__doc__:
|
||||
cause_detail = get_doc_summary(exc.__cause__.__doc__)
|
||||
parsed.append(f"{cause} ({cause_detail})")
|
||||
else:
|
||||
parsed.append(cause)
|
||||
return ", caused by ".join(parsed)
|
||||
|
||||
|
||||
def make_repr(_class):
|
||||
"""Create a user-friendly represention of an object."""
|
||||
|
||||
def _process_attrs(_dir):
|
||||
for attr in _dir:
|
||||
if not attr.startswith("_"):
|
||||
attr_val = getattr(_class, attr)
|
||||
|
||||
if callable(attr_val):
|
||||
yield f'{attr}=<function name="{attr_val.__name__}">'
|
||||
|
||||
elif iscoroutine(attr_val):
|
||||
yield f'{attr}=<coroutine name="{attr_val.__name__}">'
|
||||
|
||||
elif isinstance(attr_val, str):
|
||||
yield f'{attr}="{attr_val}"'
|
||||
|
||||
else:
|
||||
yield f"{attr}={str(attr_val)}"
|
||||
|
||||
return f'{_class.__name__}({", ".join(_process_attrs(dir(_class)))})'
|
||||
|
||||
|
||||
def repr_from_attrs(obj: object, attrs: Series[str], strip: t.Optional[str] = None) -> str:
|
||||
"""Generate a `__repr__()` value from a specific set of attribute names.
|
||||
|
||||
Useful for complex models/objects where `__repr__()` should only display specific fields.
|
||||
"""
|
||||
# Check the object to ensure each attribute actually exists, and deduplicate
|
||||
attr_names = {a for a in attrs if hasattr(obj, a)}
|
||||
# Dict representation of attr name to obj value (e.g. `obj.attr`), if the value has a
|
||||
# `__repr__` method.
|
||||
attr_values = {
|
||||
f if strip is None else f.strip(strip): v # noqa: IF100
|
||||
for f in attr_names
|
||||
if hasattr((v := getattr(obj, f)), "__repr__")
|
||||
}
|
||||
pairs = (f"{k}={v!r}" for k, v in attr_values.items())
|
||||
return f"{obj.__class__.__name__}({', '.join(pairs)})"
|
||||
|
||||
|
||||
def validate_platform(_type: str) -> t.Tuple[bool, t.Union[None, str]]:
|
||||
"""Validate device type is supported."""
|
||||
|
||||
result = (False, None)
|
||||
|
||||
if _type in ALL_DEVICE_TYPES:
|
||||
result = (True, DRIVER_MAP.get(_type, "netmiko"))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_driver(_type: str, driver: t.Optional[str]) -> str:
|
||||
"""Determine the appropriate driver for a device."""
|
||||
|
||||
if driver is None:
|
||||
# If no driver is set, use the driver map with netmiko as
|
||||
# fallback.
|
||||
return DRIVER_MAP.get(_type, "netmiko")
|
||||
elif driver in ALL_DRIVERS:
|
||||
# If a driver is set and it is valid, allow it.
|
||||
return driver
|
||||
else:
|
||||
# Otherwise, fail validation.
|
||||
raise ValueError("{} is not a supported driver.".format(driver))
|
||||
|
||||
|
||||
def current_log_level(logger: LoguruLogger) -> str:
|
||||
"""Get the current log level of a logger instance."""
|
||||
|
||||
try:
|
||||
handler = list(logger._core.handlers.values())[0]
|
||||
levels = {v.no: k for k, v in logger._core.levels.items()}
|
||||
current_level = levels[handler.levelno].lower()
|
||||
|
||||
except Exception as err:
|
||||
logger.error(err)
|
||||
current_level = "info"
|
||||
|
||||
return current_level
|
||||
|
||||
|
||||
def resolve_hostname(hostname: str) -> t.Generator[t.Union[IPv4Address, IPv6Address], None, None]:
|
||||
"""Resolve a hostname via DNS/hostfile."""
|
||||
# Standard Library
|
||||
from socket import gaierror, getaddrinfo
|
||||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
|
||||
log.debug("Ensuring '{}' is resolvable...", hostname)
|
||||
|
||||
ip4 = None
|
||||
ip6 = None
|
||||
try:
|
||||
res = getaddrinfo(hostname, None)
|
||||
for sock in res:
|
||||
if sock[0].value == 2 and ip4 is None:
|
||||
ip4 = ip_address(sock[4][0])
|
||||
elif sock[0].value in (10, 30) and ip6 is None:
|
||||
ip6 = ip_address(sock[4][0])
|
||||
except (gaierror, ValueError, IndexError) as err:
|
||||
log.debug(str(err))
|
||||
pass
|
||||
|
||||
yield ip4
|
||||
yield ip6
|
||||
|
||||
|
||||
def snake_to_camel(value: str) -> str:
|
||||
"""Convert a string from snake_case to camelCase."""
|
||||
parts = value.split("_")
|
||||
humps = (hump.capitalize() for hump in parts[1:])
|
||||
return "".join((parts[0], *humps))
|
||||
|
||||
|
||||
def get_fmt_keys(template: str) -> t.List[str]:
|
||||
"""Get a list of str.format keys.
|
||||
|
||||
For example, string `"The value of {key} is {value}"` returns
|
||||
`["key", "value"]`.
|
||||
"""
|
||||
keys = []
|
||||
for block in (b for b in string.Formatter.parse("", template) if isinstance(template, str)):
|
||||
key = block[1]
|
||||
if key:
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
|
||||
def deep_convert_keys(_dict: t.Type[DeepConvert], predicate: t.Callable[[str], str]) -> DeepConvert:
|
||||
"""Convert all dictionary keys and nested dictionary keys."""
|
||||
converted = {}
|
||||
|
||||
def get_value(value: t.Any):
|
||||
if isinstance(value, t.Dict):
|
||||
return {predicate(k): get_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, t.List):
|
||||
return [get_value(v) for v in value]
|
||||
elif isinstance(value, t.Tuple):
|
||||
return tuple(get_value(v) for v in value)
|
||||
return value
|
||||
|
||||
for key, value in _dict.items():
|
||||
converted[predicate(key)] = get_value(value)
|
||||
|
||||
return converted
|
||||
|
||||
|
||||
def at_least(
|
||||
minimum: int,
|
||||
value: int,
|
||||
) -> int:
|
||||
"""Get a number value that is at least a specified minimum."""
|
||||
if value < minimum:
|
||||
return minimum
|
||||
return value
|
||||
|
||||
|
||||
def compare_dicts(dict_a: t.Dict[t.Any, t.Any], dict_b: t.Dict[t.Any, t.Any]) -> bool:
|
||||
"""Determine if two dictationaries are (mostly) equal."""
|
||||
if isinstance(dict_a, t.Dict) and isinstance(dict_b, t.Dict):
|
||||
dict_a_keys, dict_a_values = set(dict_a.keys()), set(dict_a.values())
|
||||
dict_b_keys, dict_b_values = set(dict_b.keys()), set(dict_b.values())
|
||||
return all((dict_a_keys == dict_b_keys, dict_a_values == dict_b_values))
|
||||
return False
|
||||
|
||||
|
||||
def compare_init(obj_a: object, obj_b: object) -> bool:
|
||||
"""Compare the `__init__` annoations of two objects."""
|
||||
|
||||
def _check_obj(obj: object):
|
||||
"""Ensure `__annotations__` exists on the `__init__` method."""
|
||||
if hasattr(obj, "__init__") and isinstance(getattr(obj, "__init__", None), t.Callable):
|
||||
if hasattr(obj.__init__, "__annotations__") and isinstance(
|
||||
getattr(obj.__init__, "__annotations__", None), t.Dict
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
if all((_check_obj(obj_a), _check_obj(obj_b))):
|
||||
obj_a.__init__.__annotations__.pop("self", None)
|
||||
obj_b.__init__.__annotations__.pop("self", None)
|
||||
return compare_dicts(obj_a.__init__.__annotations__, obj_b.__init__.__annotations__)
|
||||
return False
|
||||
|
||||
|
||||
def run_coroutine_in_new_thread(coroutine: t.Coroutine) -> t.Any:
|
||||
"""Run an async function in a separate thread and get the result."""
|
||||
# Standard Library
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
class Resolver(threading.Thread):
|
||||
def __init__(self, coro: t.Coroutine) -> None:
|
||||
self.result: t.Any = None
|
||||
self.coro: t.Coroutine = coro
|
||||
super().__init__()
|
||||
|
||||
def run(self):
|
||||
self.result = asyncio.run(self.coro())
|
||||
|
||||
thread = Resolver(coroutine)
|
||||
thread.start()
|
||||
thread.join()
|
||||
return thread.result
|
||||
|
||||
|
||||
def compare_lists(left: t.List[t.Any], right: t.List[t.Any], *, ignore: Series[t.Any] = ()) -> bool:
|
||||
"""Determine if all items in left list exist in right list."""
|
||||
left_ignored = [i for i in left if i not in ignore]
|
||||
diff_ignored = [i for i in left if i in right and i not in ignore]
|
||||
return len(left_ignored) == len(diff_ignored)
|
||||
# Local
|
||||
from .files import copyfiles, check_path, move_files, dotenv_to_dict
|
||||
from .tools import (
|
||||
at_least,
|
||||
compare_init,
|
||||
get_fmt_keys,
|
||||
compare_dicts,
|
||||
compare_lists,
|
||||
snake_to_camel,
|
||||
parse_exception,
|
||||
repr_from_attrs,
|
||||
deep_convert_keys,
|
||||
split_on_uppercase,
|
||||
run_coroutine_in_new_thread,
|
||||
)
|
||||
from .typing import is_type, is_series
|
||||
from .frontend import build_ui, build_frontend
|
||||
from .validation import get_driver, resolve_hostname, validate_platform
|
||||
from .system_info import cpu_count, check_python
|
||||
|
||||
__all__ = (
|
||||
"at_least",
|
||||
"build_frontend",
|
||||
"build_ui",
|
||||
"check_path",
|
||||
"check_python",
|
||||
"compare_dicts",
|
||||
"compare_init",
|
||||
"compare_lists",
|
||||
"copyfiles",
|
||||
"cpu_count",
|
||||
"deep_convert_keys",
|
||||
"dotenv_to_dict",
|
||||
"get_driver",
|
||||
"get_fmt_keys",
|
||||
"is_series",
|
||||
"is_type",
|
||||
"move_files",
|
||||
"parse_exception",
|
||||
"repr_from_attrs",
|
||||
"resolve_hostname",
|
||||
"run_coroutine_in_new_thread",
|
||||
"snake_to_camel",
|
||||
"split_on_uppercase",
|
||||
"validate_platform",
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ import math
|
||||
import shutil
|
||||
import typing as t
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
# Project
|
||||
@@ -21,18 +20,6 @@ if t.TYPE_CHECKING:
|
||||
from hyperglass.models.ui import UIParameters
|
||||
|
||||
|
||||
def get_node_version() -> t.Tuple[int, int, int]:
|
||||
"""Get the system's NodeJS version."""
|
||||
node_path = shutil.which("node")
|
||||
|
||||
raw_version = subprocess.check_output([node_path, "--version"]).decode() # noqa: S603
|
||||
|
||||
# Node returns the version as 'v14.5.0', for example. Remove the v.
|
||||
version = raw_version.replace("v", "")
|
||||
# Parse the version parts.
|
||||
return tuple((int(v) for v in version.split(".")))
|
||||
|
||||
|
||||
def get_ui_build_timeout() -> t.Optional[int]:
|
||||
"""Read the UI build timeout from environment variables or set a default."""
|
||||
timeout = None
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
# Standard Library
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
import platform
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
# Third Party
|
||||
import psutil as _psutil
|
||||
@@ -12,10 +13,7 @@ from cpuinfo import get_cpu_info as _get_cpu_info # type: ignore
|
||||
# Project
|
||||
from hyperglass.constants import __version__
|
||||
|
||||
# Local
|
||||
from .frontend import get_node_version
|
||||
|
||||
SystemData = Dict[str, Tuple[Union[str, int], str]]
|
||||
SystemData = t.Dict[str, t.Tuple[t.Union[str, int], str]]
|
||||
|
||||
|
||||
def _cpu() -> SystemData:
|
||||
@@ -44,6 +42,48 @@ def _disk() -> SystemData:
|
||||
return (total_gb, usage_percent)
|
||||
|
||||
|
||||
def get_node_version() -> t.Tuple[int, int, int]:
|
||||
"""Get the system's NodeJS version."""
|
||||
|
||||
# Standard Library
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
node_path = shutil.which("node")
|
||||
|
||||
raw_version = subprocess.check_output([node_path, "--version"]).decode() # noqa: S603
|
||||
|
||||
# Node returns the version as 'v14.5.0', for example. Remove the v.
|
||||
version = raw_version.replace("v", "")
|
||||
# Parse the version parts.
|
||||
return tuple((int(v) for v in version.split(".")))
|
||||
|
||||
|
||||
def cpu_count(multiplier: int = 0) -> int:
|
||||
"""Get server's CPU core count.
|
||||
|
||||
Used to determine the number of web server workers.
|
||||
"""
|
||||
# Standard Library
|
||||
import multiprocessing
|
||||
|
||||
return multiprocessing.cpu_count() * multiplier
|
||||
|
||||
|
||||
def check_python() -> str:
|
||||
"""Verify Python Version."""
|
||||
# Project
|
||||
from hyperglass.constants import MIN_PYTHON_VERSION
|
||||
|
||||
pretty_version = ".".join(tuple(str(v) for v in MIN_PYTHON_VERSION))
|
||||
running_version = ".".join(
|
||||
str(v) for v in (sys.version_info.major, sys.version_info.minor, sys.version_info.micro)
|
||||
)
|
||||
if sys.version_info < MIN_PYTHON_VERSION:
|
||||
raise RuntimeError(f"Python {pretty_version}+ is required (Running {running_version})")
|
||||
return running_version
|
||||
|
||||
|
||||
def get_system_info() -> SystemData:
|
||||
"""Get system info."""
|
||||
|
||||
|
||||
191
hyperglass/util/tests/test_tools.py
Normal file
191
hyperglass/util/tests/test_tools.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Test generic utilities."""
|
||||
|
||||
# Standard Library
|
||||
import asyncio
|
||||
|
||||
# Third Party
|
||||
import pytest
|
||||
|
||||
# Local
|
||||
from ..tools import (
|
||||
at_least,
|
||||
compare_init,
|
||||
get_fmt_keys,
|
||||
compare_dicts,
|
||||
compare_lists,
|
||||
snake_to_camel,
|
||||
parse_exception,
|
||||
repr_from_attrs,
|
||||
deep_convert_keys,
|
||||
split_on_uppercase,
|
||||
run_coroutine_in_new_thread,
|
||||
)
|
||||
|
||||
|
||||
def test_split_on_uppercase():
|
||||
strings = (
|
||||
("TestOne", ["Test", "One"]),
|
||||
("testTwo", ["test", "Two"]),
|
||||
("TestingOneTwoThree", ["Testing", "One", "Two", "Three"]),
|
||||
)
|
||||
for str_in, list_out in strings:
|
||||
result = split_on_uppercase(str_in)
|
||||
assert result == list_out
|
||||
|
||||
|
||||
def test_parse_exception():
|
||||
with pytest.raises(TypeError):
|
||||
parse_exception(1)
|
||||
|
||||
exc1 = RuntimeError("Test1")
|
||||
exc1_expected = f"Runtime Error ({(RuntimeError.__doc__ or '').strip('.')})"
|
||||
exc2 = RuntimeError("Test2")
|
||||
exc2_cause = f"Connection Error ({(ConnectionError.__doc__ or '').strip('.')})"
|
||||
exc2_expected = f"{exc1_expected}, caused by {exc2_cause}"
|
||||
try:
|
||||
raise exc1
|
||||
except Exception as err:
|
||||
result = parse_exception(err)
|
||||
assert result == exc1_expected
|
||||
try:
|
||||
raise exc2 from ConnectionError
|
||||
except Exception as err:
|
||||
result = parse_exception(err)
|
||||
assert result == exc2_expected
|
||||
|
||||
|
||||
def test_repr_from_attrs():
|
||||
# Third Party
|
||||
from pydantic import create_model
|
||||
|
||||
model = create_model("TestModel", one=(str, ...), two=(int, ...), three=(bool, ...))
|
||||
implementation = model(one="one", two=2, three=True)
|
||||
result = repr_from_attrs(implementation, ("one", "two", "three"))
|
||||
assert result == "TestModel(one='one', three=True, two=2)"
|
||||
|
||||
|
||||
@pytest.mark.dependency()
|
||||
def test_snake_to_camel():
|
||||
keys = (
|
||||
("test_one", "testOne"),
|
||||
("test_two_three", "testTwoThree"),
|
||||
("Test_four_five_six", "testFourFiveSix"),
|
||||
)
|
||||
for key_in, key_out in keys:
|
||||
result = snake_to_camel(key_in)
|
||||
assert result == key_out
|
||||
|
||||
|
||||
def test_get_fmt_keys():
|
||||
template = "This is a {template} for a {test}"
|
||||
result = get_fmt_keys(template)
|
||||
assert len(result) == 2 and "template" in result and "test" in result
|
||||
|
||||
|
||||
@pytest.mark.dependency(
|
||||
depends=["hyperglass/util/tests/test_tools.py::test_snake_to_camel"], scope="session"
|
||||
)
|
||||
def test_deep_convert_keys():
|
||||
dict_in = {
|
||||
"key_one": 1,
|
||||
"key_two": 2,
|
||||
"key_dict": {
|
||||
"key_one": "one",
|
||||
"key_two": "two",
|
||||
},
|
||||
"key_list_dicts": [{"key_one": 101, "key_two": 102}, {"key_three": 103, "key_four": 104}],
|
||||
}
|
||||
|
||||
result = deep_convert_keys(dict_in, snake_to_camel)
|
||||
assert result.get("keyOne") is not None
|
||||
assert result.get("keyTwo") is not None
|
||||
assert result.get("keyDict") is not None
|
||||
assert result["keyDict"].get("keyOne") is not None
|
||||
assert result["keyDict"].get("keyTwo") is not None
|
||||
assert isinstance(result.get("keyListDicts"), list)
|
||||
assert result["keyListDicts"][0].get("keyOne") is not None
|
||||
assert result["keyListDicts"][0].get("keyTwo") is not None
|
||||
assert result["keyListDicts"][1].get("keyThree") is not None
|
||||
assert result["keyListDicts"][1].get("keyFour") is not None
|
||||
|
||||
|
||||
def test_at_least():
|
||||
assert at_least(8, 10) == 10
|
||||
assert at_least(8, 6) == 8
|
||||
|
||||
|
||||
def test_compare_dicts():
|
||||
|
||||
d1 = {"one": 1, "two": 2}
|
||||
d2 = {"one": 1, "two": 2}
|
||||
d3 = {"one": 1, "three": 3}
|
||||
d4 = {"one": 1, "two": 3}
|
||||
d5 = {}
|
||||
d6 = {}
|
||||
checks = (
|
||||
(d1, d2, True),
|
||||
(d1, d3, False),
|
||||
(d1, d4, False),
|
||||
(d1, d1, True),
|
||||
(d5, d6, True),
|
||||
(d1, [], False),
|
||||
)
|
||||
for a, b, expected in checks:
|
||||
assert compare_dicts(a, b) is expected
|
||||
|
||||
|
||||
def test_compare_init():
|
||||
class Compare1:
|
||||
def __init__(self, item: str) -> None:
|
||||
pass
|
||||
|
||||
class Compare2:
|
||||
def __init__(self: "Compare2", item: str) -> None:
|
||||
pass
|
||||
|
||||
class Compare3:
|
||||
def __init__(self: "Compare3", item: str, other_item: int) -> None:
|
||||
pass
|
||||
|
||||
class Compare4:
|
||||
def __init__(self: "Compare4", item: bool) -> None:
|
||||
pass
|
||||
|
||||
class Compare5:
|
||||
pass
|
||||
|
||||
checks = (
|
||||
(Compare1, Compare2, True),
|
||||
(Compare1, Compare3, False),
|
||||
(Compare1, Compare4, False),
|
||||
(Compare1, Compare5, False),
|
||||
(Compare1, Compare1, True),
|
||||
)
|
||||
for a, b, expected in checks:
|
||||
assert compare_init(a, b) is expected
|
||||
|
||||
|
||||
def test_run_coroutine_in_new_thread():
|
||||
async def sleeper():
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def test():
|
||||
return True
|
||||
|
||||
asyncio.run(sleeper())
|
||||
result = run_coroutine_in_new_thread(test)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_compare_lists():
|
||||
# Standard Library
|
||||
import random
|
||||
|
||||
list1 = ["one", 2, "3"]
|
||||
list2 = [4, "5", "six"]
|
||||
list3 = ["one", 11, False]
|
||||
list4 = [*list1, *list2]
|
||||
random.shuffle(list4)
|
||||
assert compare_lists(list1, list2) is False
|
||||
assert compare_lists(list1, list3) is False
|
||||
assert compare_lists(list1, list4) is True
|
||||
@@ -1,69 +0,0 @@
|
||||
"""Test generic utilities."""
|
||||
# Standard Library
|
||||
import asyncio
|
||||
|
||||
# Local
|
||||
from .. import compare_init, compare_dicts, run_coroutine_in_new_thread
|
||||
|
||||
|
||||
def test_compare_dicts():
|
||||
|
||||
d1 = {"one": 1, "two": 2}
|
||||
d2 = {"one": 1, "two": 2}
|
||||
d3 = {"one": 1, "three": 3}
|
||||
d4 = {"one": 1, "two": 3}
|
||||
d5 = {}
|
||||
d6 = {}
|
||||
checks = (
|
||||
(d1, d2, True),
|
||||
(d1, d3, False),
|
||||
(d1, d4, False),
|
||||
(d1, d1, True),
|
||||
(d5, d6, True),
|
||||
(d1, [], False),
|
||||
)
|
||||
for a, b, expected in checks:
|
||||
assert compare_dicts(a, b) is expected
|
||||
|
||||
|
||||
def test_compare_init():
|
||||
class Compare1:
|
||||
def __init__(self, item: str) -> None:
|
||||
pass
|
||||
|
||||
class Compare2:
|
||||
def __init__(self: "Compare2", item: str) -> None:
|
||||
pass
|
||||
|
||||
class Compare3:
|
||||
def __init__(self: "Compare3", item: str, other_item: int) -> None:
|
||||
pass
|
||||
|
||||
class Compare4:
|
||||
def __init__(self: "Compare4", item: bool) -> None:
|
||||
pass
|
||||
|
||||
class Compare5:
|
||||
pass
|
||||
|
||||
checks = (
|
||||
(Compare1, Compare2, True),
|
||||
(Compare1, Compare3, False),
|
||||
(Compare1, Compare4, False),
|
||||
(Compare1, Compare5, False),
|
||||
(Compare1, Compare1, True),
|
||||
)
|
||||
for a, b, expected in checks:
|
||||
assert compare_init(a, b) is expected
|
||||
|
||||
|
||||
def test_run_coroutine_in_new_thread():
|
||||
async def sleeper():
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def test():
|
||||
return True
|
||||
|
||||
asyncio.run(sleeper())
|
||||
result = run_coroutine_in_new_thread(test)
|
||||
assert result is True
|
||||
185
hyperglass/util/tools.py
Normal file
185
hyperglass/util/tools.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Collection of generalized functional tools."""
|
||||
|
||||
# Standard Library
|
||||
import typing as t
|
||||
|
||||
# Project
|
||||
from hyperglass.types import Series
|
||||
|
||||
DeepConvert = t.TypeVar("DeepConvert", bound=t.Dict[str, t.Any])
|
||||
|
||||
|
||||
def run_coroutine_in_new_thread(coroutine: t.Coroutine) -> t.Any:
|
||||
"""Run an async function in a separate thread and get the result."""
|
||||
# Standard Library
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
class Resolver(threading.Thread):
|
||||
def __init__(self, coro: t.Coroutine) -> None:
|
||||
self.result: t.Any = None
|
||||
self.coro: t.Coroutine = coro
|
||||
super().__init__()
|
||||
|
||||
def run(self):
|
||||
self.result = asyncio.run(self.coro())
|
||||
|
||||
thread = Resolver(coroutine)
|
||||
thread.start()
|
||||
thread.join()
|
||||
return thread.result
|
||||
|
||||
|
||||
def split_on_uppercase(s: str) -> t.List[str]:
|
||||
"""Split characters by uppercase letters.
|
||||
|
||||
From: https://stackoverflow.com/a/40382663
|
||||
"""
|
||||
string_length = len(s)
|
||||
is_lower_around = lambda: s[i - 1].islower() or string_length > (i + 1) and s[i + 1].islower()
|
||||
|
||||
start = 0
|
||||
parts = []
|
||||
for i in range(1, string_length):
|
||||
if s[i].isupper() and is_lower_around():
|
||||
parts.append(s[start:i])
|
||||
start = i
|
||||
parts.append(s[start:])
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def parse_exception(exc: BaseException) -> str:
|
||||
"""Parse an exception and its direct cause."""
|
||||
|
||||
if not isinstance(exc, BaseException):
|
||||
raise TypeError(f"'{repr(exc)}' is not an exception.")
|
||||
|
||||
def get_exc_name(exc):
|
||||
return " ".join(split_on_uppercase(exc.__class__.__name__))
|
||||
|
||||
def get_doc_summary(doc):
|
||||
return doc.strip().split("\n")[0].strip(".")
|
||||
|
||||
name = get_exc_name(exc)
|
||||
parsed = []
|
||||
if exc.__doc__:
|
||||
detail = get_doc_summary(exc.__doc__)
|
||||
parsed.append(f"{name} ({detail})")
|
||||
else:
|
||||
parsed.append(name)
|
||||
|
||||
if exc.__cause__:
|
||||
cause = get_exc_name(exc.__cause__)
|
||||
if exc.__cause__.__doc__:
|
||||
cause_detail = get_doc_summary(exc.__cause__.__doc__)
|
||||
parsed.append(f"{cause} ({cause_detail})")
|
||||
else:
|
||||
parsed.append(cause)
|
||||
return ", caused by ".join(parsed)
|
||||
|
||||
|
||||
def repr_from_attrs(obj: object, attrs: Series[str], strip: t.Optional[str] = None) -> str:
|
||||
"""Generate a `__repr__()` value from a specific set of attribute names.
|
||||
|
||||
Useful for complex models/objects where `__repr__()` should only display specific fields.
|
||||
"""
|
||||
# Check the object to ensure each attribute actually exists, and deduplicate
|
||||
attr_names = {a for a in attrs if hasattr(obj, a)}
|
||||
# Dict representation of attr name to obj value (e.g. `obj.attr`), if the value has a
|
||||
# `__repr__` method.
|
||||
attr_values = {
|
||||
f if strip is None else f.strip(strip): v # noqa: IF100
|
||||
for f in attr_names
|
||||
if hasattr((v := getattr(obj, f)), "__repr__")
|
||||
}
|
||||
pairs = (f"{k}={v!r}" for k, v in sorted(attr_values.items()))
|
||||
return f"{obj.__class__.__name__}({', '.join(pairs)})"
|
||||
|
||||
|
||||
def snake_to_camel(value: str) -> str:
|
||||
"""Convert a string from snake_case to camelCase."""
|
||||
head, *body = value.split("_")
|
||||
humps = (hump.capitalize() for hump in body)
|
||||
return "".join((head.lower(), *humps))
|
||||
|
||||
|
||||
def get_fmt_keys(template: str) -> t.List[str]:
|
||||
"""Get a list of str.format keys.
|
||||
|
||||
For example, string `"The value of {key} is {value}"` returns
|
||||
`["key", "value"]`.
|
||||
"""
|
||||
# Standard Library
|
||||
import string
|
||||
|
||||
keys = []
|
||||
for block in (b for b in string.Formatter.parse("", template) if isinstance(template, str)):
|
||||
key = block[1]
|
||||
if key:
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
|
||||
def deep_convert_keys(_dict: t.Type[DeepConvert], predicate: t.Callable[[str], str]) -> DeepConvert:
|
||||
"""Convert all dictionary keys and nested dictionary keys."""
|
||||
converted = {}
|
||||
|
||||
def get_value(value: t.Any):
|
||||
if isinstance(value, t.Dict):
|
||||
return {predicate(k): get_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, t.List):
|
||||
return [get_value(v) for v in value]
|
||||
elif isinstance(value, t.Tuple):
|
||||
return tuple(get_value(v) for v in value)
|
||||
return value
|
||||
|
||||
for key, value in _dict.items():
|
||||
converted[predicate(key)] = get_value(value)
|
||||
|
||||
return converted
|
||||
|
||||
|
||||
def at_least(
|
||||
minimum: int,
|
||||
value: int,
|
||||
) -> int:
|
||||
"""Get a number value that is at least a specified minimum."""
|
||||
if value < minimum:
|
||||
return minimum
|
||||
return value
|
||||
|
||||
|
||||
def compare_dicts(dict_a: t.Dict[t.Any, t.Any], dict_b: t.Dict[t.Any, t.Any]) -> bool:
|
||||
"""Determine if two dictationaries are (mostly) equal."""
|
||||
if isinstance(dict_a, t.Dict) and isinstance(dict_b, t.Dict):
|
||||
dict_a_keys, dict_a_values = set(dict_a.keys()), set(dict_a.values())
|
||||
dict_b_keys, dict_b_values = set(dict_b.keys()), set(dict_b.values())
|
||||
return all((dict_a_keys == dict_b_keys, dict_a_values == dict_b_values))
|
||||
return False
|
||||
|
||||
|
||||
def compare_lists(left: t.List[t.Any], right: t.List[t.Any], *, ignore: Series[t.Any] = ()) -> bool:
|
||||
"""Determine if all items in left list exist in right list."""
|
||||
left_ignored = [i for i in left if i not in ignore]
|
||||
diff_ignored = [i for i in left if i in right and i not in ignore]
|
||||
return len(left_ignored) == len(diff_ignored)
|
||||
|
||||
|
||||
def compare_init(obj_a: object, obj_b: object) -> bool:
|
||||
"""Compare the `__init__` annoations of two objects."""
|
||||
|
||||
def _check_obj(obj: object):
|
||||
"""Ensure `__annotations__` exists on the `__init__` method."""
|
||||
if hasattr(obj, "__init__") and isinstance(getattr(obj, "__init__", None), t.Callable):
|
||||
if hasattr(obj.__init__, "__annotations__") and isinstance(
|
||||
getattr(obj.__init__, "__annotations__", None), t.Dict
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
if all((_check_obj(obj_a), _check_obj(obj_b))):
|
||||
obj_a.__init__.__annotations__.pop("self", None)
|
||||
obj_b.__init__.__annotations__.pop("self", None)
|
||||
return compare_dicts(obj_a.__init__.__annotations__, obj_b.__init__.__annotations__)
|
||||
return False
|
||||
73
hyperglass/util/validation.py
Normal file
73
hyperglass/util/validation.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Validation Utilities."""
|
||||
|
||||
# Standard Library
|
||||
import typing as t
|
||||
|
||||
# Third Party
|
||||
from netmiko.ssh_dispatcher import CLASS_MAPPER # type: ignore
|
||||
|
||||
# Project
|
||||
from hyperglass.constants import DRIVER_MAP
|
||||
|
||||
ALL_DEVICE_TYPES = {*DRIVER_MAP.keys(), *CLASS_MAPPER.keys()}
|
||||
ALL_DRIVERS = {*DRIVER_MAP.values(), "netmiko"}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Standard Library
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
|
||||
|
||||
def validate_platform(_type: str) -> t.Tuple[bool, t.Union[None, str]]:
|
||||
"""Validate device type is supported."""
|
||||
|
||||
result = (False, None)
|
||||
|
||||
if _type in ALL_DEVICE_TYPES:
|
||||
result = (True, DRIVER_MAP.get(_type, "netmiko"))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_driver(_type: str, driver: t.Optional[str]) -> str:
|
||||
"""Determine the appropriate driver for a device."""
|
||||
|
||||
if driver is None:
|
||||
# If no driver is set, use the driver map with netmiko as
|
||||
# fallback.
|
||||
return DRIVER_MAP.get(_type, "netmiko")
|
||||
elif driver in ALL_DRIVERS:
|
||||
# If a driver is set and it is valid, allow it.
|
||||
return driver
|
||||
else:
|
||||
# Otherwise, fail validation.
|
||||
raise ValueError("{} is not a supported driver.".format(driver))
|
||||
|
||||
|
||||
def resolve_hostname(
|
||||
hostname: str,
|
||||
) -> t.Generator[t.Union["IPv4Address", "IPv6Address"], None, None]:
|
||||
"""Resolve a hostname via DNS/hostfile."""
|
||||
# Standard Library
|
||||
from socket import gaierror, getaddrinfo
|
||||
from ipaddress import ip_address
|
||||
|
||||
# Project
|
||||
from hyperglass.log import log
|
||||
|
||||
log.debug("Ensuring {!r} is resolvable...", hostname)
|
||||
|
||||
ip4 = None
|
||||
ip6 = None
|
||||
try:
|
||||
res = getaddrinfo(hostname, None)
|
||||
for sock in res:
|
||||
if sock[0].value == 2 and ip4 is None:
|
||||
ip4 = ip_address(sock[4][0])
|
||||
elif sock[0].value in (10, 30) and ip6 is None:
|
||||
ip6 = ip_address(sock[4][0])
|
||||
except (gaierror, ValueError, IndexError) as err:
|
||||
log.debug(str(err))
|
||||
pass
|
||||
|
||||
yield ip4
|
||||
yield ip6
|
||||
6
poetry.lock
generated
6
poetry.lock
generated
@@ -820,7 +820,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
|
||||
[[package]]
|
||||
name = "py-cpuinfo"
|
||||
version = "7.0.0"
|
||||
version = "8.0.0"
|
||||
description = "Get CPU info with pure Python 2 & 3"
|
||||
category = "main"
|
||||
optional = false
|
||||
@@ -1340,7 +1340,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "a13c78fca92dfe22206f8b27362c779021ee420756bb7c6d7f39ec72bc9d7148"
|
||||
content-hash = "59c7bf05d11ded8cd759701bc8e9c8962ae44ecfb8b1fa0b222119b08284c273"
|
||||
|
||||
[metadata.files]
|
||||
aiofiles = [
|
||||
@@ -1819,7 +1819,7 @@ py = [
|
||||
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
|
||||
]
|
||||
py-cpuinfo = [
|
||||
{file = "py-cpuinfo-7.0.0.tar.gz", hash = "sha256:9aa2e49675114959697d25cf57fec41c29b55887bff3bc4809b44ac6f5730097"},
|
||||
{file = "py-cpuinfo-8.0.0.tar.gz", hash = "sha256:5f269be0e08e33fd959de96b34cd4aeeeacac014dd8305f70eb28d06de2345c5"},
|
||||
]
|
||||
pycodestyle = [
|
||||
{file = "pycodestyle-2.7.0-py2.py3-none-any.whl", hash = "sha256:514f76d918fcc0b55c6680472f0a37970994e07bbb80725808c17089be302068"},
|
||||
|
||||
@@ -43,7 +43,7 @@ loguru = "^0.5.3"
|
||||
netmiko = "^3.4.0"
|
||||
paramiko = "^2.7.2"
|
||||
psutil = "^5.7.2"
|
||||
py-cpuinfo = "^7.0.0"
|
||||
py-cpuinfo = "^8.0.0"
|
||||
pydantic = {extras = ["dotenv"], version = "^1.8.2"}
|
||||
python = ">=3.8.1,<4.0"
|
||||
redis = "^3.5.3"
|
||||
|
||||
Reference in New Issue
Block a user