1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00

validation & construction overhaul

This commit is contained in:
checktheroads
2020-01-31 02:06:27 -10:00
parent b5eefed064
commit b68a75273b
9 changed files with 480 additions and 718 deletions

View File

@@ -11,6 +11,7 @@ from pydantic import validator
# Project Imports
from hyperglass.configuration import devices
from hyperglass.configuration import params
from hyperglass.configuration.models.vrfs import Vrf
from hyperglass.exceptions import InputInvalid
from hyperglass.models.types import SupportedQuery
from hyperglass.models.validators import validate_aspath
@@ -18,13 +19,40 @@ from hyperglass.models.validators import validate_community
from hyperglass.models.validators import validate_ip
def get_vrf_object(vrf_name):
"""Match VRF object from VRF name.
Arguments:
vrf_name {str} -- VRF name
Raises:
InputInvalid: Raised if no VRF is matched.
Returns:
{object} -- Valid VRF object
"""
matched = None
for vrf_obj in devices.vrf_objects:
if vrf_name is not None:
if vrf_name == vrf_obj.name or vrf_name == vrf_obj.display_name:
matched = vrf_obj
break
elif vrf_name is None:
if vrf_obj.name == "default":
matched = vrf_obj
break
if matched is None:
raise InputInvalid(params.messages.vrf_not_found, vrf_name=vrf_name)
return matched
class Query(BaseModel):
"""Validation model for input query parameters."""
query_location: StrictStr
query_type: SupportedQuery
query_vrf: Vrf
query_target: StrictStr
query_vrf: StrictStr
def digest(self):
"""Create SHA256 hash digest of model representation."""
@@ -65,24 +93,20 @@ class Query(BaseModel):
Returns:
{str} -- Valid query_vrf
"""
vrf_object = get_vrf_object(value)
device = getattr(devices, values["query_location"])
default_vrf = "default"
if value is not None and value != default_vrf:
for vrf in device.vrfs:
if value == vrf.name:
value = vrf.name
elif value == vrf.display_name:
value = vrf.name
else:
raise InputInvalid(
params.messages.vrf_not_associated,
level="warning",
vrf_name=vrf.display_name,
device_name=device.display_name,
)
if value is None:
value = default_vrf
return value
device_vrf = None
for vrf in device.vrfs:
if vrf == vrf_object:
device_vrf = vrf
break
if device_vrf is None:
raise InputInvalid(
params.messages.vrf_not_associated,
vrf_name=vrf_object.display_name,
device_name=device.display_name,
)
return device_vrf
@validator("query_target", always=True)
def validate_query_target(cls, value, values):
@@ -98,6 +122,14 @@ class Query(BaseModel):
"ping": validate_ip,
"traceroute": validate_ip,
}
validator_args_map = {
"bgp_aspath": (value,),
"bgp_community": (value,),
"bgp_route": (value, values["query_type"], values["query_vrf"]),
"ping": (value, values["query_type"], values["query_vrf"]),
"traceroute": (value, values["query_type"], values["query_vrf"]),
}
validate_func = validator_map[query_type]
validate_args = validator_args_map[query_type]
return validate_func(value, query_type)
return validate_func(*validate_args)

View File

@@ -1,17 +1,61 @@
# Standard Library Imports
import operator
import re
from ipaddress import ip_network
# Project Imports
from hyperglass.configuration import params
from hyperglass.exceptions import InputInvalid
from hyperglass.exceptions import InputNotAllowed
from hyperglass.util import log
def validate_ip(value, query_type):
def _member_of(target, network):
"""Check if IP address belongs to network.
Arguments:
target {object} -- Target IPv4/IPv6 address
network {object} -- ACL network
Returns:
{bool} -- True if target is a member of network, False if not
"""
log.debug(f"Checking membership of {target} for {network}")
membership = False
if (
network.network_address <= target.network_address
and network.broadcast_address >= target.broadcast_address # NOQA: W503
):
log.debug(f"{target} is a member of {network}")
membership = True
return membership
def _prefix_range(target, ge, le):
"""Verify if target prefix length is within ge/le threshold.
Arguments:
target {IPv4Network|IPv6Network} -- Valid IPv4/IPv6 Network
ge {int} -- Greater than
le {int} -- Less than
Returns:
{bool} -- True if target in range; False if not
"""
matched = False
if target.prefixlen <= le and target.prefixlen >= ge:
matched = True
return matched
def validate_ip(value, query_type, query_vrf): # noqa: C901
"""Ensure input IP address is both valid and not within restricted allocations.
Arguments:
value {str} -- Unvalidated IP Address
query_type {str} -- Valid query type
query_vrf {object} -- Matched query vrf
Raises:
ValueError: Raised if input IP address is not an IP address.
ValueError: Raised if IP address is valid, but is within a restricted range.
@@ -38,7 +82,6 @@ def validate_ip(value, query_type):
- Otherwise IETF Reserved
...and returns an error if so.
"""
if valid_ip.is_reserved or valid_ip.is_unspecified or valid_ip.is_loopback:
raise InputInvalid(
params.messages.invalid_input,
@@ -46,29 +89,73 @@ def validate_ip(value, query_type):
query_type=query_type_params.display_name,
)
"""
If the valid IP is a host and not a network, return the
IPv4Address/IPv6Address object instead of IPv4Network/IPv6Network.
"""
ip_version = valid_ip.version
if valid_ip.num_addresses == 1:
valid_ip = valid_ip.network_address
if query_type in ("ping", "traceroute"):
new_ip = valid_ip.network_address
log.debug(
"Converted '{o}' to '{n}' for '{q}' query",
o=valid_ip,
n=new_ip,
q=query_type,
)
valid_ip = new_ip
elif query_type in ("bgp_route",):
max_le = max(
ace.le
for ace in query_vrf[ip_version].access_list
if ace.action == "permit"
)
new_ip = valid_ip.supernet(new_prefix=max_le)
log.debug(
"Converted '{o}' to '{n}' for '{q}' query",
o=valid_ip,
n=new_ip,
q=query_type,
)
valid_ip = new_ip
vrf_acl = operator.attrgetter(f"ipv{ip_version}.access_list")(query_vrf)
for ace in [a for a in vrf_acl if a.network.version == ip_version]:
if _member_of(valid_ip, ace.network):
if query_type == "bgp_route" and _prefix_range(valid_ip, ace.ge, ace.le):
pass
if ace.action == "permit":
log.debug(
"{t} is allowed by access-list {a}", t=str(valid_ip), a=repr(ace)
)
break
elif ace.action == "deny":
raise InputNotAllowed(
params.messages.acl_denied,
target=str(valid_ip),
denied_network=str(ace.network),
)
log.debug("Validation passed for {ip}", ip=value)
return valid_ip
def validate_community(value, query_type):
def validate_community(value):
"""Validate input communities against configured or default regex pattern."""
# RFC4360: Extended Communities (New Format)
if re.match(params.queries.bgp_community.regex.extended_as, value):
if re.match(params.queries.bgp_community.pattern.extended_as, value):
pass
# RFC4360: Extended Communities (32 Bit Format)
elif re.match(params.queries.bgp_community.regex.decimal, value):
elif re.match(params.queries.bgp_community.pattern.decimal, value):
pass
# RFC8092: Large Communities
elif re.match(params.queries.bgp_community.regex.large, value):
elif re.match(params.queries.bgp_community.pattern.large, value):
pass
else:
@@ -80,11 +167,11 @@ def validate_community(value, query_type):
return value
def validate_aspath(value, query_type):
def validate_aspath(value):
"""Validate input AS_PATH against configured or default regext pattern."""
mode = params.queries.bgp_aspath.regex.mode
pattern = getattr(params.queries.bgp_aspath.regex, mode)
mode = params.queries.bgp_aspath.pattern.mode
pattern = getattr(params.queries.bgp_aspath.pattern, mode)
if not re.match(pattern, value):
raise InputInvalid(