Source code for asyncssh.listener

# Copyright (c) 2013-2024 by Ron Frederick <ronf@timeheart.net> and others.
#
# This program and the accompanying materials are made available under
# the terms of the Eclipse Public License v2.0 which accompanies this
# distribution and is available at:
#
#     http://www.eclipse.org/legal/epl-2.0/
#
# This program may also be made available under the following secondary
# licenses when the conditions for such availability set forth in the
# Eclipse Public License v2.0 are satisfied:
#
#    GNU General Public License, Version 2.0, or any later versions of
#    that license
#
# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
#
# Contributors:
#     Ron Frederick - initial implementation, API, and documentation

"""SSH listeners"""

import asyncio
import errno
import socket
from types import TracebackType
from typing import TYPE_CHECKING, AnyStr, Callable, Generic, List, Optional
from typing import Sequence, Set, Tuple, Type, Union
from typing_extensions import Self

from .forward import SSHForwarderCoro
from .forward import SSHLocalPortForwarder, SSHLocalPathForwarder
from .misc import HostPort, MaybeAwait
from .session import SSHTCPSession, SSHUNIXSession
from .socks import SSHSOCKSForwarder


if TYPE_CHECKING:
    # pylint: disable=cyclic-import
    from .channel import SSHTCPChannel, SSHUNIXChannel
    from .connection import SSHConnection, SSHClientConnection


_LocalListenerFactory = Callable[[], asyncio.BaseProtocol]

ListenKey = Union[HostPort, str]

TCPListenerFactory = Callable[[str, int], MaybeAwait[SSHTCPSession[AnyStr]]]
UNIXListenerFactory = Callable[[], MaybeAwait[SSHUNIXSession[AnyStr]]]


[docs] class SSHListener: """SSH listener for inbound connections""" def __init__(self) -> None: self._tunnel: Optional['SSHConnection'] = None async def __aenter__(self) -> Self: return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _exc_value: Optional[BaseException], _traceback: Optional[TracebackType]) -> bool: self.close() await self.wait_closed() return False
[docs] def get_port(self) -> int: """Return the port number being listened on This method returns the port number that the remote listener was bound to. When the requested remote listening port is `0` to indicate a dynamic port, this method can be called to determine what listening port was selected. This function only applies to TCP listeners. :returns: The port number being listened on """ # pylint: disable=no-self-use return 0
def set_tunnel(self, tunnel: 'SSHConnection') -> None: """Set tunnel associated with listener""" self._tunnel = tunnel
[docs] def close(self) -> None: """Stop listening for new connections This method can be called to stop listening for connections. Existing connections will remain open. """ if self._tunnel: self._tunnel.close()
[docs] async def wait_closed(self) -> None: """Wait for the listener to close This method is a coroutine which waits for the associated listeners to be closed. """ if self._tunnel: await self._tunnel.wait_closed() self._tunnel = None
class SSHClientListener(SSHListener): """Client listener used to accept inbound forwarded connections""" def __init__(self, conn: 'SSHClientConnection', encoding: Optional[str], errors: str, window: int, max_pktsize: int): super().__init__() self._conn: Optional['SSHClientConnection'] = conn self._encoding = encoding self._errors = errors self._window = window self._max_pktsize = max_pktsize self._close_event = asyncio.Event() async def _close(self) -> None: """Close this listener""" self._close_event.set() self._conn = None def close(self) -> None: """Close this listener asynchronously""" super().close() if self._conn: self._conn.create_task(self._close()) async def wait_closed(self) -> None: """Wait for this listener to finish closing""" await super().wait_closed() await self._close_event.wait() class SSHTCPClientListener(SSHClientListener, Generic[AnyStr]): """Client listener used to accept inbound forwarded TCP connections""" def __init__(self, conn: 'SSHClientConnection', session_factory: TCPListenerFactory[AnyStr], listen_host: str, listen_port: int, encoding: Optional[str], errors: str, window: int, max_pktsize: int): super().__init__(conn, encoding, errors, window, max_pktsize) self._session_factory: TCPListenerFactory[AnyStr] = session_factory self._listen_host = listen_host self._listen_port = listen_port async def _close(self) -> None: """Close this listener""" if self._conn: # pragma: no branch await self._conn.close_client_tcp_listener(self._listen_host, self._listen_port) await super()._close() def process_connection(self, orig_host: str, orig_port: int) -> \ Tuple['SSHTCPChannel[AnyStr]', MaybeAwait[SSHTCPSession[AnyStr]]]: """Process a forwarded TCP connection""" assert self._conn is not None chan = self._conn.create_tcp_channel(self._encoding, self._errors, self._window, self._max_pktsize) chan.set_inbound_peer_names(self._listen_host, self._listen_port, orig_host, orig_port) return chan, self._session_factory(orig_host, orig_port) def get_addresses(self) -> List[Tuple]: """Return the socket addresses being listened on""" return [(self._listen_host, self._listen_port)] def get_port(self) -> int: """Return the port number being listened on""" return self._listen_port class SSHUNIXClientListener(SSHClientListener, Generic[AnyStr]): """Client listener used to accept inbound forwarded UNIX connections""" def __init__(self, conn: 'SSHClientConnection', session_factory: UNIXListenerFactory[AnyStr], listen_path: str, encoding: Optional[str], errors: str, window: int, max_pktsize: int): super().__init__(conn, encoding, errors, window, max_pktsize) self._session_factory: UNIXListenerFactory[AnyStr] = session_factory self._listen_path = listen_path async def _close(self) -> None: """Close this listener""" if self._conn: # pragma: no branch await self._conn.close_client_unix_listener(self._listen_path) await super()._close() def process_connection(self) -> \ Tuple['SSHUNIXChannel[AnyStr]', MaybeAwait[SSHUNIXSession[AnyStr]]]: """Process a forwarded UNIX connection""" assert self._conn is not None chan = self._conn.create_unix_channel(self._encoding, self._errors, self._window, self._max_pktsize) chan.set_inbound_peer_names(self._listen_path) return chan, self._session_factory() class SSHForwardListener(SSHListener): """A listener used when forwarding traffic from local ports""" def __init__(self, conn: 'SSHConnection', servers: Sequence[asyncio.AbstractServer], listen_key: ListenKey, listen_port: int = 0): super().__init__() self._conn: Optional['SSHConnection'] = conn self._servers = servers self._listen_key = listen_key self._listen_port = listen_port def get_port(self) -> int: """Return the port number being listened on""" return self._listen_port def close(self) -> None: """Close this listener""" if self._conn: # pragma: no branch self._conn.close_forward_listener(self._listen_key) for server in self._servers: server.close() self._conn = None async def wait_closed(self) -> None: """Wait for this listener to finish closing""" await super().wait_closed() for server in self._servers: await server.wait_closed() self._servers = [] async def create_tcp_local_listener( conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, protocol_factory: _LocalListenerFactory, listen_host: str, listen_port: int) -> 'SSHForwardListener': """Create a listener to forward traffic from a local TCP port over SSH""" addrinfo = await loop.getaddrinfo(listen_host or None, listen_port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM, flags=socket.AI_PASSIVE) if not addrinfo: # pragma: no cover raise OSError('getaddrinfo() returned empty list') seen_addrinfo: Set[Tuple] = set() servers: List[asyncio.AbstractServer] = [] for addrinfo_entry in addrinfo: # Work around an issue where getaddrinfo() on some systems may # return duplicate results, causing bind to fail. if addrinfo_entry in seen_addrinfo: # pragma: no cover continue seen_addrinfo.add(addrinfo_entry) family, socktype, proto, _, sa = addrinfo_entry try: sock = socket.socket(family, socktype, proto) except OSError: # pragma: no cover continue sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) if family == socket.AF_INET6: try: sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True) except AttributeError: # pragma: no cover pass if sa[1] == 0: sa = sa[:1] + (listen_port,) + sa[2:] # type: ignore try: sock.bind(sa) except (OSError, OverflowError) as exc: sock.close() for server in servers: server.close() if isinstance(exc, OverflowError): # pragma: no cover exc.errno = errno.EOVERFLOW # type: ignore exc.strerror = str(exc) # type: ignore # pylint: disable=no-member raise OSError(exc.errno, f'error while attempting ' # type: ignore f'to bind on address {sa!r}: ' f'{exc.strerror}') from None # type: ignore if listen_port == 0: listen_port = sock.getsockname()[1] conn.logger.debug1('Assigning dynamic port %d', listen_port) server = await loop.create_server(protocol_factory, sock=sock) servers.append(server) listen_key = listen_host or '', listen_port return SSHForwardListener(conn, servers, listen_key, listen_port) async def create_tcp_forward_listener(conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, coro: SSHForwarderCoro, listen_host: str, listen_port: int) -> \ 'SSHForwardListener': """Create a listener to forward traffic from a local TCP port over SSH""" def protocol_factory() -> asyncio.BaseProtocol: """Start a port forwarder for each new local connection""" return SSHLocalPortForwarder(conn, coro) return await create_tcp_local_listener(conn, loop, protocol_factory, listen_host, listen_port) async def create_unix_forward_listener(conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, coro: SSHForwarderCoro, listen_path: str) -> \ 'SSHForwardListener': """Create a listener to forward a local UNIX domain socket over SSH""" def protocol_factory() -> asyncio.BaseProtocol: """Start a path forwarder for each new local connection""" return SSHLocalPathForwarder(conn, coro) server = await loop.create_unix_server(protocol_factory, listen_path) return SSHForwardListener(conn, [server], listen_path) async def create_socks_listener(conn: 'SSHConnection', loop: asyncio.AbstractEventLoop, coro: SSHForwarderCoro, listen_host: str, listen_port: int) -> SSHForwardListener: """Create a SOCKS listener to forward traffic over SSH""" def protocol_factory() -> asyncio.BaseProtocol: """Start a port forwarder for each new SOCKS connection""" return SSHSOCKSForwarder(conn, coro) return await create_tcp_local_listener(conn, loop, protocol_factory, listen_host, listen_port)