mirror of
				https://github.com/checktheroads/hyperglass
				synced 2024-05-11 05:55:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			364 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			364 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Session handler for external http data sources."""
 | 
						|
 | 
						|
# Standard Library
 | 
						|
import re
 | 
						|
import json as _json
 | 
						|
import socket
 | 
						|
import typing as t
 | 
						|
from json import JSONDecodeError
 | 
						|
from socket import gaierror
 | 
						|
 | 
						|
# Third Party
 | 
						|
import httpx
 | 
						|
 | 
						|
# Project
 | 
						|
from hyperglass.log import log
 | 
						|
from hyperglass.util import parse_exception, repr_from_attrs
 | 
						|
from hyperglass.settings import Settings
 | 
						|
from hyperglass.constants import __version__
 | 
						|
from hyperglass.models.fields import JsonValue, HttpMethod, Primitives
 | 
						|
from hyperglass.exceptions.private import ExternalError
 | 
						|
 | 
						|
if t.TYPE_CHECKING:
 | 
						|
    # Standard Library
 | 
						|
    from types import TracebackType
 | 
						|
 | 
						|
    # Project
 | 
						|
    from hyperglass.exceptions._common import ErrorLevel
 | 
						|
    from hyperglass.models.config.logging import Http
 | 
						|
 | 
						|
D = t.TypeVar("D", bound=t.Dict)
 | 
						|
 | 
						|
 | 
						|
def _prepare_dict(_dict: D) -> D:
 | 
						|
    return _json.loads(_json.dumps(_dict, default=str))
 | 
						|
 | 
						|
 | 
						|
class BaseExternal:
 | 
						|
    """Base session handler."""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        base_url: str,
 | 
						|
        config: t.Optional["Http"] = None,
 | 
						|
        uri_prefix: str = "",
 | 
						|
        uri_suffix: str = "",
 | 
						|
        verify_ssl: bool = True,
 | 
						|
        timeout: int = 10,
 | 
						|
        parse: bool = True,
 | 
						|
    ) -> None:
 | 
						|
        """Initialize connection instance."""
 | 
						|
        self.__name__ = getattr(self, "name", "BaseExternal")
 | 
						|
        self.name = self.__name__
 | 
						|
        self.config = config
 | 
						|
        self.base_url = base_url.strip("/")
 | 
						|
        self.uri_prefix = uri_prefix.strip("/")
 | 
						|
        self.uri_suffix = uri_suffix.strip("/")
 | 
						|
        self.verify_ssl = verify_ssl
 | 
						|
        self.timeout = timeout
 | 
						|
        self.parse = parse
 | 
						|
 | 
						|
        context = httpx.create_ssl_context(verify=verify_ssl)
 | 
						|
 | 
						|
        if Settings.ca_cert is not None:
 | 
						|
            context.load_verify_locations(cafile=str(Settings.ca_cert))
 | 
						|
 | 
						|
        client_kwargs = {
 | 
						|
            "base_url": self.base_url,
 | 
						|
            "timeout": self.timeout,
 | 
						|
            "verify": context,
 | 
						|
        }
 | 
						|
 | 
						|
        self._session = httpx.Client(**client_kwargs)
 | 
						|
        self._asession = httpx.AsyncClient(**client_kwargs)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def __init_subclass__(
 | 
						|
        cls: "BaseExternal", name: t.Optional[str] = None, **kwargs: t.Any
 | 
						|
    ) -> None:
 | 
						|
        """Set correct subclass name."""
 | 
						|
        super().__init_subclass__(**kwargs)
 | 
						|
        cls.name = name or cls.__name__
 | 
						|
 | 
						|
    async def __aenter__(self: "BaseExternal") -> "BaseExternal":
 | 
						|
        """Test connection on entry."""
 | 
						|
        available = await self._atest()
 | 
						|
 | 
						|
        if available:
 | 
						|
            log.debug("Initialized session with {}", self.base_url)
 | 
						|
            return self
 | 
						|
        raise self._exception(f"Unable to create session to {self.name}")
 | 
						|
 | 
						|
    async def __aexit__(
 | 
						|
        self: "BaseExternal",
 | 
						|
        exc_type: t.Optional[t.Type[BaseException]] = None,
 | 
						|
        exc_value: t.Optional[BaseException] = None,
 | 
						|
        traceback: t.Optional["TracebackType"] = None,
 | 
						|
    ) -> True:
 | 
						|
        """Close connection on exit."""
 | 
						|
        log.debug("Closing session with {}", self.base_url)
 | 
						|
 | 
						|
        if exc_type is not None:
 | 
						|
            log.error(str(exc_value))
 | 
						|
 | 
						|
        await self._asession.aclose()
 | 
						|
        if exc_value is not None:
 | 
						|
            raise exc_value
 | 
						|
        return True
 | 
						|
 | 
						|
    def __enter__(self: "BaseExternal") -> "BaseExternal":
 | 
						|
        """Test connection on entry."""
 | 
						|
        available = self._test()
 | 
						|
 | 
						|
        if available:
 | 
						|
            log.debug("Initialized session with {}", self.base_url)
 | 
						|
            return self
 | 
						|
        raise self._exception(f"Unable to create session to {self.name}")
 | 
						|
 | 
						|
    def __exit__(
 | 
						|
        self: "BaseExternal",
 | 
						|
        exc_type: t.Optional[t.Type[BaseException]] = None,
 | 
						|
        exc_value: t.Optional[BaseException] = None,
 | 
						|
        exc_traceback: t.Optional["TracebackType"] = None,
 | 
						|
    ) -> bool:
 | 
						|
        """Close connection on exit."""
 | 
						|
        if exc_type is not None:
 | 
						|
            log.error(str(exc_value))
 | 
						|
        self._session.close()
 | 
						|
        if exc_value is not None:
 | 
						|
            raise exc_value
 | 
						|
        return True
 | 
						|
 | 
						|
    def __repr__(self: "BaseExternal") -> str:
 | 
						|
        """Return user friendly representation of instance."""
 | 
						|
        return repr_from_attrs(self, ("name", "base_url", "config", "parse"))
 | 
						|
 | 
						|
    def _exception(
 | 
						|
        self: "BaseExternal",
 | 
						|
        message: str,
 | 
						|
        exc: t.Optional[BaseException] = None,
 | 
						|
        level: "ErrorLevel" = "warning",
 | 
						|
        **kwargs: t.Any,
 | 
						|
    ) -> ExternalError:
 | 
						|
        """Add stringified exception to message if passed."""
 | 
						|
        if exc is not None:
 | 
						|
            message = f"{message!s}: {exc!s}"
 | 
						|
 | 
						|
        return ExternalError(message=message, level=level, **kwargs)
 | 
						|
 | 
						|
    def _parse_response(self: "BaseExternal", response: httpx.Response) -> t.Any:
 | 
						|
        if self.parse:
 | 
						|
            parsed = {}
 | 
						|
            try:
 | 
						|
                parsed = response.json()
 | 
						|
            except JSONDecodeError:
 | 
						|
                try:
 | 
						|
                    parsed = _json.loads(response)
 | 
						|
                except (JSONDecodeError, TypeError):
 | 
						|
                    log.error("Error parsing JSON for response {}", repr(response))
 | 
						|
                    parsed = {"data": response.text}
 | 
						|
        else:
 | 
						|
            parsed = response
 | 
						|
        return parsed
 | 
						|
 | 
						|
    def _test(self: "BaseExternal") -> bool:
 | 
						|
        """Open a low-level connection to the base URL to ensure its port is open."""
 | 
						|
        log.debug("Testing connection to {}", self.base_url)
 | 
						|
 | 
						|
        try:
 | 
						|
            # Parse out just the hostname from a URL string.
 | 
						|
            # E.g. `https://www.example.com` becomes `www.example.com`
 | 
						|
            test_host = re.sub(r"http(s)?\:\/\/", "", self.base_url)
 | 
						|
 | 
						|
            # Create a generic socket object
 | 
						|
            test_socket = socket.socket()
 | 
						|
 | 
						|
            # Try opening a low-level socket to make sure it's even
 | 
						|
            # listening on the port prior to trying to use it.
 | 
						|
            test_socket.connect((test_host, 443))
 | 
						|
 | 
						|
            # Properly shutdown & close the socket.
 | 
						|
            test_socket.shutdown(1)
 | 
						|
            test_socket.close()
 | 
						|
 | 
						|
        except gaierror as err:
 | 
						|
            # Raised if the target isn't listening on the port
 | 
						|
            raise self._exception(
 | 
						|
                f"{self.name!r} appears to be unreachable at {self.base_url!r}", err
 | 
						|
            ) from None
 | 
						|
 | 
						|
        return True
 | 
						|
 | 
						|
    async def _atest(self: "BaseExternal") -> bool:
 | 
						|
        """Open a low-level connection to the base URL to ensure its port is open."""
 | 
						|
        return self._test()
 | 
						|
 | 
						|
    def _build_request(self: "BaseExternal", **kwargs: t.Any) -> t.Dict[str, t.Any]:
 | 
						|
        """Process requests parameters into structure usable by http library."""
 | 
						|
        # Standard Library
 | 
						|
        from operator import itemgetter
 | 
						|
 | 
						|
        supported_methods = ("GET", "POST", "PUT", "DELETE", "HEAD", "PATCH")
 | 
						|
 | 
						|
        (method, endpoint, item, headers, params, data, timeout, response_required,) = itemgetter(
 | 
						|
            *kwargs.keys()
 | 
						|
        )(kwargs)
 | 
						|
 | 
						|
        if method.upper() not in supported_methods:
 | 
						|
            raise self._exception(
 | 
						|
                f'Method must be one of {", ".join(supported_methods)}. ' f"Got: {str(method)}"
 | 
						|
            )
 | 
						|
 | 
						|
        endpoint = "/".join(
 | 
						|
            i
 | 
						|
            for i in (
 | 
						|
                "",
 | 
						|
                self.uri_prefix.strip("/"),
 | 
						|
                endpoint.strip("/"),
 | 
						|
                self.uri_suffix.strip("/"),
 | 
						|
                item,
 | 
						|
            )
 | 
						|
            if i
 | 
						|
        )
 | 
						|
 | 
						|
        request = {
 | 
						|
            "method": method,
 | 
						|
            "url": endpoint,
 | 
						|
            "headers": {"user-agent": f"hyperglass/{__version__}"},
 | 
						|
        }
 | 
						|
 | 
						|
        if headers is not None:
 | 
						|
            request.update({"headers": headers})
 | 
						|
 | 
						|
        if params is not None:
 | 
						|
            params = {str(k): str(v) for k, v in params.items() if v is not None}
 | 
						|
            request["params"] = params
 | 
						|
 | 
						|
        if data is not None:
 | 
						|
            if not isinstance(data, dict):
 | 
						|
                raise self._exception(f"Data must be a dict, got: {str(data)}")
 | 
						|
            request["json"] = _prepare_dict(data)
 | 
						|
 | 
						|
        if timeout is not None:
 | 
						|
            if not isinstance(timeout, int):
 | 
						|
                try:
 | 
						|
                    timeout = int(timeout)
 | 
						|
                except TypeError:
 | 
						|
                    raise self._exception(f"Timeout must be an int, got: {str(timeout)}")
 | 
						|
            request["timeout"] = timeout
 | 
						|
 | 
						|
        log.debug("Constructed request parameters {}", request)
 | 
						|
        return request
 | 
						|
 | 
						|
    async def _arequest(  # noqa: C901
 | 
						|
        self: "BaseExternal",
 | 
						|
        method: HttpMethod,
 | 
						|
        endpoint: str,
 | 
						|
        item: t.Union[str, int, None] = None,
 | 
						|
        headers: t.Dict[str, str] = None,
 | 
						|
        params: t.Dict[str, JsonValue[Primitives]] = None,
 | 
						|
        data: t.Optional[t.Any] = None,
 | 
						|
        timeout: t.Optional[int] = None,
 | 
						|
        response_required: bool = False,
 | 
						|
    ) -> t.Any:
 | 
						|
        """Run HTTP POST operation."""
 | 
						|
        request = self._build_request(
 | 
						|
            method=method,
 | 
						|
            endpoint=endpoint,
 | 
						|
            item=item,
 | 
						|
            headers=None,
 | 
						|
            params=params,
 | 
						|
            data=data,
 | 
						|
            timeout=timeout,
 | 
						|
            response_required=response_required,
 | 
						|
        )
 | 
						|
 | 
						|
        try:
 | 
						|
            response = await self._asession.request(**request)
 | 
						|
 | 
						|
            if response.status_code not in range(200, 300):
 | 
						|
                status = httpx.codes(response.status_code)
 | 
						|
                error = self._parse_response(response)
 | 
						|
                raise self._exception(
 | 
						|
                    f'{status.name.replace("_", " ")}: {error}', level="danger"
 | 
						|
                ) from None
 | 
						|
 | 
						|
        except httpx.HTTPError as http_err:
 | 
						|
            raise self._exception(parse_exception(http_err), level="danger") from None
 | 
						|
 | 
						|
        return self._parse_response(response)
 | 
						|
 | 
						|
    async def _aget(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return await self._arequest(method="GET", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    async def _apost(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return await self._arequest(method="POST", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    async def _aput(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return await self._arequest(method="PUT", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    async def _adelete(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return await self._arequest(method="DELETE", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    async def _apatch(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return await self._arequest(method="PATCH", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    async def _ahead(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return await self._arequest(method="HEAD", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    def _request(  # noqa: C901
 | 
						|
        self: "BaseExternal",
 | 
						|
        method: HttpMethod,
 | 
						|
        endpoint: str,
 | 
						|
        item: t.Union[str, int, None] = None,
 | 
						|
        headers: t.Dict[str, str] = None,
 | 
						|
        params: t.Dict[str, JsonValue[Primitives]] = None,
 | 
						|
        data: t.Optional[t.Any] = None,
 | 
						|
        timeout: t.Optional[int] = None,
 | 
						|
        response_required: bool = False,
 | 
						|
    ) -> t.Any:
 | 
						|
        """Run HTTP POST operation."""
 | 
						|
        request = self._build_request(
 | 
						|
            method=method,
 | 
						|
            endpoint=endpoint,
 | 
						|
            item=item,
 | 
						|
            headers=None,
 | 
						|
            params=params,
 | 
						|
            data=data,
 | 
						|
            timeout=timeout,
 | 
						|
            response_required=response_required,
 | 
						|
        )
 | 
						|
 | 
						|
        try:
 | 
						|
            response = self._session.request(**request)
 | 
						|
 | 
						|
            if response.status_code not in range(200, 300):
 | 
						|
                status = httpx.codes(response.status_code)
 | 
						|
                error = self._parse_response(response)
 | 
						|
                raise self._exception(
 | 
						|
                    f'{status.name.replace("_", " ")}: {error}', level="danger"
 | 
						|
                ) from None
 | 
						|
 | 
						|
        except httpx.HTTPError as http_err:
 | 
						|
            raise self._exception(parse_exception(http_err), level="danger") from None
 | 
						|
 | 
						|
        return self._parse_response(response)
 | 
						|
 | 
						|
    def _get(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return self._request(method="GET", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    def _post(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return self._request(method="POST", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    def _put(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return self._request(method="PUT", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    def _delete(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return self._request(method="DELETE", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    def _patch(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return self._request(method="PATCH", endpoint=endpoint, **kwargs)
 | 
						|
 | 
						|
    def _head(self: "BaseExternal", endpoint: str, **kwargs: t.Any) -> t.Any:
 | 
						|
        return self._request(method="HEAD", endpoint=endpoint, **kwargs)
 |