diff --git a/.flake8 b/.flake8 index e15d858..c629b2b 100644 --- a/.flake8 +++ b/.flake8 @@ -3,7 +3,7 @@ max-line-length=88 count=True show-source=False statistics=True -exclude=.git, __pycache__, hyperglass/api/examples/*.py +exclude=.git, __pycache__, hyperglass/api/examples/*.py, hyperglass/compat/_sshtunnel.py, hyperglass/test.py filename=*.py per-file-ignores= # Disable redefinition warning for exception handlers diff --git a/hyperglass/api/__init__.py b/hyperglass/api/__init__.py index 166e292..dfd99c3 100644 --- a/hyperglass/api/__init__.py +++ b/hyperglass/api/__init__.py @@ -7,7 +7,7 @@ from pathlib import Path # Third Party from fastapi import FastAPI from fastapi.exceptions import RequestValidationError -from starlette.responses import UJSONResponse +from starlette.responses import JSONResponse from starlette.exceptions import HTTPException as StarletteHTTPException from fastapi.openapi.utils import get_openapi from starlette.staticfiles import StaticFiles @@ -71,7 +71,7 @@ app = FastAPI( title=params.site_title, description=params.site_description, version=__version__, - default_response_class=UJSONResponse, + default_response_class=JSONResponse, **DOCS_PARAMS, ) @@ -169,7 +169,7 @@ app.add_api_route( endpoint=routers, methods=["GET"], response_model=List[RoutersResponse], - response_class=UJSONResponse, + response_class=JSONResponse, summary=params.docs.devices.summary, description=params.docs.devices.description, tags=[params.docs.devices.title], @@ -178,7 +178,7 @@ app.add_api_route( path="/api/queries", endpoint=queries, methods=["GET"], - response_class=UJSONResponse, + response_class=JSONResponse, response_model=List[SupportedQueryResponse], summary=params.docs.queries.summary, description=params.docs.queries.description, @@ -197,7 +197,7 @@ app.add_api_route( }, response_model=QueryResponse, tags=[params.docs.query.title], - response_class=UJSONResponse, + response_class=JSONResponse, ) # Enable certificate import route only if a device using diff --git a/hyperglass/api/error_handlers.py b/hyperglass/api/error_handlers.py index 5e3e168..d2ac9bd 100644 --- a/hyperglass/api/error_handlers.py +++ b/hyperglass/api/error_handlers.py @@ -1,7 +1,7 @@ """API Error Handlers.""" # Third Party -from starlette.responses import UJSONResponse +from starlette.responses import JSONResponse # Project from hyperglass.configuration import params @@ -9,7 +9,7 @@ from hyperglass.configuration import params async def default_handler(request, exc): """Handle uncaught errors.""" - return UJSONResponse( + return JSONResponse( {"output": params.messages.general, "level": "danger", "keywords": []}, status_code=500, ) @@ -17,7 +17,7 @@ async def default_handler(request, exc): async def http_handler(request, exc): """Handle web server errors.""" - return UJSONResponse( + return JSONResponse( {"output": exc.detail, "level": "danger", "keywords": []}, status_code=exc.status_code, ) @@ -25,7 +25,7 @@ async def http_handler(request, exc): async def app_handler(request, exc): """Handle application errors.""" - return UJSONResponse( + return JSONResponse( {"output": exc.message, "level": exc.level, "keywords": exc.keywords}, status_code=exc.status_code, ) @@ -34,7 +34,7 @@ async def app_handler(request, exc): async def validation_handler(request, exc): """Handle Pydantic validation errors raised by FastAPI.""" error = exc.errors()[0] - return UJSONResponse( + return JSONResponse( {"output": error["msg"], "level": "error", "keywords": error["loc"]}, status_code=422, ) diff --git a/hyperglass/api/models/rfc8522.py b/hyperglass/api/models/rfc8522.py index ae4d55c..9d55dd1 100644 --- a/hyperglass/api/models/rfc8522.py +++ b/hyperglass/api/models/rfc8522.py @@ -8,7 +8,6 @@ from typing import List, Union, Optional from datetime import datetime # Third Party -import ujson from pydantic import BaseModel, StrictInt, StrictStr, StrictFloat, constr, validator @@ -27,8 +26,6 @@ GET /.well-known/looking-glass/v1/cmd class _HyperglassQuery(BaseModel): class Config: - json_loads = ujson.loads - json_dumps = ujson.dumps validate_all = True validate_assignment = True diff --git a/hyperglass/api/models/validators.py b/hyperglass/api/models/validators.py index 88f24ef..9b35bab 100644 --- a/hyperglass/api/models/validators.py +++ b/hyperglass/api/models/validators.py @@ -179,7 +179,7 @@ def validate_aspath(value): mode = params.queries.bgp_aspath.pattern.mode pattern = getattr(params.queries.bgp_aspath.pattern, mode) - if not re.match(pattern, value): + if not bool(re.match(pattern, value)): raise InputInvalid( params.messages.invalid_input, target=value, diff --git a/hyperglass/compat/__init__.py b/hyperglass/compat/__init__.py new file mode 100644 index 0000000..0352479 --- /dev/null +++ b/hyperglass/compat/__init__.py @@ -0,0 +1 @@ +"""Functions for maintaining compatability with older Python versions or libraries.""" diff --git a/hyperglass/compat.py b/hyperglass/compat/_asyncio.py similarity index 97% rename from hyperglass/compat.py rename to hyperglass/compat/_asyncio.py index c784266..7ba156a 100644 --- a/hyperglass/compat.py +++ b/hyperglass/compat/_asyncio.py @@ -1,4 +1,4 @@ -"""Functions for maintaining compatability with older Python versions.""" +"""Functions for maintaining asyncio compatability with other versions of Python.""" # Standard Library import sys diff --git a/hyperglass/compat/_sshtunnel.py b/hyperglass/compat/_sshtunnel.py new file mode 100644 index 0000000..3f9447d --- /dev/null +++ b/hyperglass/compat/_sshtunnel.py @@ -0,0 +1,1624 @@ +"""Initiate SSH tunnels via a remote gateway. + +Copyright (c) 2014-2019 Pahaz Blinov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +*sshtunnel* - Initiate SSH tunnels via a remote gateway. + +``sshtunnel`` works by opening a port forwarding SSH connection in the +background, using threads. + +The connection(s) are closed when explicitly calling the +:meth:`SSHTunnelForwarder.stop` method or using it as a context. +""" + +# Standard Library +import os +import sys +import queue +import socket +import getpass +import logging +import argparse +import warnings +import threading +import socketserver +from select import select +from binascii import hexlify + +# Third Party +import paramiko + +# Project +from hyperglass.util import log +from hyperglass.constants import LOG_FMT +from hyperglass.configuration import params + +if params.debug: + logging.getLogger("paramiko").setLevel(logging.DEBUG) + +TUNNEL_TIMEOUT = 1.0 #: Timeout (seconds) for tunnel connection +_DAEMON = False #: Use daemon threads in connections +_CONNECTION_COUNTER = 1 +_LOCK = threading.Lock() + +DEPRECATIONS = { + "ssh_address": "ssh_address_or_host", + "ssh_host": "ssh_address_or_host", + "ssh_private_key": "ssh_pkey", + "raise_exception_if_any_forwarder_have_a_problem": "mute_exceptions", +} + +if os.name == "posix": + DEFAULT_SSH_DIRECTORY = "~/.ssh" + UnixStreamServer = socketserver.UnixStreamServer +else: + DEFAULT_SSH_DIRECTORY = "~/ssh" + UnixStreamServer = socketserver.TCPServer + +#: Path of optional ssh configuration file +SSH_CONFIG_FILE = os.path.join(DEFAULT_SSH_DIRECTORY, "config") + +######################## +# # +# Utils # +# # +######################## + + +class DefaultHandlers: + sink = sys.stdout + format = LOG_FMT + level = "INFO" + + +def check_host(host): + assert isinstance(host, str), "IP is not a string ({0})".format(type(host).__name__) + + +def check_port(port): + assert isinstance(port, int), "PORT is not a number" + assert port >= 0, "PORT < 0 ({0})".format(port) + + +def check_address(address): + """Check if the format of the address is correct. + + Arguments: + address (tuple): + (``str``, ``int``) representing an IP address and port, + respectively + + .. note:: + alternatively a local ``address`` can be a ``str`` when working + with UNIX domain sockets, if supported by the platform + Raises: + ValueError: + raised when address has an incorrect format + + Example: + >>> check_address(('127.0.0.1', 22)) + """ + if isinstance(address, tuple): + check_host(address[0]) + check_port(address[1]) + elif isinstance(address, str): + if os.name != "posix": + raise ValueError("Platform does not support UNIX domain sockets") + if not ( + os.path.exists(address) or os.access(os.path.dirname(address), os.W_OK) + ): + raise ValueError( + "ADDRESS not a valid socket domain socket ({0})".format(address) + ) + else: + raise ValueError( + "ADDRESS is not a tuple, string, or character buffer " + "({0})".format(type(address).__name__) + ) + + +def check_addresses(address_list, is_remote=False): + """ + Check if the format of the addresses is correct + + Arguments: + address_list (list[tuple]): + Sequence of (``str``, ``int``) pairs, each representing an IP + address and port respectively + + .. note:: + when supported by the platform, one or more of the elements in + the list can be of type ``str``, representing a valid UNIX + domain socket + + is_remote (boolean): + Whether or not the address list + Raises: + AssertionError: + raised when ``address_list`` contains an invalid element + ValueError: + raised when any address in the list has an incorrect format + + Example: + + >>> check_addresses([('127.0.0.1', 22), ('127.0.0.1', 2222)]) + """ + assert all(isinstance(x, (tuple, str)) for x in address_list) + if is_remote and any(isinstance(x, str) for x in address_list): + raise AssertionError("UNIX domain sockets not allowed for remote" "addresses") + + for address in address_list: + check_address(address) + + +def address_to_str(address): + if isinstance(address, tuple): + return "{0[0]}:{0[1]}".format(address) + return str(address) + + +def get_connection_id(): + global _CONNECTION_COUNTER + with _LOCK: + uid = _CONNECTION_COUNTER + _CONNECTION_COUNTER += 1 + return uid + + +def _remove_none_values(dictionary): + """ Remove dictionary keys whose value is None.""" + return list(map(dictionary.pop, [i for i in dictionary if dictionary[i] is None])) + + +######################## +# # +# Errors # +# # +######################## + + +class BaseSSHTunnelForwarderError(Exception): + """ Exception raised by :class:`SSHTunnelForwarder` errors """ + + def __init__(self, *args, **kwargs): + self.value = kwargs.pop("value", args[0] if args else "") + + def __str__(self): + return self.value + + +class HandlerSSHTunnelForwarderError(BaseSSHTunnelForwarderError): + """ Exception for Tunnel forwarder errors """ + + pass + + +######################## +# # +# Handlers # +# # +######################## + + +class _ForwardHandler(socketserver.BaseRequestHandler): + """ Base handler for tunnel connections """ + + remote_address = None + ssh_transport = None + logger = None + info = None + + def _redirect(self, chan): + while chan.active: + rqst, _, _ = select([self.request, chan], [], [], 5) + if self.request in rqst: + data = self.request.recv(1024) + if not data: + break + self.logger.trace( + ">>> OUT {0} send to {1}: {2} >>>".format( + self.info, self.remote_address, hexlify(data) + ), + ) + chan.sendall(data) + if chan in rqst: # else + if not chan.recv_ready(): + break + data = chan.recv(1024) + self.logger.trace( + "<<< IN {0} recv: {1} <<<".format(self.info, hexlify(data)), + ) + self.request.sendall(data) + + def handle(self): + uid = get_connection_id() + self.info = "#{0} <-- {1}".format( + uid, self.client_address or self.server.local_address + ) + src_address = self.request.getpeername() + if not isinstance(src_address, tuple): + src_address = ("dummy", 12345) + try: + chan = self.ssh_transport.open_channel( + kind="direct-tcpip", + dest_addr=self.remote_address, + src_addr=src_address, + timeout=TUNNEL_TIMEOUT, + ) + except paramiko.SSHException: + chan = None + if chan is None: + msg = "{0} to {1} was rejected by the SSH server".format( + self.info, self.remote_address + ) + self.logger.trace(msg) + raise HandlerSSHTunnelForwarderError(msg) + + self.logger.trace("{0} connected".format(self.info)) + try: + self._redirect(chan) + except socket.error: + # Sometimes a RST is sent and a socket error is raised, treat this + # exception. It was seen that a 3way FIN is processed later on, so + # no need to make an ordered close of the connection here or raise + # the exception beyond this point... + self.logger.trace("{0} sending RST".format(self.info)) + except Exception as e: + self.logger.trace("{0} error: {1}".format(self.info, repr(e))) + finally: + chan.close() + self.request.close() + self.logger.trace("{0} connection closed.".format(self.info)) + + +class _ForwardServer(socketserver.TCPServer): # Not Threading + """ + Non-threading version of the forward server + """ + + allow_reuse_address = True # faster rebinding + + def __init__(self, *args, **kwargs): + self.logger = kwargs.pop("logger") or log + self.tunnel_ok = queue.Queue() + socketserver.TCPServer.__init__(self, *args, **kwargs) + + def handle_error(self, request, client_address): + (exc_class, exc, tb) = sys.exc_info() + self.logger.error( + "Could not establish connection from {0} to remote " "side of the tunnel", + request.getsockname(), + ) + self.tunnel_ok.put(False) + + @property + def local_address(self): + return self.server_address + + @property + def local_host(self): + return self.server_address[0] + + @property + def local_port(self): + return self.server_address[1] + + @property + def remote_address(self): + return self.RequestHandlerClass.remote_address + + @property + def remote_host(self): + return self.RequestHandlerClass.remote_address[0] + + @property + def remote_port(self): + return self.RequestHandlerClass.remote_address[1] + + +class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer): + """ + Allow concurrent connections to each tunnel + """ + + # If True, cleanly stop threads created by ThreadingMixIn when quitting + daemon_threads = _DAEMON + + +class _UnixStreamForwardServer(UnixStreamServer): + """ + Serve over UNIX domain sockets (does not work on Windows) + """ + + def __init__(self, *args, **kwargs): + self.logger = kwargs.pop("logger") or log + self.tunnel_ok = queue.Queue() + UnixStreamServer.__init__(self, *args, **kwargs) + + @property + def local_address(self): + return self.server_address + + @property + def local_host(self): + return None + + @property + def local_port(self): + return None + + @property + def remote_address(self): + return self.RequestHandlerClass.remote_address + + @property + def remote_host(self): + return self.RequestHandlerClass.remote_address[0] + + @property + def remote_port(self): + return self.RequestHandlerClass.remote_address[1] + + +class _ThreadingUnixStreamForwardServer( + socketserver.ThreadingMixIn, _UnixStreamForwardServer +): + """ + Allow concurrent connections to each tunnel + """ + + # If True, cleanly stop threads created by ThreadingMixIn when quitting + daemon_threads = _DAEMON + + +class SSHTunnelForwarder: + """ + **SSH tunnel class** + + - Initialize a SSH tunnel to a remote host according to the input + arguments + + - Optionally: + + Read an SSH configuration file (typically ``~/.ssh/config``) + + Load keys from a running SSH agent (i.e. Pageant, GNOME Keyring) + + Raises: + + :class:`.BaseSSHTunnelForwarderError`: + raised by SSHTunnelForwarder class methods + + :class:`.HandlerSSHTunnelForwarderError`: + raised by tunnel forwarder threads + + .. note:: + Attributes ``mute_exceptions`` and + ``raise_exception_if_any_forwarder_have_a_problem`` + (deprecated) may be used to silence most exceptions raised + from this class + + Keyword Arguments: + + ssh_address_or_host (tuple or str): + IP or hostname of ``REMOTE GATEWAY``. It may be a two-element + tuple (``str``, ``int``) representing IP and port respectively, + or a ``str`` representing the IP address only + + .. versionadded:: 0.0.4 + + ssh_config_file (str): + SSH configuration file that will be read. If explicitly set to + ``None``, parsing of this configuration is omitted + + Default: :const:`SSH_CONFIG_FILE` + + .. versionadded:: 0.0.4 + + ssh_host_key (str): + Representation of a line in an OpenSSH-style "known hosts" + file. + + ``REMOTE GATEWAY``'s key fingerprint will be compared to this + host key in order to prevent against SSH server spoofing. + Important when using passwords in order not to accidentally + do a login attempt to a wrong (perhaps an attacker's) machine + + ssh_username (str): + Username to authenticate as in ``REMOTE SERVER`` + + Default: current local user name + + ssh_password (str): + Text representing the password used to connect to ``REMOTE + SERVER`` or for unlocking a private key. + + .. note:: + Avoid coding secret password directly in the code, since this + may be visible and make your service vulnerable to attacks + + ssh_port (int): + Optional port number of the SSH service on ``REMOTE GATEWAY``, + when `ssh_address_or_host`` is a ``str`` representing the + IP part of ``REMOTE GATEWAY``'s address + + Default: 22 + + ssh_pkey (str or paramiko.PKey): + **Private** key file name (``str``) to obtain the public key + from or a **public** key (:class:`paramiko.pkey.PKey`) + + ssh_private_key_password (str): + Password for an encrypted ``ssh_pkey`` + + .. note:: + Avoid coding secret password directly in the code, since this + may be visible and make your service vulnerable to attacks + + ssh_proxy (socket-like object or tuple): + Proxy where all SSH traffic will be passed through. + It might be for example a :class:`paramiko.proxy.ProxyCommand` + instance. + See either the :class:`paramiko.transport.Transport`'s sock + parameter documentation or ``ProxyCommand`` in ``ssh_config(5)`` + for more information. + + It is also possible to specify the proxy address as a tuple of + type (``str``, ``int``) representing proxy's IP and port + + .. note:: + Ignored if ``ssh_proxy_enabled`` is False + + .. versionadded:: 0.0.5 + + ssh_proxy_enabled (boolean): + Enable/disable SSH proxy. If True and user's + ``ssh_config_file`` contains a ``ProxyCommand`` directive + that matches the specified ``ssh_address_or_host``, + a :class:`paramiko.proxy.ProxyCommand` object will be created where + all SSH traffic will be passed through + + Default: ``True`` + + .. versionadded:: 0.0.4 + + local_bind_address (tuple): + Local tuple in the format (``str``, ``int``) representing the + IP and port of the local side of the tunnel. Both elements in + the tuple are optional so both ``('', 8000)`` and + ``('10.0.0.1', )`` are valid values + + Default: ``('0.0.0.0', RANDOM_PORT)`` + + .. versionchanged:: 0.0.8 + Added the ability to use a UNIX domain socket as local bind + address + + local_bind_addresses (list[tuple]): + In case more than one tunnel is established at once, a list + of tuples (in the same format as ``local_bind_address``) + can be specified, such as [(ip1, port_1), (ip_2, port2), ...] + + Default: ``[local_bind_address]`` + + .. versionadded:: 0.0.4 + + remote_bind_address (tuple): + Remote tuple in the format (``str``, ``int``) representing the + IP and port of the remote side of the tunnel. + + remote_bind_addresses (list[tuple]): + In case more than one tunnel is established at once, a list + of tuples (in the same format as ``remote_bind_address``) + can be specified, such as [(ip1, port_1), (ip_2, port2), ...] + + Default: ``[remote_bind_address]`` + + .. versionadded:: 0.0.4 + + allow_agent (boolean): + Enable/disable load of keys from an SSH agent + + Default: ``True`` + + .. versionadded:: 0.0.8 + + host_pkey_directories (list): + Look for pkeys in folders on this list, for example ['~/.ssh']. + + Default: ``None`` (disabled) + + .. versionadded:: 0.1.4 + + compression (boolean): + Turn on/off transport compression. By default compression is + disabled since it may negatively affect interactive sessions + + Default: ``False`` + + .. versionadded:: 0.0.8 + + logger (logging.Logger): + logging instance for sshtunnel and paramiko + + Default: :class:`logging.Logger` instance with a single + :class:`logging.StreamHandler` handler and + :const:`DEFAULT_LOGLEVEL` level + + .. versionadded:: 0.0.3 + + mute_exceptions (boolean): + Allow silencing :class:`BaseSSHTunnelForwarderError` or + :class:`HandlerSSHTunnelForwarderError` exceptions when enabled + + Default: ``False`` + + .. versionadded:: 0.0.8 + + set_keepalive (float): + Interval in seconds defining the period in which, if no data + was sent over the connection, a *'keepalive'* packet will be + sent (and ignored by the remote host). This can be useful to + keep connections alive over a NAT + + Default: 0.0 (no keepalive packets are sent) + + .. versionadded:: 0.0.7 + + threaded (boolean): + Allow concurrent connections over a single tunnel + + Default: ``True`` + + .. versionadded:: 0.0.3 + + ssh_address (str): + Superseded by ``ssh_address_or_host``, tuple of type (str, int) + representing the IP and port of ``REMOTE SERVER`` + + .. deprecated:: 0.0.4 + + ssh_host (str): + Superseded by ``ssh_address_or_host``, tuple of type + (str, int) representing the IP and port of ``REMOTE SERVER`` + + .. deprecated:: 0.0.4 + + ssh_private_key (str or paramiko.PKey): + Superseded by ``ssh_pkey``, which can represent either a + **private** key file name (``str``) or a **public** key + (:class:`paramiko.pkey.PKey`) + + .. deprecated:: 0.0.8 + + raise_exception_if_any_forwarder_have_a_problem (boolean): + Allow silencing :class:`BaseSSHTunnelForwarderError` or + :class:`HandlerSSHTunnelForwarderError` exceptions when set to + False + + Default: ``True`` + + .. versionadded:: 0.0.4 + + .. deprecated:: 0.0.8 (use ``mute_exceptions`` instead) + + Attributes: + + tunnel_is_up (dict): + Describe whether or not the other side of the tunnel was reported + to be up (and we must close it) or not (skip shutting down that + tunnel) + + .. note:: + This attribute should not be modified + + .. note:: + When :attr:`.skip_tunnel_checkup` is disabled or the local bind + is a UNIX socket, the value will always be ``True`` + + **Example**:: + + {('127.0.0.1', 55550): True, # this tunnel is up + ('127.0.0.1', 55551): False} # this one isn't + + where 55550 and 55551 are the local bind ports + + skip_tunnel_checkup (boolean): + Disable tunnel checkup (default for backwards compatibility). + + .. versionadded:: 0.1.0 + + """ + + skip_tunnel_checkup = True + daemon_forward_servers = _DAEMON #: flag tunnel threads in daemon mode + daemon_transport = _DAEMON #: flag SSH transport thread in daemon mode + + def local_is_up(self, target): + """ + Check if a tunnel is up (remote target's host is reachable on TCP + target's port) + + Arguments: + target (tuple): + tuple of type (``str``, ``int``) indicating the listen IP + address and port + Return: + boolean + + .. deprecated:: 0.1.0 + Replaced by :meth:`.check_tunnels()` and :attr:`.tunnel_is_up` + """ + try: + check_address(target) + except ValueError: + self.logger.warning( + "Target must be a tuple (IP, port), where IP " + 'is a string (i.e. "192.168.0.1") and port is ' + "an integer (i.e. 40000). Alternatively " + "target can be a valid UNIX domain socket." + ) + return False + + if self.skip_tunnel_checkup: # force tunnel check at this point + self.skip_tunnel_checkup = False + self.check_tunnels() + self.skip_tunnel_checkup = True # roll it back + return self.tunnel_is_up.get(target, True) + + def _make_ssh_forward_handler_class(self, remote_address_): + """ + Make SSH Handler class + """ + + class Handler(_ForwardHandler): + remote_address = remote_address_ + ssh_transport = self._transport + logger = self.logger + + return Handler + + def _make_ssh_forward_server_class(self, remote_address_): + return _ThreadingForwardServer if self._threaded else _ForwardServer + + def _make_unix_ssh_forward_server_class(self, remote_address_): + return ( + _ThreadingUnixStreamForwardServer + if self._threaded + else _UnixStreamForwardServer + ) + + def _make_ssh_forward_server(self, remote_address, local_bind_address): + """ + Make SSH forward proxy Server class + """ + _Handler = self._make_ssh_forward_handler_class(remote_address) + try: + if isinstance(local_bind_address, str): + forward_maker_class = self._make_unix_ssh_forward_server_class + else: + forward_maker_class = self._make_ssh_forward_server_class + _Server = forward_maker_class(remote_address) + ssh_forward_server = _Server( + local_bind_address, _Handler, logger=self.logger, + ) + + if ssh_forward_server: + ssh_forward_server.daemon_threads = self.daemon_forward_servers + self._server_list.append(ssh_forward_server) + self.tunnel_is_up[ssh_forward_server.server_address] = False + else: + self._raise( + BaseSSHTunnelForwarderError, + "Problem setting up ssh {0} <> {1} forwarder. You can " + "suppress this exception by using the `mute_exceptions`" + "argument".format( + address_to_str(local_bind_address), + address_to_str(remote_address), + ), + ) + except IOError: + self._raise( + BaseSSHTunnelForwarderError, + "Couldn't open tunnel {0} <> {1} might be in use or " + "destination not reachable".format( + address_to_str(local_bind_address), address_to_str(remote_address) + ), + ) + + def __init__( + self, + ssh_address_or_host=None, + ssh_config_file=SSH_CONFIG_FILE, + ssh_host_key=None, + ssh_password=None, + ssh_pkey=None, + ssh_private_key_password=None, + ssh_proxy=None, + ssh_proxy_enabled=True, + ssh_username=None, + local_bind_address=None, + local_bind_addresses=None, + logger=None, + mute_exceptions=False, + remote_bind_address=None, + remote_bind_addresses=None, + set_keepalive=0.0, + threaded=True, # old version False + compression=None, + allow_agent=True, # look for keys from an SSH agent + host_pkey_directories=None, # look for keys in ~/.ssh + gateway_timeout=None, + *args, + **kwargs # for backwards compatibility + ): + self.logger = logger or log + self.ssh_host_key = ssh_host_key + self.set_keepalive = set_keepalive + self._server_list = [] # reset server list + self.tunnel_is_up = {} # handle tunnel status + self._threaded = threaded + self.is_alive = False + self.gateway_timeout = gateway_timeout + # Check if deprecated arguments ssh_address or ssh_host were used + for deprecated_argument in ["ssh_address", "ssh_host"]: + ssh_address_or_host = self._process_deprecated( + ssh_address_or_host, deprecated_argument, kwargs + ) + # other deprecated arguments + ssh_pkey = self._process_deprecated(ssh_pkey, "ssh_private_key", kwargs) + + self._raise_fwd_exc = ( + self._process_deprecated( + None, "raise_exception_if_any_forwarder_have_a_problem", kwargs + ) + or not mute_exceptions + ) + + if isinstance(ssh_address_or_host, tuple): + check_address(ssh_address_or_host) + (ssh_host, ssh_port) = ssh_address_or_host + else: + ssh_host = ssh_address_or_host + ssh_port = kwargs.pop("ssh_port", None) + + if kwargs: + raise ValueError("Unknown arguments: {0}".format(kwargs)) + + # remote binds + self._remote_binds = self._get_binds( + remote_bind_address, remote_bind_addresses, is_remote=True + ) + # local binds + self._local_binds = self._get_binds(local_bind_address, local_bind_addresses) + self._local_binds = self._consolidate_binds( + self._local_binds, self._remote_binds + ) + + ( + self.ssh_host, + self.ssh_username, + ssh_pkey, # still needs to go through _consolidate_auth + self.ssh_port, + self.ssh_proxy, + self.compression, + ) = self._read_ssh_config( + ssh_host, + ssh_config_file, + ssh_username, + ssh_pkey, + ssh_port, + ssh_proxy if ssh_proxy_enabled else None, + compression, + self.logger, + ) + + (self.ssh_password, self.ssh_pkeys) = self._consolidate_auth( + ssh_password=ssh_password, + ssh_pkey=ssh_pkey, + ssh_pkey_password=ssh_private_key_password, + allow_agent=allow_agent, + host_pkey_directories=host_pkey_directories, + logger=self.logger, + ) + + check_host(self.ssh_host) + check_port(self.ssh_port) + + self.logger.info( + "Connecting to gateway: {h}:{p} as user '{u}', timeout {t}", + h=self.ssh_host, + p=self.ssh_port, + u=self.ssh_username, + t=self.gateway_timeout, + ) + + self.logger.debug("Concurrent connections allowed: {0}", self._threaded) + + @staticmethod + def _read_ssh_config( + ssh_host, + ssh_config_file, + ssh_username=None, + ssh_pkey=None, + ssh_port=None, + ssh_proxy=None, + compression=None, + logger=log, + ): + """Read ssh_config_file. + + Read ssh_config_file and try to look for user (ssh_username), + identityfile (ssh_pkey), port (ssh_port) and proxycommand + (ssh_proxy) entries for ssh_host + """ + ssh_config = paramiko.SSHConfig() + if not ssh_config_file: # handle case where it's an empty string + ssh_config_file = None + + # Try to read SSH_CONFIG_FILE + try: + # open the ssh config file + with open(os.path.expanduser(ssh_config_file), "r") as f: + ssh_config.parse(f) + # looks for information for the destination system + hostname_info = ssh_config.lookup(ssh_host) + # gather settings for user, port and identity file + # last resort: use the 'login name' of the user + ssh_username = ssh_username or hostname_info.get("user") + ssh_pkey = ssh_pkey or hostname_info.get("identityfile", [None])[0] + ssh_host = hostname_info.get("hostname") + ssh_port = ssh_port or hostname_info.get("port") + + proxycommand = hostname_info.get("proxycommand") + ssh_proxy = ssh_proxy or ( + paramiko.ProxyCommand(proxycommand) if proxycommand else None + ) + if compression is None: + compression = hostname_info.get("compression", "") + compression = True if compression.upper() == "YES" else False + except IOError: + logger.warning( + "Could not read SSH configuration file: {f}", f=ssh_config_file + ) + except (AttributeError, TypeError): # ssh_config_file is None + logger.info("Skipping loading of ssh configuration file") + finally: + return ( + ssh_host, + ssh_username or getpass.getuser(), + ssh_pkey, + int(ssh_port) if ssh_port else 22, # fallback value + ssh_proxy, + compression, + ) + + @staticmethod + def get_agent_keys(logger=log): + """Load public keys from any available SSH agent. + + Arguments: + logger (Optional[logging.Logger]) + + Return: + list + """ + paramiko_agent = paramiko.Agent() + agent_keys = paramiko_agent.get_keys() + + logger.info("{k} keys loaded from agent", k=len(agent_keys)) + + return list(agent_keys) + + @staticmethod + def get_keys(logger=log, host_pkey_directories=None, allow_agent=False): + """Load public keys from any available SSH agent or local .ssh directory. + + Arguments: + logger (Optional[logging.Logger]) + + host_pkey_directories (Optional[list[str]]): + List of local directories where host SSH pkeys in the format + "id_*" are searched. For example, ['~/.ssh'] + + .. versionadded:: 0.1.0 + + allow_agent (Optional[boolean]): + Whether or not load keys from agent + + Default: False + + Return: + list + """ + keys = SSHTunnelForwarder.get_agent_keys(logger=logger) if allow_agent else [] + + if host_pkey_directories is not None: + paramiko_key_types = { + "rsa": paramiko.RSAKey, + "dsa": paramiko.DSSKey, + "ecdsa": paramiko.ECDSAKey, + "ed25519": paramiko.Ed25519Key, + } + for directory in host_pkey_directories or [DEFAULT_SSH_DIRECTORY]: + for keytype in paramiko_key_types.keys(): + ssh_pkey_expanded = os.path.expanduser( + os.path.join(directory, "id_{}".format(keytype)) + ) + if os.path.isfile(ssh_pkey_expanded): + ssh_pkey = SSHTunnelForwarder.read_private_key_file( + pkey_file=ssh_pkey_expanded, + logger=logger, + key_type=paramiko_key_types[keytype], + ) + if ssh_pkey: + keys.append(ssh_pkey) + + logger.info("{k} keys loaded from host directory", k=len(keys)) + + return keys + + @staticmethod + def _consolidate_binds(local_binds, remote_binds): + """Fill local_binds with defaults. + + Fill local_binds with defaults when no value/s were specified, + leaving paramiko to decide in which local port the tunnel will be open. + """ + count = len(remote_binds) - len(local_binds) + if count < 0: + raise ValueError( + "Too many local bind addresses " + "(local_bind_addresses > remote_bind_addresses)" + ) + local_binds.extend([("0.0.0.0", 0) for x in range(count)]) + return local_binds + + @staticmethod + def _consolidate_auth( + ssh_password=None, + ssh_pkey=None, + ssh_pkey_password=None, + allow_agent=True, + host_pkey_directories=None, + logger=log, + ): + """Get sure authentication information is in place. + + ``ssh_pkey`` may be of classes: + - ``str`` - in this case it represents a private key file; public + key will be obtained from it + - ``paramiko.Pkey`` - it will be transparently added to loaded keys + """ + ssh_loaded_pkeys = SSHTunnelForwarder.get_keys( + logger=logger, + host_pkey_directories=host_pkey_directories, + allow_agent=allow_agent, + ) + + if isinstance(ssh_pkey, str): + ssh_pkey_expanded = os.path.expanduser(ssh_pkey) + if os.path.exists(ssh_pkey_expanded): + ssh_pkey = SSHTunnelForwarder.read_private_key_file( + pkey_file=ssh_pkey_expanded, + pkey_password=ssh_pkey_password or ssh_password, + logger=logger, + ) + else: + logger.warning("Private key file not found: {k}", k=ssh_pkey) + + if isinstance(ssh_pkey, paramiko.pkey.PKey): + ssh_loaded_pkeys.insert(0, ssh_pkey) + + if not ssh_password and not ssh_loaded_pkeys: + raise ValueError("No password or public key available!") + return (ssh_password, ssh_loaded_pkeys) + + def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None): + if self._raise_fwd_exc: + raise exception(reason) + else: + self.logger.error(repr(exception(reason))) + + def _get_transport(self): + """Return the SSH transport to the remote gateway.""" + if self.ssh_proxy: + if isinstance(self.ssh_proxy, paramiko.proxy.ProxyCommand): + proxy_repr = repr(self.ssh_proxy.cmd[1]) + else: + proxy_repr = repr(self.ssh_proxy) + self.logger.debug("Connecting via proxy: {0}".format(proxy_repr)) + _socket = self.ssh_proxy + else: + _socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if isinstance(_socket, socket.socket): + _socket.settimeout(self.gateway_timeout) + _socket.connect((self.ssh_host, self.ssh_port)) + transport = paramiko.Transport(_socket) + transport.set_keepalive(self.set_keepalive) + transport.use_compression(compress=self.compression) + transport.daemon = self.daemon_transport + + return transport + + def _create_tunnels(self): + """Create SSH tunnels on top of a transport to the remote gateway.""" + if not self.is_active: + try: + self._connect_to_gateway() + except socket.gaierror: # raised by paramiko.Transport + msg = "Could not resolve IP address for {0}, aborting!".format( + self.ssh_host + ) + self.logger.error(msg) + return + except (paramiko.SSHException, socket.error) as e: + template = "Could not connect to gateway {0}:{1} : {2}" + msg = template.format(self.ssh_host, self.ssh_port, e.args[0]) + self.logger.error(msg) + return + for (rem, loc) in zip(self._remote_binds, self._local_binds): + try: + self._make_ssh_forward_server(rem, loc) + except BaseSSHTunnelForwarderError as e: + msg = "Problem setting SSH Forwarder up: {0}".format(e.value) + self.logger.error(msg) + + @staticmethod + def _get_binds(bind_address, bind_addresses, is_remote=False): + addr_kind = "remote" if is_remote else "local" + + if not bind_address and not bind_addresses: + if is_remote: + raise ValueError( + "No {0} bind addresses specified. Use " + "'{0}_bind_address' or '{0}_bind_addresses'" + " argument".format(addr_kind) + ) + else: + return [] + elif bind_address and bind_addresses: + raise ValueError( + "You can't use both '{0}_bind_address' and " + "'{0}_bind_addresses' arguments. Use one of " + "them.".format(addr_kind) + ) + if bind_address: + bind_addresses = [bind_address] + if not is_remote: + # Add random port if missing in local bind + for (i, local_bind) in enumerate(bind_addresses): + if isinstance(local_bind, tuple) and len(local_bind) == 1: + bind_addresses[i] = (local_bind[0], 0) + check_addresses(bind_addresses, is_remote) + return bind_addresses + + @staticmethod + def _process_deprecated(attrib, deprecated_attrib, kwargs): + """Processes optional deprecate arguments.""" + + if deprecated_attrib not in DEPRECATIONS: + raise ValueError( + "{0} not included in deprecations list".format(deprecated_attrib) + ) + if deprecated_attrib in kwargs: + warnings.warn( + "'{0}' is DEPRECATED use '{1}' instead".format( + deprecated_attrib, DEPRECATIONS[deprecated_attrib] + ), + DeprecationWarning, + ) + if attrib: + raise ValueError( + "You can't use both '{0}' and '{1}'. " + "Please only use one of them".format( + deprecated_attrib, DEPRECATIONS[deprecated_attrib] + ) + ) + else: + return kwargs.pop(deprecated_attrib) + return attrib + + @staticmethod + def read_private_key_file(pkey_file, pkey_password=None, key_type=None, logger=log): + """Get SSH Public key from a private key file, given an optional password. + + Arguments: + pkey_file (str): + File containing a private key (RSA, DSS or ECDSA) + Keyword Arguments: + pkey_password (Optional[str]): + Password to decrypt the private key + logger (Optional[logging.Logger]) + Return: + paramiko.Pkey + """ + ssh_pkey = None + for pkey_class in ( + (key_type,) + if key_type + else ( + paramiko.RSAKey, + paramiko.DSSKey, + paramiko.ECDSAKey, + paramiko.Ed25519Key, + ) + ): + try: + ssh_pkey = pkey_class.from_private_key_file( + pkey_file, password=pkey_password + ) + + logger.debug( + "Private key file ({k0}, {k1}) successfully loaded", + k0=pkey_file, + k1=pkey_class, + ) + + break + except paramiko.PasswordRequiredException: + + logger.error("Password is required for key {k}", k=pkey_file) + + break + except paramiko.SSHException: + logger.debug( + "Private key file ({k0}) could not be loaded as type {k1} or bad password", + k0=pkey_file, + k1=pkey_class, + ) + + return ssh_pkey + + def _check_tunnel(self, _srv): + """Check if tunnel is already established.""" + if self.skip_tunnel_checkup: + self.tunnel_is_up[_srv.local_address] = True + return + + self.logger.info("Checking tunnel to: {a}", a=_srv.remote_address) + + if isinstance(_srv.local_address, str): # UNIX stream + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + else: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(TUNNEL_TIMEOUT) + try: + # Windows raises WinError 10049 if trying to connect to 0.0.0.0 + connect_to = ( + ("127.0.0.1", _srv.local_port) + if _srv.local_host == "0.0.0.0" + else _srv.local_address + ) + s.connect(connect_to) + self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get( + timeout=TUNNEL_TIMEOUT * 1.1 + ) + self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address)) + except socket.error: + self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address)) + self.tunnel_is_up[_srv.local_address] = False + + except queue.Empty: + self.logger.debug("Tunnel to {0} is UP".format(_srv.remote_address)) + self.tunnel_is_up[_srv.local_address] = True + finally: + s.close() + + def check_tunnels(self): + """Check that if all tunnels are established and populates. + + :attr:`.tunnel_is_up` + """ + for _srv in self._server_list: + self._check_tunnel(_srv) + + def start(self): + """Start the SSH tunnels.""" + if self.is_alive: + self.logger.warning("Already started!") + return + self._create_tunnels() + if not self.is_active: + self._raise( + BaseSSHTunnelForwarderError, + reason="Could not establish session to SSH gateway", + ) + for _srv in self._server_list: + thread = threading.Thread( + target=self._serve_forever_wrapper, + args=(_srv,), + name="Srv-{0}".format(address_to_str(_srv.local_port)), + ) + thread.daemon = self.daemon_forward_servers + thread.start() + self._check_tunnel(_srv) + self.is_alive = any(self.tunnel_is_up.values()) + if not self.is_alive: + self._raise( + HandlerSSHTunnelForwarderError, + "An error occurred while opening tunnels.", + ) + + def stop(self): + """Shut the tunnel down. + + .. note:: This **had** to be handled with care before ``0.1.0``: + + - if a port redirection is opened + - the destination is not reachable + - we attempt a connection to that tunnel (``SYN`` is sent and + acknowledged, then a ``FIN`` packet is sent and never + acknowledged... weird) + - we try to shutdown: it will not succeed until ``FIN_WAIT_2`` and + ``CLOSE_WAIT`` time out. + + .. note:: + Handle these scenarios with :attr:`.tunnel_is_up`: if False, server + ``shutdown()`` will be skipped on that tunnel + """ + self.logger.info("Closing all open connections...") + opened_address_text = ( + ", ".join((address_to_str(k.local_address) for k in self._server_list)) + or "None" + ) + self.logger.debug("Listening tunnels: " + opened_address_text) + self._stop_transport() + self._server_list = [] # reset server list + self.tunnel_is_up = {} # reset tunnel status + + def close(self): + """Stop the an active tunnel, alias to :meth:`.stop`.""" + self.stop() + + def restart(self): + """Restart connection to the gateway and tunnels.""" + self.stop() + self.start() + + def _connect_to_gateway(self): + """Open connection to SSH gateway. + + - First try with all keys loaded from an SSH agent (if allowed) + - Then with those passed directly or read from ~/.ssh/config + - As last resort, try with a provided password + """ + for key in self.ssh_pkeys: + self.logger.debug( + "Trying to log in with key: {0}".format(hexlify(key.get_fingerprint())) + ) + try: + self._transport = self._get_transport() + self._transport.connect( + hostkey=self.ssh_host_key, username=self.ssh_username, pkey=key + ) + if self._transport.is_alive: + return + except paramiko.AuthenticationException: + self.logger.debug("Authentication error") + self._stop_transport() + + if self.ssh_password: # avoid conflict using both pass and pkey + self.logger.debug( + "Trying to log in with password: {0}".format( + "*" * len(self.ssh_password) + ) + ) + try: + self._transport = self._get_transport() + self._transport.connect( + hostkey=self.ssh_host_key, + username=self.ssh_username, + password=self.ssh_password, + ) + if self._transport.is_alive: + return + except paramiko.AuthenticationException: + self.logger.debug("Authentication error") + self._stop_transport() + + self.logger.error("Could not open connection to gateway") + + def _serve_forever_wrapper(self, _srv, poll_interval=0.1): + """Wrapper for the server created for a SSH forward.""" + self.logger.info( + "Opening tunnel: {0} <> {1}".format( + address_to_str(_srv.local_address), address_to_str(_srv.remote_address) + ) + ) + _srv.serve_forever(poll_interval) # blocks until finished + + self.logger.info( + "Tunnel: {0} <> {1} released".format( + address_to_str(_srv.local_address), address_to_str(_srv.remote_address) + ) + ) + + def _stop_transport(self): + """Close the underlying transport when nothing more is needed.""" + + try: + self._check_is_started() + except (BaseSSHTunnelForwarderError, HandlerSSHTunnelForwarderError) as e: + self.logger.warning(e) + for _srv in self._server_list: + tunnel = _srv.local_address + if self.tunnel_is_up[tunnel]: + self.logger.info("Shutting down tunnel {0}".format(tunnel)) + _srv.shutdown() + _srv.server_close() + # clean up the UNIX domain socket if we're using one + if isinstance(_srv, _UnixStreamForwardServer): + try: + os.unlink(_srv.local_address) + except Exception as e: + self.logger.error( + "Unable to unlink socket {0}: {1}".format( + self.local_address, repr(e) + ) + ) + self.is_alive = False + if self.is_active: + self._transport.close() + self._transport.stop_thread() + self.logger.debug("Transport is closed") + + @property + def local_bind_port(self): + + # BACKWARDS COMPATIBILITY + self._check_is_started() + if len(self._server_list) != 1: + raise BaseSSHTunnelForwarderError( + "Use .local_bind_ports property for more than one tunnel" + ) + return self.local_bind_ports[0] + + @property + def local_bind_host(self): + + # BACKWARDS COMPATIBILITY + self._check_is_started() + if len(self._server_list) != 1: + raise BaseSSHTunnelForwarderError( + "Use .local_bind_hosts property for more than one tunnel" + ) + return self.local_bind_hosts[0] + + @property + def local_bind_address(self): + + # BACKWARDS COMPATIBILITY + self._check_is_started() + if len(self._server_list) != 1: + raise BaseSSHTunnelForwarderError( + "Use .local_bind_addresses property for more than one tunnel" + ) + return self.local_bind_addresses[0] + + @property + def local_bind_ports(self): + """Return a list containing the ports of local side of the TCP tunnels.""" + + self._check_is_started() + return [ + _server.local_port + for _server in self._server_list + if _server.local_port is not None + ] + + @property + def local_bind_hosts(self): + """Return a list containing the IP addresses listening for the tunnels.""" + self._check_is_started() + return [ + _server.local_host + for _server in self._server_list + if _server.local_host is not None + ] + + @property + def local_bind_addresses(self): + """Return a list of (IP, port) pairs for the local side of the tunnels.""" + self._check_is_started() + return [_server.local_address for _server in self._server_list] + + @property + def tunnel_bindings(self): + """Return a dictionary containing the active local<>remote tunnel_bindings.""" + return dict( + (_server.remote_address, _server.local_address) + for _server in self._server_list + if self.tunnel_is_up[_server.local_address] + ) + + @property + def is_active(self): + """ Return True if the underlying SSH transport is up """ + if "_transport" in self.__dict__ and self._transport.is_active(): + return True + return False + + def _check_is_started(self): + if not self.is_active: # underlying transport not alive + msg = "Server is not started. Please .start() first!" + raise BaseSSHTunnelForwarderError(msg) + if not self.is_alive: + msg = "Tunnels are not started. Please .start() first!" + raise HandlerSSHTunnelForwarderError(msg) + + def __str__(self): + credentials = { + "password": self.ssh_password, + "pkeys": [ + (key.get_name(), hexlify(key.get_fingerprint())) + for key in self.ssh_pkeys + ] + if any(self.ssh_pkeys) + else None, + } + _remove_none_values(credentials) + template = os.linesep.join( + [ + "{0} object", + "ssh gateway: {1}:{2}", + "proxy: {3}", + "username: {4}", + "authentication: {5}", + "hostkey: {6}", + "status: {7}started", + "keepalive messages: {8}", + "tunnel connection check: {9}", + "concurrent connections: {10}allowed", + "compression: {11}requested", + "logging level: {12}", + "local binds: {13}", + "remote binds: {14}", + ] + ) + return template.format( + self.__class__, + self.ssh_host, + self.ssh_port, + self.ssh_proxy.cmd[1] if self.ssh_proxy else "no", + self.ssh_username, + credentials, + self.ssh_host_key if self.ssh_host_key else "not checked", + "" if self.is_alive else "not ", + "disabled" + if not self.set_keepalive + else "every {0} sec".format(self.set_keepalive), + "disabled" if self.skip_tunnel_checkup else "enabled", + "" if self._threaded else "not ", + "" if self.compression else "not ", + os.environ.get("HYPERGLASS_LOG_LEVEL") or "INFO", + self._local_binds, + self._remote_binds, + ) + + def __repr__(self): + return self.__str__() + + def __enter__(self): + try: + self.start() + return self + except KeyboardInterrupt: + self.__exit__() + + def __exit__(self, *args): + self._stop_transport() + + +def open_tunnel(*args, **kwargs): + """Open an SSH Tunnel, wrapper for :class:`SSHTunnelForwarder`. + + Arguments: + destination (Optional[tuple]): + SSH server's IP address and port in the format + (``ssh_address``, ``ssh_port``) + + Keyword Arguments: + debug_level (Optional[int or str]): + log level for :class:`logging.Logger` instance, i.e. ``DEBUG`` + + skip_tunnel_checkup (boolean): + Enable/disable the local side check and populate + :attr:`~SSHTunnelForwarder.tunnel_is_up` + + Default: True + + .. versionadded:: 0.1.0 + + block_on_close (boolean): + Wait until all connections are done during close by changing the + value of :attr:`~SSHTunnelForwarder.block_on_close` + + Default: True + + .. note:: + A value of ``debug_level`` set to 1 == ``TRACE`` enables tracing mode + .. note:: + See :class:`SSHTunnelForwarder` for keyword arguments + + **Example**:: + + from sshtunnel import open_tunnel + + with open_tunnel(SERVER, + ssh_username=SSH_USER, + ssh_port=22, + ssh_password=SSH_PASSWORD, + remote_bind_address=(REMOTE_HOST, REMOTE_PORT), + local_bind_address=('', LOCAL_PORT)) as server: + def do_something(port): + pass + + print("LOCAL PORTS:", server.local_bind_port) + + do_something(server.local_bind_port) + """ + # Attach a console handler to the logger or create one if not passed + kwargs["logger"] = kwargs.get("logger") or log + + ssh_address_or_host = kwargs.pop("ssh_address_or_host", None) + # Check if deprecated arguments ssh_address or ssh_host were used + for deprecated_argument in ["ssh_address", "ssh_host"]: + ssh_address_or_host = SSHTunnelForwarder._process_deprecated( + ssh_address_or_host, deprecated_argument, kwargs + ) + + ssh_port = kwargs.pop("ssh_port", 22) + skip_tunnel_checkup = kwargs.pop("skip_tunnel_checkup", True) + block_on_close = kwargs.pop("block_on_close", _DAEMON) + if not args: + if isinstance(ssh_address_or_host, tuple): + args = (ssh_address_or_host,) + else: + args = ((ssh_address_or_host, ssh_port),) + forwarder = SSHTunnelForwarder(*args, **kwargs) + forwarder.skip_tunnel_checkup = skip_tunnel_checkup + forwarder.daemon_forward_servers = not block_on_close + forwarder.daemon_transport = not block_on_close + return forwarder + + +def _bindlist(input_str): + """Define type of data expected for remote and local bind address lists. + + Returns a tuple (ip_address, port) whose elements are (str, int) + """ + try: + ip_port = input_str.split(":") + if len(ip_port) == 1: + _ip = ip_port[0] + _port = None + else: + (_ip, _port) = ip_port + if not _ip and not _port: + raise AssertionError + elif not _port: + _port = "22" # default port if not given + return _ip, int(_port) + except ValueError: + raise argparse.ArgumentTypeError( + "Address tuple must be of type IP_ADDRESS:PORT" + ) + except AssertionError: + raise argparse.ArgumentTypeError("Both IP:PORT can't be missing!") diff --git a/hyperglass/configuration/__init__.py b/hyperglass/configuration/__init__.py index 9ac75a7..fe58188 100644 --- a/hyperglass/configuration/__init__.py +++ b/hyperglass/configuration/__init__.py @@ -3,18 +3,17 @@ # Standard Library import os import copy +import json import math from pathlib import Path # Third Party import yaml -import ujson as json from aiofile import AIOFile from pydantic import ValidationError # Project from hyperglass.util import log, check_path, set_app_path -from hyperglass.compat import aiorun from hyperglass.constants import ( CREDIT, LOG_LEVELS, @@ -27,6 +26,7 @@ from hyperglass.constants import ( __version__, ) from hyperglass.exceptions import ConfigError, ConfigInvalid, ConfigMissing +from hyperglass.compat._asyncio import aiorun from hyperglass.configuration.models import params as _params from hyperglass.configuration.models import routers as _routers from hyperglass.configuration.models import commands as _commands @@ -100,6 +100,7 @@ def _set_log_level(debug, log_file=None): log_level = "DEBUG" stdout_handler["level"] = log_level file_handler["level"] = log_level + os.environ["HYPERGLASS_LOG_LEVEL"] = log_level if log_file is not None: file_handler.update({"sink": log_file}) diff --git a/hyperglass/exceptions.py b/hyperglass/exceptions.py index ad8e3ef..a5c8ad3 100644 --- a/hyperglass/exceptions.py +++ b/hyperglass/exceptions.py @@ -1,7 +1,7 @@ """Custom exceptions for hyperglass.""" -# Third Party -import ujson as _json +# Standard Library +import json as _json # Project from hyperglass.util import log diff --git a/hyperglass/util.py b/hyperglass/util.py index 85e86bd..8bac32b 100644 --- a/hyperglass/util.py +++ b/hyperglass/util.py @@ -150,7 +150,7 @@ async def write_env(variables): RuntimeError: Raised on any errors. """ from aiofile import AIOFile - import ujson as json + import json from pathlib import Path env_file = Path("/tmp/hyperglass.env.json") # noqa: S108 @@ -390,14 +390,14 @@ async def read_package_json(): {dict} -- NPM package.json as dict """ from pathlib import Path - import ujson + import json package_json_file = Path(__file__).parent / "ui" / "package.json" try: with package_json_file.open("r") as file: - package_json = ujson.load(file) + package_json = json.load(file) except Exception as e: raise RuntimeError(f"Error reading package.json: {str(e)}")