1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00
2021-09-15 18:25:37 -07:00

1554 lines
54 KiB
Python

"""Initiate SSH tunnels via a remote gateway.
Copyright (c) 2014-2019 Pahaz Blinov
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*sshtunnel* - Initiate SSH tunnels via a remote gateway.
``sshtunnel`` works by opening a port forwarding SSH connection in the
background, using threads.
The connection(s) are closed when explicitly calling the
:meth:`SSHTunnelForwarder.stop` method or using it as a context.
"""
# Standard Library
import os
import sys
import queue
import socket
import getpass
import argparse
import warnings
import threading
import socketserver
from select import select
from binascii import hexlify
# Third Party
import paramiko
# Project
from hyperglass.log import log
TUNNEL_TIMEOUT = 1.0 #: Timeout (seconds) for tunnel connection
_DAEMON = False #: Use daemon threads in connections
_CONNECTION_COUNTER = 1
_LOCK = threading.Lock()
DEPRECATIONS = {
"ssh_address": "ssh_address_or_host",
"ssh_host": "ssh_address_or_host",
"ssh_private_key": "ssh_pkey",
"raise_exception_if_any_forwarder_have_a_problem": "mute_exceptions",
}
if os.name == "posix":
DEFAULT_SSH_DIRECTORY = "~/.ssh"
UnixStreamServer = socketserver.UnixStreamServer
else:
DEFAULT_SSH_DIRECTORY = "~/ssh"
UnixStreamServer = socketserver.TCPServer
#: Path of optional ssh configuration file
SSH_CONFIG_FILE = os.path.join(DEFAULT_SSH_DIRECTORY, "config")
########################
# #
# Utils #
# #
########################
def check_host(host):
assert isinstance(host, str), "IP is not a string ({0})".format(type(host).__name__)
def check_port(port):
assert isinstance(port, int), "PORT is not a number"
assert port >= 0, "PORT < 0 ({0})".format(port)
def check_address(address):
"""Check if the format of the address is correct.
Arguments:
address (tuple):
(``str``, ``int``) representing an IP address and port,
respectively
.. note::
alternatively a local ``address`` can be a ``str`` when working
with UNIX domain sockets, if supported by the platform
Raises:
ValueError:
raised when address has an incorrect format
Example:
>>> check_address(('127.0.0.1', 22))
"""
if isinstance(address, tuple):
check_host(address[0])
check_port(address[1])
elif isinstance(address, str):
if os.name != "posix":
raise ValueError("Platform does not support UNIX domain sockets")
if not (os.path.exists(address) or os.access(os.path.dirname(address), os.W_OK)):
raise ValueError("ADDRESS not a valid socket domain socket ({0})".format(address))
else:
raise ValueError(
"ADDRESS is not a tuple, string, or character buffer "
"({0})".format(type(address).__name__)
)
def check_addresses(address_list, is_remote=False):
"""
Check if the format of the addresses is correct
Arguments:
address_list (list[tuple]):
Sequence of (``str``, ``int``) pairs, each representing an IP
address and port respectively
.. note::
when supported by the platform, one or more of the elements in
the list can be of type ``str``, representing a valid UNIX
domain socket
is_remote (boolean):
Whether or not the address list
Raises:
AssertionError:
raised when ``address_list`` contains an invalid element
ValueError:
raised when any address in the list has an incorrect format
Example:
>>> check_addresses([('127.0.0.1', 22), ('127.0.0.1', 2222)])
"""
assert all(isinstance(x, (tuple, str)) for x in address_list)
if is_remote and any(isinstance(x, str) for x in address_list):
raise AssertionError("UNIX domain sockets not allowed for remote" "addresses")
for address in address_list:
check_address(address)
def address_to_str(address):
if isinstance(address, tuple):
return "{0[0]}:{0[1]}".format(address)
return str(address)
def get_connection_id():
global _CONNECTION_COUNTER
with _LOCK:
uid = _CONNECTION_COUNTER
_CONNECTION_COUNTER += 1
return uid
def _remove_none_values(dictionary):
""" Remove dictionary keys whose value is None."""
return list(map(dictionary.pop, [i for i in dictionary if dictionary[i] is None]))
########################
# #
# Errors #
# #
########################
class BaseSSHTunnelForwarderError(Exception):
""" Exception raised by :class:`SSHTunnelForwarder` errors """
def __init__(self, *args, **kwargs):
self.value = kwargs.pop("value", args[0] if args else "")
def __str__(self):
return self.value
class HandlerSSHTunnelForwarderError(BaseSSHTunnelForwarderError):
""" Exception for Tunnel forwarder errors """
pass
########################
# #
# Handlers #
# #
########################
class _ForwardHandler(socketserver.BaseRequestHandler):
""" Base handler for tunnel connections """
remote_address = None
ssh_transport = None
logger = None
info = None
def _redirect(self, chan):
while chan.active:
rqst, _, _ = select([self.request, chan], [], [], 5)
if self.request in rqst:
data = self.request.recv(1024)
if not data:
break
self.logger.trace(
">>> OUT {0} send to {1}: {2} >>>".format(
self.info, self.remote_address, hexlify(data)
),
)
chan.sendall(data)
if chan in rqst: # else
if not chan.recv_ready():
break
data = chan.recv(1024)
self.logger.trace("<<< IN {0} recv: {1} <<<".format(self.info, hexlify(data)),)
self.request.sendall(data)
def handle(self):
uid = get_connection_id()
self.info = "#{0} <-- {1}".format(uid, self.client_address or self.server.local_address)
src_address = self.request.getpeername()
if not isinstance(src_address, tuple):
src_address = ("dummy", 12345)
try:
chan = self.ssh_transport.open_channel(
kind="direct-tcpip",
dest_addr=self.remote_address,
src_addr=src_address,
timeout=TUNNEL_TIMEOUT,
)
except paramiko.SSHException:
chan = None
if chan is None:
msg = "{0} to {1} was rejected by the SSH server".format(self.info, self.remote_address)
self.logger.trace(msg)
raise HandlerSSHTunnelForwarderError(msg)
self.logger.trace("{0} connected".format(self.info))
try:
self._redirect(chan)
except socket.error:
# Sometimes a RST is sent and a socket error is raised, treat this
# exception. It was seen that a 3way FIN is processed later on, so
# no need to make an ordered close of the connection here or raise
# the exception beyond this point...
self.logger.trace("{0} sending RST".format(self.info))
except Exception as e:
self.logger.trace("{0} error: {1}".format(self.info, repr(e)))
finally:
chan.close()
self.request.close()
self.logger.trace("{0} connection closed.".format(self.info))
class _ForwardServer(socketserver.TCPServer): # Not Threading
"""
Non-threading version of the forward server
"""
allow_reuse_address = True # faster rebinding
def __init__(self, *args, **kwargs):
self.logger = kwargs.pop("logger") or log
self.tunnel_ok = queue.Queue()
socketserver.TCPServer.__init__(self, *args, **kwargs)
def handle_error(self, request, client_address):
(exc_class, exc, tb) = sys.exc_info()
self.logger.error(
"Could not establish connection from {0} to remote " "side of the tunnel",
request.getsockname(),
)
self.tunnel_ok.put(False)
@property
def local_address(self):
return self.server_address
@property
def local_host(self):
return self.server_address[0]
@property
def local_port(self):
return self.server_address[1]
@property
def remote_address(self):
return self.RequestHandlerClass.remote_address
@property
def remote_host(self):
return self.RequestHandlerClass.remote_address[0]
@property
def remote_port(self):
return self.RequestHandlerClass.remote_address[1]
class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer):
"""
Allow concurrent connections to each tunnel
"""
# If True, cleanly stop threads created by ThreadingMixIn when quitting
daemon_threads = _DAEMON
class _UnixStreamForwardServer(UnixStreamServer):
"""
Serve over UNIX domain sockets (does not work on Windows)
"""
def __init__(self, *args, **kwargs):
self.logger = kwargs.pop("logger") or log
self.tunnel_ok = queue.Queue()
UnixStreamServer.__init__(self, *args, **kwargs)
@property
def local_address(self):
return self.server_address
@property
def local_host(self):
return None
@property
def local_port(self):
return None
@property
def remote_address(self):
return self.RequestHandlerClass.remote_address
@property
def remote_host(self):
return self.RequestHandlerClass.remote_address[0]
@property
def remote_port(self):
return self.RequestHandlerClass.remote_address[1]
class _ThreadingUnixStreamForwardServer(socketserver.ThreadingMixIn, _UnixStreamForwardServer):
"""
Allow concurrent connections to each tunnel
"""
# If True, cleanly stop threads created by ThreadingMixIn when quitting
daemon_threads = _DAEMON
class SSHTunnelForwarder:
"""
**SSH tunnel class**
- Initialize a SSH tunnel to a remote host according to the input
arguments
- Optionally:
+ Read an SSH configuration file (typically ``~/.ssh/config``)
+ Load keys from a running SSH agent (i.e. Pageant, GNOME Keyring)
Raises:
:class:`.BaseSSHTunnelForwarderError`:
raised by SSHTunnelForwarder class methods
:class:`.HandlerSSHTunnelForwarderError`:
raised by tunnel forwarder threads
.. note::
Attributes ``mute_exceptions`` and
``raise_exception_if_any_forwarder_have_a_problem``
(deprecated) may be used to silence most exceptions raised
from this class
Keyword Arguments:
ssh_address_or_host (tuple or str):
IP or hostname of ``REMOTE GATEWAY``. It may be a two-element
tuple (``str``, ``int``) representing IP and port respectively,
or a ``str`` representing the IP address only
.. versionadded:: 0.0.4
ssh_config_file (str):
SSH configuration file that will be read. If explicitly set to
``None``, parsing of this configuration is omitted
Default: :const:`SSH_CONFIG_FILE`
.. versionadded:: 0.0.4
ssh_host_key (str):
Representation of a line in an OpenSSH-style "known hosts"
file.
``REMOTE GATEWAY``'s key fingerprint will be compared to this
host key in order to prevent against SSH server spoofing.
Important when using passwords in order not to accidentally
do a login attempt to a wrong (perhaps an attacker's) machine
ssh_username (str):
Username to authenticate as in ``REMOTE SERVER``
Default: current local user name
ssh_password (str):
Text representing the password used to connect to ``REMOTE
SERVER`` or for unlocking a private key.
.. note::
Avoid coding secret password directly in the code, since this
may be visible and make your service vulnerable to attacks
ssh_port (int):
Optional port number of the SSH service on ``REMOTE GATEWAY``,
when `ssh_address_or_host`` is a ``str`` representing the
IP part of ``REMOTE GATEWAY``'s address
Default: 22
ssh_pkey (str or paramiko.PKey):
**Private** key file name (``str``) to obtain the public key
from or a **public** key (:class:`paramiko.pkey.PKey`)
ssh_private_key_password (str):
Password for an encrypted ``ssh_pkey``
.. note::
Avoid coding secret password directly in the code, since this
may be visible and make your service vulnerable to attacks
ssh_proxy (socket-like object or tuple):
Proxy where all SSH traffic will be passed through.
It might be for example a :class:`paramiko.proxy.ProxyCommand`
instance.
See either the :class:`paramiko.transport.Transport`'s sock
parameter documentation or ``ProxyCommand`` in ``ssh_config(5)``
for more information.
It is also possible to specify the proxy address as a tuple of
type (``str``, ``int``) representing proxy's IP and port
.. note::
Ignored if ``ssh_proxy_enabled`` is False
.. versionadded:: 0.0.5
ssh_proxy_enabled (boolean):
Enable/disable SSH proxy. If True and user's
``ssh_config_file`` contains a ``ProxyCommand`` directive
that matches the specified ``ssh_address_or_host``,
a :class:`paramiko.proxy.ProxyCommand` object will be created where
all SSH traffic will be passed through
Default: ``True``
.. versionadded:: 0.0.4
local_bind_address (tuple):
Local tuple in the format (``str``, ``int``) representing the
IP and port of the local side of the tunnel. Both elements in
the tuple are optional so both ``('', 8000)`` and
``('10.0.0.1', )`` are valid values
Default: ``('0.0.0.0', RANDOM_PORT)``
.. versionchanged:: 0.0.8
Added the ability to use a UNIX domain socket as local bind
address
local_bind_addresses (list[tuple]):
In case more than one tunnel is established at once, a list
of tuples (in the same format as ``local_bind_address``)
can be specified, such as [(ip1, port_1), (ip_2, port2), ...]
Default: ``[local_bind_address]``
.. versionadded:: 0.0.4
remote_bind_address (tuple):
Remote tuple in the format (``str``, ``int``) representing the
IP and port of the remote side of the tunnel.
remote_bind_addresses (list[tuple]):
In case more than one tunnel is established at once, a list
of tuples (in the same format as ``remote_bind_address``)
can be specified, such as [(ip1, port_1), (ip_2, port2), ...]
Default: ``[remote_bind_address]``
.. versionadded:: 0.0.4
allow_agent (boolean):
Enable/disable load of keys from an SSH agent
Default: ``True``
.. versionadded:: 0.0.8
host_pkey_directories (list):
Look for pkeys in folders on this list, for example ['~/.ssh'].
Default: ``None`` (disabled)
.. versionadded:: 0.1.4
compression (boolean):
Turn on/off transport compression. By default compression is
disabled since it may negatively affect interactive sessions
Default: ``False``
.. versionadded:: 0.0.8
logger (logging.Logger):
logging instance for sshtunnel and paramiko
Default: :class:`logging.Logger` instance with a single
:class:`logging.StreamHandler` handler and
:const:`DEFAULT_LOGLEVEL` level
.. versionadded:: 0.0.3
mute_exceptions (boolean):
Allow silencing :class:`BaseSSHTunnelForwarderError` or
:class:`HandlerSSHTunnelForwarderError` exceptions when enabled
Default: ``False``
.. versionadded:: 0.0.8
set_keepalive (float):
Interval in seconds defining the period in which, if no data
was sent over the connection, a *'keepalive'* packet will be
sent (and ignored by the remote host). This can be useful to
keep connections alive over a NAT
Default: 0.0 (no keepalive packets are sent)
.. versionadded:: 0.0.7
threaded (boolean):
Allow concurrent connections over a single tunnel
Default: ``True``
.. versionadded:: 0.0.3
ssh_address (str):
Superseded by ``ssh_address_or_host``, tuple of type (str, int)
representing the IP and port of ``REMOTE SERVER``
.. deprecated:: 0.0.4
ssh_host (str):
Superseded by ``ssh_address_or_host``, tuple of type
(str, int) representing the IP and port of ``REMOTE SERVER``
.. deprecated:: 0.0.4
ssh_private_key (str or paramiko.PKey):
Superseded by ``ssh_pkey``, which can represent either a
**private** key file name (``str``) or a **public** key
(:class:`paramiko.pkey.PKey`)
.. deprecated:: 0.0.8
raise_exception_if_any_forwarder_have_a_problem (boolean):
Allow silencing :class:`BaseSSHTunnelForwarderError` or
:class:`HandlerSSHTunnelForwarderError` exceptions when set to
False
Default: ``True``
.. versionadded:: 0.0.4
.. deprecated:: 0.0.8 (use ``mute_exceptions`` instead)
Attributes:
tunnel_is_up (dict):
Describe whether or not the other side of the tunnel was reported
to be up (and we must close it) or not (skip shutting down that
tunnel)
.. note::
This attribute should not be modified
.. note::
When :attr:`.skip_tunnel_checkup` is disabled or the local bind
is a UNIX socket, the value will always be ``True``
**Example**::
{('127.0.0.1', 55550): True, # this tunnel is up
('127.0.0.1', 55551): False} # this one isn't
where 55550 and 55551 are the local bind ports
skip_tunnel_checkup (boolean):
Disable tunnel checkup (default for backwards compatibility).
.. versionadded:: 0.1.0
"""
skip_tunnel_checkup = True
daemon_forward_servers = _DAEMON #: flag tunnel threads in daemon mode
daemon_transport = _DAEMON #: flag SSH transport thread in daemon mode
def local_is_up(self, target):
"""
Check if a tunnel is up (remote target's host is reachable on TCP
target's port)
Arguments:
target (tuple):
tuple of type (``str``, ``int``) indicating the listen IP
address and port
Return:
boolean
.. deprecated:: 0.1.0
Replaced by :meth:`.check_tunnels()` and :attr:`.tunnel_is_up`
"""
try:
check_address(target)
except ValueError:
self.logger.warning(
"Target must be a tuple (IP, port), where IP "
'is a string (i.e. "192.168.0.1") and port is '
"an integer (i.e. 40000). Alternatively "
"target can be a valid UNIX domain socket."
)
return False
if self.skip_tunnel_checkup: # force tunnel check at this point
self.skip_tunnel_checkup = False
self.check_tunnels()
self.skip_tunnel_checkup = True # roll it back
return self.tunnel_is_up.get(target, True)
def _make_ssh_forward_handler_class(self, remote_address_):
"""
Make SSH Handler class
"""
class Handler(_ForwardHandler):
remote_address = remote_address_
ssh_transport = self._transport
logger = self.logger
return Handler
def _make_ssh_forward_server_class(self, remote_address_):
return _ThreadingForwardServer if self._threaded else _ForwardServer
def _make_unix_ssh_forward_server_class(self, remote_address_):
return _ThreadingUnixStreamForwardServer if self._threaded else _UnixStreamForwardServer
def _make_ssh_forward_server(self, remote_address, local_bind_address):
"""
Make SSH forward proxy Server class
"""
_Handler = self._make_ssh_forward_handler_class(remote_address)
try:
if isinstance(local_bind_address, str):
forward_maker_class = self._make_unix_ssh_forward_server_class
else:
forward_maker_class = self._make_ssh_forward_server_class
_Server = forward_maker_class(remote_address)
ssh_forward_server = _Server(local_bind_address, _Handler, logger=self.logger,)
if ssh_forward_server:
ssh_forward_server.daemon_threads = self.daemon_forward_servers
self._server_list.append(ssh_forward_server)
self.tunnel_is_up[ssh_forward_server.server_address] = False
else:
self._raise(
BaseSSHTunnelForwarderError,
"Problem setting up ssh {0} <> {1} forwarder. You can "
"suppress this exception by using the `mute_exceptions`"
"argument".format(
address_to_str(local_bind_address), address_to_str(remote_address),
),
)
except IOError:
self._raise(
BaseSSHTunnelForwarderError,
"Couldn't open tunnel {0} <> {1} might be in use or "
"destination not reachable".format(
address_to_str(local_bind_address), address_to_str(remote_address)
),
)
def __init__(
self,
ssh_address_or_host=None,
ssh_config_file=SSH_CONFIG_FILE,
ssh_host_key=None,
ssh_password=None,
ssh_pkey=None,
ssh_private_key_password=None,
ssh_proxy=None,
ssh_proxy_enabled=True,
ssh_username=None,
local_bind_address=None,
local_bind_addresses=None,
logger=None,
mute_exceptions=False,
remote_bind_address=None,
remote_bind_addresses=None,
set_keepalive=0.0,
threaded=True, # old version False
compression=None,
allow_agent=True, # look for keys from an SSH agent
host_pkey_directories=None, # look for keys in ~/.ssh
gateway_timeout=None,
*args,
**kwargs, # for backwards compatibility
) -> None:
self.logger = logger or log
self.ssh_host_key = ssh_host_key
self.set_keepalive = set_keepalive
self._server_list = [] # reset server list
self.tunnel_is_up = {} # handle tunnel status
self._threaded = threaded
self.is_alive = False
self.gateway_timeout = gateway_timeout
# Check if deprecated arguments ssh_address or ssh_host were used
for deprecated_argument in ["ssh_address", "ssh_host"]:
ssh_address_or_host = self._process_deprecated(
ssh_address_or_host, deprecated_argument, kwargs
)
# other deprecated arguments
ssh_pkey = self._process_deprecated(ssh_pkey, "ssh_private_key", kwargs)
self._raise_fwd_exc = (
self._process_deprecated(
None, "raise_exception_if_any_forwarder_have_a_problem", kwargs
)
or not mute_exceptions
)
if isinstance(ssh_address_or_host, tuple):
check_address(ssh_address_or_host)
(ssh_host, ssh_port) = ssh_address_or_host
else:
ssh_host = ssh_address_or_host
ssh_port = kwargs.pop("ssh_port", None)
if kwargs:
raise ValueError("Unknown arguments: {0}".format(kwargs))
# remote binds
self._remote_binds = self._get_binds(
remote_bind_address, remote_bind_addresses, is_remote=True
)
# local binds
self._local_binds = self._get_binds(local_bind_address, local_bind_addresses)
self._local_binds = self._consolidate_binds(self._local_binds, self._remote_binds)
(
self.ssh_host,
self.ssh_username,
ssh_pkey, # still needs to go through _consolidate_auth
self.ssh_port,
self.ssh_proxy,
self.compression,
) = self._read_ssh_config(
ssh_host,
ssh_config_file,
ssh_username,
ssh_pkey,
ssh_port,
ssh_proxy if ssh_proxy_enabled else None,
compression,
self.logger,
)
(self.ssh_password, self.ssh_pkeys) = self._consolidate_auth(
ssh_password=ssh_password,
ssh_pkey=ssh_pkey,
ssh_pkey_password=ssh_private_key_password,
allow_agent=allow_agent,
host_pkey_directories=host_pkey_directories,
logger=self.logger,
)
check_host(self.ssh_host)
check_port(self.ssh_port)
self.logger.info(
"Connecting to gateway: {h}:{p} as user '{u}', timeout {t}",
h=self.ssh_host,
p=self.ssh_port,
u=self.ssh_username,
t=self.gateway_timeout,
)
self.logger.debug("Concurrent connections allowed: {0}", self._threaded)
@staticmethod
def _read_ssh_config(
ssh_host,
ssh_config_file,
ssh_username=None,
ssh_pkey=None,
ssh_port=None,
ssh_proxy=None,
compression=None,
logger=log,
):
"""Read ssh_config_file.
Read ssh_config_file and try to look for user (ssh_username),
identityfile (ssh_pkey), port (ssh_port) and proxycommand
(ssh_proxy) entries for ssh_host
"""
ssh_config = paramiko.SSHConfig()
if not ssh_config_file: # handle case where it's an empty string
ssh_config_file = None
# Try to read SSH_CONFIG_FILE
try:
# open the ssh config file
with open(os.path.expanduser(ssh_config_file), "r") as f:
ssh_config.parse(f)
# looks for information for the destination system
hostname_info = ssh_config.lookup(ssh_host)
# gather settings for user, port and identity file
# last resort: use the 'login name' of the user
ssh_username = ssh_username or hostname_info.get("user")
ssh_pkey = ssh_pkey or hostname_info.get("identityfile", [None])[0]
ssh_host = hostname_info.get("hostname")
ssh_port = ssh_port or hostname_info.get("port")
proxycommand = hostname_info.get("proxycommand")
ssh_proxy = ssh_proxy or (paramiko.ProxyCommand(proxycommand) if proxycommand else None)
if compression is None:
compression = hostname_info.get("compression", "")
compression = True if compression.upper() == "YES" else False
except IOError:
logger.warning("Could not read SSH configuration file: {f}", f=ssh_config_file)
except (AttributeError, TypeError): # ssh_config_file is None
logger.info("Skipping loading of ssh configuration file")
finally:
return (
ssh_host,
ssh_username or getpass.getuser(),
ssh_pkey,
int(ssh_port) if ssh_port else 22, # fallback value
ssh_proxy,
compression,
)
@staticmethod
def get_agent_keys(logger=log):
"""Load public keys from any available SSH agent.
Arguments:
logger (Optional[logging.Logger])
Return:
list
"""
paramiko_agent = paramiko.Agent()
agent_keys = paramiko_agent.get_keys()
logger.info("{k} keys loaded from agent", k=len(agent_keys))
return list(agent_keys)
@staticmethod
def get_keys(logger=log, host_pkey_directories=None, allow_agent=False):
"""Load public keys from any available SSH agent or local .ssh directory.
Arguments:
logger (Optional[logging.Logger])
host_pkey_directories (Optional[list[str]]):
List of local directories where host SSH pkeys in the format
"id_*" are searched. For example, ['~/.ssh']
.. versionadded:: 0.1.0
allow_agent (Optional[boolean]):
Whether or not load keys from agent
Default: False
Return:
list
"""
keys = SSHTunnelForwarder.get_agent_keys(logger=logger) if allow_agent else []
if host_pkey_directories is not None:
paramiko_key_types = {
"rsa": paramiko.RSAKey,
"dsa": paramiko.DSSKey,
"ecdsa": paramiko.ECDSAKey,
"ed25519": paramiko.Ed25519Key,
}
for directory in host_pkey_directories or [DEFAULT_SSH_DIRECTORY]:
for keytype in paramiko_key_types.keys():
ssh_pkey_expanded = os.path.expanduser(
os.path.join(directory, "id_{}".format(keytype))
)
if os.path.isfile(ssh_pkey_expanded):
ssh_pkey = SSHTunnelForwarder.read_private_key_file(
pkey_file=ssh_pkey_expanded,
logger=logger,
key_type=paramiko_key_types[keytype],
)
if ssh_pkey:
keys.append(ssh_pkey)
logger.info("{k} keys loaded from host directory", k=len(keys))
return keys
@staticmethod
def _consolidate_binds(local_binds, remote_binds):
"""Fill local_binds with defaults.
Fill local_binds with defaults when no value/s were specified,
leaving paramiko to decide in which local port the tunnel will be open.
"""
count = len(remote_binds) - len(local_binds)
if count < 0:
raise ValueError(
"Too many local bind addresses " "(local_bind_addresses > remote_bind_addresses)"
)
local_binds.extend([("0.0.0.0", 0) for x in range(count)])
return local_binds
@staticmethod
def _consolidate_auth(
ssh_password=None,
ssh_pkey=None,
ssh_pkey_password=None,
allow_agent=True,
host_pkey_directories=None,
logger=log,
):
"""Get sure authentication information is in place.
``ssh_pkey`` may be of classes:
- ``str`` - in this case it represents a private key file; public
key will be obtained from it
- ``paramiko.Pkey`` - it will be transparently added to loaded keys
"""
ssh_loaded_pkeys = SSHTunnelForwarder.get_keys(
logger=logger, host_pkey_directories=host_pkey_directories, allow_agent=allow_agent,
)
if isinstance(ssh_pkey, str):
ssh_pkey_expanded = os.path.expanduser(ssh_pkey)
if os.path.exists(ssh_pkey_expanded):
ssh_pkey = SSHTunnelForwarder.read_private_key_file(
pkey_file=ssh_pkey_expanded,
pkey_password=ssh_pkey_password or ssh_password,
logger=logger,
)
else:
logger.warning("Private key file not found: {k}", k=ssh_pkey)
if isinstance(ssh_pkey, paramiko.pkey.PKey):
ssh_loaded_pkeys.insert(0, ssh_pkey)
if not ssh_password and not ssh_loaded_pkeys:
raise ValueError("No password or public key available!")
return (ssh_password, ssh_loaded_pkeys)
def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None):
if self._raise_fwd_exc:
raise exception(reason)
else:
self.logger.error(repr(exception(reason)))
def _get_transport(self):
"""Return the SSH transport to the remote gateway."""
if self.ssh_proxy:
if isinstance(self.ssh_proxy, paramiko.proxy.ProxyCommand):
proxy_repr = repr(self.ssh_proxy.cmd[1])
else:
proxy_repr = repr(self.ssh_proxy)
self.logger.debug("Connecting via proxy: {0}".format(proxy_repr))
_socket = self.ssh_proxy
else:
_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if isinstance(_socket, socket.socket):
_socket.settimeout(self.gateway_timeout)
_socket.connect((self.ssh_host, self.ssh_port))
transport = paramiko.Transport(_socket)
transport.set_keepalive(self.set_keepalive)
transport.use_compression(compress=self.compression)
transport.daemon = self.daemon_transport
return transport
def _create_tunnels(self):
"""Create SSH tunnels on top of a transport to the remote gateway."""
if not self.is_active:
try:
self._connect_to_gateway()
except socket.gaierror: # raised by paramiko.Transport
msg = "Could not resolve IP address for {0}, aborting!".format(self.ssh_host)
self.logger.error(msg)
return
except (paramiko.SSHException, socket.error) as e:
template = "Could not connect to gateway {0}:{1} : {2}"
msg = template.format(self.ssh_host, self.ssh_port, e.args[0])
self.logger.error(msg)
return
for (rem, loc) in zip(self._remote_binds, self._local_binds):
try:
self._make_ssh_forward_server(rem, loc)
except BaseSSHTunnelForwarderError as e:
msg = "Problem setting SSH Forwarder up: {0}".format(e.value)
self.logger.error(msg)
@staticmethod
def _get_binds(bind_address, bind_addresses, is_remote=False):
addr_kind = "remote" if is_remote else "local"
if not bind_address and not bind_addresses:
if is_remote:
raise ValueError(
"No {0} bind addresses specified. Use "
"'{0}_bind_address' or '{0}_bind_addresses'"
" argument".format(addr_kind)
)
else:
return []
elif bind_address and bind_addresses:
raise ValueError(
"You can't use both '{0}_bind_address' and "
"'{0}_bind_addresses' arguments. Use one of "
"them.".format(addr_kind)
)
if bind_address:
bind_addresses = [bind_address]
if not is_remote:
# Add random port if missing in local bind
for (i, local_bind) in enumerate(bind_addresses):
if isinstance(local_bind, tuple) and len(local_bind) == 1:
bind_addresses[i] = (local_bind[0], 0)
check_addresses(bind_addresses, is_remote)
return bind_addresses
@staticmethod
def _process_deprecated(attrib, deprecated_attrib, kwargs):
"""Processes optional deprecate arguments."""
if deprecated_attrib not in DEPRECATIONS:
raise ValueError("{0} not included in deprecations list".format(deprecated_attrib))
if deprecated_attrib in kwargs:
warnings.warn(
"'{0}' is DEPRECATED use '{1}' instead".format(
deprecated_attrib, DEPRECATIONS[deprecated_attrib]
),
DeprecationWarning,
)
if attrib:
raise ValueError(
"You can't use both '{0}' and '{1}'. "
"Please only use one of them".format(
deprecated_attrib, DEPRECATIONS[deprecated_attrib]
)
)
else:
return kwargs.pop(deprecated_attrib)
return attrib
@staticmethod
def read_private_key_file(pkey_file, pkey_password=None, key_type=None, logger=log):
"""Get SSH Public key from a private key file, given an optional password.
Arguments:
pkey_file (str):
File containing a private key (RSA, DSS or ECDSA)
Keyword Arguments:
pkey_password (Optional[str]):
Password to decrypt the private key
logger (Optional[logging.Logger])
Return:
paramiko.Pkey
"""
ssh_pkey = None
for pkey_class in (
(key_type,)
if key_type
else (paramiko.RSAKey, paramiko.DSSKey, paramiko.ECDSAKey, paramiko.Ed25519Key,)
):
try:
ssh_pkey = pkey_class.from_private_key_file(pkey_file, password=pkey_password)
logger.debug(
"Private key file ({k0}, {k1}) successfully loaded",
k0=pkey_file,
k1=pkey_class,
)
break
except paramiko.PasswordRequiredException:
logger.error("Password is required for key {k}", k=pkey_file)
break
except paramiko.SSHException:
logger.debug(
"Private key file ({k0}) could not be loaded as type {k1} or bad password",
k0=pkey_file,
k1=pkey_class,
)
return ssh_pkey
def _check_tunnel(self, _srv) -> None:
"""Check if tunnel is already established."""
if self.skip_tunnel_checkup:
self.tunnel_is_up[_srv.local_address] = True
return
self.logger.info("Checking tunnel to: {a}", a=_srv.remote_address)
if isinstance(_srv.local_address, str): # UNIX stream
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
else:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(TUNNEL_TIMEOUT)
try:
# Windows raises WinError 10049 if trying to connect to 0.0.0.0
connect_to = (
("127.0.0.1", _srv.local_port)
if _srv.local_host == "0.0.0.0"
else _srv.local_address
)
s.connect(connect_to)
self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get(timeout=TUNNEL_TIMEOUT * 1.1)
self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address))
except socket.error:
self.logger.debug("Tunnel to {0} is DOWN".format(_srv.remote_address))
self.tunnel_is_up[_srv.local_address] = False
except queue.Empty:
self.logger.debug("Tunnel to {0} is UP".format(_srv.remote_address))
self.tunnel_is_up[_srv.local_address] = True
finally:
s.close()
def check_tunnels(self) -> None:
"""Check that if all tunnels are established and populates.
:attr:`.tunnel_is_up`
"""
for _srv in self._server_list:
self._check_tunnel(_srv)
def start(self) -> None:
"""Start the SSH tunnels."""
if self.is_alive:
self.logger.warning("Already started!")
return
self._create_tunnels()
if not self.is_active:
self._raise(
BaseSSHTunnelForwarderError, reason="Could not establish session to SSH gateway",
)
for _srv in self._server_list:
thread = threading.Thread(
target=self._serve_forever_wrapper,
args=(_srv,),
name="Srv-{0}".format(address_to_str(_srv.local_port)),
)
thread.daemon = self.daemon_forward_servers
thread.start()
self._check_tunnel(_srv)
self.is_alive = any(self.tunnel_is_up.values())
if not self.is_alive:
self._raise(
HandlerSSHTunnelForwarderError, "An error occurred while opening tunnels.",
)
def stop(self) -> None:
"""Shut the tunnel down.
.. note:: This **had** to be handled with care before ``0.1.0``:
- if a port redirection is opened
- the destination is not reachable
- we attempt a connection to that tunnel (``SYN`` is sent and
acknowledged, then a ``FIN`` packet is sent and never
acknowledged... weird)
- we try to shutdown: it will not succeed until ``FIN_WAIT_2`` and
``CLOSE_WAIT`` time out.
.. note::
Handle these scenarios with :attr:`.tunnel_is_up`: if False, server
``shutdown()`` will be skipped on that tunnel
"""
self.logger.info("Closing all open connections...")
opened_address_text = (
", ".join((address_to_str(k.local_address) for k in self._server_list)) or "None"
)
self.logger.debug("Listening tunnels: " + opened_address_text)
self._stop_transport()
self._server_list = [] # reset server list
self.tunnel_is_up = {} # reset tunnel status
def close(self) -> None:
"""Stop the an active tunnel, alias to :meth:`.stop`."""
self.stop()
def restart(self) -> None:
"""Restart connection to the gateway and tunnels."""
self.stop()
self.start()
def _connect_to_gateway(self) -> None:
"""Open connection to SSH gateway.
- First try with all keys loaded from an SSH agent (if allowed)
- Then with those passed directly or read from ~/.ssh/config
- As last resort, try with a provided password
"""
for key in self.ssh_pkeys:
self.logger.debug(
"Trying to log in with key: {0}".format(hexlify(key.get_fingerprint()))
)
try:
self._transport = self._get_transport()
self._transport.connect(
hostkey=self.ssh_host_key, username=self.ssh_username, pkey=key
)
if self._transport.is_alive:
return
except paramiko.AuthenticationException:
self.logger.debug("Authentication error")
self._stop_transport()
if self.ssh_password: # avoid conflict using both pass and pkey
self.logger.debug(
"Trying to log in with password: {0}".format("*" * len(self.ssh_password))
)
try:
self._transport = self._get_transport()
self._transport.connect(
hostkey=self.ssh_host_key,
username=self.ssh_username,
password=self.ssh_password,
)
if self._transport.is_alive:
return
except paramiko.AuthenticationException:
self.logger.debug("Authentication error")
self._stop_transport()
self.logger.error("Could not open connection to gateway")
def _serve_forever_wrapper(self, _srv, poll_interval=0.1) -> None:
"""Wrapper for the server created for a SSH forward."""
self.logger.info(
"Opening tunnel: {0} <> {1}".format(
address_to_str(_srv.local_address), address_to_str(_srv.remote_address)
)
)
_srv.serve_forever(poll_interval) # blocks until finished
self.logger.info(
"Tunnel: {0} <> {1} released".format(
address_to_str(_srv.local_address), address_to_str(_srv.remote_address)
)
)
def _stop_transport(self) -> None:
"""Close the underlying transport when nothing more is needed."""
try:
self._check_is_started()
except (BaseSSHTunnelForwarderError, HandlerSSHTunnelForwarderError) as e:
self.logger.warning(e)
for _srv in self._server_list:
tunnel = _srv.local_address
if self.tunnel_is_up[tunnel]:
self.logger.info("Shutting down tunnel {0}".format(tunnel))
_srv.shutdown()
_srv.server_close()
# clean up the UNIX domain socket if we're using one
if isinstance(_srv, _UnixStreamForwardServer):
try:
os.unlink(_srv.local_address)
except Exception as e:
self.logger.error(
"Unable to unlink socket {0}: {1}".format(self.local_address, repr(e))
)
self.is_alive = False
if self.is_active:
self._transport.close()
self._transport.stop_thread()
self.logger.debug("Transport is closed")
@property
def local_bind_port(self):
# BACKWARDS COMPATIBILITY
self._check_is_started()
if len(self._server_list) != 1:
raise BaseSSHTunnelForwarderError(
"Use .local_bind_ports property for more than one tunnel"
)
return self.local_bind_ports[0]
@property
def local_bind_host(self):
# BACKWARDS COMPATIBILITY
self._check_is_started()
if len(self._server_list) != 1:
raise BaseSSHTunnelForwarderError(
"Use .local_bind_hosts property for more than one tunnel"
)
return self.local_bind_hosts[0]
@property
def local_bind_address(self):
# BACKWARDS COMPATIBILITY
self._check_is_started()
if len(self._server_list) != 1:
raise BaseSSHTunnelForwarderError(
"Use .local_bind_addresses property for more than one tunnel"
)
return self.local_bind_addresses[0]
@property
def local_bind_ports(self):
"""Return a list containing the ports of local side of the TCP tunnels."""
self._check_is_started()
return [
_server.local_port for _server in self._server_list if _server.local_port is not None
]
@property
def local_bind_hosts(self):
"""Return a list containing the IP addresses listening for the tunnels."""
self._check_is_started()
return [
_server.local_host for _server in self._server_list if _server.local_host is not None
]
@property
def local_bind_addresses(self):
"""Return a list of (IP, port) pairs for the local side of the tunnels."""
self._check_is_started()
return [_server.local_address for _server in self._server_list]
@property
def tunnel_bindings(self):
"""Return a dictionary containing the active local<>remote tunnel_bindings."""
return dict(
(_server.remote_address, _server.local_address)
for _server in self._server_list
if self.tunnel_is_up[_server.local_address]
)
@property
def is_active(self) -> bool:
""" Return True if the underlying SSH transport is up """
if "_transport" in self.__dict__ and self._transport.is_active():
return True
return False
def _check_is_started(self) -> None:
if not self.is_active: # underlying transport not alive
msg = "Server is not started. Please .start() first!"
raise BaseSSHTunnelForwarderError(msg)
if not self.is_alive:
msg = "Tunnels are not started. Please .start() first!"
raise HandlerSSHTunnelForwarderError(msg)
def __str__(self) -> str:
credentials = {
"password": self.ssh_password,
"pkeys": [(key.get_name(), hexlify(key.get_fingerprint())) for key in self.ssh_pkeys]
if any(self.ssh_pkeys)
else None,
}
_remove_none_values(credentials)
template = os.linesep.join(
[
"{0} object",
"ssh gateway: {1}:{2}",
"proxy: {3}",
"username: {4}",
"authentication: {5}",
"hostkey: {6}",
"status: {7}started",
"keepalive messages: {8}",
"tunnel connection check: {9}",
"concurrent connections: {10}allowed",
"compression: {11}requested",
"logging level: {12}",
"local binds: {13}",
"remote binds: {14}",
]
)
return template.format(
self.__class__,
self.ssh_host,
self.ssh_port,
self.ssh_proxy.cmd[1] if self.ssh_proxy else "no",
self.ssh_username,
credentials,
self.ssh_host_key if self.ssh_host_key else "not checked",
"" if self.is_alive else "not ",
"disabled" if not self.set_keepalive else "every {0} sec".format(self.set_keepalive),
"disabled" if self.skip_tunnel_checkup else "enabled",
"" if self._threaded else "not ",
"" if self.compression else "not ",
os.environ.get("HYPERGLASS_LOG_LEVEL") or "INFO",
self._local_binds,
self._remote_binds,
)
def __repr__(self) -> str:
return self.__str__()
def __enter__(self) -> "SSHTunnelForwarder":
try:
self.start()
return self
except KeyboardInterrupt:
self.__exit__()
def __exit__(self, *args) -> None:
self._stop_transport()
def open_tunnel(*args, **kwargs) -> "SSHTunnelForwarder":
"""Open an SSH Tunnel, wrapper for :class:`SSHTunnelForwarder`.
Arguments:
destination (Optional[tuple]):
SSH server's IP address and port in the format
(``ssh_address``, ``ssh_port``)
Keyword Arguments:
debug_level (Optional[int or str]):
log level for :class:`logging.Logger` instance, i.e. ``DEBUG``
skip_tunnel_checkup (boolean):
Enable/disable the local side check and populate
:attr:`~SSHTunnelForwarder.tunnel_is_up`
Default: True
.. versionadded:: 0.1.0
block_on_close (boolean):
Wait until all connections are done during close by changing the
value of :attr:`~SSHTunnelForwarder.block_on_close`
Default: True
.. note::
A value of ``debug_level`` set to 1 == ``TRACE`` enables tracing mode
.. note::
See :class:`SSHTunnelForwarder` for keyword arguments
**Example**::
from sshtunnel import open_tunnel
with open_tunnel(SERVER,
ssh_username=SSH_USER,
ssh_port=22,
ssh_password=SSH_PASSWORD,
remote_bind_address=(REMOTE_HOST, REMOTE_PORT),
local_bind_address=('', LOCAL_PORT)) as server:
def do_something(port):
pass
print("LOCAL PORTS:", server.local_bind_port)
do_something(server.local_bind_port)
"""
# Attach a console handler to the logger or create one if not passed
kwargs["logger"] = kwargs.get("logger") or log
ssh_address_or_host = kwargs.pop("ssh_address_or_host", None)
# Check if deprecated arguments ssh_address or ssh_host were used
for deprecated_argument in ["ssh_address", "ssh_host"]:
ssh_address_or_host = SSHTunnelForwarder._process_deprecated(
ssh_address_or_host, deprecated_argument, kwargs
)
ssh_port = kwargs.pop("ssh_port", 22)
skip_tunnel_checkup = kwargs.pop("skip_tunnel_checkup", True)
block_on_close = kwargs.pop("block_on_close", _DAEMON)
if not args:
if isinstance(ssh_address_or_host, tuple):
args = (ssh_address_or_host,)
else:
args = ((ssh_address_or_host, ssh_port),)
forwarder = SSHTunnelForwarder(*args, **kwargs)
forwarder.skip_tunnel_checkup = skip_tunnel_checkup
forwarder.daemon_forward_servers = not block_on_close
forwarder.daemon_transport = not block_on_close
return forwarder
def _bindlist(input_str):
"""Define type of data expected for remote and local bind address lists.
Returns a tuple (ip_address, port) whose elements are (str, int)
"""
try:
ip_port = input_str.split(":")
if len(ip_port) == 1:
_ip = ip_port[0]
_port = None
else:
(_ip, _port) = ip_port
if not _ip and not _port:
raise AssertionError
elif not _port:
_port = "22" # default port if not given
return _ip, int(_port)
except ValueError:
raise argparse.ArgumentTypeError("Address tuple must be of type IP_ADDRESS:PORT")
except AssertionError:
raise argparse.ArgumentTypeError("Both IP:PORT can't be missing!")