1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00

improve validation

This commit is contained in:
checktheroads
2020-01-26 02:22:28 -07:00
parent 09f7d98a54
commit 40b714b463

View File

@@ -11,7 +11,6 @@ import re
# Project Imports
from hyperglass.configuration import params
from hyperglass.exceptions import HyperglassError
from hyperglass.exceptions import InputInvalid
from hyperglass.exceptions import InputNotAllowed
from hyperglass.util import log
@@ -96,32 +95,6 @@ class IPType:
return state
def ip_validate(target):
"""Validate if input is a valid IP address.
Arguments:
target {str} -- Unvalidated IPv4/IPv6 address
Raises:
ValueError: Raised if target is not a valid IPv4 or IPv6 address
Returns:
{object} -- Valid IPv4Network/IPv6Network object
"""
try:
valid_ip = ipaddress.ip_network(target)
if valid_ip.is_reserved or valid_ip.is_unspecified or valid_ip.is_loopback:
_exception = ValueError(params.messages.invalid_input)
_exception.details = {}
raise _exception
except (ipaddress.AddressValueError, ValueError) as ip_error:
log.debug(f"IP {target} is invalid")
_exception = ValueError(ip_error)
_exception.details = {}
raise _exception
return valid_ip
def ip_access_list(query_data, device):
"""Check VRF access list for matching prefixes.
@@ -272,7 +245,8 @@ def ip_type_check(query_type, target, device):
if query_type in ("ping", "traceroute") and IPType().is_cidr(target):
log.debug("Failed CIDR format for ping/traceroute check")
_exception = ValueError(params.messages.directed_cidr)
_exception.details = {"query_type": getattr(params.branding.text, query_type)}
query_type_params = getattr(params.features, query_type)
_exception.details = {"query_type": query_type_params.display_name}
raise _exception
return target
@@ -305,18 +279,6 @@ class Validate:
"""
log.debug(f"Validating {self.query_type} query for target {self.target}...")
# Perform basic validation of an IP address, return error if
# not a valid IP.
try:
ip_validate(self.target)
except ValueError as unformatted_error:
raise InputInvalid(
params.messages.invalid_input,
target=self.target,
query_type=getattr(params.branding.text, self.query_type),
**unformatted_error.details,
)
# If target is a not allowed, return an error.
try:
ip_access_list(self.query_data, self.device)
@@ -336,61 +298,14 @@ class Validate:
return self.target
def validate_dual(self):
"""Validate dual-stack input such as bgp_community & bgp_aspath.
Raises:
InputInvalid: Raised if target community is invalid.
InputInvalid: Raised if target AS_PATh is invalid.
Returns:
{str} -- target if validation passes.
"""
log.debug(f"Validating {self.query_type} query for target {self.target}...")
if self.query_type == "bgp_community":
# Validate input communities against configured or default regex
# pattern.
# Extended Communities, new-format
if re.match(params.features.bgp_community.regex.extended_as, self.target):
pass
# Extended Communities, 32 bit format
elif re.match(params.features.bgp_community.regex.decimal, self.target):
pass
# RFC 8092 Large Community Support
elif re.match(params.features.bgp_community.regex.large, self.target):
pass
else:
raise InputInvalid(
params.messages.invalid_input,
target=self.target,
query_type=getattr(params.branding.text, self.query_type),
)
elif self.query_type == "bgp_aspath":
# Validate input AS_PATH regex pattern against configured or
# default regex pattern.
mode = params.features.bgp_aspath.regex.mode
pattern = getattr(params.features.bgp_aspath.regex, mode)
if re.match(pattern, self.target):
pass
else:
raise InputInvalid(
params.messages.invalid_input,
target=self.target,
query_type=getattr(params.branding.text, self.query_type),
)
return self.target
def validate_query(self):
"""Validate input.
Returns:
{str} -- target if validation passes
"""
if self.query_type in ("bgp_community", "bgp_aspath"):
return self.validate_dual()
else:
if self.query_type not in ("bgp_community", "bgp_aspath"):
return self.validate_ip()
return self.target