1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00

flatten configuration model

This commit is contained in:
checktheroads
2020-01-28 08:59:27 -07:00
parent 6ad69ae6bb
commit 9b9fd95061
10 changed files with 143 additions and 156 deletions

View File

@@ -33,21 +33,21 @@ UI_DIR = STATIC_DIR / "ui"
IMAGES_DIR = STATIC_DIR / "images" IMAGES_DIR = STATIC_DIR / "images"
ASGI_PARAMS = { ASGI_PARAMS = {
"host": str(params.general.listen_address), "host": str(params.listen_address),
"port": params.general.listen_port, "port": params.listen_port,
"debug": params.general.debug, "debug": params.debug,
} }
# Main App Definition # Main App Definition
app = FastAPI( app = FastAPI(
debug=params.general.debug, debug=params.debug,
title=params.general.site_title, title=params.site_title,
description=params.general.site_description, description=params.site_description,
version=__version__, version=__version__,
default_response_class=UJSONResponse, default_response_class=UJSONResponse,
docs_url=None, docs_url=None,
redoc_url=None, redoc_url=None,
openapi_url=params.general.docs.openapi_url, openapi_url=params.docs.openapi_url,
) )
# Add Event Handlers # Add Event Handlers
@@ -73,9 +73,9 @@ app.add_exception_handler(Exception, default_handler)
def _custom_openapi(): def _custom_openapi():
"""Generate custom OpenAPI config.""" """Generate custom OpenAPI config."""
openapi_schema = get_openapi( openapi_schema = get_openapi(
title=params.general.site_title, title=params.site_title,
version=__version__, version=__version__,
description=params.general.site_description, description=params.site_description,
routes=app.routes, routes=app.routes,
) )
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
@@ -84,11 +84,11 @@ def _custom_openapi():
app.openapi = _custom_openapi app.openapi = _custom_openapi
if params.general.docs.enable: if params.docs.enable:
log.debug(f"API Docs config: {app.openapi()}") log.debug(f"API Docs config: {app.openapi()}")
CORS_ORIGINS = params.general.cors_origins.copy() CORS_ORIGINS = params.cors_origins.copy()
if params.general.developer_mode: if params.developer_mode:
CORS_ORIGINS.append(URL_DEV) CORS_ORIGINS.append(URL_DEV)
# CORS Configuration # CORS Configuration
@@ -103,10 +103,10 @@ app.add_api_route(
path="/api/query/", path="/api/query/",
endpoint=query, endpoint=query,
methods=["POST"], methods=["POST"],
summary=params.general.docs.endpoint_summary, summary=params.docs.endpoint_summary,
description=params.general.docs.endpoint_description, description=params.docs.endpoint_description,
response_model=QueryResponse, response_model=QueryResponse,
tags=[params.general.docs.group_title], tags=[params.docs.group_title],
response_class=UJSONResponse, response_class=UJSONResponse,
) )
app.add_api_route(path="api/docs", endpoint=docs, include_in_schema=False) app.add_api_route(path="api/docs", endpoint=docs, include_in_schema=False)

View File

@@ -58,7 +58,7 @@ async def build_ui():
""" """
try: try:
await build_frontend( await build_frontend(
dev_mode=params.general.developer_mode, dev_mode=params.developer_mode,
dev_url=URL_DEV, dev_url=URL_DEV,
prod_url=URL_PROD, prod_url=URL_PROD,
params=frontend_params, params=frontend_params,

View File

@@ -64,12 +64,11 @@ async def query(query_data: Query, request: Request):
async def docs(): async def docs():
"""Serve custom docs.""" """Serve custom docs."""
if params.general.docs.enable: if params.docs.enable:
docs_func_map = {"swagger": get_swagger_ui_html, "redoc": get_redoc_html} docs_func_map = {"swagger": get_swagger_ui_html, "redoc": get_redoc_html}
docs_func = docs_func_map[params.general.docs.mode] docs_func = docs_func_map[params.docs.mode]
return docs_func( return docs_func(
openapi_url=params.general.docs.openapi_url, openapi_url=params.docs.openapi_url, title=params.site_title + " - API Docs"
title=params.general.site_title + " - API Docs",
) )
else: else:
raise HTTPException(detail="Not found", status_code=404) raise HTTPException(detail="Not found", status_code=404)

View File

@@ -145,7 +145,7 @@ try:
params = _params.Params() params = _params.Params()
try: try:
params.branding.text.subtitle = params.branding.text.subtitle.format( params.branding.text.subtitle = params.branding.text.subtitle.format(
**params.general.dict() **params.dict(exclude={"branding", "features", "messages"})
) )
except KeyError: except KeyError:
pass pass
@@ -167,7 +167,7 @@ except ValidationError as validation_errors:
) )
# Re-evaluate debug state after config is validated # Re-evaluate debug state after config is validated
_set_log_level(params.general.debug, params.general.log_file) _set_log_level(params.debug, params.log_file)
def _build_frontend_networks(): def _build_frontend_networks():
@@ -328,9 +328,7 @@ def _build_queries():
content_params = json.loads( content_params = json.loads(
params.general.json( params.json(include={"primary_asn", "org_name", "site_title", "site_description"})
include={"primary_asn", "org_name", "site_title", "site_description"}
)
) )
@@ -436,11 +434,11 @@ _frontend_params.update(
) )
frontend_params = _frontend_params frontend_params = _frontend_params
URL_DEV = f"http://localhost:{str(params.general.listen_port)}/api/" URL_DEV = f"http://localhost:{str(params.listen_port)}/api/"
URL_PROD = "/api/" URL_PROD = "/api/"
REDIS_CONFIG = { REDIS_CONFIG = {
"host": str(params.general.redis_host), "host": str(params.redis_host),
"port": params.general.redis_port, "port": params.redis_port,
"decode_responses": True, "decode_responses": True,
} }

View File

@@ -10,7 +10,7 @@ from hyperglass.configuration.models._utils import HyperglassModel
class Docs(HyperglassModel): class Docs(HyperglassModel):
"""Validation model for params.general.docs.""" """Validation model for params.docs."""
enable: StrictBool = True enable: StrictBool = True
mode: constr(regex=r"(swagger|redoc)") = "swagger" mode: constr(regex=r"(swagger|redoc)") = "swagger"

View File

@@ -1,117 +0,0 @@
"""Validate general configuration variables."""
# Standard Library Imports
from datetime import datetime
from ipaddress import ip_address
from pathlib import Path
from typing import List
from typing import Optional
from typing import Union
# Third Party Imports
from pydantic import FilePath
from pydantic import IPvAnyAddress
from pydantic import StrictBool
from pydantic import StrictInt
from pydantic import StrictStr
from pydantic import validator
# Project Imports
from hyperglass.configuration.models._utils import HyperglassModel
from hyperglass.configuration.models.docs import Docs
from hyperglass.configuration.models.opengraph import OpenGraph
class General(HyperglassModel):
"""Validation model for params.general."""
debug: StrictBool = False
developer_mode: StrictBool = False
primary_asn: StrictStr = "65001"
org_name: StrictStr = "The Company"
site_title: StrictStr = "hyperglass"
site_description: StrictStr = "{org_name} Network Looking Glass"
site_keywords: List[StrictStr] = [
"hyperglass",
"looking glass",
"lg",
"peer",
"peering",
"ipv4",
"ipv6",
"transit",
"community",
"communities",
"bgp",
"routing",
"network",
"isp",
]
opengraph: OpenGraph = OpenGraph()
docs: Docs = Docs()
google_analytics: StrictStr = ""
redis_host: StrictStr = "localhost"
redis_port: StrictInt = 6379
requires_ipv6_cidr: List[StrictStr] = ["cisco_ios", "cisco_nxos"]
request_timeout: StrictInt = 30
listen_address: Optional[Union[IPvAnyAddress, StrictStr]]
listen_port: StrictInt = 8001
log_file: Optional[FilePath]
cors_origins: List[StrictStr] = []
@validator("listen_address", pre=True, always=True)
def validate_listen_address(cls, value, values):
"""Set default listen_address based on debug mode.
Arguments:
value {str|IPvAnyAddress|None} -- listen_address
values {dict} -- already-validated entries before listen_address
Returns:
{str} -- Validated listen_address
"""
if value is None and not values["debug"]:
listen_address = "localhost"
elif value is None and values["debug"]:
listen_address = ip_address("0.0.0.0") # noqa: S104
elif isinstance(value, str) and value != "localhost":
try:
listen_address = ip_address(value)
except ValueError:
raise ValueError(str(value))
elif isinstance(value, str) and value == "localhost":
listen_address = value
else:
raise ValueError(str(value))
return listen_address
@validator("site_description")
def validate_site_description(cls, value, values):
"""Format the site descripion with the org_name field.
Arguments:
value {str} -- site_description
values {str} -- Values before site_description
Returns:
{str} -- Formatted description
"""
return value.format(org_name=values["org_name"])
@validator("log_file")
def validate_log_file(cls, value):
"""Set default logfile location if none is configured.
Arguments:
value {FilePath} -- Path to log file
Returns:
{Path} -- Logfile path object
"""
if value is None:
now = datetime.now()
now.isoformat
value = Path(
f'/tmp/hyperglass_{now.strftime(r"%Y%M%d-%H%M%S")}.log' # noqa: S108
)
return value

View File

@@ -13,7 +13,7 @@ from hyperglass.configuration.models._utils import HyperglassModel
class OpenGraph(HyperglassModel): class OpenGraph(HyperglassModel):
"""Validation model for params.general.opengraph.""" """Validation model for params.opengraph."""
width: Optional[StrictInt] width: Optional[StrictInt]
height: Optional[StrictInt] height: Optional[StrictInt]

View File

@@ -1,17 +1,124 @@
"""Configuration validation entry point.""" """Configuration validation entry point."""
# Standard Library Imports
from datetime import datetime
from ipaddress import ip_address
from pathlib import Path
from typing import List
from typing import Optional
from typing import Union
# Third Party Imports
from pydantic import FilePath
from pydantic import IPvAnyAddress
from pydantic import StrictBool
from pydantic import StrictInt
from pydantic import StrictStr
from pydantic import validator
# Project Imports # Project Imports
from hyperglass.configuration.models._utils import HyperglassModel from hyperglass.configuration.models._utils import HyperglassModel
from hyperglass.configuration.models.branding import Branding from hyperglass.configuration.models.branding import Branding
from hyperglass.configuration.models.docs import Docs
from hyperglass.configuration.models.features import Features from hyperglass.configuration.models.features import Features
from hyperglass.configuration.models.general import General
from hyperglass.configuration.models.messages import Messages from hyperglass.configuration.models.messages import Messages
from hyperglass.configuration.models.opengraph import OpenGraph
class Params(HyperglassModel): class Params(HyperglassModel):
"""Validation model for all configuration variables.""" """Validation model for all configuration variables."""
general: General = General() debug: StrictBool = False
developer_mode: StrictBool = False
primary_asn: StrictStr = "65001"
org_name: StrictStr = "The Company"
site_title: StrictStr = "hyperglass"
site_description: StrictStr = "{org_name} Network Looking Glass"
site_keywords: List[StrictStr] = [
"hyperglass",
"looking glass",
"lg",
"peer",
"peering",
"ipv4",
"ipv6",
"transit",
"community",
"communities",
"bgp",
"routing",
"network",
"isp",
]
opengraph: OpenGraph = OpenGraph()
docs: Docs = Docs()
google_analytics: StrictStr = ""
redis_host: StrictStr = "localhost"
redis_port: StrictInt = 6379
requires_ipv6_cidr: List[StrictStr] = ["cisco_ios", "cisco_nxos"]
request_timeout: StrictInt = 30
listen_address: Optional[Union[IPvAnyAddress, StrictStr]]
listen_port: StrictInt = 8001
log_file: Optional[FilePath]
cors_origins: List[StrictStr] = []
@validator("listen_address", pre=True, always=True)
def validate_listen_address(cls, value, values):
"""Set default listen_address based on debug mode.
Arguments:
value {str|IPvAnyAddress|None} -- listen_address
values {dict} -- already-validated entries before listen_address
Returns:
{str} -- Validated listen_address
"""
if value is None and not values["debug"]:
listen_address = "localhost"
elif value is None and values["debug"]:
listen_address = ip_address("0.0.0.0") # noqa: S104
elif isinstance(value, str) and value != "localhost":
try:
listen_address = ip_address(value)
except ValueError:
raise ValueError(str(value))
elif isinstance(value, str) and value == "localhost":
listen_address = value
else:
raise ValueError(str(value))
return listen_address
@validator("site_description")
def validate_site_description(cls, value, values):
"""Format the site descripion with the org_name field.
Arguments:
value {str} -- site_description
values {str} -- Values before site_description
Returns:
{str} -- Formatted description
"""
return value.format(org_name=values["org_name"])
@validator("log_file")
def validate_log_file(cls, value):
"""Set default logfile location if none is configured.
Arguments:
value {FilePath} -- Path to log file
Returns:
{Path} -- Logfile path object
"""
if value is None:
now = datetime.now()
now.isoformat
value = Path(
f'/tmp/hyperglass_{now.strftime(r"%Y%M%d-%H%M%S")}.log' # noqa: S108
)
return value
features: Features = Features() features: Features = Features()
branding: Branding = Branding() branding: Branding = Branding()
messages: Messages = Messages() messages: Messages = Messages()

View File

@@ -104,7 +104,7 @@ class Connect:
) )
signal.signal(signal.SIGALRM, handle_timeout) signal.signal(signal.SIGALRM, handle_timeout)
signal.alarm(params.general.request_timeout - 1) signal.alarm(params.request_timeout - 1)
with tunnel: with tunnel:
log.debug( log.debug(
@@ -119,7 +119,7 @@ class Connect:
"username": self.device.credential.username, "username": self.device.credential.username,
"password": self.device.credential.password.get_secret_value(), "password": self.device.credential.password.get_secret_value(),
"global_delay_factor": 0.2, "global_delay_factor": 0.2,
"timeout": params.general.request_timeout - 1, "timeout": params.request_timeout - 1,
} }
try: try:
@@ -194,7 +194,7 @@ class Connect:
"username": self.device.credential.username, "username": self.device.credential.username,
"password": self.device.credential.password.get_secret_value(), "password": self.device.credential.password.get_secret_value(),
"global_delay_factor": 0.2, "global_delay_factor": 0.2,
"timeout": params.general.request_timeout, "timeout": params.request_timeout,
} }
try: try:
@@ -210,7 +210,7 @@ class Connect:
) )
signal.signal(signal.SIGALRM, handle_timeout) signal.signal(signal.SIGALRM, handle_timeout)
signal.alarm(params.general.request_timeout - 1) signal.alarm(params.request_timeout - 1)
responses = [] responses = []
@@ -259,7 +259,7 @@ class Connect:
client_params = { client_params = {
"headers": {"Content-Type": "application/json"}, "headers": {"Content-Type": "application/json"},
"timeout": params.general.request_timeout, "timeout": params.request_timeout,
} }
if self.device.ssl is not None and self.device.ssl.enable: if self.device.ssl is not None and self.device.ssl.enable:
http_protocol = "https" http_protocol = "https"
@@ -286,7 +286,7 @@ class Connect:
encoded_query = await jwt_encode( encoded_query = await jwt_encode(
payload=query, payload=query,
secret=self.device.credential.password.get_secret_value(), secret=self.device.credential.password.get_secret_value(),
duration=params.general.request_timeout, duration=params.request_timeout,
) )
log.debug(f"Encoded JWT: {encoded_query}") log.debug(f"Encoded JWT: {encoded_query}")

View File

@@ -232,7 +232,7 @@ def ip_type_check(query_type, target, device):
if ( if (
query_type == "bgp_route" query_type == "bgp_route"
and prefix_attr["version"] == 6 and prefix_attr["version"] == 6
and device.nos in params.general.requires_ipv6_cidr and device.nos in params.requires_ipv6_cidr
and IPType().is_host(target) and IPType().is_host(target)
): ):
log.debug("Failed requires IPv6 CIDR check") log.debug("Failed requires IPv6 CIDR check")