Skip to content

Commit 795e974

Browse files
authored
Do not use windows-sys types in public API (#576)
1 parent 071378d commit 795e974

File tree

6 files changed

+139
-47
lines changed

6 files changed

+139
-47
lines changed

.github/workflows/main.yml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,28 @@ jobs:
119119
- uses: dtolnay/rust-toolchain@stable
120120
- name: Run Clippy
121121
run: cargo clippy --all-targets --all-features -- -D warnings
122+
CheckExternalTypes:
123+
name: check-external-types (${{ matrix.os }})
124+
runs-on: ${{ matrix.os }}
125+
strategy:
126+
matrix:
127+
os:
128+
- windows-latest
129+
- ubuntu-latest
130+
rust:
131+
# `check-external-types` requires a specific Rust nightly version. See
132+
# the README for details: https://github.com/awslabs/cargo-check-external-types
133+
- nightly-2024-06-30
134+
steps:
135+
- uses: actions/checkout@v4
136+
- name: Install Rust ${{ matrix.rust }}
137+
uses: dtolnay/rust-toolchain@stable
138+
with:
139+
toolchain: ${{ matrix.rust }}
140+
- name: Install cargo-check-external-types
141+
uses: taiki-e/cache-cargo-install-action@v1
142+
with:
143+
tool: cargo-check-external-types@0.1.13
144+
locked: true
145+
- name: check-external-types
146+
run: cargo check-external-types --all-features

Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,11 @@ features = [
6666
[features]
6767
# Enable all API, even ones not available on all OSs.
6868
all = []
69+
70+
[package.metadata.cargo_check_external_types]
71+
allowed_external_types = [
72+
"libc::*",
73+
# Referenced via a type alias.
74+
"windows_sys::Win32::Networking::WinSock::socklen_t",
75+
"windows_sys::Win32::Networking::WinSock::ADDRESS_FAMILY",
76+
]

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ compile_error!("Socket2 doesn't support the compile target");
185185

186186
use sys::c_int;
187187

188-
pub use sockaddr::SockAddr;
188+
pub use sockaddr::{sa_family_t, socklen_t, SockAddr, SockAddrStorage};
189189
pub use socket::Socket;
190190
pub use sockref::SockRef;
191191

src/sockaddr.rs

Lines changed: 77 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,71 @@
11
use std::hash::Hash;
2-
use std::mem::{self, size_of, MaybeUninit};
2+
use std::mem::{self, size_of};
33
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
44
use std::path::Path;
55
use std::{fmt, io, ptr};
66

77
#[cfg(windows)]
88
use windows_sys::Win32::Networking::WinSock::SOCKADDR_IN6_0;
99

10-
use crate::sys::{
11-
c_int, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_storage, socklen_t, AF_INET,
12-
AF_INET6, AF_UNIX,
13-
};
10+
use crate::sys::{c_int, sockaddr_in, sockaddr_in6, sockaddr_storage, AF_INET, AF_INET6, AF_UNIX};
1411
use crate::Domain;
1512

13+
/// The integer type used with `getsockname` on this platform.
14+
#[allow(non_camel_case_types)]
15+
pub type socklen_t = crate::sys::socklen_t;
16+
17+
/// The integer type for the `ss_family` field on this platform.
18+
#[allow(non_camel_case_types)]
19+
pub type sa_family_t = crate::sys::sa_family_t;
20+
21+
/// Rust version of the [`sockaddr_storage`] type.
22+
///
23+
/// This type is intended to be used with with direct calls to the `getsockname` syscall. See the
24+
/// documentation of [`SockAddr::new`] for examples.
25+
///
26+
/// This crate defines its own `sockaddr_storage` type to avoid semver concerns with upgrading
27+
/// `windows-sys`.
28+
#[repr(transparent)]
29+
pub struct SockAddrStorage {
30+
storage: sockaddr_storage,
31+
}
32+
33+
impl SockAddrStorage {
34+
/// Construct a new storage containing all zeros.
35+
#[inline]
36+
pub fn zeroed() -> Self {
37+
// SAFETY: All zeros is valid for this type.
38+
unsafe { mem::zeroed() }
39+
}
40+
41+
/// Returns the size of this storage.
42+
#[inline]
43+
pub fn size_of(&self) -> socklen_t {
44+
size_of::<Self>() as socklen_t
45+
}
46+
47+
/// View this type as another type.
48+
///
49+
/// # Safety
50+
///
51+
/// The type `T` must be one of the `sockaddr_*` types defined by this platform.
52+
#[inline]
53+
pub unsafe fn view_as<T>(&mut self) -> &mut T {
54+
assert!(size_of::<T>() <= size_of::<Self>());
55+
// SAFETY: This type is repr(transparent) over `sockaddr_storage` and `T` is one of the
56+
// `sockaddr_*` types defined by this platform.
57+
&mut *(self as *mut Self as *mut T)
58+
}
59+
}
60+
61+
impl std::fmt::Debug for SockAddrStorage {
62+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63+
f.debug_struct("sockaddr_storage")
64+
.field("ss_family", &self.storage.ss_family)
65+
.finish_non_exhaustive()
66+
}
67+
}
68+
1669
/// The address of a socket.
1770
///
1871
/// `SockAddr`s may be constructed directly to and from the standard library
@@ -40,23 +93,22 @@ impl SockAddr {
4093
/// # fn main() -> std::io::Result<()> {
4194
/// # #[cfg(unix)] {
4295
/// use std::io;
43-
/// use std::mem;
4496
/// use std::os::unix::io::AsRawFd;
4597
///
46-
/// use socket2::{SockAddr, Socket, Domain, Type};
98+
/// use socket2::{SockAddr, SockAddrStorage, Socket, Domain, Type};
4799
///
48100
/// let socket = Socket::new(Domain::IPV4, Type::STREAM, None)?;
49101
///
50102
/// // Initialise a `SocketAddr` byte calling `getsockname(2)`.
51-
/// let mut addr_storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
52-
/// let mut len = mem::size_of_val(&addr_storage) as libc::socklen_t;
103+
/// let mut addr_storage = SockAddrStorage::zeroed();
104+
/// let mut len = addr_storage.size_of();
53105
///
54106
/// // The `getsockname(2)` system call will intiliase `storage` for
55107
/// // us, setting `len` to the correct length.
56108
/// let res = unsafe {
57109
/// libc::getsockname(
58110
/// socket.as_raw_fd(),
59-
/// (&mut addr_storage as *mut libc::sockaddr_storage).cast(),
111+
/// addr_storage.view_as(),
60112
/// &mut len,
61113
/// )
62114
/// };
@@ -70,8 +122,11 @@ impl SockAddr {
70122
/// # Ok(())
71123
/// # }
72124
/// ```
73-
pub const unsafe fn new(storage: sockaddr_storage, len: socklen_t) -> SockAddr {
74-
SockAddr { storage, len }
125+
pub const unsafe fn new(storage: SockAddrStorage, len: socklen_t) -> SockAddr {
126+
SockAddr {
127+
storage: storage.storage,
128+
len: len as socklen_t,
129+
}
75130
}
76131

77132
/// Initialise a `SockAddr` by calling the function `init`.
@@ -121,25 +176,19 @@ impl SockAddr {
121176
/// ```
122177
pub unsafe fn try_init<F, T>(init: F) -> io::Result<(T, SockAddr)>
123178
where
124-
F: FnOnce(*mut sockaddr_storage, *mut socklen_t) -> io::Result<T>,
179+
F: FnOnce(*mut SockAddrStorage, *mut socklen_t) -> io::Result<T>,
125180
{
126181
const STORAGE_SIZE: socklen_t = size_of::<sockaddr_storage>() as socklen_t;
127182
// NOTE: `SockAddr::unix` depends on the storage being zeroed before
128183
// calling `init`.
129184
// NOTE: calling `recvfrom` with an empty buffer also depends on the
130185
// storage being zeroed before calling `init` as the OS might not
131186
// initialise it.
132-
let mut storage = MaybeUninit::<sockaddr_storage>::zeroed();
187+
let mut storage = SockAddrStorage::zeroed();
133188
let mut len = STORAGE_SIZE;
134-
init(storage.as_mut_ptr(), &mut len).map(|res| {
189+
init(&mut storage, &mut len).map(|res| {
135190
debug_assert!(len <= STORAGE_SIZE, "overflown address storage");
136-
let addr = SockAddr {
137-
// Safety: zeroed-out `sockaddr_storage` is valid, caller must
138-
// ensure at least `len` bytes are valid.
139-
storage: storage.assume_init(),
140-
len,
141-
};
142-
(res, addr)
191+
(res, SockAddr::new(storage, len))
143192
})
144193
}
145194

@@ -179,13 +228,15 @@ impl SockAddr {
179228
}
180229

181230
/// Returns a raw pointer to the address.
182-
pub const fn as_ptr(&self) -> *const sockaddr {
183-
ptr::addr_of!(self.storage).cast()
231+
pub const fn as_ptr(&self) -> *const SockAddrStorage {
232+
&self.storage as *const sockaddr_storage as *const SockAddrStorage
184233
}
185234

186235
/// Retuns the address as the storage.
187-
pub const fn as_storage(self) -> sockaddr_storage {
188-
self.storage
236+
pub const fn as_storage(self) -> SockAddrStorage {
237+
SockAddrStorage {
238+
storage: self.storage,
239+
}
189240
}
190241

191242
/// Returns true if this address is in the `AF_INET` (IPv4) family, false otherwise.

src/sys/unix.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ use std::{io, slice};
7979
use libc::ssize_t;
8080
use libc::{in6_addr, in_addr};
8181

82-
use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type};
82+
use crate::{Domain, Protocol, SockAddr, SockAddrStorage, TcpKeepalive, Type};
8383
#[cfg(not(target_os = "redox"))]
8484
use crate::{MsgHdr, MsgHdrMut, RecvFlags};
8585

@@ -658,10 +658,10 @@ pub(crate) fn offset_of_path(storage: &libc::sockaddr_un) -> usize {
658658

659659
#[allow(unsafe_op_in_unsafe_fn)]
660660
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
661-
// SAFETY: a `sockaddr_storage` of all zeros is valid.
662-
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
661+
let mut storage = SockAddrStorage::zeroed();
663662
let len = {
664-
let storage = unsafe { &mut *ptr::addr_of_mut!(storage).cast::<libc::sockaddr_un>() };
663+
// SAFETY: sockaddr_un is one of the sockaddr_* types defined by this platform.
664+
let storage = unsafe { storage.view_as::<libc::sockaddr_un>() };
665665

666666
let bytes = path.as_os_str().as_bytes();
667667
let too_long = match bytes.first() {
@@ -750,11 +750,10 @@ impl SockAddr {
750750
#[allow(unsafe_op_in_unsafe_fn)]
751751
#[cfg(all(feature = "all", any(target_os = "android", target_os = "linux")))]
752752
pub fn vsock(cid: u32, port: u32) -> SockAddr {
753-
// SAFETY: a `sockaddr_storage` of all zeros is valid.
754-
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
753+
let mut storage = SockAddrStorage::zeroed();
755754
{
756-
let storage: &mut libc::sockaddr_vm =
757-
unsafe { &mut *((&mut storage as *mut sockaddr_storage).cast()) };
755+
// SAFETY: sockaddr_vm is one of the sockaddr_* types defined by this platform.
756+
let storage = unsafe { storage.view_as::<libc::sockaddr_vm>() };
758757
storage.svm_family = libc::AF_VSOCK as sa_family_t;
759758
storage.svm_cid = cid;
760759
storage.svm_port = port;
@@ -895,11 +894,11 @@ pub(crate) fn socketpair(family: c_int, ty: c_int, protocol: c_int) -> io::Resul
895894
}
896895

897896
pub(crate) fn bind(fd: Socket, addr: &SockAddr) -> io::Result<()> {
898-
syscall!(bind(fd, addr.as_ptr(), addr.len() as _)).map(|_| ())
897+
syscall!(bind(fd, addr.as_ptr().cast::<sockaddr>(), addr.len() as _)).map(|_| ())
899898
}
900899

901900
pub(crate) fn connect(fd: Socket, addr: &SockAddr) -> io::Result<()> {
902-
syscall!(connect(fd, addr.as_ptr(), addr.len())).map(|_| ())
901+
syscall!(connect(fd, addr.as_ptr().cast::<sockaddr>(), addr.len())).map(|_| ())
903902
}
904903

905904
pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
@@ -1116,7 +1115,7 @@ pub(crate) fn send_to(fd: Socket, buf: &[u8], addr: &SockAddr, flags: c_int) ->
11161115
buf.as_ptr().cast(),
11171116
min(buf.len(), MAX_BUF_LEN),
11181117
flags,
1119-
addr.as_ptr(),
1118+
addr.as_ptr().cast::<sockaddr>(),
11201119
addr.len(),
11211120
))
11221121
.map(|n| n as usize)

src/sys/windows.rs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use windows_sys::Win32::Networking::WinSock::{
3232
};
3333
use windows_sys::Win32::System::Threading::INFINITE;
3434

35-
use crate::{MsgHdr, RecvFlags, SockAddr, TcpKeepalive, Type};
35+
use crate::{MsgHdr, RecvFlags, SockAddr, SockAddrStorage, TcpKeepalive, Type};
3636

3737
#[allow(non_camel_case_types)]
3838
pub(crate) type c_int = std::os::raw::c_int;
@@ -271,11 +271,21 @@ pub(crate) fn socket(family: c_int, mut ty: c_int, protocol: c_int) -> io::Resul
271271
}
272272

273273
pub(crate) fn bind(socket: Socket, addr: &SockAddr) -> io::Result<()> {
274-
syscall!(bind(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
274+
syscall!(
275+
bind(socket, addr.as_ptr().cast::<sockaddr>(), addr.len()),
276+
PartialEq::ne,
277+
0
278+
)
279+
.map(|_| ())
275280
}
276281

277282
pub(crate) fn connect(socket: Socket, addr: &SockAddr) -> io::Result<()> {
278-
syscall!(connect(socket, addr.as_ptr(), addr.len()), PartialEq::ne, 0).map(|_| ())
283+
syscall!(
284+
connect(socket, addr.as_ptr().cast::<sockaddr>(), addr.len()),
285+
PartialEq::ne,
286+
0
287+
)
288+
.map(|_| ())
279289
}
280290

281291
pub(crate) fn poll_connect(socket: &crate::Socket, timeout: Duration) -> io::Result<()> {
@@ -635,7 +645,7 @@ pub(crate) fn send_to(
635645
buf.as_ptr().cast(),
636646
min(buf.len(), MAX_BUF_LEN) as c_int,
637647
flags,
638-
addr.as_ptr(),
648+
addr.as_ptr().cast::<sockaddr>(),
639649
addr.len(),
640650
),
641651
PartialEq::eq,
@@ -659,7 +669,7 @@ pub(crate) fn send_to_vectored(
659669
bufs.len().min(u32::MAX as usize) as u32,
660670
&mut nsent,
661671
flags as u32,
662-
addr.as_ptr(),
672+
addr.as_ptr().cast::<sockaddr>(),
663673
addr.len(),
664674
ptr::null_mut(),
665675
None,
@@ -900,11 +910,10 @@ pub(crate) fn original_dst_ipv6(socket: Socket) -> io::Result<SockAddr> {
900910

901911
#[allow(unsafe_op_in_unsafe_fn)]
902912
pub(crate) fn unix_sockaddr(path: &Path) -> io::Result<SockAddr> {
903-
// SAFETY: a `sockaddr_storage` of all zeros is valid.
904-
let mut storage = unsafe { mem::zeroed::<sockaddr_storage>() };
913+
let mut storage = SockAddrStorage::zeroed();
905914
let len = {
906-
let storage: &mut windows_sys::Win32::Networking::WinSock::SOCKADDR_UN =
907-
unsafe { &mut *(&mut storage as *mut sockaddr_storage).cast() };
915+
let storage =
916+
unsafe { storage.view_as::<windows_sys::Win32::Networking::WinSock::SOCKADDR_UN>() };
908917

909918
// Windows expects a UTF-8 path here even though Windows paths are
910919
// usually UCS-2 encoded. If Rust exposed OsStr's Wtf8 encoded

0 commit comments

Comments
 (0)