Skip to content

Use OwnedFd/OwnedSocket in Socket #600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 12 additions & 28 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,47 +73,31 @@ use crate::{MaybeUninitSlice, MsgHdr, RecvFlags};
/// # Ok(()) }
/// ```
pub struct Socket {
inner: Inner,
inner: sys::Socket,
}

/// Store a `TcpStream` internally to take advantage of its niche optimizations on Unix platforms.
pub(crate) type Inner = std::net::TcpStream;

impl Socket {
/// # Safety
///
/// The caller must ensure `raw` is a valid file descriptor/socket. NOTE:
/// this should really be marked `unsafe`, but this being an internal
/// function, often passed as mapping function, it's makes it very
/// inconvenient to mark it as `unsafe`.
pub(crate) fn from_raw(raw: sys::Socket) -> Socket {
pub(crate) fn from_raw(raw: sys::RawSocket) -> Socket {
Socket {
inner: unsafe {
// SAFETY: the caller must ensure that `raw` is a valid file
// descriptor, but when it isn't it could return I/O errors, or
// potentially close a fd it doesn't own. All of that isn't
// memory unsafe, so it's not desired but never memory unsafe or
// causes UB.
//
// However there is one exception. We use `TcpStream` to
// represent the `Socket` internally (see `Inner` type),
// `TcpStream` has a layout optimisation that doesn't allow for
// negative file descriptors (as those are always invalid).
// Violating this assumption (fd never negative) causes UB,
// something we don't want. So check for that we have this
// `assert!`.
#[cfg(unix)]
assert!(raw >= 0, "tried to create a `Socket` with an invalid fd");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there is a test for this assertion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the test with the error message from std lib, don't think there is value is checking it twice.

sys::socket_from_raw(raw)
},
}
}

pub(crate) fn as_raw(&self) -> sys::Socket {
// SAFETY: the caller must ensure that `raw` is a valid file
// descriptor, but when it isn't it could return I/O errors, or
// potentially close a fd it doesn't own. All of that isn't memory
// unsafe, so it's not desired but never memory unsafe or causes UB.
Comment on lines +90 to +91
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not immediate UB, but we consider it library UB because downstream effects can cause memory unsafety, e.g. by incorrectly touching mmap'ed fds.

https://doc.rust-lang.org/std/io/index.html#io-safety

inner: unsafe { sys::socket_from_raw(raw) },
}
}

pub(crate) fn as_raw(&self) -> sys::RawSocket {
sys::socket_as_raw(&self.inner)
}

pub(crate) fn into_raw(self) -> sys::Socket {
pub(crate) fn into_raw(self) -> sys::RawSocket {
sys::socket_into_raw(self.inner)
}

Expand Down
92 changes: 51 additions & 41 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -859,35 +859,36 @@ impl SockAddr {
}
}

pub(crate) type Socket = c_int;
pub(crate) type Socket = std::os::fd::OwnedFd;
pub(crate) type RawSocket = c_int;

pub(crate) unsafe fn socket_from_raw(socket: Socket) -> crate::socket::Inner {
crate::socket::Inner::from_raw_fd(socket)
pub(crate) unsafe fn socket_from_raw(socket: RawSocket) -> Socket {
Socket::from_raw_fd(socket)
}

pub(crate) fn socket_as_raw(socket: &crate::socket::Inner) -> Socket {
pub(crate) fn socket_as_raw(socket: &Socket) -> RawSocket {
socket.as_raw_fd()
}

pub(crate) fn socket_into_raw(socket: crate::socket::Inner) -> Socket {
pub(crate) fn socket_into_raw(socket: Socket) -> RawSocket {
socket.into_raw_fd()
}

pub(crate) fn socket(family: c_int, ty: c_int, protocol: c_int) -> io::Result<Socket> {
pub(crate) fn socket(family: c_int, ty: c_int, protocol: c_int) -> io::Result<RawSocket> {
syscall!(socket(family, ty, protocol))
}

#[cfg(all(feature = "all", unix))]
pub(crate) fn socketpair(family: c_int, ty: c_int, protocol: c_int) -> io::Result<[Socket; 2]> {
pub(crate) fn socketpair(family: c_int, ty: c_int, protocol: c_int) -> io::Result<[RawSocket; 2]> {
let mut fds = [0, 0];
syscall!(socketpair(family, ty, protocol, fds.as_mut_ptr())).map(|_| fds)
}

pub(crate) fn bind(fd: Socket, addr: &SockAddr) -> io::Result<()> {
pub(crate) fn bind(fd: RawSocket, addr: &SockAddr) -> io::Result<()> {
syscall!(bind(fd, addr.as_ptr().cast::<sockaddr>(), addr.len() as _)).map(|_| ())
}

pub(crate) fn connect(fd: Socket, addr: &SockAddr) -> io::Result<()> {
pub(crate) fn connect(fd: RawSocket, addr: &SockAddr) -> io::Result<()> {
syscall!(connect(fd, addr.as_ptr().cast::<sockaddr>(), addr.len())).map(|_| ())
}

Expand Down Expand Up @@ -933,46 +934,46 @@ pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Res
}
}

pub(crate) fn listen(fd: Socket, backlog: c_int) -> io::Result<()> {
pub(crate) fn listen(fd: RawSocket, backlog: c_int) -> io::Result<()> {
syscall!(listen(fd, backlog)).map(|_| ())
}

pub(crate) fn accept(fd: Socket) -> io::Result<(Socket, SockAddr)> {
pub(crate) fn accept(fd: RawSocket) -> io::Result<(RawSocket, SockAddr)> {
// Safety: `accept` initialises the `SockAddr` for us.
unsafe { SockAddr::try_init(|storage, len| syscall!(accept(fd, storage.cast(), len))) }
}

pub(crate) fn getsockname(fd: Socket) -> io::Result<SockAddr> {
pub(crate) fn getsockname(fd: RawSocket) -> io::Result<SockAddr> {
// Safety: `accept` initialises the `SockAddr` for us.
unsafe { SockAddr::try_init(|storage, len| syscall!(getsockname(fd, storage.cast(), len))) }
.map(|(_, addr)| addr)
}

pub(crate) fn getpeername(fd: Socket) -> io::Result<SockAddr> {
pub(crate) fn getpeername(fd: RawSocket) -> io::Result<SockAddr> {
// Safety: `accept` initialises the `SockAddr` for us.
unsafe { SockAddr::try_init(|storage, len| syscall!(getpeername(fd, storage.cast(), len))) }
.map(|(_, addr)| addr)
}

pub(crate) fn try_clone(fd: Socket) -> io::Result<Socket> {
pub(crate) fn try_clone(fd: RawSocket) -> io::Result<RawSocket> {
syscall!(fcntl(fd, libc::F_DUPFD_CLOEXEC, 0))
}

#[cfg(all(feature = "all", unix, not(target_os = "vita")))]
pub(crate) fn nonblocking(fd: Socket) -> io::Result<bool> {
pub(crate) fn nonblocking(fd: RawSocket) -> io::Result<bool> {
let file_status_flags = fcntl_get(fd, libc::F_GETFL)?;
Ok((file_status_flags & libc::O_NONBLOCK) != 0)
}

#[cfg(all(feature = "all", target_os = "vita"))]
pub(crate) fn nonblocking(fd: Socket) -> io::Result<bool> {
pub(crate) fn nonblocking(fd: RawSocket) -> io::Result<bool> {
unsafe {
getsockopt::<Bool>(fd, libc::SOL_SOCKET, libc::SO_NONBLOCK).map(|non_block| non_block != 0)
}
}

#[cfg(not(target_os = "vita"))]
pub(crate) fn set_nonblocking(fd: Socket, nonblocking: bool) -> io::Result<()> {
pub(crate) fn set_nonblocking(fd: RawSocket, nonblocking: bool) -> io::Result<()> {
if nonblocking {
fcntl_add(fd, libc::F_GETFL, libc::F_SETFL, libc::O_NONBLOCK)
} else {
Expand All @@ -981,7 +982,7 @@ pub(crate) fn set_nonblocking(fd: Socket, nonblocking: bool) -> io::Result<()> {
}

#[cfg(target_os = "vita")]
pub(crate) fn set_nonblocking(fd: Socket, nonblocking: bool) -> io::Result<()> {
pub(crate) fn set_nonblocking(fd: RawSocket, nonblocking: bool) -> io::Result<()> {
unsafe {
setsockopt(
fd,
Expand All @@ -992,7 +993,7 @@ pub(crate) fn set_nonblocking(fd: Socket, nonblocking: bool) -> io::Result<()> {
}
}

pub(crate) fn shutdown(fd: Socket, how: Shutdown) -> io::Result<()> {
pub(crate) fn shutdown(fd: RawSocket, how: Shutdown) -> io::Result<()> {
let how = match how {
Shutdown::Write => libc::SHUT_WR,
Shutdown::Read => libc::SHUT_RD,
Expand All @@ -1001,7 +1002,7 @@ pub(crate) fn shutdown(fd: Socket, how: Shutdown) -> io::Result<()> {
syscall!(shutdown(fd, how)).map(|_| ())
}

pub(crate) fn recv(fd: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
pub(crate) fn recv(fd: RawSocket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io::Result<usize> {
syscall!(recv(
fd,
buf.as_mut_ptr().cast(),
Expand All @@ -1012,7 +1013,7 @@ pub(crate) fn recv(fd: Socket, buf: &mut [MaybeUninit<u8>], flags: c_int) -> io:
}

pub(crate) fn recv_from(
fd: Socket,
fd: RawSocket,
buf: &mut [MaybeUninit<u8>],
flags: c_int,
) -> io::Result<(usize, SockAddr)> {
Expand All @@ -1032,7 +1033,7 @@ pub(crate) fn recv_from(
}
}

pub(crate) fn peek_sender(fd: Socket) -> io::Result<SockAddr> {
pub(crate) fn peek_sender(fd: RawSocket) -> io::Result<SockAddr> {
// Unix-like platforms simply truncate the returned data, so this implementation is trivial.
// However, for Windows this requires suppressing the `WSAEMSGSIZE` error,
// so that requires a different approach.
Expand All @@ -1043,7 +1044,7 @@ pub(crate) fn peek_sender(fd: Socket) -> io::Result<SockAddr> {

#[cfg(not(target_os = "redox"))]
pub(crate) fn recv_vectored(
fd: Socket,
fd: RawSocket,
bufs: &mut [crate::MaybeUninitSlice<'_>],
flags: c_int,
) -> io::Result<(usize, RecvFlags)> {
Expand All @@ -1054,7 +1055,7 @@ pub(crate) fn recv_vectored(

#[cfg(not(target_os = "redox"))]
pub(crate) fn recv_from_vectored(
fd: Socket,
fd: RawSocket,
bufs: &mut [crate::MaybeUninitSlice<'_>],
flags: c_int,
) -> io::Result<(usize, RecvFlags, SockAddr)> {
Expand All @@ -1076,14 +1077,14 @@ pub(crate) fn recv_from_vectored(

#[cfg(not(target_os = "redox"))]
pub(crate) fn recvmsg(
fd: Socket,
fd: RawSocket,
msg: &mut MsgHdrMut<'_, '_, '_>,
flags: c_int,
) -> io::Result<usize> {
syscall!(recvmsg(fd, &mut msg.inner, flags)).map(|n| n as usize)
}

pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
pub(crate) fn send(fd: RawSocket, buf: &[u8], flags: c_int) -> io::Result<usize> {
syscall!(send(
fd,
buf.as_ptr().cast(),
Expand All @@ -1094,12 +1095,21 @@ pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result<usize> {
}

#[cfg(not(target_os = "redox"))]
pub(crate) fn send_vectored(fd: Socket, bufs: &[IoSlice<'_>], flags: c_int) -> io::Result<usize> {
pub(crate) fn send_vectored(
fd: RawSocket,
bufs: &[IoSlice<'_>],
flags: c_int,
) -> io::Result<usize> {
let msg = MsgHdr::new().with_buffers(bufs);
sendmsg(fd, &msg, flags)
}

pub(crate) fn send_to(fd: Socket, buf: &[u8], addr: &SockAddr, flags: c_int) -> io::Result<usize> {
pub(crate) fn send_to(
fd: RawSocket,
buf: &[u8],
addr: &SockAddr,
flags: c_int,
) -> io::Result<usize> {
syscall!(sendto(
fd,
buf.as_ptr().cast(),
Expand All @@ -1113,7 +1123,7 @@ pub(crate) fn send_to(fd: Socket, buf: &[u8], addr: &SockAddr, flags: c_int) ->

#[cfg(not(target_os = "redox"))]
pub(crate) fn send_to_vectored(
fd: Socket,
fd: RawSocket,
bufs: &[IoSlice<'_>],
addr: &SockAddr,
flags: c_int,
Expand All @@ -1123,12 +1133,12 @@ pub(crate) fn send_to_vectored(
}

#[cfg(not(target_os = "redox"))]
pub(crate) fn sendmsg(fd: Socket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io::Result<usize> {
pub(crate) fn sendmsg(fd: RawSocket, msg: &MsgHdr<'_, '_, '_>, flags: c_int) -> io::Result<usize> {
syscall!(sendmsg(fd, &msg.inner, flags)).map(|n| n as usize)
}

/// Wrapper around `getsockopt` to deal with platform specific timeouts.
pub(crate) fn timeout_opt(fd: Socket, opt: c_int, val: c_int) -> io::Result<Option<Duration>> {
pub(crate) fn timeout_opt(fd: RawSocket, opt: c_int, val: c_int) -> io::Result<Option<Duration>> {
unsafe { getsockopt(fd, opt, val).map(from_timeval) }
}

Expand All @@ -1144,7 +1154,7 @@ const fn from_timeval(duration: libc::timeval) -> Option<Duration> {

/// Wrapper around `setsockopt` to deal with platform specific timeouts.
pub(crate) fn set_timeout_opt(
fd: Socket,
fd: RawSocket,
opt: c_int,
val: c_int,
duration: Option<Duration>,
Expand Down Expand Up @@ -1172,15 +1182,15 @@ fn into_timeval(duration: Option<Duration>) -> libc::timeval {
feature = "all",
not(any(target_os = "haiku", target_os = "openbsd", target_os = "vita"))
))]
pub(crate) fn tcp_keepalive_time(fd: Socket) -> io::Result<Duration> {
pub(crate) fn tcp_keepalive_time(fd: RawSocket) -> io::Result<Duration> {
unsafe {
getsockopt::<c_int>(fd, IPPROTO_TCP, KEEPALIVE_TIME)
.map(|secs| Duration::from_secs(secs as u64))
}
}

#[allow(unused_variables)]
pub(crate) fn set_tcp_keepalive(fd: Socket, keepalive: &TcpKeepalive) -> io::Result<()> {
pub(crate) fn set_tcp_keepalive(fd: RawSocket, keepalive: &TcpKeepalive) -> io::Result<()> {
#[cfg(not(any(
target_os = "haiku",
target_os = "openbsd",
Expand Down Expand Up @@ -1241,13 +1251,13 @@ fn into_secs(duration: Duration) -> c_int {

/// Get the flags using `cmd`.
#[cfg(not(target_os = "vita"))]
fn fcntl_get(fd: Socket, cmd: c_int) -> io::Result<c_int> {
fn fcntl_get(fd: RawSocket, cmd: c_int) -> io::Result<c_int> {
syscall!(fcntl(fd, cmd))
}

/// Add `flag` to the current set flags of `F_GETFD`.
#[cfg(not(target_os = "vita"))]
fn fcntl_add(fd: Socket, get_cmd: c_int, set_cmd: c_int, flag: c_int) -> io::Result<()> {
fn fcntl_add(fd: RawSocket, get_cmd: c_int, set_cmd: c_int, flag: c_int) -> io::Result<()> {
let previous = fcntl_get(fd, get_cmd)?;
let new = previous | flag;
if new != previous {
Expand All @@ -1260,7 +1270,7 @@ fn fcntl_add(fd: Socket, get_cmd: c_int, set_cmd: c_int, flag: c_int) -> io::Res

/// Remove `flag` to the current set flags of `F_GETFD`.
#[cfg(not(target_os = "vita"))]
fn fcntl_remove(fd: Socket, get_cmd: c_int, set_cmd: c_int, flag: c_int) -> io::Result<()> {
fn fcntl_remove(fd: RawSocket, get_cmd: c_int, set_cmd: c_int, flag: c_int) -> io::Result<()> {
let previous = fcntl_get(fd, get_cmd)?;
let new = previous & !flag;
if new != previous {
Expand All @@ -1272,7 +1282,7 @@ fn fcntl_remove(fd: Socket, get_cmd: c_int, set_cmd: c_int, flag: c_int) -> io::
}

/// Caller must ensure `T` is the correct type for `opt` and `val`.
pub(crate) unsafe fn getsockopt<T>(fd: Socket, opt: c_int, val: c_int) -> io::Result<T> {
pub(crate) unsafe fn getsockopt<T>(fd: RawSocket, opt: c_int, val: c_int) -> io::Result<T> {
let mut payload: MaybeUninit<T> = MaybeUninit::uninit();
let mut len = size_of::<T>() as libc::socklen_t;
syscall!(getsockopt(
Expand All @@ -1291,7 +1301,7 @@ pub(crate) unsafe fn getsockopt<T>(fd: Socket, opt: c_int, val: c_int) -> io::Re

/// Caller must ensure `T` is the correct type for `opt` and `val`.
pub(crate) unsafe fn setsockopt<T>(
fd: Socket,
fd: RawSocket,
opt: c_int,
val: c_int,
payload: T,
Expand Down Expand Up @@ -1365,7 +1375,7 @@ pub(crate) const fn to_mreqn(
feature = "all",
any(target_os = "android", target_os = "fuchsia", target_os = "linux")
))]
pub(crate) fn original_dst_v4(fd: Socket) -> io::Result<SockAddr> {
pub(crate) fn original_dst_v4(fd: RawSocket) -> io::Result<SockAddr> {
// Safety: `getsockopt` initialises the `SockAddr` for us.
unsafe {
SockAddr::try_init(|storage, len| {
Expand All @@ -1386,7 +1396,7 @@ pub(crate) fn original_dst_v4(fd: Socket) -> io::Result<SockAddr> {
/// This value contains the original destination IPv6 address of the connection
/// redirected using `ip6tables` `REDIRECT` or `TPROXY`.
#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux")))]
pub(crate) fn original_dst_v6(fd: Socket) -> io::Result<SockAddr> {
pub(crate) fn original_dst_v6(fd: RawSocket) -> io::Result<SockAddr> {
// Safety: `getsockopt` initialises the `SockAddr` for us.
unsafe {
SockAddr::try_init(|storage, len| {
Expand Down
Loading