Skip to content

Commit 35f52a1

Browse files
bors[bot]blckngm
andauthored
Merge #92
92: Simplify windows AsyncTun async read r=sopium a=sopium We don't need to send buffers around, just the event to wait for. Co-authored-by: Guanhao Yin <sopium@mysterious.site>
2 parents 8bd1411 + e517bf3 commit 35f52a1

File tree

3 files changed

+52
-155
lines changed

3 files changed

+52
-155
lines changed

src/wireguard/tun_windows.rs

Lines changed: 33 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ use std::sync::Arc;
3131
use anyhow::Context;
3232
use once_cell::sync::OnceCell;
3333
use parking_lot::Mutex;
34-
use tokio::sync::mpsc::{channel, Receiver, Sender};
3534
use tokio::sync::Mutex as AsyncMutex;
3635
use wchar::*;
3736
use widestring::*;
@@ -48,6 +47,7 @@ use winapi::um::handleapi::*;
4847
use winapi::um::ioapiset::*;
4948
use winapi::um::minwinbase::*;
5049
use winapi::um::namespaceapi::*;
50+
use winapi::um::processthreadsapi::*;
5151
use winapi::um::securitybaseapi::*;
5252
use winapi::um::setupapi::*;
5353
use winapi::um::synchapi::*;
@@ -110,105 +110,21 @@ use self::interface::*;
110110
mod buffer;
111111
use self::buffer::*;
112112

113-
pub struct AsyncTun {
114-
tun: Arc<Tun>,
115-
channels: AsyncMutex<TunChannels>,
116-
}
117-
118-
struct TunChannels {
119-
read_rx: Receiver<io::Result<(Box<[u8]>, usize)>>,
120-
buffer_tx: Sender<Box<[u8]>>,
121-
}
122-
123-
impl Drop for AsyncTun {
124-
fn drop(&mut self) {
125-
self.tun.interrupt();
126-
}
127-
}
128-
129-
impl AsyncTun {
130-
pub fn open(name: &OsStr) -> anyhow::Result<AsyncTun> {
131-
let tun = Tun::open(name)?;
132-
Ok(AsyncTun::new(tun))
133-
}
134-
135-
fn new(tun: Tun) -> AsyncTun {
136-
let tun = Arc::new(tun);
137-
// read thread -> async fn read.
138-
let (mut read_tx, read_rx) = channel(2);
139-
// async fn read -> read thread, to reuse buffers.
140-
let (mut buffer_tx, mut buffer_rx) = channel::<Box<[u8]>>(2);
141-
buffer_tx.try_send(vec![0u8; 65536].into()).unwrap();
142-
buffer_tx.try_send(vec![0u8; 65536].into()).unwrap();
143-
// We run this in a separate thread so that the wintun intance can be
144-
// dropped.
145-
//
146-
// We let tokio manage this thread so that the tokio runtime is not
147-
// dropped before the wintun instance.
148-
tokio::spawn({
149-
let tun = tun.clone();
150-
async move {
151-
'outer: loop {
152-
let mut buf = match buffer_rx.recv().await {
153-
None => break,
154-
Some(buf) => buf,
155-
};
156-
// We don't want to consume the `buf` when we get an
157-
// `Err`. So loop until we get an `Ok`.
158-
loop {
159-
match tokio::task::block_in_place(|| tun.read(&mut buf[..])) {
160-
Err(e) => {
161-
if read_tx.send(Err(e)).await.is_err() {
162-
break 'outer;
163-
}
164-
}
165-
Ok(len) => {
166-
if read_tx.send(Ok((buf, len))).await.is_err() {
167-
break 'outer;
168-
}
169-
break;
170-
}
171-
}
172-
}
173-
}
174-
}
175-
});
176-
AsyncTun {
177-
tun,
178-
channels: AsyncMutex::new(TunChannels { read_rx, buffer_tx }),
179-
}
180-
}
181-
182-
pub(crate) async fn read<'a>(&'a self, buf: &'a mut [u8]) -> io::Result<usize> {
183-
// Don't use `blocking` for operations that may block forever.
184-
185-
let mut channels = self.channels.lock().await;
186-
187-
let (p, p_len) = channels.read_rx.recv().await.unwrap()?;
188-
let len = std::cmp::min(p_len, buf.len());
189-
buf[..len].copy_from_slice(&p[..len]);
190-
channels.buffer_tx.send(p).await.unwrap();
191-
Ok(len)
192-
}
193-
194-
pub(crate) async fn write<'a>(&'a self, buf: &'a [u8]) -> io::Result<usize> {
195-
self.tun.write(buf)
196-
}
197-
}
198-
199113
/// A handle to a tun interface.
200-
struct Tun {
114+
pub struct AsyncTun {
201115
handle: HandleWrapper,
202116
name: OsString,
203117
rings: TunRegisterRings,
204-
read_lock: Mutex<()>,
118+
// A clone of `rings.send.tail_moved` (created by `DuplicateHandle`) wrapped
119+
// in `Arc` that can be sent to `spawn_blocking`.
120+
send_tail_moved_clone: Arc<HandleWrapper>,
121+
read_lock: AsyncMutex<()>,
205122
write_lock: Mutex<()>,
206-
cancel_event: HandleWrapper,
207123
}
208124

209-
impl fmt::Debug for Tun {
125+
impl fmt::Debug for AsyncTun {
210126
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211-
f.debug_struct("Tun")
127+
f.debug_struct("AsyncTun")
212128
.field("name", &self.name)
213129
.field("handle", &self.handle.0)
214130
.finish()
@@ -227,9 +143,9 @@ const TUN_IOCTL_REGISTER_RINGS: u32 = CTL_CODE(
227143
FILE_READ_DATA | FILE_WRITE_DATA,
228144
);
229145

230-
impl Tun {
146+
impl AsyncTun {
231147
/// Open a handle to a wintun interface.
232-
pub fn open(name: &OsStr) -> anyhow::Result<Tun> {
148+
pub fn open(name: &OsStr) -> anyhow::Result<AsyncTun> {
233149
info!("opening wintun device {}", name.to_string_lossy());
234150
let interface = WINTUN_POOL.get_interface(name)?;
235151
if let Some(interface) = interface {
@@ -251,41 +167,48 @@ impl Tun {
251167
))
252168
.context("DeviceIoControl TUN_IOCTL_REGISTER_RINGS")?;
253169

254-
let cancel_event =
255-
unsafe_h!(CreateEventW(null_mut(), 1, 0, null_mut())).context("CreateEventW")?;
170+
let send_tail_moved_clone = {
171+
let mut handle: HANDLE = null_mut();
172+
let self_process = unsafe { GetCurrentProcess() };
173+
unsafe_b!(DuplicateHandle(
174+
self_process,
175+
rings.send.tail_moved.0,
176+
self_process,
177+
&mut handle,
178+
0,
179+
0,
180+
DUPLICATE_SAME_ACCESS,
181+
))
182+
.context("DuplicateHandle")?;
183+
Arc::new(HandleWrapper(handle))
184+
};
256185

257186
Ok(Self {
258187
handle,
259188
name: name.into(),
260189
rings,
261-
read_lock: Mutex::new(()),
190+
send_tail_moved_clone,
191+
read_lock: AsyncMutex::new(()),
262192
write_lock: Mutex::new(()),
263-
cancel_event,
264193
})
265194
}
266195

267196
/// Read a packet from the interface.
268197
///
269198
/// Blocking.
270-
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
271-
let _read_lock_guard = self.read_lock.lock();
272-
unsafe { self.rings.read(buf, self.cancel_event.0) }
199+
pub async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
200+
let _read_lock_guard = self.read_lock.lock().await;
201+
unsafe { self.rings.read(buf, &self.send_tail_moved_clone).await }
273202
}
274203

275204
/// Write a packet to the interface.
276205
///
277206
/// Does not block.
278-
pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
207+
pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
279208
let _write_lock_guard = self.write_lock.lock();
280209
unsafe { self.rings.write(buf) }
281210
}
282211

283-
/// Interrupt a blocking read operation on this Tun interface.
284-
pub fn interrupt(&self) {
285-
debug!("interrupt tun read");
286-
unsafe_b!(SetEvent(self.cancel_event.0)).unwrap();
287-
}
288-
289212
fn close(&self) -> anyhow::Result<()> {
290213
info!("closing wintun interface");
291214
let it = WINTUN_POOL
@@ -298,10 +221,7 @@ impl Tun {
298221
}
299222
}
300223

301-
unsafe impl Sync for Tun {}
302-
unsafe impl Send for Tun {}
303-
304-
impl Drop for Tun {
224+
impl Drop for AsyncTun {
305225
fn drop(&mut self) {
306226
self.close()
307227
.unwrap_or_else(|e| warn!("failed to close tun: {:#}", e))
@@ -321,7 +241,7 @@ mod tests {
321241
#[test]
322242
fn test_open_and_close() {
323243
let _ = env_logger::try_init();
324-
let t = Tun::open(OsStr::new("tun0"));
244+
let t = AsyncTun::open(OsStr::new("tun0"));
325245
println!("Tun::open(): {:?}", t);
326246
}
327247
}

src/wireguard/tun_windows/pool.rs

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pub const HARDWARE_ID: &str = "Wintun";
2424
const HARDWARE_ID_MULTI_SZ: &[u16] = wch!("Wintun\0\0");
2525
const WAIT_REGISTRY_TIMEOUT: Duration = Duration::from_secs(1);
2626

27-
fn wait_for_single_object(obj: HANDLE, timeout: Option<Duration>) -> io::Result<()> {
27+
pub fn wait_for_single_object(obj: HANDLE, timeout: Option<Duration>) -> io::Result<()> {
2828
let timeout_millis = timeout
2929
.map(|t| {
3030
let millis = t.as_millis();
@@ -923,33 +923,4 @@ mod tests {
923923
fn test_mutex_name() {
924924
assert_eq!(WINTUN_POOL.mutex_name(), "Wintun\\Wintun-Name-Mutex-4843d7de9cb25125d50ac5cec3866519a1cab845a26e84d5593f8ab8bbc5fb2c");
925925
}
926-
927-
// XXX: this test only really works when running as LocalSystem.
928-
//
929-
// You can use .e.g. psexec (from pstools) to run this test as LocalSystem.
930-
#[test]
931-
fn test_take_named_mutex() {
932-
println!("take_named_mutex: {:?}", WINTUN_POOL.take_named_mutex());
933-
}
934-
935-
// XXX: this test only really works when running as LocalSystem.
936-
#[test]
937-
fn test_get_interface() {
938-
std::env::set_var("RUST_LOG", "warn");
939-
let _ = env_logger::try_init();
940-
println!(
941-
"get_interface: {:?}",
942-
WINTUN_POOL.get_interface(OsStr::new("tun0"))
943-
);
944-
}
945-
946-
#[test]
947-
fn test_create_interface() {
948-
std::env::set_var("RUST_LOG", "debug");
949-
let _ = env_logger::try_init();
950-
println!(
951-
"create_interface: {:?}",
952-
WINTUN_POOL.create_interface(OsStr::new("tun0")),
953-
);
954-
}
955926
}

src/wireguard/tun_windows/ring.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,22 @@ fn _assert_long_is_i32() {
4141
}
4242

4343
#[repr(C)]
44-
struct TunRegisterRing {
44+
pub struct TunRegisterRing {
4545
ring_size: u32,
4646
// ABI compatible with *mut TunRing.
4747
ring: Box<UnsafeCell<TunRing>>,
4848
// ABI compatible with HANDLE.
49-
tail_moved: HandleWrapper,
49+
pub tail_moved: HandleWrapper,
5050
}
5151

5252
#[repr(C)]
5353
pub struct TunRegisterRings {
54-
send: TunRegisterRing,
54+
pub send: TunRegisterRing,
5555
receive: TunRegisterRing,
5656
}
5757

58+
unsafe impl Sync for TunRegisterRings {}
59+
5860
fn align_4(len: u32) -> u32 {
5961
(len + 0b11) & !0b11
6062
}
@@ -163,22 +165,26 @@ impl TunRegisterRings {
163165
/// # Safety
164166
///
165167
/// Only one thread should call this.
166-
pub unsafe fn read(&self, buf: &mut [u8], canceled: HANDLE) -> io::Result<usize> {
168+
pub async unsafe fn read(
169+
&self,
170+
buf: &mut [u8],
171+
send_tail_moved_clone: &Arc<HandleWrapper>,
172+
) -> io::Result<usize> {
167173
let send_ring = self.send.ring.get().as_mut().unwrap();
168-
let events = [self.send.tail_moved.0, canceled];
169174
loop {
170175
match send_ring.read(buf) {
171176
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
172177
send_ring.alertable.store(1, Ordering::SeqCst);
173178
match send_ring.read(buf) {
174179
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
175-
let wait_result =
176-
WaitForMultipleObjects(2, events.as_ptr(), 0, INFINITE);
177-
match wait_result {
178-
0 => (),
179-
1 => return Err(io::ErrorKind::Interrupted.into()),
180-
_ => unreachable!(),
181-
}
180+
let send_tail_moved_clone = send_tail_moved_clone.clone();
181+
// It seems like when the interface is closed, the
182+
// event is automatically signalled, so cancellation
183+
// just works as expected.
184+
let _ = tokio::task::spawn_blocking(move || {
185+
wait_for_single_object(send_tail_moved_clone.0, None)
186+
})
187+
.await;
182188
send_ring.alertable.store(0, Ordering::SeqCst);
183189
continue;
184190
}

0 commit comments

Comments
 (0)