From 7f0841b9d6616368f02d04f67033f004962e6ee5 Mon Sep 17 00:00:00 2001 From: link2xt Date: Thu, 10 Jul 2025 01:02:16 +0000 Subject: [PATCH 1/2] feat: log the number of read/written bytes on IMAP stream read error --- src/imap/client.rs | 44 +++++++---- src/log.rs | 4 + src/log/logging_stream.rs | 155 ++++++++++++++++++++++++++++++++++++++ src/net/session.rs | 54 ++++++++++++- 4 files changed, 241 insertions(+), 16 deletions(-) create mode 100644 src/log/logging_stream.rs diff --git a/src/imap/client.rs b/src/imap/client.rs index 5fcadd5b7f..c413829b36 100644 --- a/src/imap/client.rs +++ b/src/imap/client.rs @@ -8,15 +8,13 @@ use tokio::io::BufWriter; use super::capabilities::Capabilities; use crate::context::Context; -use crate::log::{info, warn}; +use crate::log::{LoggingStream, info, warn}; use crate::login_param::{ConnectionCandidate, ConnectionSecurity}; use crate::net::dns::{lookup_host_with_cache, update_connect_timestamp}; use crate::net::proxy::ProxyConfig; use crate::net::session::SessionStream; use crate::net::tls::wrap_tls; -use crate::net::{ - connect_tcp_inner, connect_tls_inner, run_connection_attempts, update_connection_history, -}; +use crate::net::{connect_tcp_inner, run_connection_attempts, update_connection_history}; use crate::tools::time; #[derive(Debug)] @@ -126,12 +124,12 @@ impl Client { ); let res = match security { ConnectionSecurity::Tls => { - Client::connect_secure(resolved_addr, host, strict_tls).await + Client::connect_secure(context, resolved_addr, host, strict_tls).await } ConnectionSecurity::Starttls => { - Client::connect_starttls(resolved_addr, host, strict_tls).await + Client::connect_starttls(context, resolved_addr, host, strict_tls).await } - ConnectionSecurity::Plain => Client::connect_insecure(resolved_addr).await, + ConnectionSecurity::Plain => Client::connect_insecure(context, resolved_addr).await, }; match res { Ok(client) => { @@ -202,8 +200,17 @@ impl Client { } } - async fn connect_secure(addr: SocketAddr, hostname: &str, strict_tls: bool) -> Result { - let tls_stream = connect_tls_inner(addr, hostname, strict_tls, alpn(addr.port())).await?; + async fn connect_secure( + context: &Context, + addr: SocketAddr, + hostname: &str, + strict_tls: bool, + ) -> Result { + let tcp_stream = connect_tcp_inner(addr).await?; + let account_id = context.get_id(); + let events = context.events.clone(); + let logging_stream = LoggingStream::new(tcp_stream, account_id, events); + let tls_stream = wrap_tls(strict_tls, hostname, alpn(addr.port()), logging_stream).await?; let buffered_stream = BufWriter::new(tls_stream); let session_stream: Box = Box::new(buffered_stream); let mut client = Client::new(session_stream); @@ -214,9 +221,12 @@ impl Client { Ok(client) } - async fn connect_insecure(addr: SocketAddr) -> Result { + async fn connect_insecure(context: &Context, addr: SocketAddr) -> Result { let tcp_stream = connect_tcp_inner(addr).await?; - let buffered_stream = BufWriter::new(tcp_stream); + let account_id = context.get_id(); + let events = context.events.clone(); + let logging_stream = LoggingStream::new(tcp_stream, account_id, events); + let buffered_stream = BufWriter::new(logging_stream); let session_stream: Box = Box::new(buffered_stream); let mut client = Client::new(session_stream); let _greeting = client @@ -226,9 +236,18 @@ impl Client { Ok(client) } - async fn connect_starttls(addr: SocketAddr, host: &str, strict_tls: bool) -> Result { + async fn connect_starttls( + context: &Context, + addr: SocketAddr, + host: &str, + strict_tls: bool, + ) -> Result { let tcp_stream = connect_tcp_inner(addr).await?; + let account_id = context.get_id(); + let events = context.events.clone(); + let tcp_stream = LoggingStream::new(tcp_stream, account_id, events); + // Run STARTTLS command and convert the client back into a stream. let buffered_tcp_stream = BufWriter::new(tcp_stream); let mut client = async_imap::Client::new(buffered_tcp_stream); @@ -246,7 +265,6 @@ impl Client { let tls_stream = wrap_tls(strict_tls, host, &[], tcp_stream) .await .context("STARTTLS upgrade failed")?; - let buffered_stream = BufWriter::new(tls_stream); let session_stream: Box = Box::new(buffered_stream); let client = Client::new(session_stream); diff --git a/src/log.rs b/src/log.rs index b4edc82a3a..97fce60a94 100644 --- a/src/log.rs +++ b/src/log.rs @@ -4,6 +4,10 @@ use crate::context::Context; +mod logging_stream; + +pub(crate) use logging_stream::LoggingStream; + macro_rules! info { ($ctx:expr, $msg:expr) => { info!($ctx, $msg,) diff --git a/src/log/logging_stream.rs b/src/log/logging_stream.rs new file mode 100644 index 0000000000..4065502756 --- /dev/null +++ b/src/log/logging_stream.rs @@ -0,0 +1,155 @@ +//! Stream that logs errors as events. +//! +//! This stream can be used to wrap IMAP, +//! SMTP and HTTP streams so errors +//! that occur are logged before +//! they are processed. + +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use anyhow::Result; +use pin_project::pin_project; + +use crate::events::{Event, EventType, Events}; +use crate::net::session::SessionStream; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +#[derive(Debug)] +struct Metrics { + /// Total number of bytes read. + pub total_read: usize, + + /// Total number of bytes written. + pub total_written: usize, +} + +impl Metrics { + fn new() -> Self { + Self { + total_read: 0, + total_written: 0, + } + } +} + +/// Stream that logs errors to the event channel. +#[derive(Debug)] +#[pin_project] +pub(crate) struct LoggingStream { + #[pin] + inner: S, + + /// Account ID for logging. + account_id: u32, + + /// Event channel. + events: Events, + + /// Metrics for this stream. + metrics: Metrics, +} + +impl LoggingStream { + pub fn new(inner: S, account_id: u32, events: Events) -> Self { + Self { + inner, + account_id, + events, + metrics: Metrics::new(), + } + } +} + +impl AsyncRead for LoggingStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.project(); + let peer_addr = this.inner.peer_addr(); + let old_remaining = buf.remaining(); + + let res = this.inner.poll_read(cx, buf); + + if let Poll::Ready(Err(ref err)) = res { + if let Ok(peer_addr) = peer_addr { + let log_message = format!( + "Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.", + this.metrics.total_read, this.metrics.total_written + ); + this.events.emit(Event { + id: *this.account_id, + typ: EventType::Warning(log_message), + }); + } + } + + let n = old_remaining - buf.remaining(); + if n > 0 { + this.metrics.total_read = this.metrics.total_read.saturating_add(n); + } + + res + } +} + +impl AsyncWrite for LoggingStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + let res = this.inner.poll_write(cx, buf); + if let Poll::Ready(Ok(n)) = res { + this.metrics.total_written = this.metrics.total_written.saturating_add(n); + } + res + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().inner.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + let this = self.project(); + let res = this.inner.poll_write_vectored(cx, bufs); + if let Poll::Ready(Ok(n)) = res { + this.metrics.total_written = this.metrics.total_written.saturating_add(n); + } + res + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +impl SessionStream for LoggingStream { + fn set_read_timeout(&mut self, timeout: Option) { + self.inner.set_read_timeout(timeout) + } + + fn peer_addr(&self) -> Result { + self.inner.peer_addr() + } +} diff --git a/src/net/session.rs b/src/net/session.rs index 8c56f2bbcf..b3f09f6083 100644 --- a/src/net/session.rs +++ b/src/net/session.rs @@ -1,7 +1,10 @@ +use anyhow::Result; use fast_socks5::client::Socks5Stream; +use std::net::SocketAddr; use std::pin::Pin; use std::time::Duration; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter}; +use tokio::net::TcpStream; use tokio_io_timeout::TimeoutStream; pub(crate) trait SessionStream: @@ -9,54 +12,99 @@ pub(crate) trait SessionStream: { /// Change the read timeout on the session stream. fn set_read_timeout(&mut self, timeout: Option); + + /// Returns the remote address that this stream is connected to. + fn peer_addr(&self) -> Result; } impl SessionStream for Box { fn set_read_timeout(&mut self, timeout: Option) { self.as_mut().set_read_timeout(timeout); } + + fn peer_addr(&self) -> Result { + let addr = self.as_ref().peer_addr()?; + Ok(addr) + } } impl SessionStream for async_native_tls::TlsStream { fn set_read_timeout(&mut self, timeout: Option) { self.get_mut().set_read_timeout(timeout); } + + fn peer_addr(&self) -> Result { + let addr = self.get_ref().peer_addr()?; + Ok(addr) + } } impl SessionStream for tokio_rustls::client::TlsStream { fn set_read_timeout(&mut self, timeout: Option) { self.get_mut().0.set_read_timeout(timeout); } + fn peer_addr(&self) -> Result { + let addr = self.get_ref().0.peer_addr()?; + Ok(addr) + } } impl SessionStream for BufStream { fn set_read_timeout(&mut self, timeout: Option) { self.get_mut().set_read_timeout(timeout); } + + fn peer_addr(&self) -> Result { + let addr = self.get_ref().peer_addr()?; + Ok(addr) + } } impl SessionStream for BufWriter { fn set_read_timeout(&mut self, timeout: Option) { self.get_mut().set_read_timeout(timeout); } + + fn peer_addr(&self) -> Result { + let addr = self.get_ref().peer_addr()?; + Ok(addr) + } } -impl SessionStream - for Pin>> -{ +impl SessionStream for Pin>> { fn set_read_timeout(&mut self, timeout: Option) { self.as_mut().set_read_timeout_pinned(timeout); } + + fn peer_addr(&self) -> Result { + let addr = self.get_ref().peer_addr()?; + Ok(addr) + } } impl SessionStream for Socks5Stream { fn set_read_timeout(&mut self, timeout: Option) { self.get_socket_mut().set_read_timeout(timeout) } + + fn peer_addr(&self) -> Result { + let addr = self.get_socket_ref().peer_addr()?; + Ok(addr) + } } impl SessionStream for shadowsocks::ProxyClientStream { fn set_read_timeout(&mut self, timeout: Option) { self.get_mut().set_read_timeout(timeout) } + + fn peer_addr(&self) -> Result { + let addr = self.get_ref().peer_addr()?; + Ok(addr) + } } impl SessionStream for async_imap::DeflateStream { fn set_read_timeout(&mut self, timeout: Option) { self.get_mut().set_read_timeout(timeout) } + + fn peer_addr(&self) -> Result { + let addr = self.get_ref().peer_addr()?; + Ok(addr) + } } /// Session stream with a read buffer. From db4fc91be03984a5d52e94869346ae6a420644ca Mon Sep 17 00:00:00 2001 From: link2xt Date: Wed, 16 Jul 2025 13:12:32 +0000 Subject: [PATCH 2/2] check that connection has peer address --- src/log/logging_stream.rs | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/log/logging_stream.rs b/src/log/logging_stream.rs index 4065502756..6b508d9bbe 100644 --- a/src/log/logging_stream.rs +++ b/src/log/logging_stream.rs @@ -77,16 +77,23 @@ impl AsyncRead for LoggingStream { let res = this.inner.poll_read(cx, buf); if let Poll::Ready(Err(ref err)) = res { - if let Ok(peer_addr) = peer_addr { - let log_message = format!( + debug_assert!( + peer_addr.is_ok(), + "Logging stream should be created over bound sockets" + ); + let log_message = match peer_addr { + Ok(peer_addr) => format!( "Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.", this.metrics.total_read, this.metrics.total_written - ); - this.events.emit(Event { - id: *this.account_id, - typ: EventType::Warning(log_message), - }); - } + ), + Err(_) => { + format!("Read error on a stream that does not have a peer address: {err}.") + } + }; + this.events.emit(Event { + id: *this.account_id, + typ: EventType::Warning(log_message), + }); } let n = old_remaining - buf.remaining();