mirror of
https://github.com/checktheroads/hyperglass
synced 2024-05-11 05:55:08 +00:00
309 lines
9.5 KiB
Python
309 lines
9.5 KiB
Python
"""Utility functions."""
|
|
|
|
# Standard Library
|
|
import os
|
|
import sys
|
|
import json
|
|
import string
|
|
import typing as t
|
|
import platform
|
|
from asyncio import iscoroutine
|
|
from pathlib import Path
|
|
from ipaddress import IPv4Address, IPv6Address, ip_address
|
|
|
|
# Third Party
|
|
from loguru._logger import Logger as LoguruLogger
|
|
from netmiko.ssh_dispatcher import CLASS_MAPPER # type: ignore
|
|
|
|
# Project
|
|
from hyperglass.types import Series
|
|
from hyperglass.constants import DRIVER_MAP
|
|
|
|
ALL_DEVICE_TYPES = {*DRIVER_MAP.keys(), *CLASS_MAPPER.keys()}
|
|
ALL_DRIVERS = {*DRIVER_MAP.values(), "netmiko"}
|
|
|
|
DeepConvert = t.TypeVar("DeepConvert", bound=t.Dict[str, t.Any])
|
|
|
|
|
|
def cpu_count(multiplier: int = 0) -> int:
|
|
"""Get server's CPU core count.
|
|
|
|
Used to determine the number of web server workers.
|
|
"""
|
|
# Standard Library
|
|
import multiprocessing
|
|
|
|
return multiprocessing.cpu_count() * multiplier
|
|
|
|
|
|
def check_python() -> str:
|
|
"""Verify Python Version."""
|
|
# Project
|
|
from hyperglass.constants import MIN_PYTHON_VERSION
|
|
|
|
pretty_version = ".".join(tuple(str(v) for v in MIN_PYTHON_VERSION))
|
|
if sys.version_info < MIN_PYTHON_VERSION:
|
|
raise RuntimeError(f"Python {pretty_version}+ is required.")
|
|
return platform.python_version()
|
|
|
|
|
|
def split_on_uppercase(s: str) -> t.List[str]:
|
|
"""Split characters by uppercase letters.
|
|
|
|
From: https://stackoverflow.com/a/40382663
|
|
"""
|
|
string_length = len(s)
|
|
is_lower_around = lambda: s[i - 1].islower() or string_length > (i + 1) and s[i + 1].islower()
|
|
|
|
start = 0
|
|
parts = []
|
|
for i in range(1, string_length):
|
|
if s[i].isupper() and is_lower_around():
|
|
parts.append(s[start:i])
|
|
start = i
|
|
parts.append(s[start:])
|
|
|
|
return parts
|
|
|
|
|
|
def parse_exception(exc: BaseException) -> str:
|
|
"""Parse an exception and its direct cause."""
|
|
|
|
if not isinstance(exc, BaseException):
|
|
raise TypeError(f"'{repr(exc)}' is not an exception.")
|
|
|
|
def get_exc_name(exc):
|
|
return " ".join(split_on_uppercase(exc.__class__.__name__))
|
|
|
|
def get_doc_summary(doc):
|
|
return doc.strip().split("\n")[0].strip(".")
|
|
|
|
name = get_exc_name(exc)
|
|
parsed = []
|
|
if exc.__doc__:
|
|
detail = get_doc_summary(exc.__doc__)
|
|
parsed.append(f"{name} ({detail})")
|
|
else:
|
|
parsed.append(name)
|
|
|
|
if exc.__cause__:
|
|
cause = get_exc_name(exc.__cause__)
|
|
if exc.__cause__.__doc__:
|
|
cause_detail = get_doc_summary(exc.__cause__.__doc__)
|
|
parsed.append(f"{cause} ({cause_detail})")
|
|
else:
|
|
parsed.append(cause)
|
|
return ", caused by ".join(parsed)
|
|
|
|
|
|
def make_repr(_class):
|
|
"""Create a user-friendly represention of an object."""
|
|
|
|
def _process_attrs(_dir):
|
|
for attr in _dir:
|
|
if not attr.startswith("_"):
|
|
attr_val = getattr(_class, attr)
|
|
|
|
if callable(attr_val):
|
|
yield f'{attr}=<function name="{attr_val.__name__}">'
|
|
|
|
elif iscoroutine(attr_val):
|
|
yield f'{attr}=<coroutine name="{attr_val.__name__}">'
|
|
|
|
elif isinstance(attr_val, str):
|
|
yield f'{attr}="{attr_val}"'
|
|
|
|
else:
|
|
yield f"{attr}={str(attr_val)}"
|
|
|
|
return f'{_class.__name__}({", ".join(_process_attrs(dir(_class)))})'
|
|
|
|
|
|
def repr_from_attrs(obj: object, attrs: Series[str], strip: t.Optional[str] = None) -> str:
|
|
"""Generate a `__repr__()` value from a specific set of attribute names.
|
|
|
|
Useful for complex models/objects where `__repr__()` should only display specific fields.
|
|
"""
|
|
# Check the object to ensure each attribute actually exists, and deduplicate
|
|
attr_names = {a for a in attrs if hasattr(obj, a)}
|
|
# Dict representation of attr name to obj value (e.g. `obj.attr`), if the value has a
|
|
# `__repr__` method.
|
|
attr_values = {
|
|
f if strip is None else f.strip(strip): v # noqa: IF100
|
|
for f in attr_names
|
|
if hasattr((v := getattr(obj, f)), "__repr__")
|
|
}
|
|
pairs = (f"{k}={v!r}" for k, v in attr_values.items())
|
|
return f"{obj.__class__.__name__}({', '.join(pairs)})"
|
|
|
|
|
|
def validate_platform(_type: str) -> t.Tuple[bool, t.Union[None, str]]:
|
|
"""Validate device type is supported."""
|
|
|
|
result = (False, None)
|
|
|
|
if _type in ALL_DEVICE_TYPES:
|
|
result = (True, DRIVER_MAP.get(_type, "netmiko"))
|
|
|
|
return result
|
|
|
|
|
|
def get_driver(_type: str, driver: t.Optional[str]) -> str:
|
|
"""Determine the appropriate driver for a device."""
|
|
|
|
if driver is None:
|
|
# If no driver is set, use the driver map with netmiko as
|
|
# fallback.
|
|
return DRIVER_MAP.get(_type, "netmiko")
|
|
elif driver in ALL_DRIVERS:
|
|
# If a driver is set and it is valid, allow it.
|
|
return driver
|
|
else:
|
|
# Otherwise, fail validation.
|
|
raise ValueError("{} is not a supported driver.".format(driver))
|
|
|
|
|
|
def current_log_level(logger: LoguruLogger) -> str:
|
|
"""Get the current log level of a logger instance."""
|
|
|
|
try:
|
|
handler = list(logger._core.handlers.values())[0]
|
|
levels = {v.no: k for k, v in logger._core.levels.items()}
|
|
current_level = levels[handler.levelno].lower()
|
|
|
|
except Exception as err:
|
|
logger.error(err)
|
|
current_level = "info"
|
|
|
|
return current_level
|
|
|
|
|
|
def resolve_hostname(hostname: str) -> t.Generator[t.Union[IPv4Address, IPv6Address], None, None]:
|
|
"""Resolve a hostname via DNS/hostfile."""
|
|
# Standard Library
|
|
from socket import gaierror, getaddrinfo
|
|
|
|
# Project
|
|
from hyperglass.log import log
|
|
|
|
log.debug("Ensuring '{}' is resolvable...", hostname)
|
|
|
|
ip4 = None
|
|
ip6 = None
|
|
try:
|
|
res = getaddrinfo(hostname, None)
|
|
for sock in res:
|
|
if sock[0].value == 2 and ip4 is None:
|
|
ip4 = ip_address(sock[4][0])
|
|
elif sock[0].value in (10, 30) and ip6 is None:
|
|
ip6 = ip_address(sock[4][0])
|
|
except (gaierror, ValueError, IndexError) as err:
|
|
log.debug(str(err))
|
|
pass
|
|
|
|
yield ip4
|
|
yield ip6
|
|
|
|
|
|
def snake_to_camel(value: str) -> str:
|
|
"""Convert a string from snake_case to camelCase."""
|
|
parts = value.split("_")
|
|
humps = (hump.capitalize() for hump in parts[1:])
|
|
return "".join((parts[0], *humps))
|
|
|
|
|
|
def get_fmt_keys(template: str) -> t.List[str]:
|
|
"""Get a list of str.format keys.
|
|
|
|
For example, string `"The value of {key} is {value}"` returns
|
|
`["key", "value"]`.
|
|
"""
|
|
keys = []
|
|
for block in (b for b in string.Formatter.parse("", template) if isinstance(template, str)):
|
|
key = block[1]
|
|
if key:
|
|
keys.append(key)
|
|
return keys
|
|
|
|
|
|
def deep_convert_keys(_dict: t.Type[DeepConvert], predicate: t.Callable[[str], str]) -> DeepConvert:
|
|
"""Convert all dictionary keys and nested dictionary keys."""
|
|
converted = {}
|
|
|
|
def get_value(value: t.Any):
|
|
if isinstance(value, t.Dict):
|
|
return {predicate(k): get_value(v) for k, v in value.items()}
|
|
elif isinstance(value, t.List):
|
|
return [get_value(v) for v in value]
|
|
elif isinstance(value, t.Tuple):
|
|
return tuple(get_value(v) for v in value)
|
|
return value
|
|
|
|
for key, value in _dict.items():
|
|
converted[predicate(key)] = get_value(value)
|
|
|
|
return converted
|
|
|
|
|
|
def at_least(minimum: int, value: int,) -> int:
|
|
"""Get a number value that is at least a specified minimum."""
|
|
if value < minimum:
|
|
return minimum
|
|
return value
|
|
|
|
|
|
def compare_dicts(dict_a: t.Dict[t.Any, t.Any], dict_b: t.Dict[t.Any, t.Any]) -> bool:
|
|
"""Determine if two dictationaries are (mostly) equal."""
|
|
if isinstance(dict_a, t.Dict) and isinstance(dict_b, t.Dict):
|
|
dict_a_keys, dict_a_values = set(dict_a.keys()), set(dict_a.values())
|
|
dict_b_keys, dict_b_values = set(dict_b.keys()), set(dict_b.values())
|
|
return all((dict_a_keys == dict_b_keys, dict_a_values == dict_b_values))
|
|
return False
|
|
|
|
|
|
def compare_init(obj_a: object, obj_b: object) -> bool:
|
|
"""Compare the `__init__` annoations of two objects."""
|
|
|
|
def _check_obj(obj: object):
|
|
"""Ensure `__annotations__` exists on the `__init__` method."""
|
|
if hasattr(obj, "__init__") and isinstance(getattr(obj, "__init__", None), t.Callable):
|
|
if hasattr(obj.__init__, "__annotations__") and isinstance(
|
|
getattr(obj.__init__, "__annotations__", None), t.Dict
|
|
):
|
|
return True
|
|
return False
|
|
|
|
if all((_check_obj(obj_a), _check_obj(obj_b))):
|
|
obj_a.__init__.__annotations__.pop("self", None)
|
|
obj_b.__init__.__annotations__.pop("self", None)
|
|
return compare_dicts(obj_a.__init__.__annotations__, obj_b.__init__.__annotations__)
|
|
return False
|
|
|
|
|
|
def run_coroutine_in_new_thread(coroutine: t.Coroutine) -> t.Any:
|
|
"""Run an async function in a separate thread and get the result."""
|
|
# Standard Library
|
|
import asyncio
|
|
import threading
|
|
|
|
class Resolver(threading.Thread):
|
|
def __init__(self, coro: t.Coroutine) -> None:
|
|
self.result: t.Any = None
|
|
self.coro: t.Coroutine = coro
|
|
super().__init__()
|
|
|
|
def run(self):
|
|
self.result = asyncio.run(self.coro())
|
|
|
|
thread = Resolver(coroutine)
|
|
thread.start()
|
|
thread.join()
|
|
return thread.result
|
|
|
|
|
|
def compare_lists(left: t.List[t.Any], right: t.List[t.Any], *, ignore: Series[t.Any] = ()) -> bool:
|
|
"""Determine if all items in left list exist in right list."""
|
|
left_ignored = [i for i in left if i not in ignore]
|
|
diff_ignored = [i for i in left if i in right and i not in ignore]
|
|
return len(left_ignored) == len(diff_ignored)
|