1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00

fix docstrings

This commit is contained in:
checktheroads
2019-12-31 13:30:55 -07:00
parent 4fb4755cba
commit 004ae06e47
3 changed files with 192 additions and 40 deletions

View File

@@ -22,6 +22,14 @@ class Construct:
"""Construct SSH commands/REST API parameters from validated query data.""" """Construct SSH commands/REST API parameters from validated query data."""
def get_device_vrf(self): def get_device_vrf(self):
"""Match query VRF to device VRF.
Raises:
HyperglassError: Raised if VRFs do not match.
Returns:
{object} -- Matched VRF object
"""
_device_vrf = None _device_vrf = None
for vrf in self.device.vrfs: for vrf in self.device.vrfs:
if vrf.name == self.query_vrf: if vrf.name == self.query_vrf:
@@ -35,6 +43,13 @@ class Construct:
return _device_vrf return _device_vrf
def __init__(self, device, query_data, transport): def __init__(self, device, query_data, transport):
"""Initialize command construction.
Arguments:
device {object} -- Device object
query_data {object} -- Validated query object
transport {str} -- Transport name; 'scrape' or 'rest'
"""
self.device = device self.device = device
self.query_data = query_data self.query_data = query_data
self.transport = transport self.transport = transport
@@ -43,7 +58,14 @@ class Construct:
self.device_vrf = self.get_device_vrf() self.device_vrf = self.get_device_vrf()
def format_target(self, target): def format_target(self, target):
"""Formats query target based on NOS requirement""" """Format query target based on NOS requirement.
Arguments:
target {str} -- Query target
Returns:
{str} -- Formatted target
"""
if self.device.nos in target_format_space: if self.device.nos in target_format_space:
_target = re.sub(r"\/", r" ", target) _target = re.sub(r"\/", r" ", target)
else: else:
@@ -54,20 +76,36 @@ class Construct:
@staticmethod @staticmethod
def device_commands(nos, afi, query_type): def device_commands(nos, afi, query_type):
""" """Construct class attribute path for device commansd.
Constructs class attribute path from input parameters, returns
class attribute value for command. This is required because This is required because class attributes are set dynamically
class attributes are set dynamically when devices.yaml is when devices.yaml is imported, so the attribute path is unknown
imported, so the attribute path is unknown until runtime. until runtime.
Arguments:
nos {str} -- NOS short name
afi {str} -- Address family
query_type {str} -- Query type
Returns:
{str} -- Dotted attribute path, e.g. 'cisco_ios.ipv4.bgp_route'
""" """
cmd_path = f"{nos}.{afi}.{query_type}" cmd_path = f"{nos}.{afi}.{query_type}"
return operator.attrgetter(cmd_path)(commands) return operator.attrgetter(cmd_path)(commands)
@staticmethod @staticmethod
def get_cmd_type(query_protocol, query_vrf): def get_cmd_type(query_protocol, query_vrf):
""" """Construct AFI string.
Constructs AFI string. If query_vrf is specified, AFI prefix is
"vpnv", if not, AFI prefix is "ipv" If query_vrf is specified, AFI prefix is "vpnv".
If not, AFI prefix is "ipv".
Arguments:
query_protocol {str} -- 'ipv4' or 'ipv6'
query_vrf {str} -- Query VRF name
Returns:
{str} -- Constructed command name
""" """
if query_vrf and query_vrf != "default": if query_vrf and query_vrf != "default":
cmd_type = f"{query_protocol}_vpn" cmd_type = f"{query_protocol}_vpn"
@@ -76,8 +114,11 @@ class Construct:
return cmd_type return cmd_type
def ping(self): def ping(self):
"""Constructs ping query parameters from pre-validated input""" """Construct ping query parameters from pre-validated input.
Returns:
{str} -- SSH command or stringified JSON
"""
log.debug( log.debug(
f"Constructing ping query for {self.query_target} via {self.transport}" f"Constructing ping query for {self.query_target} via {self.transport}"
) )
@@ -113,8 +154,10 @@ class Construct:
return query return query
def traceroute(self): def traceroute(self):
""" """Construct traceroute query parameters from pre-validated input.
Constructs traceroute query parameters from pre-validated input.
Returns:
{str} -- SSH command or stringified JSON
""" """
log.debug( log.debug(
( (
@@ -154,8 +197,10 @@ class Construct:
return query return query
def bgp_route(self): def bgp_route(self):
""" """Construct bgp_route query parameters from pre-validated input.
Constructs bgp_route query parameters from pre-validated input.
Returns:
{str} -- SSH command or stringified JSON
""" """
log.debug( log.debug(
f"Constructing bgp_route query for {self.query_target} via {self.transport}" f"Constructing bgp_route query for {self.query_target} via {self.transport}"
@@ -192,9 +237,10 @@ class Construct:
return query return query
def bgp_community(self): def bgp_community(self):
""" """Construct bgp_community query parameters from pre-validated input.
Constructs bgp_community query parameters from pre-validated
input. Returns:
{str} -- SSH command or stringified JSON
""" """
log.debug( log.debug(
( (
@@ -243,8 +289,10 @@ class Construct:
return query return query
def bgp_aspath(self): def bgp_aspath(self):
""" """Construct bgp_aspath query parameters from pre-validated input.
Constructs bgp_aspath query parameters from pre-validated input.
Returns:
{str} -- SSH command or stringified JSON
""" """
log.debug( log.debug(
( (

View File

@@ -11,7 +11,19 @@ from hyperglass.exceptions import RestError
async def jwt_decode(payload, secret): async def jwt_decode(payload, secret):
"""Decode & validate an encoded JSON Web Token (JWT).""" """Decode & validate an encoded JSON Web Token (JWT).
Arguments:
payload {str} -- Raw JWT payload
secret {str} -- JWT secret
Raises:
RestError: Raised if decoded payload is improperly formatted
or if the JWT is not able to be decoded.
Returns:
{str} -- Decoded response payload
"""
try: try:
decoded = jwt.decode(payload, secret, algorithm="HS256") decoded = jwt.decode(payload, secret, algorithm="HS256")
decoded = decoded["payload"] decoded = decoded["payload"]
@@ -21,7 +33,16 @@ async def jwt_decode(payload, secret):
async def jwt_encode(payload, secret, duration): async def jwt_encode(payload, secret, duration):
"""Encode a query to a JSON Web Token (JWT).""" """Encode a query to a JSON Web Token (JWT).
Arguments:
payload {str} -- Stringified JSON request
secret {str} -- JWT secret
duration {int} -- Number of seconds claim is valid
Returns:
str -- Encoded request payload
"""
token = { token = {
"payload": payload, "payload": payload,
"nbf": datetime.datetime.utcnow(), "nbf": datetime.datetime.utcnow(),

View File

@@ -17,13 +17,15 @@ from hyperglass.util import log
class IPType: class IPType:
""" """Build IPv4 & IPv6 attributes for input target.
Passes input through IPv4/IPv6 regex patterns to determine if input Passes input through IPv4/IPv6 regex patterns to determine if input
is formatted as a host (e.g. 192.0.2.1), or as CIDR is formatted as a host (e.g. 192.0.2.1), or as CIDR
(e.g. 192.0.2.0/24). is_host() and is_cidr() return a boolean. (e.g. 192.0.2.0/24). is_host() and is_cidr() return a boolean.
""" """
def __init__(self): def __init__(self):
"""Initialize attribute builder."""
self.ipv4_host = ( self.ipv4_host = (
r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4]" r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4]"
r"[0-9]|[01]?[0-9][0-9]?)?$" r"[0-9]|[01]?[0-9][0-9]?)?$"
@@ -58,7 +60,14 @@ class IPType:
) )
def is_host(self, target): def is_host(self, target):
"""Tests input to see if formatted as host""" """Test target to see if it is formatted as a host address.
Arguments:
target {str} -- Target IPv4/IPv6 address
Returns:
{bool} -- True if host, False if not
"""
ip_version = ipaddress.ip_network(target).version ip_version = ipaddress.ip_network(target).version
state = False state = False
if ip_version == 4 and re.match(self.ipv4_host, target): if ip_version == 4 and re.match(self.ipv4_host, target):
@@ -70,7 +79,14 @@ class IPType:
return state return state
def is_cidr(self, target): def is_cidr(self, target):
"""Tests input to see if formatted as CIDR""" """Test target to see if it is formatted as CIDR.
Arguments:
target {str} -- Target IPv4/IPv6 address
Returns:
{bool} -- True if CIDR, False if not
"""
ip_version = ipaddress.ip_network(target).version ip_version = ipaddress.ip_network(target).version
state = False state = False
if ip_version == 4 and re.match(self.ipv4_cidr, target): if ip_version == 4 and re.match(self.ipv4_cidr, target):
@@ -81,7 +97,17 @@ class IPType:
def ip_validate(target): def ip_validate(target):
"""Validates if input is a valid IP address""" """Validate if input is a valid IP address.
Arguments:
target {str} -- Unvalidated IPv4/IPv6 address
Raises:
ValueError: Raised if target is not a valid IPv4 or IPv6 address
Returns:
{object} -- Valid IPv4Network/IPv6Network object
"""
try: try:
valid_ip = ipaddress.ip_network(target) valid_ip = ipaddress.ip_network(target)
if valid_ip.is_reserved or valid_ip.is_unspecified or valid_ip.is_loopback: if valid_ip.is_reserved or valid_ip.is_unspecified or valid_ip.is_loopback:
@@ -97,16 +123,31 @@ def ip_validate(target):
def ip_access_list(query_data, device): def ip_access_list(query_data, device):
""" """Check VRF access list for matching prefixes.
Check VRF access list for matching prefixes, returns an error if a
match is found. Arguments:
query_data {object} -- Query object
device {object} -- Device object
Raises:
HyperglassError: Raised if query VRF and ACL VRF do not match
ValueError: Raised if an ACL deny match is found
ValueError: Raised if no ACL permit match is found
Returns:
{str} -- Allowed target
""" """
log.debug(f'Checking Access List for: {query_data["query_target"]}') log.debug(f'Checking Access List for: {query_data["query_target"]}')
def member_of(target, network): def _member_of(target, network):
""" """Check if IP address belongs to network.
Returns boolean if an input target IP is a member of an input
network. Arguments:
target {object} -- Target IPv4/IPv6 address
network {object} -- ACL network
Returns:
{bool} -- True if target is a member of network, False if not
""" """
log.debug(f"Checking membership of {target} for {network}") log.debug(f"Checking membership of {target} for {network}")
@@ -141,12 +182,12 @@ def ip_access_list(query_data, device):
a: n for a, n in ace.items() for ace in vrf_acl if n.version == target_ver a: n for a, n in ace.items() for ace in vrf_acl if n.version == target_ver
}.items(): }.items():
# If the target is a member of an allowed network, exit successfully. # If the target is a member of an allowed network, exit successfully.
if member_of(target, net) and action == "allow": if _member_of(target, net) and action == "allow":
log.debug(f"{target} is specifically allowed") log.debug(f"{target} is specifically allowed")
return target return target
# If the target is a member of a denied network, return an error. # If the target is a member of a denied network, return an error.
elif member_of(target, net) and action == "deny": elif _member_of(target, net) and action == "deny":
log.debug(f"{target} is specifically denied") log.debug(f"{target} is specifically denied")
_exception = ValueError(params.messages.acl_denied) _exception = ValueError(params.messages.acl_denied)
_exception.details = {"denied_network": str(net)} _exception.details = {"denied_network": str(net)}
@@ -160,8 +201,13 @@ def ip_access_list(query_data, device):
def ip_attributes(target): def ip_attributes(target):
""" """Construct dictionary of validated IP attributes for repeated use.
Construct dictionary of validated IP attributes for repeated use.
Arguments:
target {str} -- Target IPv4/IPv6 address
Returns:
{dict} -- IP attribute dict
""" """
network = ipaddress.ip_network(target) network = ipaddress.ip_network(target)
addr = network.network_address addr = network.network_address
@@ -180,7 +226,21 @@ def ip_attributes(target):
def ip_type_check(query_type, target, device): def ip_type_check(query_type, target, device):
"""Checks multiple IP address related validation parameters""" """Check multiple IP address related validation parameters.
Arguments:
query_type {str} -- Query type
target {str} -- Query target
device {object} -- Device
Raises:
ValueError: Raised if max prefix length check fails
ValueError: Raised if Requires IPv6 CIDR check fails
ValueError: Raised if directed CIDR check fails
Returns:
{str} -- target if checks pass
"""
prefix_attr = ip_attributes(target) prefix_attr = ip_attributes(target)
log.debug(f"IP Attributes:\n{prefix_attr}") log.debug(f"IP Attributes:\n{prefix_attr}")
@@ -218,7 +278,8 @@ def ip_type_check(query_type, target, device):
class Validate: class Validate:
""" """Validates query data with selected device.
Accepts raw input and associated device parameters from execute.py Accepts raw input and associated device parameters from execute.py
and validates the input based on specific query type. Returns and validates the input based on specific query type. Returns
boolean for validity, specific error message, and status code. boolean for validity, specific error message, and status code.
@@ -232,7 +293,16 @@ class Validate:
self.target = target self.target = target
def validate_ip(self): def validate_ip(self):
"""Validates IPv4/IPv6 Input""" """Validate IPv4/IPv6 Input.
Raises:
InputInvalid: Raised if IP validation fails
InputNotAllowed: Raised if ACL checks fail
InputNotAllowed: Raised if IP type checks fail
Returns:
{str} -- target if validation passes
"""
log.debug(f"Validating {self.query_type} query for target {self.target}...") log.debug(f"Validating {self.query_type} query for target {self.target}...")
# Perform basic validation of an IP address, return error if # Perform basic validation of an IP address, return error if
@@ -267,7 +337,15 @@ class Validate:
return self.target return self.target
def validate_dual(self): def validate_dual(self):
"""Validates Dual-Stack Input""" """Validate dual-stack input such as bgp_community & bgp_aspath.
Raises:
InputInvalid: Raised if target community is invalid.
InputInvalid: Raised if target AS_PATh is invalid.
Returns:
{str} -- target if validation passes.
"""
log.debug(f"Validating {self.query_type} query for target {self.target}...") log.debug(f"Validating {self.query_type} query for target {self.target}...")
if self.query_type == "bgp_community": if self.query_type == "bgp_community":
@@ -307,6 +385,11 @@ class Validate:
return self.target return self.target
def validate_query(self): def validate_query(self):
"""Validate input.
Returns:
{str} -- target if validation passes
"""
if self.query_type in ("bgp_community", "bgp_aspath"): if self.query_type in ("bgp_community", "bgp_aspath"):
return self.validate_dual() return self.validate_dual()
else: else: