1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00

improve validation

This commit is contained in:
checktheroads
2020-01-26 02:22:28 -07:00
parent 09f7d98a54
commit 40b714b463

View File

@@ -11,7 +11,6 @@ import re
# Project Imports # Project Imports
from hyperglass.configuration import params from hyperglass.configuration import params
from hyperglass.exceptions import HyperglassError from hyperglass.exceptions import HyperglassError
from hyperglass.exceptions import InputInvalid
from hyperglass.exceptions import InputNotAllowed from hyperglass.exceptions import InputNotAllowed
from hyperglass.util import log from hyperglass.util import log
@@ -96,32 +95,6 @@ class IPType:
return state return state
def ip_validate(target):
"""Validate if input is a valid IP address.
Arguments:
target {str} -- Unvalidated IPv4/IPv6 address
Raises:
ValueError: Raised if target is not a valid IPv4 or IPv6 address
Returns:
{object} -- Valid IPv4Network/IPv6Network object
"""
try:
valid_ip = ipaddress.ip_network(target)
if valid_ip.is_reserved or valid_ip.is_unspecified or valid_ip.is_loopback:
_exception = ValueError(params.messages.invalid_input)
_exception.details = {}
raise _exception
except (ipaddress.AddressValueError, ValueError) as ip_error:
log.debug(f"IP {target} is invalid")
_exception = ValueError(ip_error)
_exception.details = {}
raise _exception
return valid_ip
def ip_access_list(query_data, device): def ip_access_list(query_data, device):
"""Check VRF access list for matching prefixes. """Check VRF access list for matching prefixes.
@@ -272,7 +245,8 @@ def ip_type_check(query_type, target, device):
if query_type in ("ping", "traceroute") and IPType().is_cidr(target): if query_type in ("ping", "traceroute") and IPType().is_cidr(target):
log.debug("Failed CIDR format for ping/traceroute check") log.debug("Failed CIDR format for ping/traceroute check")
_exception = ValueError(params.messages.directed_cidr) _exception = ValueError(params.messages.directed_cidr)
_exception.details = {"query_type": getattr(params.branding.text, query_type)} query_type_params = getattr(params.features, query_type)
_exception.details = {"query_type": query_type_params.display_name}
raise _exception raise _exception
return target return target
@@ -305,18 +279,6 @@ class Validate:
""" """
log.debug(f"Validating {self.query_type} query for target {self.target}...") log.debug(f"Validating {self.query_type} query for target {self.target}...")
# Perform basic validation of an IP address, return error if
# not a valid IP.
try:
ip_validate(self.target)
except ValueError as unformatted_error:
raise InputInvalid(
params.messages.invalid_input,
target=self.target,
query_type=getattr(params.branding.text, self.query_type),
**unformatted_error.details,
)
# If target is a not allowed, return an error. # If target is a not allowed, return an error.
try: try:
ip_access_list(self.query_data, self.device) ip_access_list(self.query_data, self.device)
@@ -336,61 +298,14 @@ class Validate:
return self.target return self.target
def validate_dual(self):
"""Validate dual-stack input such as bgp_community & bgp_aspath.
Raises:
InputInvalid: Raised if target community is invalid.
InputInvalid: Raised if target AS_PATh is invalid.
Returns:
{str} -- target if validation passes.
"""
log.debug(f"Validating {self.query_type} query for target {self.target}...")
if self.query_type == "bgp_community":
# Validate input communities against configured or default regex
# pattern.
# Extended Communities, new-format
if re.match(params.features.bgp_community.regex.extended_as, self.target):
pass
# Extended Communities, 32 bit format
elif re.match(params.features.bgp_community.regex.decimal, self.target):
pass
# RFC 8092 Large Community Support
elif re.match(params.features.bgp_community.regex.large, self.target):
pass
else:
raise InputInvalid(
params.messages.invalid_input,
target=self.target,
query_type=getattr(params.branding.text, self.query_type),
)
elif self.query_type == "bgp_aspath":
# Validate input AS_PATH regex pattern against configured or
# default regex pattern.
mode = params.features.bgp_aspath.regex.mode
pattern = getattr(params.features.bgp_aspath.regex, mode)
if re.match(pattern, self.target):
pass
else:
raise InputInvalid(
params.messages.invalid_input,
target=self.target,
query_type=getattr(params.branding.text, self.query_type),
)
return self.target
def validate_query(self): def validate_query(self):
"""Validate input. """Validate input.
Returns: Returns:
{str} -- target if validation passes {str} -- target if validation passes
""" """
if self.query_type in ("bgp_community", "bgp_aspath"):
return self.validate_dual() if self.query_type not in ("bgp_community", "bgp_aspath"):
else:
return self.validate_ip() return self.validate_ip()
return self.target