"""Input validation functions for submitted queries.""" # Standard Library import re from ipaddress import ip_network # Project from hyperglass.util import log, get_containing_prefix from hyperglass.exceptions import InputInvalid, InputNotAllowed from hyperglass.configuration import params def _member_of(target, network): """Check if IP address belongs to network. Arguments: target {object} -- Target IPv4/IPv6 address network {object} -- ACL network Returns: {bool} -- True if target is a member of network, False if not """ log.debug(f"Checking membership of {target} for {network}") membership = False if ( network.network_address <= target.network_address and network.broadcast_address >= target.broadcast_address # NOQA: W503 ): log.debug(f"{target} is a member of {network}") membership = True return membership def _prefix_range(target, ge, le): """Verify if target prefix length is within ge/le threshold. Arguments: target {IPv4Network|IPv6Network} -- Valid IPv4/IPv6 Network ge {int} -- Greater than le {int} -- Less than Returns: {bool} -- True if target in range; False if not """ matched = False if target.prefixlen <= le and target.prefixlen >= ge: matched = True return matched def validate_ip(value, query_type, query_vrf): # noqa: C901 """Ensure input IP address is both valid and not within restricted allocations. Arguments: value {str} -- Unvalidated IP Address query_type {str} -- Valid query type query_vrf {object} -- Matched query vrf Raises: ValueError: Raised if input IP address is not an IP address. ValueError: Raised if IP address is valid, but is within a restricted range. Returns: Union[IPv4Address, IPv6Address] -- Validated IP address object """ query_type_params = getattr(params.queries, query_type) try: # Attempt to use IP object factory to create an IP address object valid_ip = ip_network(value) except ValueError: raise InputInvalid( params.messages.invalid_input, target=value, query_type=query_type_params.display_name, ) # Test the valid IP address to determine if it is: # - Unspecified (See RFC5735, RFC2373) # - Loopback (See RFC5735, RFC2373) # - Otherwise IETF Reserved # ...and returns an error if so. if valid_ip.is_reserved or valid_ip.is_unspecified or valid_ip.is_loopback: raise InputInvalid( params.messages.invalid_input, target=value, query_type=query_type_params.display_name, ) ip_version = valid_ip.version vrf_afi = getattr(query_vrf, f"ipv{ip_version}") if vrf_afi is None: raise InputInvalid( params.messages.feature_not_enabled, feature=f"IPv{ip_version}", device_name=f"VRF {query_vrf.display_name}", ) for ace in [a for a in vrf_afi.access_list if a.network.version == ip_version]: if _member_of(valid_ip, ace.network): if query_type == "bgp_route" and _prefix_range(valid_ip, ace.ge, ace.le): pass if ace.action == "permit": log.debug( "{t} is allowed by access-list {a}", t=str(valid_ip), a=repr(ace) ) break elif ace.action == "deny": raise InputNotAllowed( params.messages.acl_denied, target=str(valid_ip), denied_network=str(ace.network), ) # Handling logic for host queries, e.g. 192.0.2.1 vs. 192.0.2.0/24 if valid_ip.num_addresses == 1: # For a host query with ping or traceroute query types, convert # the query_target to an IP address instead of a network. if query_type in ("ping", "traceroute"): new_ip = valid_ip.network_address log.debug( "Converted '{o}' to '{n}' for '{q}' query", o=valid_ip, n=new_ip, q=query_type, ) valid_ip = new_ip # For a host query with bgp_route query type and force_cidr # enabled (the default), convert the host query to a network # query. elif query_type in ("bgp_route",) and vrf_afi.force_cidr: valid_ip = get_containing_prefix(valid_ip.network_address) # For a host query with bgp_route query type and force_cidr # disabled, convert the host query to a single IP address. elif query_type in ("bgp_route",) and not vrf_afi.force_cidr: valid_ip = valid_ip.network_address log.debug("Validation passed for {ip}", ip=value) return valid_ip def validate_community(value): """Validate input communities against configured or default regex pattern.""" # RFC4360: Extended Communities (New Format) if re.match(params.queries.bgp_community.pattern.extended_as, value): pass # RFC4360: Extended Communities (32 Bit Format) elif re.match(params.queries.bgp_community.pattern.decimal, value): pass # RFC8092: Large Communities elif re.match(params.queries.bgp_community.pattern.large, value): pass else: raise InputInvalid( params.messages.invalid_input, target=value, query_type=params.queries.bgp_community.display_name, ) return value def validate_aspath(value): """Validate input AS_PATH against configured or default regext pattern.""" mode = params.queries.bgp_aspath.pattern.mode pattern = getattr(params.queries.bgp_aspath.pattern, mode) if not re.match(pattern, value): raise InputInvalid( params.messages.invalid_input, target=value, query_type=params.queries.bgp_aspath.display_name, ) return value