Skip to content

Commit a7bd58a

Browse files
committed
io_uring: implement Endpoint trait and generic copy function
1 parent 6ceef86 commit a7bd58a

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

src/io_uring.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ use std::time::{Duration, Instant};
77
use tokio::sync::Notify;
88
use tokio::time::timeout;
99
use tokio_uring::buf::BoundedBuf;
10+
use tokio_uring::buf::BoundedBufMut;
11+
use tokio_uring::BufResult;
12+
use tokio_uring::UnsubmittedWrite;
1013

1114
// module name for logging engine
1215
const NAME: &str = "<i><bright-black> proxy: </>";
@@ -16,6 +19,109 @@ const BUFFER_LEN: usize = 16 * 1024;
1619
const READ_TIMEOUT: Duration = Duration::new(5, 0);
1720
const TCP_CLIENT_TIMEOUT: Duration = Duration::new(30, 0);
1821

22+
// tokio_uring::fs::File and tokio_uring::net::TcpStream are using different
23+
// read and write calls:
24+
// File is using read_at() and write_at(),
25+
// TcpStream is using read() and write()
26+
//
27+
// In our case we are reading a special unix character device for
28+
// the USB gadget, which is not a regular file where an offset is important.
29+
// We just use offset 0 for reading and writing, so below is a trait
30+
// for this, to be able to use it in a generic copy() function below.
31+
32+
pub trait Endpoint<E> {
33+
async fn read<T: BoundedBufMut>(&self, buf: T) -> BufResult<usize, T>;
34+
fn write<T: BoundedBuf>(&self, buf: T) -> UnsubmittedWrite<T>;
35+
}
36+
37+
impl Endpoint<tokio_uring::fs::File> for tokio_uring::fs::File {
38+
async fn read<T: BoundedBufMut>(&self, buf: T) -> BufResult<usize, T> {
39+
self.read_at(buf, 0).await
40+
}
41+
fn write<T: BoundedBuf>(&self, buf: T) -> UnsubmittedWrite<T> {
42+
self.write_at(buf, 0)
43+
}
44+
}
45+
46+
impl Endpoint<tokio_uring::net::TcpStream> for tokio_uring::net::TcpStream {
47+
async fn read<T: BoundedBufMut>(&self, buf: T) -> BufResult<usize, T> {
48+
self.read(buf).await
49+
}
50+
fn write<T: BoundedBuf>(&self, buf: T) -> UnsubmittedWrite<T> {
51+
self.write(buf)
52+
}
53+
}
54+
55+
async fn copy<A: Endpoint<A>, B: Endpoint<B>>(
56+
from: Rc<A>,
57+
to: Rc<B>,
58+
dbg_name: &'static str,
59+
direction: &'static str,
60+
stats_interval: Option<Duration>,
61+
) -> Result<(), std::io::Error> {
62+
// For statistics
63+
let mut bytes_out: usize = 0;
64+
let mut bytes_out_last: usize = 0;
65+
let mut report_time = Instant::now();
66+
67+
let mut buf = vec![0u8; BUFFER_LEN];
68+
loop {
69+
// Handle stats printing
70+
if stats_interval.is_some() && report_time.elapsed() > stats_interval.unwrap() {
71+
let transferred_total = ByteSize::b(bytes_out.try_into().unwrap());
72+
let transferred_last = ByteSize::b(bytes_out_last.try_into().unwrap());
73+
74+
let speed: u64 =
75+
(bytes_out_last as f64 / report_time.elapsed().as_secs_f64()).round() as u64;
76+
let speed = ByteSize::b(speed);
77+
78+
info!(
79+
"{} {} transfer: {:#} ({:#}/s), {:#} total",
80+
NAME,
81+
direction,
82+
transferred_last.to_string_as(true),
83+
speed.to_string_as(true),
84+
transferred_total.to_string_as(true),
85+
);
86+
87+
report_time = Instant::now();
88+
bytes_out_last = 0;
89+
}
90+
91+
// things look weird: we pass ownership of the buffer to `read`, and we get
92+
// it back, _even if there was an error_. There's a whole trait for that,
93+
// which `Vec<u8>` implements!
94+
debug!("{}: before read", dbg_name);
95+
let retval = from.read(buf);
96+
let (res, buf_read) = timeout(READ_TIMEOUT, retval).await?;
97+
// Propagate errors, see how many bytes we read
98+
let n = res?;
99+
debug!("{}: after read, {} bytes", dbg_name, n);
100+
if n == 0 {
101+
// A read of size zero signals EOF (end of file), finish gracefully
102+
return Ok(());
103+
}
104+
105+
// The `slice` method here is implemented in an extension trait: it
106+
// returns an owned slice of our `Vec<u8>`, which we later turn back
107+
// into the full `Vec<u8>`
108+
debug!("{}: before write", dbg_name);
109+
let (res, buf_write) = to.write(buf_read.slice(..n)).submit().await;
110+
let n = res?;
111+
debug!("{}: after write, {} bytes", dbg_name, n);
112+
// Increment byte counters for statistics
113+
if stats_interval.is_some() {
114+
bytes_out += n;
115+
bytes_out_last += n;
116+
}
117+
118+
// Later is now, we want our full buffer back.
119+
// That's why we declared our binding `mut` way back at the start of `copy`,
120+
// even though we moved it into the very first `TcpStream::read` call.
121+
buf = buf_write.into_inner();
122+
}
123+
}
124+
19125
async fn copy_file_to_stream(
20126
from: Rc<tokio_uring::fs::File>,
21127
to: Rc<tokio_uring::net::TcpStream>,

0 commit comments

Comments
 (0)