Skip to content

Commit ff537b8

Browse files
committed
feat: log the number of read/written bytes on IMAP stream read error
1 parent d45ec7f commit ff537b8

File tree

4 files changed

+243
-16
lines changed

4 files changed

+243
-16
lines changed

src/imap/client.rs

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@ use tokio::io::BufWriter;
88

99
use super::capabilities::Capabilities;
1010
use crate::context::Context;
11-
use crate::log::{info, warn};
11+
use crate::log::{LoggingStream, info, warn};
1212
use crate::login_param::{ConnectionCandidate, ConnectionSecurity};
1313
use crate::net::dns::{lookup_host_with_cache, update_connect_timestamp};
1414
use crate::net::proxy::ProxyConfig;
1515
use crate::net::session::SessionStream;
1616
use crate::net::tls::wrap_tls;
17-
use crate::net::{
18-
connect_tcp_inner, connect_tls_inner, run_connection_attempts, update_connection_history,
19-
};
17+
use crate::net::{connect_tcp_inner, run_connection_attempts, update_connection_history};
2018
use crate::tools::time;
2119

2220
#[derive(Debug)]
@@ -126,12 +124,12 @@ impl Client {
126124
);
127125
let res = match security {
128126
ConnectionSecurity::Tls => {
129-
Client::connect_secure(resolved_addr, host, strict_tls).await
127+
Client::connect_secure(context, resolved_addr, host, strict_tls).await
130128
}
131129
ConnectionSecurity::Starttls => {
132-
Client::connect_starttls(resolved_addr, host, strict_tls).await
130+
Client::connect_starttls(context, resolved_addr, host, strict_tls).await
133131
}
134-
ConnectionSecurity::Plain => Client::connect_insecure(resolved_addr).await,
132+
ConnectionSecurity::Plain => Client::connect_insecure(context, resolved_addr).await,
135133
};
136134
match res {
137135
Ok(client) => {
@@ -202,8 +200,17 @@ impl Client {
202200
}
203201
}
204202

205-
async fn connect_secure(addr: SocketAddr, hostname: &str, strict_tls: bool) -> Result<Self> {
206-
let tls_stream = connect_tls_inner(addr, hostname, strict_tls, alpn(addr.port())).await?;
203+
async fn connect_secure(
204+
context: &Context,
205+
addr: SocketAddr,
206+
hostname: &str,
207+
strict_tls: bool,
208+
) -> Result<Self> {
209+
let tcp_stream = connect_tcp_inner(addr).await?;
210+
let account_id = context.get_id();
211+
let events = context.events.clone();
212+
let logging_stream = LoggingStream::new(tcp_stream, account_id, events);
213+
let tls_stream = wrap_tls(strict_tls, hostname, alpn(addr.port()), logging_stream).await?;
207214
let buffered_stream = BufWriter::new(tls_stream);
208215
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
209216
let mut client = Client::new(session_stream);
@@ -214,9 +221,12 @@ impl Client {
214221
Ok(client)
215222
}
216223

217-
async fn connect_insecure(addr: SocketAddr) -> Result<Self> {
224+
async fn connect_insecure(context: &Context, addr: SocketAddr) -> Result<Self> {
218225
let tcp_stream = connect_tcp_inner(addr).await?;
219-
let buffered_stream = BufWriter::new(tcp_stream);
226+
let account_id = context.get_id();
227+
let events = context.events.clone();
228+
let logging_stream = LoggingStream::new(tcp_stream, account_id, events);
229+
let buffered_stream = BufWriter::new(logging_stream);
220230
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
221231
let mut client = Client::new(session_stream);
222232
let _greeting = client
@@ -226,9 +236,18 @@ impl Client {
226236
Ok(client)
227237
}
228238

229-
async fn connect_starttls(addr: SocketAddr, host: &str, strict_tls: bool) -> Result<Self> {
239+
async fn connect_starttls(
240+
context: &Context,
241+
addr: SocketAddr,
242+
host: &str,
243+
strict_tls: bool,
244+
) -> Result<Self> {
230245
let tcp_stream = connect_tcp_inner(addr).await?;
231246

247+
let account_id = context.get_id();
248+
let events = context.events.clone();
249+
let tcp_stream = LoggingStream::new(tcp_stream, account_id, events);
250+
232251
// Run STARTTLS command and convert the client back into a stream.
233252
let buffered_tcp_stream = BufWriter::new(tcp_stream);
234253
let mut client = async_imap::Client::new(buffered_tcp_stream);
@@ -246,7 +265,6 @@ impl Client {
246265
let tls_stream = wrap_tls(strict_tls, host, &[], tcp_stream)
247266
.await
248267
.context("STARTTLS upgrade failed")?;
249-
250268
let buffered_stream = BufWriter::new(tls_stream);
251269
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
252270
let client = Client::new(session_stream);

src/log.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
use crate::context::Context;
66

7+
mod logging_stream;
8+
9+
pub(crate) use logging_stream::LoggingStream;
10+
711
macro_rules! info {
812
($ctx:expr, $msg:expr) => {
913
info!($ctx, $msg,)

src/log/logging_stream.rs

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
//! Stream that logs errors as events.
2+
//!
3+
//! This stream can be used to wrap IMAP,
4+
//! SMTP and HTTP streams so errors
5+
//! that occur are logged before
6+
//! they are processed.
7+
8+
use std::net::SocketAddr;
9+
use std::pin::Pin;
10+
use std::task::{Context, Poll};
11+
use std::time::Duration;
12+
13+
use anyhow::Result;
14+
use pin_project::pin_project;
15+
16+
use crate::events::{Event, EventType, Events};
17+
use crate::net::session::SessionStream;
18+
19+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20+
21+
#[derive(Debug)]
22+
struct Metrics {
23+
/// Total number of bytes read.
24+
pub total_read: usize,
25+
26+
/// Total number of bytes written.
27+
pub total_written: usize,
28+
}
29+
30+
impl Metrics {
31+
fn new() -> Self {
32+
Self {
33+
total_read: 0,
34+
total_written: 0,
35+
}
36+
}
37+
}
38+
39+
/// Stream that logs errors to the event channel.
40+
#[derive(Debug)]
41+
#[pin_project]
42+
pub(crate) struct LoggingStream<S: SessionStream> {
43+
#[pin]
44+
inner: S,
45+
46+
/// Account ID for logging.
47+
account_id: u32,
48+
49+
/// Event channel.
50+
events: Events,
51+
52+
/// Metrics for this stream.
53+
metrics: Metrics,
54+
}
55+
56+
impl<S: SessionStream> LoggingStream<S> {
57+
pub fn new(inner: S, account_id: u32, events: Events) -> Self {
58+
Self {
59+
inner,
60+
account_id,
61+
events,
62+
metrics: Metrics::new(),
63+
}
64+
}
65+
}
66+
67+
impl<S: SessionStream> AsyncRead for LoggingStream<S> {
68+
fn poll_read(
69+
self: Pin<&mut Self>,
70+
cx: &mut Context<'_>,
71+
buf: &mut ReadBuf<'_>,
72+
) -> Poll<std::io::Result<()>> {
73+
let this = self.project();
74+
let peer_addr = this.inner.peer_addr();
75+
let old_remaining = buf.remaining();
76+
77+
let res = this.inner.poll_read(cx, buf);
78+
79+
if let Poll::Ready(Err(ref err)) = res {
80+
if let Ok(peer_addr) = peer_addr {
81+
let log_message = format!(
82+
"Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.",
83+
this.metrics.total_read, this.metrics.total_written
84+
);
85+
this.events.emit(Event {
86+
id: *this.account_id,
87+
typ: EventType::Warning(log_message),
88+
});
89+
}
90+
}
91+
92+
let n = old_remaining - buf.remaining();
93+
if n > 0 {
94+
this.metrics.total_read = this.metrics.total_read.saturating_add(n);
95+
}
96+
97+
res
98+
}
99+
}
100+
101+
impl<S: SessionStream> AsyncWrite for LoggingStream<S> {
102+
fn poll_write(
103+
self: Pin<&mut Self>,
104+
cx: &mut std::task::Context<'_>,
105+
buf: &[u8],
106+
) -> Poll<std::io::Result<usize>> {
107+
let this = self.project();
108+
let res = this.inner.poll_write(cx, buf);
109+
if let Poll::Ready(Ok(n)) = res {
110+
this.metrics.total_written = this.metrics.total_written.saturating_add(n);
111+
}
112+
res
113+
}
114+
115+
fn poll_flush(
116+
self: Pin<&mut Self>,
117+
cx: &mut std::task::Context<'_>,
118+
) -> Poll<std::io::Result<()>> {
119+
self.project().inner.poll_flush(cx)
120+
}
121+
122+
fn poll_shutdown(
123+
self: Pin<&mut Self>,
124+
cx: &mut std::task::Context<'_>,
125+
) -> Poll<std::io::Result<()>> {
126+
self.project().inner.poll_shutdown(cx)
127+
}
128+
129+
fn poll_write_vectored(
130+
self: Pin<&mut Self>,
131+
cx: &mut Context<'_>,
132+
bufs: &[std::io::IoSlice<'_>],
133+
) -> Poll<std::io::Result<usize>> {
134+
let this = self.project();
135+
let res = this.inner.poll_write_vectored(cx, bufs);
136+
if let Poll::Ready(Ok(n)) = res {
137+
this.metrics.total_written = this.metrics.total_written.saturating_add(n);
138+
}
139+
res
140+
}
141+
142+
fn is_write_vectored(&self) -> bool {
143+
self.inner.is_write_vectored()
144+
}
145+
}
146+
147+
impl<S: SessionStream> SessionStream for LoggingStream<S> {
148+
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
149+
self.inner.set_read_timeout(timeout)
150+
}
151+
152+
fn peer_addr(&self) -> Result<SocketAddr> {
153+
self.inner.peer_addr()
154+
}
155+
}

src/net/session.rs

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,112 @@
1+
use anyhow::Result;
12
use fast_socks5::client::Socks5Stream;
3+
use std::net::SocketAddr;
24
use std::pin::Pin;
35
use std::time::Duration;
46
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter};
7+
use tokio::net::TcpStream;
58
use tokio_io_timeout::TimeoutStream;
69

710
pub(crate) trait SessionStream:
811
AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug
912
{
1013
/// Change the read timeout on the session stream.
1114
fn set_read_timeout(&mut self, timeout: Option<Duration>);
15+
16+
/// Returns the remote address that this stream is connected to.
17+
///
18+
/// If the connection is proxied, returns `None`.
19+
fn peer_addr(&self) -> Result<SocketAddr>;
1220
}
1321

1422
impl SessionStream for Box<dyn SessionStream> {
1523
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
1624
self.as_mut().set_read_timeout(timeout);
1725
}
26+
27+
fn peer_addr(&self) -> Result<SocketAddr> {
28+
let addr = self.as_ref().peer_addr()?;
29+
Ok(addr)
30+
}
1831
}
1932
impl<T: SessionStream> SessionStream for async_native_tls::TlsStream<T> {
2033
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
2134
self.get_mut().set_read_timeout(timeout);
2235
}
36+
37+
fn peer_addr(&self) -> Result<SocketAddr> {
38+
let addr = self.get_ref().peer_addr()?;
39+
Ok(addr)
40+
}
2341
}
2442
impl<T: SessionStream> SessionStream for tokio_rustls::client::TlsStream<T> {
2543
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
2644
self.get_mut().0.set_read_timeout(timeout);
2745
}
46+
fn peer_addr(&self) -> Result<SocketAddr> {
47+
let addr = self.get_ref().0.peer_addr()?;
48+
Ok(addr)
49+
}
2850
}
2951
impl<T: SessionStream> SessionStream for BufStream<T> {
3052
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
3153
self.get_mut().set_read_timeout(timeout);
3254
}
55+
56+
fn peer_addr(&self) -> Result<SocketAddr> {
57+
let addr = self.get_ref().peer_addr()?;
58+
Ok(addr)
59+
}
3360
}
3461
impl<T: SessionStream> SessionStream for BufWriter<T> {
3562
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
3663
self.get_mut().set_read_timeout(timeout);
3764
}
65+
66+
fn peer_addr(&self) -> Result<SocketAddr> {
67+
let addr = self.get_ref().peer_addr()?;
68+
Ok(addr)
69+
}
3870
}
39-
impl<T: AsyncRead + AsyncWrite + Send + Sync + std::fmt::Debug> SessionStream
40-
for Pin<Box<TimeoutStream<T>>>
41-
{
71+
impl SessionStream for Pin<Box<TimeoutStream<TcpStream>>> {
4272
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
4373
self.as_mut().set_read_timeout_pinned(timeout);
4474
}
75+
76+
fn peer_addr(&self) -> Result<SocketAddr> {
77+
let addr = self.get_ref().peer_addr()?;
78+
Ok(addr)
79+
}
4580
}
4681
impl<T: SessionStream> SessionStream for Socks5Stream<T> {
4782
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
4883
self.get_socket_mut().set_read_timeout(timeout)
4984
}
85+
86+
fn peer_addr(&self) -> Result<SocketAddr> {
87+
let addr = self.get_socket_ref().peer_addr()?;
88+
Ok(addr)
89+
}
5090
}
5191
impl<T: SessionStream> SessionStream for shadowsocks::ProxyClientStream<T> {
5292
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
5393
self.get_mut().set_read_timeout(timeout)
5494
}
95+
96+
fn peer_addr(&self) -> Result<SocketAddr> {
97+
let addr = self.get_ref().peer_addr()?;
98+
Ok(addr)
99+
}
55100
}
56101
impl<T: SessionStream> SessionStream for async_imap::DeflateStream<T> {
57102
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
58103
self.get_mut().set_read_timeout(timeout)
59104
}
105+
106+
fn peer_addr(&self) -> Result<SocketAddr> {
107+
let addr = self.get_ref().peer_addr()?;
108+
Ok(addr)
109+
}
60110
}
61111

62112
/// Session stream with a read buffer.

0 commit comments

Comments
 (0)