Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions asyncssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@

from .config import ConfigParseError

from .forward import SSHForwarder
from .forward import SSHForwarder, SSHForwardTracker

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need to expose SSHForwardTracker as a public symbol. Callers should always base their trackers on either the Port or Path version, depending on whether the non-tunneled traffic is over TCP or UNIX domain sockets.

from .forward import SSHPortForwardTracker, SSHPathForwardTracker

from .connection import SSHAcceptor, SSHClientConnection, SSHServerConnection
from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions
Expand Down Expand Up @@ -147,7 +148,8 @@
'SSHAgentKeyPair', 'SSHAuthorizedKeys', 'SSHCertificate', 'SSHClient',
'SSHClientChannel', 'SSHClientConnection', 'SSHClientConnectionOptions',
'SSHClientProcess', 'SSHClientSession', 'SSHCompletedProcess',
'SSHForwarder', 'SSHKey', 'SSHKeyPair', 'SSHKnownHosts',
'SSHForwarder', 'SSHForwardTracker', 'SSHPortForwardTracker',

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

'SSHPathForwardTracker', 'SSHKey', 'SSHKeyPair', 'SSHKnownHosts',
'SSHLineEditorChannel', 'SSHListener', 'SSHReader', 'SSHServer',
'SSHServerChannel', 'SSHServerConnection',
'SSHServerConnectionOptions', 'SSHServerProcess',
Expand Down
77 changes: 55 additions & 22 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
from .encryption import encryption_needs_mac
from .encryption import get_encryption_params, get_encryption

from .forward import SSHForwarder
from .forward import SSHForwarder, SSHForwardTrackerFactory

from .gss import GSSBase, GSSClient, GSSServer, GSSError

Expand Down Expand Up @@ -3210,7 +3210,9 @@ async def forward_unix_connection(self, dest_path: str) -> SSHForwarder:
async def forward_local_port(
self, listen_host: str, listen_port: int,
dest_host: str, dest_port: int,
accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener:
accept_handler: Optional[SSHAcceptHandler] = None,
tracker_factory:
Optional[SSHForwardTrackerFactory] = None) -> SSHListener:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll probably need factory types SSHPortForwardTrackerFactory and SSHPathForwardTrackerFactory instead of a single SSHForwardTrackerFactory. Which one you use will be determined by whether the listener is TCP or UNIX, which we'll generally know by what function is being called to set up the forwarding. For instance, forward_local_port() would need the Port version of the tracker factory.

"""Set up local port forwarding

This method is a coroutine which attempts to set up port
Expand All @@ -3233,11 +3235,18 @@ async def forward_local_port(
or not to allow connection forwarding, returning `True` to
accept the connection and begin forwarding or `False` to
reject and close it.
:param tracker_factory:
An optional callable invoked once per accepted connection
which returns a new :class:`SSHPortForwardTracker` (or
:class:`SSHForwardTracker` subclass) for observing that

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need the "or" here on SSHForwardTracker. The tracker_factory should be an SSHPortForwardTracker, or a subclass of that.

connection's lifecycle. `None` (default) disables tracking
with no overhead.
:type listen_host: `str`
:type listen_port: `int`
:type dest_host: `str`
:type dest_port: `int`
:type accept_handler: `callable` or coroutine
:type tracker_factory: `callable` or `None`

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For arguments which are optional because they default to None, I generally don't call out None in the "type" description. The argument would just be left out if a caller didn't want tracker, so there's no reason for them to pass in None explicitly.

We could also consider using the factory type name as the value under "type", since we have one defined.


:returns: :class:`SSHListener`

Expand Down Expand Up @@ -3278,10 +3287,9 @@ async def tunnel_connection(
(dest_host, dest_port))

try:
listener = await create_tcp_forward_listener(self, self._loop,
tunnel_connection,
listen_host,
listen_port)
listener = await create_tcp_forward_listener(
self, self._loop, tunnel_connection, listen_host, listen_port,
tracker_factory=tracker_factory)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to make this a keyword argument. It should be fine to keep it positional.

except OSError as exc:
self.logger.debug1('Failed to create local TCP listener: %s', exc)
raise
Expand All @@ -3297,8 +3305,10 @@ async def tunnel_connection(
return listener

@async_context_manager
async def forward_local_path(self, listen_path: str,
dest_path: str) -> SSHListener:
async def forward_local_path(

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same changes as above, but with the Path version of the tracker.

self, listen_path: str, dest_path: str,
tracker_factory:
Optional[SSHForwardTrackerFactory] = None) -> SSHListener:
"""Set up local UNIX domain socket forwarding

This method is a coroutine which attempts to set up UNIX domain
Expand All @@ -3311,8 +3321,15 @@ async def forward_local_path(self, listen_path: str,
The path on the local host to listen on
:param dest_path:
The path on the remote host to forward the connections to
:param tracker_factory:
An optional callable invoked once per accepted connection
which returns a new :class:`SSHPathForwardTracker` (or
:class:`SSHForwardTracker` subclass) for observing that
connection's lifecycle. `None` (default) disables tracking
with no overhead.
:type listen_path: `str`
:type dest_path: `str`
:type tracker_factory: `callable` or `None`

:returns: :class:`SSHListener`

Expand All @@ -3332,9 +3349,9 @@ async def tunnel_connection(
listen_path, dest_path)

try:
listener = await create_unix_forward_listener(self, self._loop,
tunnel_connection,
listen_path)
listener = await create_unix_forward_listener(
self, self._loop, tunnel_connection, listen_path,
tracker_factory=tracker_factory)
except OSError as exc:
self.logger.debug1('Failed to create local UNIX listener: %s', exc)
raise
Expand Down Expand Up @@ -5304,7 +5321,9 @@ async def open_tap(self, *args: object, **kwargs: object) -> \
@async_context_manager
async def forward_local_port_to_path(

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be the Port version.

self, listen_host: str, listen_port: int, dest_path: str,
accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener:
accept_handler: Optional[SSHAcceptHandler] = None,
tracker_factory:
Optional[SSHForwardTrackerFactory] = None) -> SSHListener:
"""Set up local TCP port forwarding to a remote UNIX domain socket

This method is a coroutine which attempts to set up port
Expand All @@ -5325,10 +5344,17 @@ async def forward_local_port_to_path(
or not to allow connection forwarding, returning `True` to
accept the connection and begin forwarding or `False` to
reject and close it.
:param tracker_factory:
An optional callable invoked once per accepted connection
which returns a new :class:`SSHPortForwardTracker` (or
:class:`SSHForwardTracker` subclass) for observing that
connection's lifecycle. `None` (default) disables tracking
with no overhead.
:type listen_host: `str`
:type listen_port: `int`
:type dest_path: `str`
:type accept_handler: `callable` or coroutine
:type tracker_factory: `callable` or `None`

:returns: :class:`SSHListener`

Expand Down Expand Up @@ -5362,10 +5388,9 @@ async def tunnel_connection(
(listen_host, listen_port), dest_path)

try:
listener = await create_tcp_forward_listener(self, self._loop,
tunnel_connection,
listen_host,
listen_port)
listener = await create_tcp_forward_listener(
self, self._loop, tunnel_connection, listen_host, listen_port,
tracker_factory=tracker_factory)
except OSError as exc:
self.logger.debug1('Failed to create local TCP listener: %s', exc)
raise
Expand All @@ -5378,9 +5403,10 @@ async def tunnel_connection(
return listener

@async_context_manager
async def forward_local_path_to_port(self, listen_path: str,
dest_host: str,
dest_port: int) -> SSHListener:
async def forward_local_path_to_port(

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...and this would be the Path one.

self, listen_path: str, dest_host: str, dest_port: int,
tracker_factory:
Optional[SSHForwardTrackerFactory] = None) -> SSHListener:
"""Set up local UNIX domain socket forwarding to a remote TCP port

This method is a coroutine which attempts to set up UNIX domain
Expand All @@ -5395,9 +5421,16 @@ async def forward_local_path_to_port(self, listen_path: str,
The hostname or address to forward the connections to
:param dest_port:
The port number to forward the connections to
:param tracker_factory:
An optional callable invoked once per accepted connection
which returns a new :class:`SSHPathForwardTracker` (or
:class:`SSHForwardTracker` subclass) for observing that
connection's lifecycle. `None` (default) disables tracking
with no overhead.
:type listen_path: `str`
:type dest_host: `str`
:type dest_port: `int`
:type tracker_factory: `callable` or `None`

:returns: :class:`SSHListener`

Expand All @@ -5417,9 +5450,9 @@ async def tunnel_connection(
listen_path, (dest_host, dest_port))

try:
listener = await create_unix_forward_listener(self, self._loop,
tunnel_connection,
listen_path)
listener = await create_unix_forward_listener(
self, self._loop, tunnel_connection, listen_path,
tracker_factory=tracker_factory)
except OSError as exc:
self.logger.debug1('Failed to create local UNIX listener: %s', exc)
raise
Expand Down
Loading