Skip to content

Implement highlevel unix socket listeners #3187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions newsfragments/279.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``trio.open_unix_listener``, ``trio.serve_unix``, and ``trio.UnixSocketListener`` to support ``SOCK_STREAM`` `Unix domain sockets <https://en.wikipedia.org/wiki/Unix_domain_socket>`__
4 changes: 4 additions & 0 deletions src/trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
serve_tcp as serve_tcp,
)
from ._highlevel_open_tcp_stream import open_tcp_stream as open_tcp_stream
from ._highlevel_open_unix_listeners import (
open_unix_listener as open_unix_listener,
serve_unix as serve_unix,
)
from ._highlevel_open_unix_stream import open_unix_socket as open_unix_socket
from ._highlevel_serve_listeners import serve_listeners as serve_listeners
from ._highlevel_socket import (
Expand Down
132 changes: 132 additions & 0 deletions src/trio/_highlevel_open_unix_listeners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING

import trio
import trio.socket as tsocket
from trio import TaskStatus

from ._highlevel_open_tcp_listeners import _compute_backlog

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable


try:
from trio.socket import AF_UNIX

HAS_UNIX = True
except ImportError:
HAS_UNIX = False


async def open_unix_listener(
path: str | bytes | os.PathLike[str] | os.PathLike[bytes],
*,
mode: int | None = None,
backlog: int | None = None,
) -> trio.SocketListener:
"""Create :class:`SocketListener` objects to listen for connections.
Opens a connection to the specified
`Unix domain socket <https://en.wikipedia.org/wiki/Unix_domain_socket>`__.

You must have read/write permission on the specified file to connect.

Args:

path (str): Filename of UNIX socket to create and listen on.
Absolute or relative paths may be used.

mode (int or None): The socket file permissions.
UNIX permissions are usually specified in octal numbers. If
you leave this as ``None``, Trio will not change the mode from
the operating system's default.

backlog (int or None): The listen backlog to use. If you leave this as
``None`` then Trio will pick a good default. (Currently:
whatever your system has configured as the maximum backlog.)

Returns:
:class:`UnixSocketListener`

Raises:
:class:`ValueError` If invalid arguments.
:class:`RuntimeError`: If AF_UNIX sockets are not supported.
:class:`FileNotFoundError`: If folder socket file is to be created in does not exist.
"""
if not HAS_UNIX:
raise RuntimeError("Unix sockets are not supported on this platform")

computed_backlog = _compute_backlog(backlog)

fspath = await trio.Path(os.fsdecode(path)).absolute()

folder = fspath.parent
if not await folder.exists():
raise FileNotFoundError(f"Socket folder does not exist: {folder!r}")

str_path = str(fspath)

# much more simplified logic vs tcp sockets - one socket family and only one
# possible location to connect to
sock = tsocket.socket(AF_UNIX, tsocket.SOCK_STREAM)
try:
await sock.bind(str_path)

if mode is not None:
await fspath.chmod(mode)

sock.listen(computed_backlog)

return trio.SocketListener(sock, str_path)
except BaseException:
sock.close()
os.unlink(str_path)
raise


async def serve_unix(
handler: Callable[[trio.SocketStream], Awaitable[object]],
path: str | bytes | os.PathLike[str] | os.PathLike[bytes],
*,
backlog: int | None = None,
handler_nursery: trio.Nursery | None = None,
task_status: TaskStatus[list[trio.UnixSocketListener]] = trio.TASK_STATUS_IGNORED,

Check failure on line 95 in src/trio/_highlevel_open_unix_listeners.py

View workflow job for this annotation

GitHub Actions / Ubuntu (3.13, check formatting)

Mypy-Linux+Mac+Windows

src/trio/_highlevel_open_unix_listeners.py:95: Name "trio.UnixSocketListener" is not defined [name-defined]
) -> None:
"""Listen for incoming UNIX connections, and for each one start a task
running ``handler(stream)``.
This is a thin convenience wrapper around :func:`open_unix_listener` and
:func:`serve_listeners` – see them for full details.
.. warning::
If ``handler`` raises an exception, then this function doesn't do
anything special to catch it – so by default the exception will
propagate out and crash your server. If you don't want this, then catch
exceptions inside your ``handler``, or use a ``handler_nursery`` object
that responds to exceptions in some other way.
When used with ``nursery.start`` you get back the newly opened listeners.
Args:
handler: The handler to start for each incoming connection. Passed to
:func:`serve_listeners`.
path: The socket file name.
Passed to :func:`open_unix_listener`.
backlog: The listen backlog, or None to have a good default picked.
Passed to :func:`open_tcp_listener`.
handler_nursery: The nursery to start handlers in, or None to use an
internal nursery. Passed to :func:`serve_listeners`.
task_status: This function can be used with ``nursery.start``.
Returns:
This function only returns when cancelled.
Raises:
RuntimeError: If AF_UNIX sockets are not supported.
"""
if not HAS_UNIX:
raise RuntimeError("Unix sockets are not supported on this platform")

listener = await open_unix_listener(path, backlog=backlog)
await trio.serve_listeners(
handler,
[listener],
handler_nursery=handler_nursery,
task_status=task_status,
)
30 changes: 28 additions & 2 deletions src/trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import errno
from contextlib import contextmanager, suppress
from typing import TYPE_CHECKING, overload
from os import PathLike, stat, unlink
from stat import S_ISSOCK
from typing import TYPE_CHECKING, Final, overload

import trio

Expand Down Expand Up @@ -31,6 +33,8 @@
errno.ENOTSOCK,
}

HAS_UNIX: Final = hasattr(tsocket, "AF_UNIX")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is cleaner than the definition in _highlevel_open_unix_listeners.py



@contextmanager
def _translate_socket_errors_to_stream_errors() -> Generator[None, None, None]:
Expand Down Expand Up @@ -68,6 +72,8 @@ class SocketStream(HalfCloseableStream):

"""

__slots__ = ("_send_conflict_detector", "socket")

def __init__(self, socket: SocketType) -> None:
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketStream requires a Trio socket object")
Expand Down Expand Up @@ -352,19 +358,34 @@ class SocketListener(Listener[SocketStream]):
incoming connections as :class:`SocketStream` objects.

Args:

socket: The Trio socket object to wrap. Must have type ``SOCK_STREAM``,
and be listening.

path: Used for keeping track of which path a Unix socket is bound
to. If not ``None``, :meth:`aclose` will unlink this path.
File must have socket mode flag set.

Note that the :class:`SocketListener` "takes ownership" of the given
socket; closing the :class:`SocketListener` will also close the socket.

.. attribute:: socket

The Trio socket object that this stream wraps.

.. attribute:: path

The path to unlink in :meth:`aclose` that a Unix socket is bound to.

"""

def __init__(self, socket: SocketType) -> None:
__slots__ = ("path", "socket")

def __init__(
self,
socket: SocketType,
path: str | bytes | PathLike[str] | PathLike[bytes] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I preferred how the old version didn't require this path argument. Is this really required? I assume we can just check whether a socket is unix type with socket.family == getattr(tsocket, "AF_UNIX", None) and have the old logic to get the path.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was looking deeper into things noted in #279 and saw this comment: #279 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, but I missed something, we only have to care about that if we are renaming socket files, which this implementation does not do, so it would be fine to do the original thing.

) -> None:
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketListener requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
Expand All @@ -377,8 +398,11 @@ def __init__(self, socket: SocketType) -> None:
else:
if not listening:
raise ValueError("SocketListener requires a listening socket")
if path is not None and not S_ISSOCK(stat(path).st_mode):
raise ValueError("Specified path must be a Unix socket file")

self.socket = socket
self.path = path

async def accept(self) -> SocketStream:
"""Accept an incoming connection.
Expand Down Expand Up @@ -411,4 +435,6 @@ async def accept(self) -> SocketStream:
async def aclose(self) -> None:
"""Close this listener and its underlying socket."""
self.socket.close()
if self.path is not None:
unlink(self.path)
await trio.lowlevel.checkpoint()
Loading