"""Initiate SSH tunnels via a remote gateway. Copyright (c) 2014-2019 Pahaz Blinov Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *sshtunnel* - Initiate SSH tunnels via a remote gateway. ``sshtunnel`` works by opening a port forwarding SSH connection in the background, using threads. The connection(s) are closed when explicitly calling the :meth:`SSHTunnelForwarder.stop` method or using it as a context. """ # Standard Library import os import sys import queue import socket import getpass import logging import argparse import warnings import threading import socketserver from select import select from binascii import hexlify # Third Party import paramiko # Project from hyperglass.log import log from hyperglass.configuration import params if params.debug: logging.getLogger("paramiko").setLevel(logging.DEBUG) log.bind(logger_name="paramiko") TUNNEL_TIMEOUT = 1.0 #: Timeout (seconds) for tunnel connection _DAEMON = False #: Use daemon threads in connections _CONNECTION_COUNTER = 1 _LOCK = threading.Lock() DEPRECATIONS = { "ssh_address": "ssh_address_or_host", "ssh_host": "ssh_address_or_host", "ssh_private_key": "ssh_pkey", "raise_exception_if_any_forwarder_have_a_problem": "mute_exceptions", } if os.name == "posix": DEFAULT_SSH_DIRECTORY = "~/.ssh" UnixStreamServer = socketserver.UnixStreamServer else: DEFAULT_SSH_DIRECTORY = "~/ssh" UnixStreamServer = socketserver.TCPServer #: Path of optional ssh configuration file SSH_CONFIG_FILE = os.path.join(DEFAULT_SSH_DIRECTORY, "config") ######################## # # # Utils # # # ######################## def check_host(host): assert isinstance(host, str), "IP is not a string ({0})".format(type(host).__name__) def check_port(port): assert isinstance(port, int), "PORT is not a number" assert port >= 0, "PORT < 0 ({0})".format(port) def check_address(address): """Check if the format of the address is correct. Arguments: address (tuple): (``str``, ``int``) representing an IP address and port, respectively .. note:: alternatively a local ``address`` can be a ``str`` when working with UNIX domain sockets, if supported by the platform Raises: ValueError: raised when address has an incorrect format Example: >>> check_address(('127.0.0.1', 22)) """ if isinstance(address, tuple): check_host(address[0]) check_port(address[1]) elif isinstance(address, str): if os.name != "posix": raise ValueError("Platform does not support UNIX domain sockets") if not ( os.path.exists(address) or os.access(os.path.dirname(address), os.W_OK) ): raise ValueError( "ADDRESS not a valid socket domain socket ({0})".format(address) ) else: raise ValueError( "ADDRESS is not a tuple, string, or character buffer " "({0})".format(type(address).__name__) ) def check_addresses(address_list, is_remote=False): """ Check if the format of the addresses is correct Arguments: address_list (list[tuple]): Sequence of (``str``, ``int``) pairs, each representing an IP address and port respectively .. note:: when supported by the platform, one or more of the elements in the list can be of type ``str``, representing a valid UNIX domain socket is_remote (boolean): Whether or not the address list Raises: AssertionError: raised when ``address_list`` contains an invalid element ValueError: raised when any address in the list has an incorrect format Example: >>> check_addresses([('127.0.0.1', 22), ('127.0.0.1', 2222)]) """ assert all(isinstance(x, (tuple, str)) for x in address_list) if is_remote and any(isinstance(x, str) for x in address_list): raise AssertionError("UNIX domain sockets not allowed for remote" "addresses") for address in address_list: check_address(address) def address_to_str(address): if isinstance(address, tuple): return "{0[0]}:{0[1]}".format(address) return str(address) def get_connection_id(): global _CONNECTION_COUNTER with _LOCK: uid = _CONNECTION_COUNTER _CONNECTION_COUNTER += 1 return uid def _remove_none_values(dictionary): """ Remove dictionary keys whose value is None.""" return list(map(dictionary.pop, [i for i in dictionary if dictionary[i] is None])) ######################## # # # Errors # # # ######################## class BaseSSHTunnelForwarderError(Exception): """ Exception raised by :class:`SSHTunnelForwarder` errors """ def __init__(self, *args, **kwargs): self.value = kwargs.pop("value", args[0] if args else "") def __str__(self): return self.value class HandlerSSHTunnelForwarderError(BaseSSHTunnelForwarderError): """ Exception for Tunnel forwarder errors """ pass ######################## # # # Handlers # # # ######################## class _ForwardHandler(socketserver.BaseRequestHandler): """ Base handler for tunnel connections """ remote_address = None ssh_transport = None logger = None info = None def _redirect(self, chan): while chan.active: rqst, _, _ = select([self.request, chan], [], [], 5) if self.request in rqst: data = self.request.recv(1024) if not data: break self.logger.trace( ">>> OUT {0} send to {1}: {2} >>>".format( self.info, self.remote_address, hexlify(data) ), ) chan.sendall(data) if chan in rqst: # else if not chan.recv_ready(): break data = chan.recv(1024) self.logger.trace( "<<< IN {0} recv: {1} <<<".format(self.info, hexlify(data)), ) self.request.sendall(data) def handle(self): uid = get_connection_id() self.info = "#{0} <-- {1}".format( uid, self.client_address or self.server.local_address ) src_address = self.request.getpeername() if not isinstance(src_address, tuple): src_address = ("dummy", 12345) try: chan = self.ssh_transport.open_channel( kind="direct-tcpip", dest_addr=self.remote_address, src_addr=src_address, timeout=TUNNEL_TIMEOUT, ) except paramiko.SSHException: chan = None if chan is None: msg = "{0} to {1} was rejected by the SSH server".format( self.info, self.remote_address ) self.logger.trace(msg) raise HandlerSSHTunnelForwarderError(msg) self.logger.trace("{0} connected".format(self.info)) try: self._redirect(chan) except socket.error: # Sometimes a RST is sent and a socket error is raised, treat this # exception. It was seen that a 3way FIN is processed later on, so # no need to make an ordered close of the connection here or raise # the exception beyond this point... self.logger.trace("{0} sending RST".format(self.info)) except Exception as e: self.logger.trace("{0} error: {1}".format(self.info, repr(e))) finally: chan.close() self.request.close() self.logger.trace("{0} connection closed.".format(self.info)) class _ForwardServer(socketserver.TCPServer): # Not Threading """ Non-threading version of the forward server """ allow_reuse_address = True # faster rebinding def __init__(self, *args, **kwargs): self.logger = kwargs.pop("logger") or log self.tunnel_ok = queue.Queue() socketserver.TCPServer.__init__(self, *args, **kwargs) def handle_error(self, request, client_address): (exc_class, exc, tb) = sys.exc_info() self.logger.error( "Could not establish connection from {0} to remote " "side of the tunnel", request.getsockname(), ) self.tunnel_ok.put(False) @property def local_address(self): return self.server_address @property def local_host(self): return self.server_address[0] @property def local_port(self): return self.server_address[1] @property def remote_address(self): return self.RequestHandlerClass.remote_address @property def remote_host(self): return self.RequestHandlerClass.remote_address[0] @property def remote_port(self): return self.RequestHandlerClass.remote_address[1] class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer): """ Allow concurrent connections to each tunnel """ # If True, cleanly stop threads created by ThreadingMixIn when quitting daemon_threads = _DAEMON class _UnixStreamForwardServer(UnixStreamServer): """ Serve over UNIX domain sockets (does not work on Windows) """ def __init__(self, *args, **kwargs): self.logger = kwargs.pop("logger") or log self.tunnel_ok = queue.Queue() UnixStreamServer.__init__(self, *args, **kwargs) @property def local_address(self): return self.server_address @property def local_host(self): return None @property def local_port(self): return None @property def remote_address(self): return self.RequestHandlerClass.remote_address @property def remote_host(self): return self.RequestHandlerClass.remote_address[0] @property def remote_port(self): return self.RequestHandlerClass.remote_address[1] class _ThreadingUnixStreamForwardServer( socketserver.ThreadingMixIn, _UnixStreamForwardServer ): """ Allow concurrent connections to each tunnel """ # If True, cleanly stop threads created by ThreadingMixIn when quitting daemon_threads = _DAEMON class SSHTunnelForwarder: """ **SSH tunnel class** - Initialize a SSH tunnel to a remote host according to the input arguments - Optionally: + Read an SSH configuration file (typically ``~/.ssh/config``) + Load keys from a running SSH agent (i.e. Pageant, GNOME Keyring) Raises: :class:`.BaseSSHTunnelForwarderError`: raised by SSHTunnelForwarder class methods :class:`.HandlerSSHTunnelForwarderError`: raised by tunnel forwarder threads .. note:: Attributes ``mute_exceptions`` and ``raise_exception_if_any_forwarder_have_a_problem`` (deprecated) may be used to silence most exceptions raised from this class Keyword Arguments: ssh_address_or_host (tuple or str): IP or hostname of ``REMOTE GATEWAY``. It may be a two-element tuple (``str``, ``int``) representing IP and port respectively, or a ``str`` representing the IP address only .. versionadded:: 0.0.4 ssh_config_file (str): SSH configuration file that will be read. If explicitly set to ``None``, parsing of this configuration is omitted Default: :const:`SSH_CONFIG_FILE` .. versionadded:: 0.0.4 ssh_host_key (str): Representation of a line in an OpenSSH-style "known hosts" file. ``REMOTE GATEWAY``'s key fingerprint will be compared to this host key in order to prevent against SSH server spoofing. Important when using passwords in order not to accidentally do a login attempt to a wrong (perhaps an attacker's) machine ssh_username (str): Username to authenticate as in ``REMOTE SERVER`` Default: current local user name ssh_password (str): Text representing the password used to connect to ``REMOTE SERVER`` or for unlocking a private key. .. note:: Avoid coding secret password directly in the code, since this may be visible and make your service vulnerable to attacks ssh_port (int): Optional port number of the SSH service on ``REMOTE GATEWAY``, when `ssh_address_or_host`` is a ``str`` representing the IP part of ``REMOTE GATEWAY``'s address Default: 22 ssh_pkey (str or paramiko.PKey): **Private** key file name (``str``) to obtain the public key from or a **public** key (:class:`paramiko.pkey.PKey`) ssh_private_key_password (str): Password for an encrypted ``ssh_pkey`` .. note:: Avoid coding secret password directly in the code, since this may be visible and make your service vulnerable to attacks ssh_proxy (socket-like object or tuple): Proxy where all SSH traffic will be passed through. It might be for example a :class:`paramiko.proxy.ProxyCommand` instance. See either the :class:`paramiko.transport.Transport`'s sock parameter documentation or ``ProxyCommand`` in ``ssh_config(5)`` for more information. It is also possible to specify the proxy address as a tuple of type (``str``, ``int``) representing proxy's IP and port .. note:: Ignored if ``ssh_proxy_enabled`` is False .. versionadded:: 0.0.5 ssh_proxy_enabled (boolean): Enable/disable SSH proxy. If True and user's ``ssh_config_file`` contains a ``ProxyCommand`` directive that matches the specified ``ssh_address_or_host``, a :class:`paramiko.proxy.ProxyCommand` object will be created where all SSH traffic will be passed through Default: ``True`` .. versionadded:: 0.0.4 local_bind_address (tuple): Local tuple in the format (``str``, ``int``) representing the IP and port of the local side of the tunnel. Both elements in the tuple are optional so both ``('', 8000)`` and ``('10.0.0.1', )`` are valid values Default: ``('0.0.0.0', RANDOM_PORT)`` .. versionchanged:: 0.0.8 Added the ability to use a UNIX domain socket as local bind address local_bind_addresses (list[tuple]): In case more than one tunnel is established at once, a list of tuples (in the same format as ``local_bind_address``) can be specified, such as [(ip1, port_1), (ip_2, port2), ...] Default: ``[local_bind_address]`` .. versionadded:: 0.0.4 remote_bind_address (tuple): Remote tuple in the format (``str``, ``int``) representing the IP and port of the remote side of the tunnel. remote_bind_addresses (list[tuple]): In case more than one tunnel is established at once, a list of tuples (in the same format as ``remote_bind_address``) can be specified, such as [(ip1, port_1), (ip_2, port2), ...] Default: ``[remote_bind_address]`` .. versionadded:: 0.0.4 allow_agent (boolean): Enable/disable load of keys from an SSH agent Default: ``True`` .. versionadded:: 0.0.8 host_pkey_directories (list): Look for pkeys in folders on this list, for example ['~/.ssh']. Default: ``None`` (disabled) .. versionadded:: 0.1.4 compression (boolean): Turn on/off transport compression. By default compression is disabled since it may negatively affect interactive sessions Default: ``False`` .. versionadded:: 0.0.8 logger (logging.Logger): logging instance for sshtunnel and paramiko Default: :class:`logging.Logger` instance with a single :class:`logging.StreamHandler` handler and :const:`DEFAULT_LOGLEVEL` level .. versionadded:: 0.0.3 mute_exceptions (boolean): Allow silencing :class:`BaseSSHTunnelForwarderError` or :class:`HandlerSSHTunnelForwarderError` exceptions when enabled Default: ``False`` .. versionadded:: 0.0.8 set_keepalive (float): Interval in seconds defining the period in which, if no data was sent over the connection, a *'keepalive'* packet will be sent (and ignored by the remote host). This can be useful to keep connections alive over a NAT Default: 0.0 (no keepalive packets are sent) .. versionadded:: 0.0.7 threaded (boolean): Allow concurrent connections over a single tunnel Default: ``True`` .. versionadded:: 0.0.3 ssh_address (str): Superseded by ``ssh_address_or_host``, tuple of type (str, int) representing the IP and port of ``REMOTE SERVER`` .. deprecated:: 0.0.4 ssh_host (str): Superseded by ``ssh_address_or_host``, tuple of type (str, int) representing the IP and port of ``REMOTE SERVER`` .. deprecated:: 0.0.4 ssh_private_key (str or paramiko.PKey): Superseded by ``ssh_pkey``, which can represent either a **private** key file name (``str``) or a **public** key (:class:`paramiko.pkey.PKey`) .. deprecated:: 0.0.8 raise_exception_if_any_forwarder_have_a_problem (boolean): Allow silencing :class:`BaseSSHTunnelForwarderError` or :class:`HandlerSSHTunnelForwarderError` exceptions when set to False Default: ``True`` .. versionadded:: 0.0.4 .. deprecated:: 0.0.8 (use ``mute_exceptions`` instead) Attributes: tunnel_is_up (dict): Describe whether or not the other side of the tunnel was reported to be up (and we must close it) or not (skip shutting down that tunnel) .. note:: This attribute should not be modified .. note:: When :attr:`.skip_tunnel_checkup` is disabled or the local bind is a UNIX socket, the value will always be ``True`` **Example**:: {('127.0.0.1', 55550): True, # this tunnel is up ('127.0.0.1', 55551): False} # this one isn't where 55550 and 55551 are the local bind ports skip_tunnel_checkup (boolean): Disable tunnel checkup (default for backwards compatibility). .. versionadded:: 0.1.0 """ skip_tunnel_checkup = True daemon_forward_servers = _DAEMON #: flag tunnel threads in daemon mode daemon_transport = _DAEMON #: flag SSH transport thread in daemon mode def local_is_up(self, target): """ Check if a tunnel is up (remote target's host is reachable on TCP target's port) Arguments: target (tuple): tuple of type (``str``, ``int``) indicating the listen IP address and port Return: boolean .. deprecated:: 0.1.0 Replaced by :meth:`.check_tunnels()` and :attr:`.tunnel_is_up` """ try: check_address(target) except ValueError: self.logger.warning( "Target must be a tuple (IP, port), where IP " 'is a string (i.e. "192.168.0.1") and port is ' "an integer (i.e. 40000). Alternatively " "target can be a valid UNIX domain socket." ) return False if self.skip_tunnel_checkup: # force tunnel check at this point self.skip_tunnel_checkup = False self.check_tunnels() self.skip_tunnel_checkup = True # roll it back return self.tunnel_is_up.get(target, True) def _make_ssh_forward_handler_class(self, remote_address_): """ Make SSH Handler class """ class Handler(_ForwardHandler): remote_address = remote_address_ ssh_transport = self._transport logger = self.logger return Handler def _make_ssh_forward_server_class(self, remote_address_): return _ThreadingForwardServer if self._threaded else _ForwardServer def _make_unix_ssh_forward_server_class(self, remote_address_): return ( _ThreadingUnixStreamForwardServer if self._threaded else _UnixStreamForwardServer ) def _make_ssh_forward_server(self, remote_address, local_bind_address): """ Make SSH forward proxy Server class """ _Handler = self._make_ssh_forward_handler_class(remote_address) try: if isinstance(local_bind_address, str): forward_maker_class = self._make_unix_ssh_forward_server_class else: forward_maker_class = self._make_ssh_forward_server_class _Server = forward_maker_class(remote_address) ssh_forward_server = _Server( local_bind_address, _Handler, logger=self.logger, ) if ssh_forward_server: ssh_forward_server.daemon_threads = self.daemon_forward_servers self._server_list.append(ssh_forward_server) self.tunnel_is_up[ssh_forward_server.server_address] = False else: self._raise( BaseSSHTunnelForwarderError, "Problem setting up ssh {0} <> {1} forwarder. You can " "suppress this exception by using the `mute_exceptions`" "argument".format( address_to_str(local_bind_address), address_to_str(remote_address), ), ) except IOError: self._raise( BaseSSHTunnelForwarderError, "Couldn't open tunnel {0} <> {1} might be in use or " "destination not reachable".format( address_to_str(local_bind_address), address_to_str(remote_address) ), ) def __init__( self, ssh_address_or_host=None, ssh_config_file=SSH_CONFIG_FILE, ssh_host_key=None, ssh_password=None, ssh_pkey=None, ssh_private_key_password=None, ssh_proxy=None, ssh_proxy_enabled=True, ssh_username=None, local_bind_address=None, local_bind_addresses=None, logger=None, mute_exceptions=False, remote_bind_address=None, remote_bind_addresses=None, set_keepalive=0.0, threaded=True, # old version False compression=None, allow_agent=True, # look for keys from an SSH agent host_pkey_directories=None, # look for keys in ~/.ssh gateway_timeout=None, *args, **kwargs, # for backwards compatibility ): self.logger = logger or log self.ssh_host_key = ssh_host_key self.set_keepalive = set_keepalive self._server_list = [] # reset server list self.tunnel_is_up = {} # handle tunnel status self._threaded = threaded self.is_alive = False self.gateway_timeout = gateway_timeout # Check if deprecated arguments ssh_address or ssh_host were used for deprecated_argument in ["ssh_address", "ssh_host"]: ssh_address_or_host = self._process_deprecated( ssh_address_or_host, deprecated_argument, kwargs ) # other deprecated arguments ssh_pkey = self._process_deprecated(ssh_pkey, "ssh_private_key", kwargs) self._raise_fwd_exc = ( self._process_deprecated( None, "raise_exception_if_any_forwarder_have_a_problem", kwargs ) or not mute_exceptions ) if isinstance(ssh_address_or_host, tuple): check_address(ssh_address_or_host) (ssh_host, ssh_port) = ssh_address_or_host else: ssh_host = ssh_address_or_host ssh_port = kwargs.pop("ssh_port", None) if kwargs: raise ValueError("Unknown arguments: {0}".format(kwargs)) # remote binds self._remote_binds = self._get_binds( remote_bind_address, remote_bind_addresses, is_remote=True ) # local binds self._local_binds = self._get_binds(local_bind_address, local_bind_addresses) self._local_binds = self._consolidate_binds( self._local_binds, self._remote_binds ) ( self.ssh_host, self.ssh_username, ssh_pkey, # still needs to go through _consolidate_auth self.ssh_port, self.ssh_proxy, self.compression, ) = self._read_ssh_config( ssh_host, ssh_config_file, ssh_username, ssh_pkey, ssh_port, ssh_proxy if ssh_proxy_enabled else None, compression, self.logger, ) (self.ssh_password, self.ssh_pkeys) = self._consolidate_auth( ssh_password=ssh_password, ssh_pkey=ssh_pkey, ssh_pkey_password=ssh_private_key_password, allow_agent=allow_agent, host_pkey_directories=host_pkey_directories, logger=self.logger, ) check_host(self.ssh_host) check_port(self.ssh_port) self.logger.info( "Connecting to gateway: {h}:{p} as user '{u}', timeout {t}", h=self.ssh_host, p=self.ssh_port, u=self.ssh_username, t=self.gateway_timeout, ) self.logger.debug("Concurrent connections allowed: {0}", self._threaded) @staticmethod def _read_ssh_config( ssh_host, ssh_config_file, ssh_username=None, ssh_pkey=None, ssh_port=None, ssh_proxy=None, compression=None, logger=log, ): """Read ssh_config_file. Read ssh_config_file and try to look for user (ssh_username), identityfile (ssh_pkey), port (ssh_port) and proxycommand (ssh_proxy) entries for ssh_host """ ssh_config = paramiko.SSHConfig() if not ssh_config_file: # handle case where it's an empty string ssh_config_file = None # Try to read SSH_CONFIG_FILE try: # open the ssh config file with open(os.path.expanduser(ssh_config_file), "r") as f: ssh_config.parse(f) # looks for information for the destination system hostname_info = ssh_config.lookup(ssh_host) # gather settings for user, port and identity file # last resort: use the 'login name' of the user ssh_username = ssh_username or hostname_info.get("user") ssh_pkey = ssh_pkey or hostname_info.get("identityfile", [None])[0] ssh_host = hostname_info.get("hostname") ssh_port = ssh_port or hostname_info.get("port") proxycommand = hostname_info.get("proxycommand") ssh_proxy = ssh_proxy or ( paramiko.ProxyCommand(proxycommand) if proxycommand else None ) if compression is None: compression = hostname_info.get("compression", "") compression = True if compression.upper() == "YES" else False except IOError: logger.warning( "Could not read SSH configuration file: {f}", f=ssh_config_file ) except (AttributeError, TypeError): # ssh_config_file is None logger.info("Skipping loading of ssh configuration file") finally: return ( ssh_host, ssh_username or getpass.getuser(), ssh_pkey, int(ssh_port) if ssh_port else 22, # fallback value ssh_proxy, compression, ) @staticmethod def get_agent_keys(logger=log): """Load public keys from any available SSH agent. Arguments: logger (Optional[logging.Logger]) Return: list """ paramiko_agent = paramiko.Agent() agent_keys = paramiko_agent.get_keys() logger.info("{k} keys loaded from agent", k=len(agent_keys)) return list(agent_keys) @staticmethod def get_keys(logger=log, host_pkey_directories=None, allow_agent=False): """Load public keys from any available SSH agent or local .ssh directory. Arguments: logger (Optional[logging.Logger]) host_pkey_directories (Optional[list[str]]): List of local directories where host SSH pkeys in the format "id_*" are searched. For example, ['~/.ssh'] .. versionadded:: 0.1.0 allow_agent (Optional[boolean]): Whether or not load keys from agent Default: False Return: list """ keys = SSHTunnelForwarder.get_agent_keys(logger=logger) if allow_agent else [] if host_pkey_directories is not None: paramiko_key_types = { "rsa": paramiko.RSAKey, "dsa": paramiko.DSSKey, "ecdsa": paramiko.ECDSAKey, "ed25519": paramiko.Ed25519Key, } for directory in host_pkey_directories or [DEFAULT_SSH_DIRECTORY]: for keytype in paramiko_key_types.keys(): ssh_pkey_expanded = os.path.expanduser( os.path.join(directory, "id_{}".format(keytype)) ) if os.path.isfile(ssh_pkey_expanded): ssh_pkey = SSHTunnelForwarder.read_private_key_file( pkey_file=ssh_pkey_expanded, logger=logger, key_type=paramiko_key_types[keytype], ) if ssh_pkey: keys.append(ssh_pkey) logger.info("{k} keys loaded from host directory", k=len(keys)) return keys @staticmethod def _consolidate_binds(local_binds, remote_binds): """Fill local_binds with defaults. Fill local_binds with defaults when no value/s were specified, leaving paramiko to decide in which local port the tunnel will be open. """ count = len(remote_binds) - len(local_binds) if count < 0: raise ValueError( "Too many local bind addresses " "(local_bind_addresses > remote_bind_addresses)" ) local_binds.extend([("0.0.0.0", 0) for x in range(count)]) return local_binds @staticmethod def _consolidate_auth( ssh_password=None, ssh_pkey=None, ssh_pkey_password=None, allow_agent=True, host_pkey_directories=None, logger=log, ): """Get sure authentication information is in place. ``ssh_pkey`` may be of classes: - ``str`` - in this case it represents a private key file; public key will be obtained from it - ``paramiko.Pkey`` - it will be transparently added to loaded keys """ ssh_loaded_pkeys = SSHTunnelForwarder.get_keys( logger=logger, host_pkey_directories=host_pkey_directories, allow_agent=allow_agent, ) if isinstance(ssh_pkey, str): ssh_pkey_expanded = os.path.expanduser(ssh_pkey) if os.path.exists(ssh_pkey_expanded): ssh_pkey = SSHTunnelForwarder.read_private_key_file( pkey_file=ssh_pkey_expanded, pkey_password=ssh_pkey_password or ssh_password, logger=logger, ) else: logger.warning("Private key file not found: {k}", k=ssh_pkey) if isinstance(ssh_pkey, paramiko.pkey.PKey): ssh_loaded_pkeys.insert(0, ssh_pkey) if not ssh_password and not ssh_loaded_pkeys: raise ValueError("No password or public key available!") return (ssh_password, ssh_loaded_pkeys) def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None): if self._raise_fwd_exc: raise exception(reason) else: self.logger.error(repr(exception(reason))) def _get_transport(self): """Return the SSH transport to the remote gateway.""" if self.ssh_proxy: if isinstance(self.ssh_proxy, paramiko.proxy.ProxyCommand): proxy_repr = repr(self.ssh_proxy.cmd[1]) else: proxy_repr = repr(self.ssh_proxy) self.logger.debug("Connecting via proxy: {0}".format(proxy_repr)) _socket = self.ssh_proxy else: _socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) if isinstance(_socket, socket.socket): _socket.settimeout(self.gateway_timeout) _socket.connect((self.ssh_host, self.ssh_port)) transport = paramiko.Transport(_socket) transport.set_keepalive(self.set_keepalive) transport.use_compression(compress=self.compression) transport.daemon = self.daemon_transport return transport def _create_tunnels(self): """Create SSH tunnels on top of a transport to the remote gateway.""" if not self.is_active: try: self._connect_to_gateway() except socket.gaierror: # raised by paramiko.Transport msg = "Could not resolve IP address for {0}, aborting!".format( self.ssh_host ) self.logger.error(msg) return except (paramiko.SSHException, socket.error) as e: template = "Could not connect to gateway {0}:{1} : {2}" msg = template.format(self.ssh_host, self.ssh_port, e.args[0]) self.logger.error(msg) return for (rem, loc) in zip(self._remote_binds, self._local_binds): try: self._make_ssh_forward_server(rem, loc) except BaseSSHTunnelForwarderError as e: msg = "Problem setting SSH Forwarder up: {0}".format(e.value) self.logger.error(msg) @staticmethod def _get_binds(bind_address, bind_addresses, is_remote=False): addr_kind = "remote" if is_remote else "local" if not bind_address and not bind_addresses: if is_remote: raise ValueError( "No {0} bind addresses specified. Use " "'{0}_bind_address' or '{0}_bind_addresses'" " argument".format(addr_kind) ) else: return [] elif bind_address and bind_addresses: raise ValueError( "You can't use both '{0}_bind_address' and " "'{0}_bind_addresses' arguments. Use one of " "them.".format(addr_kind) ) if bind_address: bind_addresses = [bind_address] if not is_remote: # Add random port if missing in local bind for (i, local_bind) in enumerate(bind_addresses): if isinstance(local_bind, tuple) and len(local_bind) == 1: bind_addresses[i] = (local_bind[0], 0) check_addresses(bind_addresses, is_remote) return bind_addresses @staticmethod def _process_deprecated(attrib, deprecated_attrib, kwargs): """Processes optional deprecate arguments.""" if deprecated_attrib not in DEPRECATIONS: raise ValueError( "{0} not included in deprecations list".format(deprecated_attrib) ) if deprecated_attrib in kwargs: warnings.warn( "'{0}' is DEPRECATED use '{1}' instead".format( deprecated_attrib, DEPRECATIONS[deprecated_attrib] ), DeprecationWarning, ) if attrib: raise ValueError( "You can't use both '{0}' and '{1}'. " "Please only use one of them".format( deprecated_attrib, DEPRECATIONS[deprecated_attrib] ) ) else: return kwargs.pop(deprecated_attrib) return attrib @staticmethod def read_private_key_file(pkey_file, pkey_password=None, key_type=None, logger=log): """Get SSH Public key from a private key file, given an optional password. Arguments: pkey_file (str): File containing a private key (RSA, DSS or ECDSA) Keyword Arguments: pkey_password (Optional[str]): Password to decrypt the private key logger (Optional[logging.Logger]) Return: paramiko.Pkey """ ssh_pkey = None for pkey_class in ( (key_type,) if key_type else ( paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey, paramiko.Ed25519Key, ) ): try: ssh_pkey = pkey_class.from_private_key_file( pkey_file, password=pkey_password ) logger.debug( "Private key file ({k0}, {k1}) successfully loaded", k0=pkey_file, k1=pkey_class, ) break except paramiko.PasswordRequiredException: logger.error("Password is required for key {k}", k=pkey_file) break except paramiko.SSHException: logger.debug( "Private key file ({k0}) could not be loaded as type {k1} or bad password", k0=pkey_file, k1=pkey_class, ) return ssh_pkey def _check_tunnel(self, _srv): """Check if tunnel is already established.""" if self.skip_tunnel_checkup: self.tunnel_is_up[_srv.local_address] = True return self.logger.info("Checking tunnel to: {a}", a=_srv.remote_address) if isinstance(_srv.local_address, str): # UNIX stream s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) else: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(TUNNEL_TIMEOUT) try: # Windows raises WinError 10049 if trying to connect to 0.0.0.0 connect_to = ( ("127.0.0.1", _srv.local_port) if _srv.local_host == "0.0.0.0" else _srv.local_address ) s.connect(connect_to) self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get( timeout=TUNNEL_TIMEOUT * 1.1 ) self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address)) except socket.error: self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address)) self.tunnel_is_up[_srv.local_address] = False except queue.Empty: self.logger.debug("Tunnel to {0} is UP".format(_srv.remote_address)) self.tunnel_is_up[_srv.local_address] = True finally: s.close() def check_tunnels(self): """Check that if all tunnels are established and populates. :attr:`.tunnel_is_up` """ for _srv in self._server_list: self._check_tunnel(_srv) def start(self): """Start the SSH tunnels.""" if self.is_alive: self.logger.warning("Already started!") return self._create_tunnels() if not self.is_active: self._raise( BaseSSHTunnelForwarderError, reason="Could not establish session to SSH gateway", ) for _srv in self._server_list: thread = threading.Thread( target=self._serve_forever_wrapper, args=(_srv,), name="Srv-{0}".format(address_to_str(_srv.local_port)), ) thread.daemon = self.daemon_forward_servers thread.start() self._check_tunnel(_srv) self.is_alive = any(self.tunnel_is_up.values()) if not self.is_alive: self._raise( HandlerSSHTunnelForwarderError, "An error occurred while opening tunnels.", ) def stop(self): """Shut the tunnel down. .. note:: This **had** to be handled with care before ``0.1.0``: - if a port redirection is opened - the destination is not reachable - we attempt a connection to that tunnel (``SYN`` is sent and acknowledged, then a ``FIN`` packet is sent and never acknowledged... weird) - we try to shutdown: it will not succeed until ``FIN_WAIT_2`` and ``CLOSE_WAIT`` time out. .. note:: Handle these scenarios with :attr:`.tunnel_is_up`: if False, server ``shutdown()`` will be skipped on that tunnel """ self.logger.info("Closing all open connections...") opened_address_text = ( ", ".join((address_to_str(k.local_address) for k in self._server_list)) or "None" ) self.logger.debug("Listening tunnels: " + opened_address_text) self._stop_transport() self._server_list = [] # reset server list self.tunnel_is_up = {} # reset tunnel status def close(self): """Stop the an active tunnel, alias to :meth:`.stop`.""" self.stop() def restart(self): """Restart connection to the gateway and tunnels.""" self.stop() self.start() def _connect_to_gateway(self): """Open connection to SSH gateway. - First try with all keys loaded from an SSH agent (if allowed) - Then with those passed directly or read from ~/.ssh/config - As last resort, try with a provided password """ for key in self.ssh_pkeys: self.logger.debug( "Trying to log in with key: {0}".format(hexlify(key.get_fingerprint())) ) try: self._transport = self._get_transport() self._transport.connect( hostkey=self.ssh_host_key, username=self.ssh_username, pkey=key ) if self._transport.is_alive: return except paramiko.AuthenticationException: self.logger.debug("Authentication error") self._stop_transport() if self.ssh_password: # avoid conflict using both pass and pkey self.logger.debug( "Trying to log in with password: {0}".format( "*" * len(self.ssh_password) ) ) try: self._transport = self._get_transport() self._transport.connect( hostkey=self.ssh_host_key, username=self.ssh_username, password=self.ssh_password, ) if self._transport.is_alive: return except paramiko.AuthenticationException: self.logger.debug("Authentication error") self._stop_transport() self.logger.error("Could not open connection to gateway") def _serve_forever_wrapper(self, _srv, poll_interval=0.1): """Wrapper for the server created for a SSH forward.""" self.logger.info( "Opening tunnel: {0} <> {1}".format( address_to_str(_srv.local_address), address_to_str(_srv.remote_address) ) ) _srv.serve_forever(poll_interval) # blocks until finished self.logger.info( "Tunnel: {0} <> {1} released".format( address_to_str(_srv.local_address), address_to_str(_srv.remote_address) ) ) def _stop_transport(self): """Close the underlying transport when nothing more is needed.""" try: self._check_is_started() except (BaseSSHTunnelForwarderError, HandlerSSHTunnelForwarderError) as e: self.logger.warning(e) for _srv in self._server_list: tunnel = _srv.local_address if self.tunnel_is_up[tunnel]: self.logger.info("Shutting down tunnel {0}".format(tunnel)) _srv.shutdown() _srv.server_close() # clean up the UNIX domain socket if we're using one if isinstance(_srv, _UnixStreamForwardServer): try: os.unlink(_srv.local_address) except Exception as e: self.logger.error( "Unable to unlink socket {0}: {1}".format( self.local_address, repr(e) ) ) self.is_alive = False if self.is_active: self._transport.close() self._transport.stop_thread() self.logger.debug("Transport is closed") @property def local_bind_port(self): # BACKWARDS COMPATIBILITY self._check_is_started() if len(self._server_list) != 1: raise BaseSSHTunnelForwarderError( "Use .local_bind_ports property for more than one tunnel" ) return self.local_bind_ports[0] @property def local_bind_host(self): # BACKWARDS COMPATIBILITY self._check_is_started() if len(self._server_list) != 1: raise BaseSSHTunnelForwarderError( "Use .local_bind_hosts property for more than one tunnel" ) return self.local_bind_hosts[0] @property def local_bind_address(self): # BACKWARDS COMPATIBILITY self._check_is_started() if len(self._server_list) != 1: raise BaseSSHTunnelForwarderError( "Use .local_bind_addresses property for more than one tunnel" ) return self.local_bind_addresses[0] @property def local_bind_ports(self): """Return a list containing the ports of local side of the TCP tunnels.""" self._check_is_started() return [ _server.local_port for _server in self._server_list if _server.local_port is not None ] @property def local_bind_hosts(self): """Return a list containing the IP addresses listening for the tunnels.""" self._check_is_started() return [ _server.local_host for _server in self._server_list if _server.local_host is not None ] @property def local_bind_addresses(self): """Return a list of (IP, port) pairs for the local side of the tunnels.""" self._check_is_started() return [_server.local_address for _server in self._server_list] @property def tunnel_bindings(self): """Return a dictionary containing the active local<>remote tunnel_bindings.""" return dict( (_server.remote_address, _server.local_address) for _server in self._server_list if self.tunnel_is_up[_server.local_address] ) @property def is_active(self): """ Return True if the underlying SSH transport is up """ if "_transport" in self.__dict__ and self._transport.is_active(): return True return False def _check_is_started(self): if not self.is_active: # underlying transport not alive msg = "Server is not started. Please .start() first!" raise BaseSSHTunnelForwarderError(msg) if not self.is_alive: msg = "Tunnels are not started. Please .start() first!" raise HandlerSSHTunnelForwarderError(msg) def __str__(self): credentials = { "password": self.ssh_password, "pkeys": [ (key.get_name(), hexlify(key.get_fingerprint())) for key in self.ssh_pkeys ] if any(self.ssh_pkeys) else None, } _remove_none_values(credentials) template = os.linesep.join( [ "{0} object", "ssh gateway: {1}:{2}", "proxy: {3}", "username: {4}", "authentication: {5}", "hostkey: {6}", "status: {7}started", "keepalive messages: {8}", "tunnel connection check: {9}", "concurrent connections: {10}allowed", "compression: {11}requested", "logging level: {12}", "local binds: {13}", "remote binds: {14}", ] ) return template.format( self.__class__, self.ssh_host, self.ssh_port, self.ssh_proxy.cmd[1] if self.ssh_proxy else "no", self.ssh_username, credentials, self.ssh_host_key if self.ssh_host_key else "not checked", "" if self.is_alive else "not ", "disabled" if not self.set_keepalive else "every {0} sec".format(self.set_keepalive), "disabled" if self.skip_tunnel_checkup else "enabled", "" if self._threaded else "not ", "" if self.compression else "not ", os.environ.get("HYPERGLASS_LOG_LEVEL") or "INFO", self._local_binds, self._remote_binds, ) def __repr__(self): return self.__str__() def __enter__(self): try: self.start() return self except KeyboardInterrupt: self.__exit__() def __exit__(self, *args): self._stop_transport() def open_tunnel(*args, **kwargs): """Open an SSH Tunnel, wrapper for :class:`SSHTunnelForwarder`. Arguments: destination (Optional[tuple]): SSH server's IP address and port in the format (``ssh_address``, ``ssh_port``) Keyword Arguments: debug_level (Optional[int or str]): log level for :class:`logging.Logger` instance, i.e. ``DEBUG`` skip_tunnel_checkup (boolean): Enable/disable the local side check and populate :attr:`~SSHTunnelForwarder.tunnel_is_up` Default: True .. versionadded:: 0.1.0 block_on_close (boolean): Wait until all connections are done during close by changing the value of :attr:`~SSHTunnelForwarder.block_on_close` Default: True .. note:: A value of ``debug_level`` set to 1 == ``TRACE`` enables tracing mode .. note:: See :class:`SSHTunnelForwarder` for keyword arguments **Example**:: from sshtunnel import open_tunnel with open_tunnel(SERVER, ssh_username=SSH_USER, ssh_port=22, ssh_password=SSH_PASSWORD, remote_bind_address=(REMOTE_HOST, REMOTE_PORT), local_bind_address=('', LOCAL_PORT)) as server: def do_something(port): pass print("LOCAL PORTS:", server.local_bind_port) do_something(server.local_bind_port) """ # Attach a console handler to the logger or create one if not passed kwargs["logger"] = kwargs.get("logger") or log ssh_address_or_host = kwargs.pop("ssh_address_or_host", None) # Check if deprecated arguments ssh_address or ssh_host were used for deprecated_argument in ["ssh_address", "ssh_host"]: ssh_address_or_host = SSHTunnelForwarder._process_deprecated( ssh_address_or_host, deprecated_argument, kwargs ) ssh_port = kwargs.pop("ssh_port", 22) skip_tunnel_checkup = kwargs.pop("skip_tunnel_checkup", True) block_on_close = kwargs.pop("block_on_close", _DAEMON) if not args: if isinstance(ssh_address_or_host, tuple): args = (ssh_address_or_host,) else: args = ((ssh_address_or_host, ssh_port),) forwarder = SSHTunnelForwarder(*args, **kwargs) forwarder.skip_tunnel_checkup = skip_tunnel_checkup forwarder.daemon_forward_servers = not block_on_close forwarder.daemon_transport = not block_on_close return forwarder def _bindlist(input_str): """Define type of data expected for remote and local bind address lists. Returns a tuple (ip_address, port) whose elements are (str, int) """ try: ip_port = input_str.split(":") if len(ip_port) == 1: _ip = ip_port[0] _port = None else: (_ip, _port) = ip_port if not _ip and not _port: raise AssertionError elif not _port: _port = "22" # default port if not given return _ip, int(_port) except ValueError: raise argparse.ArgumentTypeError( "Address tuple must be of type IP_ADDRESS:PORT" ) except AssertionError: raise argparse.ArgumentTypeError("Both IP:PORT can't be missing!")