From 6aaed215b6e3088e7e52864d7b84154bccc884af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 21 May 2025 15:00:14 +0200 Subject: [PATCH 01/80] Separate code path for websocket handshake --- iroh-relay/src/client/conn.rs | 103 +++++++++++++++++++++++---- iroh-relay/src/protos/relay.rs | 11 +-- iroh-relay/src/server/http_server.rs | 6 +- 3 files changed, 100 insertions(+), 20 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index e224e3374bd..ba1480526db 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -11,7 +11,7 @@ use std::{ use anyhow::{bail, Result}; use bytes::Bytes; use iroh_base::{NodeId, SecretKey}; -use n0_future::{time::Duration, Sink, Stream}; +use n0_future::{time::Duration, Sink, SinkExt, Stream}; #[cfg(not(wasm_browser))] use tokio_util::codec::Framed; use tracing::debug; @@ -118,14 +118,25 @@ impl Conn { key_cache: KeyCache, secret_key: &SecretKey, ) -> Result { + use n0_future::SinkExt; + let conn = Framed::new(conn, RelayCodec::new(key_cache)); - let mut conn = Self::Relay { conn }; + let mut conn = conn.sink_err_into(); // exchange information with the server - server_handshake(&mut conn, secret_key).await?; - - Ok(conn) + debug!("server_handshake: started"); + let client_info = ClientInfo { + version: PROTOCOL_VERSION, + }; + debug!("server_handshake: sending client_key: {:?}", &client_info); + #[allow(deprecated)] + crate::protos::relay::legacy_send_client_key(&mut conn, secret_key, &client_info).await?; + debug!("server_handshake: done"); + + Ok(Self::Relay { + conn: conn.into_inner(), + }) } #[cfg(not(wasm_browser))] @@ -134,23 +145,86 @@ impl Conn { key_cache: KeyCache, secret_key: &SecretKey, ) -> Result { - let mut conn = Self::Ws { conn, key_cache }; + let mut io = HandshakeIo { io: conn }; // exchange information with the server - server_handshake(&mut conn, secret_key).await?; + server_handshake(&mut io, secret_key).await?; - Ok(conn) + Ok(Self::Ws { + conn: io.io, + key_cache, + }) + } +} + +#[derive(derive_more::Debug)] +struct HandshakeIo { + #[cfg(not(wasm_browser))] + #[debug("WebSocketStream>")] + io: tokio_websockets::WebSocketStream>, + #[cfg(wasm_browser)] + #[debug("WebSocketStream")] + io: ws_stream_wasm::WsStream, +} + +impl Stream for HandshakeIo { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(Pin::new(&mut self.io).poll_next(cx)) { + None => return Poll::Ready(None), + Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))), + Some(Ok(msg)) => { + if msg.is_close() { + // Indicate the stream is done when we receive a close message. + // Note: We don't have to poll the stream to completion for it to close gracefully. + return Poll::Ready(None); + } + if msg.is_ping() || msg.is_pong() { + continue; // Responding appropriately to these is done inside of tokio_websockets/browser impls + } + if !msg.is_binary() { + tracing::warn!( + ?msg, + "Got websocket message of unsupported type, skipping." + ); + continue; + } + return Poll::Ready(Some(Ok(msg.into_payload().into()))); + } + } + } + } +} + +impl Sink for HandshakeIo { + type Error = anyhow::Error; + + fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { + #[cfg(not(wasm_browser))] + let msg = tokio_websockets::Message::binary(tokio_websockets::Payload::from(bytes)); + #[cfg(wasm_browser)] + let msg = ws_stream_wasm::WsMessage::Binary(bytes.to_vec()); + Pin::new(&mut self.io).start_send(msg).map_err(Into::into) + } + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_ready(cx).map_err(Into::into) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_flush(cx).map_err(Into::into) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_close(cx).map_err(Into::into) } } /// Sends the server handshake message. -async fn server_handshake(writer: &mut Conn, secret_key: &SecretKey) -> Result<()> { +async fn server_handshake(io: &mut HandshakeIo, secret_key: &SecretKey) -> Result<()> { debug!("server_handshake: started"); - let client_info = ClientInfo { - version: PROTOCOL_VERSION, - }; - debug!("server_handshake: sending client_key: {:?}", &client_info); - crate::protos::relay::send_client_key(&mut *writer, secret_key, &client_info).await?; debug!("server_handshake: done"); Ok(()) @@ -213,6 +287,7 @@ impl Stream for Conn { } } +// TODO(matheus23): Remove this impl, make `new_relay` work on the `Framed` directly, make the impl not rely on `ConnSendError`. impl Sink for Conn { type Error = ConnSendError; diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 76668e77714..1c8717d19f6 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -149,7 +149,8 @@ pub(crate) async fn write_frame + Unpin>( /// and the client's [`ClientInfo`], sealed using the server's [`PublicKey`]. /// /// Flushes after writing. -pub(crate) async fn send_client_key + Unpin>( +#[deprecated = "switch to proper handshake"] +pub(crate) async fn legacy_send_client_key + Unpin>( mut writer: S, client_secret_key: &SecretKey, client_info: &ClientInfo, @@ -171,7 +172,8 @@ pub(crate) async fn send_client_key + Unpi /// Reads the `FrameType::ClientInfo` frame from the client (its proof of identity) /// upon it's initial connection. #[cfg(any(test, feature = "server"))] -pub(crate) async fn recv_client_key> + Unpin>( +#[deprecated = "switch to proper handshake"] +pub(crate) async fn legacy_recv_client_key> + Unpin>( stream: S, ) -> anyhow::Result<(PublicKey, ClientInfo)> { use anyhow::Context; @@ -630,6 +632,7 @@ mod tests { } #[tokio::test] + #[allow(deprecated)] async fn test_send_recv_client_key() -> anyhow::Result<()> { let (reader, writer) = tokio::io::duplex(1024); let mut reader = FramedRead::new(reader, RelayCodec::test()); @@ -641,8 +644,8 @@ mod tests { version: PROTOCOL_VERSION, }; println!("client_key pub {:?}", client_key.public()); - send_client_key(&mut writer, &client_key, &client_info).await?; - let (client_pub_key, got_client_info) = recv_client_key(&mut reader).await?; + legacy_send_client_key(&mut writer, &client_key, &client_info).await?; + let (client_pub_key, got_client_info) = legacy_recv_client_key(&mut reader).await?; assert_eq!(client_key.public(), client_pub_key); assert_eq!(client_info, got_client_info); Ok(()) diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index ca4ea755edb..e891afbb2b4 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -20,11 +20,12 @@ use tokio_util::{codec::Framed, sync::CancellationToken, task::AbortOnDropHandle use tracing::{debug, debug_span, error, info, info_span, trace, warn, Instrument}; use super::{clients::Clients, AccessConfig}; +#[allow(deprecated)] use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, http::{Protocol, LEGACY_RELAY_PATH, RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, protos::relay::{ - recv_client_key, Frame, RelayCodec, PER_CLIENT_SEND_QUEUE_DEPTH, PROTOCOL_VERSION, + legacy_recv_client_key, Frame, RelayCodec, PER_CLIENT_SEND_QUEUE_DEPTH, PROTOCOL_VERSION, }, server::{ client::Config, @@ -541,7 +542,8 @@ impl Inner { } }; trace!("accept: recv client key"); - let (client_key, info) = recv_client_key(&mut io) + #[allow(deprecated)] + let (client_key, info) = legacy_recv_client_key(&mut io) .await .context("unable to receive client information")?; From 55fd3b28d7ef0856b872d4c54a5ffa581ed5b350 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 21 May 2025 19:07:59 +0200 Subject: [PATCH 02/80] Implement full handshake protocol for websocket transport only --- iroh-relay/src/client.rs | 9 -- iroh-relay/src/client/conn.rs | 96 +++------------- iroh-relay/src/protos.rs | 2 + iroh-relay/src/protos/handshake.rs | 160 +++++++++++++++++++++++++++ iroh-relay/src/protos/io.rs | 75 +++++++++++++ iroh-relay/src/quic.rs | 8 +- iroh-relay/src/server/http_server.rs | 48 +++++--- 7 files changed, 289 insertions(+), 109 deletions(-) create mode 100644 iroh-relay/src/protos/handshake.rs create mode 100644 iroh-relay/src/protos/io.rs diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index bf0d4680652..5c444575757 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -43,8 +43,6 @@ pub struct ClientBuilder { /// Default is None #[debug("address family selector callback")] address_family_selector: Option bool + Send + Sync>>, - /// Default is false - is_prober: bool, /// Server url. url: RelayUrl, /// Relay protocol @@ -72,7 +70,6 @@ impl ClientBuilder { ) -> Self { ClientBuilder { address_family_selector: None, - is_prober: false, url: url.into(), // Resolves to websockets in browsers and relay otherwise @@ -110,12 +107,6 @@ impl ClientBuilder { self } - /// Indicates this client is a prober - pub fn is_prober(mut self, is: bool) -> Self { - self.is_prober = is; - self - } - /// Skip the verification of the relay server's SSL certificates. /// /// May only be used in tests. diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index ba1480526db..5d3aeb1e340 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -11,16 +11,20 @@ use std::{ use anyhow::{bail, Result}; use bytes::Bytes; use iroh_base::{NodeId, SecretKey}; -use n0_future::{time::Duration, Sink, SinkExt, Stream}; +use n0_future::{time::Duration, Sink, Stream}; #[cfg(not(wasm_browser))] use tokio_util::codec::Framed; use tracing::debug; use super::KeyCache; -use crate::protos::relay::{ClientInfo, Frame, MAX_PACKET_SIZE, PROTOCOL_VERSION}; +use crate::protos::{ + handshake, + relay::{ClientInfo, Frame, MAX_PACKET_SIZE, PROTOCOL_VERSION}, +}; #[cfg(not(wasm_browser))] use crate::{ client::streams::{MaybeTlsStream, MaybeTlsStreamChained, ProxyStream}, + protos::io::HandshakeIo, protos::relay::RelayCodec, }; @@ -103,12 +107,17 @@ impl Conn { key_cache: KeyCache, secret_key: &SecretKey, ) -> Result { - let mut conn = Self::WsBrowser { conn, key_cache }; + let mut io = HandshakeIo { io: conn }; // exchange information with the server - server_handshake(&mut conn, secret_key).await?; + debug!("server_handshake: started"); + handshake::clientside(&mut io, secret_key).await?; + debug!("server_handshake: done"); - Ok(conn) + Ok(Self::WsBrowser { + conn: io.io, + key_cache, + }) } /// Constructs a new websocket connection, including the initial server handshake. @@ -148,7 +157,9 @@ impl Conn { let mut io = HandshakeIo { io: conn }; // exchange information with the server - server_handshake(&mut io, secret_key).await?; + debug!("server_handshake: started"); + handshake::clientside(&mut io, secret_key).await?; + debug!("server_handshake: done"); Ok(Self::Ws { conn: io.io, @@ -157,79 +168,6 @@ impl Conn { } } -#[derive(derive_more::Debug)] -struct HandshakeIo { - #[cfg(not(wasm_browser))] - #[debug("WebSocketStream>")] - io: tokio_websockets::WebSocketStream>, - #[cfg(wasm_browser)] - #[debug("WebSocketStream")] - io: ws_stream_wasm::WsStream, -} - -impl Stream for HandshakeIo { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - match ready!(Pin::new(&mut self.io).poll_next(cx)) { - None => return Poll::Ready(None), - Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))), - Some(Ok(msg)) => { - if msg.is_close() { - // Indicate the stream is done when we receive a close message. - // Note: We don't have to poll the stream to completion for it to close gracefully. - return Poll::Ready(None); - } - if msg.is_ping() || msg.is_pong() { - continue; // Responding appropriately to these is done inside of tokio_websockets/browser impls - } - if !msg.is_binary() { - tracing::warn!( - ?msg, - "Got websocket message of unsupported type, skipping." - ); - continue; - } - return Poll::Ready(Some(Ok(msg.into_payload().into()))); - } - } - } - } -} - -impl Sink for HandshakeIo { - type Error = anyhow::Error; - - fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { - #[cfg(not(wasm_browser))] - let msg = tokio_websockets::Message::binary(tokio_websockets::Payload::from(bytes)); - #[cfg(wasm_browser)] - let msg = ws_stream_wasm::WsMessage::Binary(bytes.to_vec()); - Pin::new(&mut self.io).start_send(msg).map_err(Into::into) - } - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.io).poll_ready(cx).map_err(Into::into) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.io).poll_flush(cx).map_err(Into::into) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.io).poll_close(cx).map_err(Into::into) - } -} - -/// Sends the server handshake message. -async fn server_handshake(io: &mut HandshakeIo, secret_key: &SecretKey) -> Result<()> { - debug!("server_handshake: started"); - - debug!("server_handshake: done"); - Ok(()) -} - impl Stream for Conn { type Item = Result; diff --git a/iroh-relay/src/protos.rs b/iroh-relay/src/protos.rs index 82bea8180d1..1e1866dc966 100644 --- a/iroh-relay/src/protos.rs +++ b/iroh-relay/src/protos.rs @@ -1,5 +1,7 @@ //! Protocols used by the iroh-relay pub mod disco; +pub mod handshake; +pub mod io; pub mod relay; pub mod stun; diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs new file mode 100644 index 00000000000..ca397b58878 --- /dev/null +++ b/iroh-relay/src/protos/handshake.rs @@ -0,0 +1,160 @@ +//! TODO(matheus23) docs + +use anyhow::Result; +use bytes::{BufMut, Bytes, BytesMut}; +use iroh_base::{PublicKey, SecretKey, Signature}; +use n0_future::{time, Sink, SinkExt, Stream, TryStreamExt}; +use quinn_proto::{coding::Codec, VarInt}; +use rand::{CryptoRng, RngCore}; + +/// TODO(matheus23) docs +pub const PROTOCOL_VERSION: &[u8] = b"1"; + +/// A challenge for the client to sign with their secret key for NodeId authentication. +#[derive(derive_more::Debug, serde::Deserialize)] +#[cfg_attr(feature = "server", derive(serde::Serialize))] +pub struct ServerChallenge { + /// The challenge to sign. + /// Must be randomly generated with an RNG that is safe to use for crypto. + pub challenge: [u8; 16], +} + +const SERVER_CHALLENGE_TAG: VarInt = VarInt::from_u32(1); + +/// Info about the client. Also serves as authentication. +#[derive(derive_more::Debug, serde::Serialize)] +#[cfg_attr(feature = "server", derive(serde::Deserialize))] +pub struct ClientInfo { + /// The client's public key, a.k.a. the `NodeId` + pub public_key: PublicKey, + /// A signature of the server challenge, serves as authentication. + pub signature: Signature, + /// Part of the extracted key material, if that's what was signed. + pub key_material_suffix: Option<[u8; 16]>, + /// Supported versions/protocol features for version negotiation + /// with other connected relay clients + pub versions: Vec>, +} + +const CLIENT_INFO_TAG: VarInt = VarInt::from_u32(2); + +/// TODO(matheus23) docs +pub trait BytesStreamSink: + Stream> + Sink + Unpin +{ +} + +impl> + Sink + Unpin> BytesStreamSink + for T +{ +} + +/// TODO(matheus23) docs +pub async fn clientside(io: &mut impl BytesStreamSink, secret_key: &SecretKey) -> Result<()> { + let challenge: ServerChallenge = + read_postcard_frame(io, SERVER_CHALLENGE_TAG, time::Duration::from_secs(30)).await?; + + let client_info = ClientInfo { + public_key: secret_key.public(), + signature: secret_key.sign(&challenge.challenge), // TODO(matheus23) add some context to the signature, so we're not signing arbitrary stuff + key_material_suffix: None, + versions: vec![PROTOCOL_VERSION.to_vec()], + }; + write_postcard_frame(io, CLIENT_INFO_TAG, client_info).await?; + + Ok(()) +} + +/// TODO(matheus23) docs +#[cfg(feature = "server")] +pub async fn serverside( + io: &mut impl BytesStreamSink, + mut rng: impl RngCore + CryptoRng, +) -> Result { + let mut challenge = [0u8; 16]; + rng.fill_bytes(&mut challenge); + + write_postcard_frame(io, SERVER_CHALLENGE_TAG, ServerChallenge { challenge }).await?; + + let client_info: ClientInfo = + read_postcard_frame(io, CLIENT_INFO_TAG, time::Duration::from_secs(10)).await?; + + // TODO(matheus23): Add context bytes to this verification check + client_info + .public_key + .verify(&challenge, &client_info.signature)?; + + Ok(client_info) +} + +async fn write_postcard_frame( + io: &mut impl BytesStreamSink, + tag: VarInt, + frame: impl serde::Serialize, +) -> Result<()> { + let mut bytes = BytesMut::new(); + tag.encode(&mut bytes); + let bytes = postcard::to_io(&frame, bytes.writer())? + .into_inner() + .freeze(); + io.send(bytes).await?; + io.flush().await?; + Ok(()) +} + +async fn read_postcard_frame( + io: &mut impl BytesStreamSink, + expected_tag: VarInt, + timeout: time::Duration, +) -> Result { + let recv = time::timeout(timeout, io.try_next()) + .await?? + .ok_or_else(|| anyhow::anyhow!("disconnected"))?; + let mut cursor = std::io::Cursor::new(recv); + let tag = VarInt::decode(&mut cursor)?; + anyhow::ensure!(tag == expected_tag); + let start = cursor.position() as usize; + let frame: F = postcard::from_bytes( + &cursor + .into_inner() + .get(start..) + .expect("cursor confirmed position"), + )?; + + Ok(frame) +} + +#[cfg(test)] +mod tests { + use bytes::BytesMut; + use iroh_base::SecretKey; + use n0_future::{SinkExt, TryStreamExt}; + use testresult::TestResult; + use tokio_util::codec::{Framed, LengthDelimitedCodec}; + + #[tokio::test] + #[cfg(feature = "server")] + async fn simulate_handshake() -> TestResult { + let (client, server) = tokio::io::duplex(1024); + let secret_key = SecretKey::generate(rand::rngs::OsRng); + + let mut client_io = Framed::new(client, LengthDelimitedCodec::new()) + .map_ok(BytesMut::freeze) + .map_err(anyhow::Error::from) + .sink_err_into(); + let mut server_io = Framed::new(server, LengthDelimitedCodec::new()) + .map_ok(BytesMut::freeze) + .map_err(anyhow::Error::from) + .sink_err_into(); + + let (_, client_info) = n0_future::future::try_zip( + super::clientside(&mut client_io, &secret_key), + super::serverside(&mut server_io, rand::rngs::OsRng), + ) + .await?; + + println!("{client_info:#?}"); + + Ok(()) + } +} diff --git a/iroh-relay/src/protos/io.rs b/iroh-relay/src/protos/io.rs new file mode 100644 index 00000000000..15922c1ea67 --- /dev/null +++ b/iroh-relay/src/protos/io.rs @@ -0,0 +1,75 @@ +//! TODO(matheus23) docs +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use anyhow::Result; +use bytes::Bytes; +use n0_future::{ready, Sink, Stream}; +use tokio::io::{AsyncRead, AsyncWrite}; + +#[derive(derive_more::Debug)] +pub(crate) struct HandshakeIo { + #[cfg(not(wasm_browser))] + #[debug("WebSocketStream>")] + pub(crate) io: tokio_websockets::WebSocketStream, + #[cfg(wasm_browser)] + #[debug("WebSocketStream")] + pub(crate) io: ws_stream_wasm::WsStream, +} + +impl Stream for HandshakeIo { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(Pin::new(&mut self.io).poll_next(cx)) { + None => return Poll::Ready(None), + Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))), + Some(Ok(msg)) => { + if msg.is_close() { + // Indicate the stream is done when we receive a close message. + // Note: We don't have to poll the stream to completion for it to close gracefully. + return Poll::Ready(None); + } + if msg.is_ping() || msg.is_pong() { + continue; // Responding appropriately to these is done inside of tokio_websockets/browser impls + } + if !msg.is_binary() { + tracing::warn!( + ?msg, + "Got websocket message of unsupported type, skipping." + ); + continue; + } + return Poll::Ready(Some(Ok(msg.into_payload().into()))); + } + } + } + } +} + +impl Sink for HandshakeIo { + type Error = anyhow::Error; + + fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { + #[cfg(not(wasm_browser))] + let msg = tokio_websockets::Message::binary(tokio_websockets::Payload::from(bytes)); + #[cfg(wasm_browser)] + let msg = ws_stream_wasm::WsMessage::Binary(bytes.to_vec()); + Pin::new(&mut self.io).start_send(msg).map_err(Into::into) + } + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_ready(cx).map_err(Into::into) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_flush(cx).map_err(Into::into) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_close(cx).map_err(Into::into) + } +} diff --git a/iroh-relay/src/quic.rs b/iroh-relay/src/quic.rs index 678d1e5910e..d22564cec8a 100644 --- a/iroh-relay/src/quic.rs +++ b/iroh-relay/src/quic.rs @@ -292,14 +292,14 @@ mod tests { use tracing_test::traced_test; use webpki_types::PrivatePkcs8KeyDer; - use super::{ - server::{QuicConfig, QuicServer}, - *, - }; + use super::*; #[tokio::test] #[traced_test] + #[cfg(feature = "test-utils")] async fn quic_endpoint_basic() -> anyhow::Result<()> { + use super::server::{QuicConfig, QuicServer}; + let host: Ipv4Addr = "127.0.0.1".parse()?; // create a server config with self signed certificates let (_, server_config) = super::super::server::testing::self_signed_tls_certs_and_config(); diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index e891afbb2b4..e7698b841a9 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -20,6 +20,7 @@ use tokio_util::{codec::Framed, sync::CancellationToken, task::AbortOnDropHandle use tracing::{debug, debug_span, error, info, info_span, trace, warn, Instrument}; use super::{clients::Clients, AccessConfig}; +use crate::protos::{handshake, io::HandshakeIo}; #[allow(deprecated)] use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, @@ -525,10 +526,27 @@ impl Inner { /// [`AsyncWrite`]: tokio::io::AsyncWrite async fn accept(&self, protocol: Protocol, io: MaybeTlsStream) -> Result<()> { trace!(?protocol, "accept: start"); - let mut io = match protocol { + let (client_key, mut io) = match protocol { Protocol::Relay => { self.metrics.relay_accepts.inc(); - RelayedStream::Relay(Framed::new(io, RelayCodec::new(self.key_cache.clone()))) + let mut io = + RelayedStream::Relay(Framed::new(io, RelayCodec::new(self.key_cache.clone()))); + + trace!("accept: recv client key"); + #[allow(deprecated)] + let (client_key, info) = legacy_recv_client_key(&mut io) + .await + .context("unable to receive client information")?; + + if info.version != PROTOCOL_VERSION { + bail!( + "unexpected client version {}, expected {}", + info.version, + PROTOCOL_VERSION + ); + } + + (client_key, io) } Protocol::Websocket => { self.metrics.websocket_accepts.inc(); @@ -538,14 +556,18 @@ impl Inner { let builder = tokio_websockets::ServerBuilder::new(); // Serve will create a WebSocketStream on an already upgraded connection let websocket = builder.serve(io); - RelayedStream::Ws(websocket, self.key_cache.clone()) + + // TODO(matheus23): Change to use `RelayedStream` or similar, so we inherit the rate limiting + let mut io = HandshakeIo { io: websocket }; + + let client_info = handshake::serverside(&mut io, rand::rngs::OsRng).await?; + + ( + client_info.public_key, + RelayedStream::Ws(io.io, self.key_cache.clone()), + ) } }; - trace!("accept: recv client key"); - #[allow(deprecated)] - let (client_key, info) = legacy_recv_client_key(&mut io) - .await - .context("unable to receive client information")?; trace!("accept: checking access: {:?}", self.access); if !self.access.is_allowed(client_key).await { @@ -555,15 +577,7 @@ impl Inner { .await?; io.flush().await?; - bail!("client is not authenticated: {}", client_key); - } - - if info.version != PROTOCOL_VERSION { - bail!( - "unexpected client version {}, expected {}", - info.version, - PROTOCOL_VERSION - ); + bail!("client is not authenticated: {client_key}"); } trace!("accept: build client conn"); From 3f52da0d1ccad50f8903e02e7157e7d263da76b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 23 May 2025 15:51:58 +0200 Subject: [PATCH 03/80] WIP --- iroh-relay/src/client/conn.rs | 18 +++++++++++------- iroh-relay/src/protos/relay.rs | 14 +++++++------- iroh-relay/src/server/http_server.rs | 1 - iroh-relay/src/server/streams.rs | 9 ++++++--- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 5d3aeb1e340..4bb0fd92c71 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -9,7 +9,7 @@ use std::{ }; use anyhow::{bail, Result}; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use iroh_base::{NodeId, SecretKey}; use n0_future::{time::Duration, Sink, Stream}; #[cfg(not(wasm_browser))] @@ -253,9 +253,11 @@ impl Sink for Conn { Self::Relay { ref mut conn } => Pin::new(conn).start_send(frame).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn) - .start_send(tokio_websockets::Message::binary( - tokio_websockets::Payload::from(frame.encode_for_ws_msg()), - )) + .start_send(tokio_websockets::Message::binary({ + let mut buf = BytesMut::new(); + frame.encode_for_ws_msg(&mut buf); + tokio_websockets::Payload::from(buf.freeze()) + })) .map_err(Into::into), #[cfg(wasm_browser)] Self::WsBrowser { ref mut conn, .. } => Pin::new(conn) @@ -319,9 +321,11 @@ impl Sink for Conn { Self::Relay { ref mut conn } => Pin::new(conn).start_send(frame).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn) - .start_send(tokio_websockets::Message::binary( - tokio_websockets::Payload::from(frame.encode_for_ws_msg()), - )) + .start_send(tokio_websockets::Message::binary({ + let mut buf = BytesMut::new(); + frame.encode_for_ws_msg(&mut buf); + tokio_websockets::Payload::from(buf.freeze()) + })) .map_err(Into::into), #[cfg(wasm_browser)] Self::WsBrowser { ref mut conn, .. } => Pin::new(conn) diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 1c8717d19f6..b295aee7c79 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -328,11 +328,9 @@ impl Frame { /// Encodes this frame for sending over websockets. /// /// Specifically meant for being put into a binary websocket message frame. - pub(crate) fn encode_for_ws_msg(self) -> Vec { - let mut bytes = Vec::new(); - bytes.put_u8(self.typ().into()); - self.write_to(&mut bytes); - bytes + pub(crate) fn encode_for_ws_msg(self, dst: &mut impl BufMut) { + dst.put_u8(self.typ().into()); + self.write_to(dst); } /// Writes it self to the given buffer. @@ -728,7 +726,8 @@ mod tests { ]; for (frame, expected_hex) in frames { - let bytes = frame.encode_for_ws_msg(); + let mut bytes = Vec::new(); + frame.encode_for_ws_msg(&mut bytes); let stripped: Vec = expected_hex .chars() .filter_map(|s| { @@ -854,7 +853,8 @@ mod proptests { #[test] fn frame_ws_roundtrip(frame in frame()) { - let encoded = frame.clone().encode_for_ws_msg(); + let mut encoded = Vec::new(); + frame.clone().encode_for_ws_msg(&mut encoded); let decoded = Frame::decode_from_ws_msg(Bytes::from(encoded), &KeyCache::test()).unwrap(); prop_assert_eq!(frame, decoded); } diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index e7698b841a9..c7fde07d5fe 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -557,7 +557,6 @@ impl Inner { // Serve will create a WebSocketStream on an already upgraded connection let websocket = builder.serve(io); - // TODO(matheus23): Change to use `RelayedStream` or similar, so we inherit the rate limiting let mut io = HandshakeIo { io: websocket }; let client_info = handshake::serverside(&mut io, rand::rngs::OsRng).await?; diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 94fa77735ef..804efc7fd2f 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -6,6 +6,7 @@ use std::{ }; use anyhow::Result; +use bytes::BytesMut; use n0_future::{Sink, Stream}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::Framed; @@ -46,9 +47,11 @@ impl Sink for RelayedStream { match *self { Self::Relay(ref mut framed) => Pin::new(framed).start_send(item), Self::Ws(ref mut ws, _) => Pin::new(ws) - .start_send(tokio_websockets::Message::binary( - tokio_websockets::Payload::from(item.encode_for_ws_msg()), - )) + .start_send(tokio_websockets::Message::binary({ + let mut buf = BytesMut::new(); + item.encode_for_ws_msg(&mut buf); + tokio_websockets::Payload::from(buf.freeze()) + })) .map_err(ws_to_io_err), } } From db787e1d9eeb756c998481d52eb8d0b7611e7fdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 23 May 2025 17:54:43 +0200 Subject: [PATCH 04/80] WIP --- Cargo.lock | 32 ++++ iroh-relay/Cargo.toml | 1 + iroh-relay/src/client/conn.rs | 12 +- iroh-relay/src/client/streams.rs | 19 +++ iroh-relay/src/lib.rs | 9 ++ iroh-relay/src/protos/handshake.rs | 252 ++++++++++++++++++++++++----- iroh-relay/src/protos/io.rs | 17 ++ iroh-relay/src/server/streams.rs | 20 ++- 8 files changed, 319 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94a9bc5f737..d1aab33f075 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -159,6 +159,18 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "asn1-rs" version = "0.6.2" @@ -448,6 +460,19 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +[[package]] +name = "blake3" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -666,6 +691,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "cordyceps" version = "0.3.3" @@ -2461,6 +2492,7 @@ version = "0.35.0" dependencies = [ "ahash", "anyhow", + "blake3", "bytes", "cfg_aliases", "clap", diff --git a/iroh-relay/Cargo.toml b/iroh-relay/Cargo.toml index 0e56551fb9a..8521d342f1d 100644 --- a/iroh-relay/Cargo.toml +++ b/iroh-relay/Cargo.toml @@ -95,6 +95,7 @@ toml = { version = "0.8", optional = true } tracing-subscriber = { version = "0.3", features = [ "env-filter", ], optional = true } +blake3 = "1.8.2" # non-wasm-in-browser dependencies [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 4bb0fd92c71..a647d2edb4f 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -12,21 +12,25 @@ use anyhow::{bail, Result}; use bytes::{Bytes, BytesMut}; use iroh_base::{NodeId, SecretKey}; use n0_future::{time::Duration, Sink, Stream}; +use tokio::io::{AsyncRead, AsyncWrite}; #[cfg(not(wasm_browser))] use tokio_util::codec::Framed; use tracing::debug; use super::KeyCache; -use crate::protos::{ - handshake, - relay::{ClientInfo, Frame, MAX_PACKET_SIZE, PROTOCOL_VERSION}, -}; #[cfg(not(wasm_browser))] use crate::{ client::streams::{MaybeTlsStream, MaybeTlsStreamChained, ProxyStream}, protos::io::HandshakeIo, protos::relay::RelayCodec, }; +use crate::{ + protos::{ + handshake, + relay::{ClientInfo, Frame, MAX_PACKET_SIZE, PROTOCOL_VERSION}, + }, + ExportKeyingMaterial, +}; /// Error for sending messages to the relay server. #[derive(Debug, thiserror::Error)] diff --git a/iroh-relay/src/client/streams.rs b/iroh-relay/src/client/streams.rs index 12e570b10cc..015d1fbda40 100644 --- a/iroh-relay/src/client/streams.rs +++ b/iroh-relay/src/client/streams.rs @@ -13,6 +13,8 @@ use tokio::{ net::TcpStream, }; +use crate::ExportKeyingMaterial; + use super::util; #[allow(clippy::large_enum_variant)] @@ -198,6 +200,23 @@ pub enum MaybeTlsStream { Tls(tokio_rustls::client::TlsStream), } +impl ExportKeyingMaterial for MaybeTlsStream { + fn export_keying_material>( + &self, + output: T, + label: &[u8], + context: Option<&[u8]>, + ) -> Option { + let Self::Tls(ref tls) = self else { + return None; + }; + tls.get_ref() + .1 + .export_keying_material(output, label, context) + .ok() + } +} + impl AsyncRead for MaybeTlsStream { fn poll_read( mut self: Pin<&mut Self>, diff --git a/iroh-relay/src/lib.rs b/iroh-relay/src/lib.rs index 71a18d7aaf3..232b794b0e4 100644 --- a/iroh-relay/src/lib.rs +++ b/iroh-relay/src/lib.rs @@ -54,3 +54,12 @@ pub use self::{ ping_tracker::PingTracker, relay_map::{RelayMap, RelayNode, RelayQuicConfig}, }; + +pub(crate) trait ExportKeyingMaterial { + fn export_keying_material>( + &self, + output: T, + label: &[u8], + context: Option<&[u8]>, + ) -> Option; +} diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index ca397b58878..50d92ee35f4 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -7,6 +7,8 @@ use n0_future::{time, Sink, SinkExt, Stream, TryStreamExt}; use quinn_proto::{coding::Codec, VarInt}; use rand::{CryptoRng, RngCore}; +use crate::ExportKeyingMaterial; + /// TODO(matheus23) docs pub const PROTOCOL_VERSION: &[u8] = b"1"; @@ -38,6 +40,20 @@ pub struct ClientInfo { const CLIENT_INFO_TAG: VarInt = VarInt::from_u32(2); +/// Confirmation of successful connection. +#[derive(derive_more::Debug, serde::Deserialize)] +#[cfg_attr(feature = "server", derive(serde::Serialize))] +pub struct ServerConfirmsConnected; + +const SERVER_CONFIRMS_CONNECTED_TAG: VarInt = VarInt::from_u32(3); + +/// Denial of connection. The client couldn't be verified as authentic. +#[derive(derive_more::Debug, serde::Deserialize)] +#[cfg_attr(feature = "server", derive(serde::Serialize))] +pub struct ServerDeniesConnection; + +const SERVER_DENIES_CONNECTION_TAG: VarInt = VarInt::from_u32(4); + /// TODO(matheus23) docs pub trait BytesStreamSink: Stream> + Sink + Unpin @@ -50,44 +66,113 @@ impl> + Sink + Unpi } /// TODO(matheus23) docs -pub async fn clientside(io: &mut impl BytesStreamSink, secret_key: &SecretKey) -> Result<()> { - let challenge: ServerChallenge = - read_postcard_frame(io, SERVER_CHALLENGE_TAG, time::Duration::from_secs(30)).await?; - - let client_info = ClientInfo { - public_key: secret_key.public(), - signature: secret_key.sign(&challenge.challenge), // TODO(matheus23) add some context to the signature, so we're not signing arbitrary stuff - key_material_suffix: None, - versions: vec![PROTOCOL_VERSION.to_vec()], +pub(crate) async fn clientside( + io: &mut (impl BytesStreamSink + ExportKeyingMaterial), + secret_key: &SecretKey, +) -> Result { + let public_key = secret_key.public(); + let versions = vec![PROTOCOL_VERSION.to_vec()]; + + let key_material = io.export_keying_material( + [0u8; 32], + b"iroh-relay handshake v1", + Some(secret_key.public().as_bytes()), + ); + + if let Some(key_material) = key_material { + write_frame( + io, + CLIENT_INFO_TAG, + ClientInfo { + public_key, + signature: secret_key.sign(&blake3::derive_key( + "iroh-relay handshake v1 key material signature", + &key_material[..16], + )), + key_material_suffix: Some(key_material[16..].try_into().expect("split right")), + versions: versions.clone(), + }, + ) + .await?; + } + + let (tag, frame) = read_frame( + io, + &[ + SERVER_CHALLENGE_TAG, + SERVER_CONFIRMS_CONNECTED_TAG, + SERVER_DENIES_CONNECTION_TAG, + ], + time::Duration::from_secs(30), + ) + .await?; + + let (tag, frame) = if tag == SERVER_CHALLENGE_TAG { + let challenge: ServerChallenge = postcard::from_bytes(&frame)?; + + let client_info = ClientInfo { + public_key, + signature: secret_key.sign(&blake3::derive_key( + "iroh-relay handshake v1 challenge signature", + &challenge.challenge, + )), + key_material_suffix: None, + versions, + }; + write_frame(io, CLIENT_INFO_TAG, client_info).await?; + + read_frame( + io, + &[SERVER_CONFIRMS_CONNECTED_TAG, SERVER_DENIES_CONNECTION_TAG], + time::Duration::from_secs(30), + ) + .await? + } else { + (tag, frame) }; - write_postcard_frame(io, CLIENT_INFO_TAG, client_info).await?; - Ok(()) + match tag { + SERVER_CONFIRMS_CONNECTED_TAG => { + let confirmation: ServerConfirmsConnected = postcard::from_bytes(&frame)?; + Ok(confirmation) + } + SERVER_DENIES_CONNECTION_TAG => { + let denial: ServerDeniesConnection = postcard::from_bytes(&frame)?; + anyhow::bail!("server denied connection: {denial:?}"); + } + _ => unreachable!(), + } } /// TODO(matheus23) docs #[cfg(feature = "server")] -pub async fn serverside( - io: &mut impl BytesStreamSink, +pub(crate) async fn serverside( + io: &mut (impl BytesStreamSink + ExportKeyingMaterial), mut rng: impl RngCore + CryptoRng, ) -> Result { let mut challenge = [0u8; 16]; rng.fill_bytes(&mut challenge); - write_postcard_frame(io, SERVER_CHALLENGE_TAG, ServerChallenge { challenge }).await?; + write_frame(io, SERVER_CHALLENGE_TAG, ServerChallenge { challenge }).await?; + + let (_, frame) = read_frame(io, &[CLIENT_INFO_TAG], time::Duration::from_secs(10)).await?; + let client_info: ClientInfo = postcard::from_bytes(&frame)?; - let client_info: ClientInfo = - read_postcard_frame(io, CLIENT_INFO_TAG, time::Duration::from_secs(10)).await?; + let result = client_info.public_key.verify( + &blake3::derive_key("iroh-relay handshake v1 challenge signature", &challenge), + &client_info.signature, + ); - // TODO(matheus23): Add context bytes to this verification check - client_info - .public_key - .verify(&challenge, &client_info.signature)?; + if result.is_ok() { + write_frame(io, SERVER_CONFIRMS_CONNECTED_TAG, ServerConfirmsConnected).await?; + } else { + write_frame(io, SERVER_DENIES_CONNECTION_TAG, ServerDeniesConnection).await?; + } Ok(client_info) } -async fn write_postcard_frame( +async fn write_frame( io: &mut impl BytesStreamSink, tag: VarInt, frame: impl serde::Serialize, @@ -102,54 +187,145 @@ async fn write_postcard_frame( Ok(()) } -async fn read_postcard_frame( +async fn read_frame( io: &mut impl BytesStreamSink, - expected_tag: VarInt, + expected_tags: &[VarInt], timeout: time::Duration, -) -> Result { +) -> Result<(VarInt, Bytes)> { let recv = time::timeout(timeout, io.try_next()) .await?? .ok_or_else(|| anyhow::anyhow!("disconnected"))?; + let mut cursor = std::io::Cursor::new(recv); let tag = VarInt::decode(&mut cursor)?; - anyhow::ensure!(tag == expected_tag); + anyhow::ensure!( + expected_tags.contains(&tag), + "Unexpected tag {tag}, expected one of {expected_tags:?}" + ); + let start = cursor.position() as usize; - let frame: F = postcard::from_bytes( - &cursor - .into_inner() - .get(start..) - .expect("cursor confirmed position"), - )?; - - Ok(frame) + let payload = cursor.into_inner().slice(start..); + + Ok((tag, payload)) } #[cfg(test)] mod tests { use bytes::BytesMut; use iroh_base::SecretKey; - use n0_future::{SinkExt, TryStreamExt}; + use n0_future::{Sink, SinkExt, Stream, TryStreamExt}; use testresult::TestResult; use tokio_util::codec::{Framed, LengthDelimitedCodec}; + use crate::ExportKeyingMaterial; + + struct TestKeyingMaterial { + shared_secret: Option, + inner: IO, + } + + trait WithTlsSharedSecret: Sized { + fn with_shared_secret(self, shared_secret: Option) -> TestKeyingMaterial; + } + + impl WithTlsSharedSecret for T { + fn with_shared_secret(self, shared_secret: Option) -> TestKeyingMaterial { + TestKeyingMaterial { + shared_secret, + inner: self, + } + } + } + + impl ExportKeyingMaterial for TestKeyingMaterial { + fn export_keying_material>( + &self, + mut output: T, + label: &[u8], + context: Option<&[u8]>, + ) -> Option { + // we simulate something like exporting keying material using blake3 + + let label_key = blake3::hash(label); + let context_key = blake3::keyed_hash(label_key.as_bytes(), context.unwrap_or(&[])); + let mut hasher = blake3::Hasher::new_keyed(context_key.as_bytes()); + hasher.update(&self.shared_secret?.to_le_bytes()); + hasher.finalize_xof().fill(output.as_mut()); + + Some(output) + } + } + + impl + Unpin> Stream for TestKeyingMaterial { + type Item = V; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.inner).poll_next(cx) + } + } + + impl + Unpin> Sink for TestKeyingMaterial { + type Error = E; + + fn poll_ready( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.inner).poll_ready(cx) + } + + fn start_send(mut self: std::pin::Pin<&mut Self>, item: V) -> Result<(), Self::Error> { + std::pin::Pin::new(&mut self.inner).start_send(item) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.inner).poll_close(cx) + } + } + #[tokio::test] #[cfg(feature = "server")] async fn simulate_handshake() -> TestResult { + use anyhow::Context; + let (client, server) = tokio::io::duplex(1024); let secret_key = SecretKey::generate(rand::rngs::OsRng); let mut client_io = Framed::new(client, LengthDelimitedCodec::new()) .map_ok(BytesMut::freeze) .map_err(anyhow::Error::from) - .sink_err_into(); + .sink_err_into() + .with_shared_secret(Some(42)); let mut server_io = Framed::new(server, LengthDelimitedCodec::new()) .map_ok(BytesMut::freeze) .map_err(anyhow::Error::from) - .sink_err_into(); + .sink_err_into() + .with_shared_secret(Some(42)); let (_, client_info) = n0_future::future::try_zip( - super::clientside(&mut client_io, &secret_key), - super::serverside(&mut server_io, rand::rngs::OsRng), + async { + super::clientside(&mut client_io, &secret_key) + .await + .context("clientside") + }, + async { + super::serverside(&mut server_io, rand::rngs::OsRng) + .await + .context("serverside") + }, ) .await?; diff --git a/iroh-relay/src/protos/io.rs b/iroh-relay/src/protos/io.rs index 15922c1ea67..48e70f26d1e 100644 --- a/iroh-relay/src/protos/io.rs +++ b/iroh-relay/src/protos/io.rs @@ -9,6 +9,8 @@ use bytes::Bytes; use n0_future::{ready, Sink, Stream}; use tokio::io::{AsyncRead, AsyncWrite}; +use crate::ExportKeyingMaterial; + #[derive(derive_more::Debug)] pub(crate) struct HandshakeIo { #[cfg(not(wasm_browser))] @@ -19,6 +21,21 @@ pub(crate) struct HandshakeIo { pub(crate) io: ws_stream_wasm::WsStream, } +impl ExportKeyingMaterial + for HandshakeIo +{ + fn export_keying_material>( + &self, + output: T, + label: &[u8], + context: Option<&[u8]>, + ) -> Option { + self.io + .get_ref() + .export_keying_material(output, label, context) + } +} + impl Stream for HandshakeIo { type Item = Result; diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 804efc7fd2f..63df604c39e 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -14,7 +14,7 @@ use tokio_websockets::WebSocketStream; use crate::{ protos::relay::{Frame, RelayCodec}, - KeyCache, + ExportKeyingMaterial, KeyCache, }; /// A Stream and Sink for [`Frame`]s connected to a single relay client. @@ -119,6 +119,24 @@ pub enum MaybeTlsStream { Test(tokio::io::DuplexStream), } +impl ExportKeyingMaterial for MaybeTlsStream { + fn export_keying_material>( + &self, + output: T, + label: &[u8], + context: Option<&[u8]>, + ) -> Option { + let Self::Tls(ref tls) = self else { + return None; + }; + + tls.get_ref() + .1 + .export_keying_material(output, label, context) + .ok() + } +} + impl AsyncRead for MaybeTlsStream { fn poll_read( mut self: Pin<&mut Self>, From 20f08c9da35c1e4fccc9c02f361260a11821e838 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Sun, 25 May 2025 14:42:26 +0200 Subject: [PATCH 05/80] Test & fix `ClientInfo` serialization round-trip --- Cargo.lock | 1 + iroh-relay/Cargo.toml | 1 + iroh-relay/src/client/conn.rs | 12 ++-- iroh-relay/src/protos/handshake.rs | 107 +++++++++++++++++++++-------- 4 files changed, 86 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d1aab33f075..d554c9619ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2532,6 +2532,7 @@ dependencies = [ "rustls-pki-types", "rustls-webpki", "serde", + "serde_bytes", "serde_json", "sha1", "simdutf8", diff --git a/iroh-relay/Cargo.toml b/iroh-relay/Cargo.toml index 8521d342f1d..426ae6002ab 100644 --- a/iroh-relay/Cargo.toml +++ b/iroh-relay/Cargo.toml @@ -96,6 +96,7 @@ tracing-subscriber = { version = "0.3", features = [ "env-filter", ], optional = true } blake3 = "1.8.2" +serde_bytes = "0.11.17" # non-wasm-in-browser dependencies [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index a647d2edb4f..4bb0fd92c71 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -12,25 +12,21 @@ use anyhow::{bail, Result}; use bytes::{Bytes, BytesMut}; use iroh_base::{NodeId, SecretKey}; use n0_future::{time::Duration, Sink, Stream}; -use tokio::io::{AsyncRead, AsyncWrite}; #[cfg(not(wasm_browser))] use tokio_util::codec::Framed; use tracing::debug; use super::KeyCache; +use crate::protos::{ + handshake, + relay::{ClientInfo, Frame, MAX_PACKET_SIZE, PROTOCOL_VERSION}, +}; #[cfg(not(wasm_browser))] use crate::{ client::streams::{MaybeTlsStream, MaybeTlsStreamChained, ProxyStream}, protos::io::HandshakeIo, protos::relay::RelayCodec, }; -use crate::{ - protos::{ - handshake, - relay::{ClientInfo, Frame, MAX_PACKET_SIZE, PROTOCOL_VERSION}, - }, - ExportKeyingMaterial, -}; /// Error for sending messages to the relay server. #[derive(Debug, thiserror::Error)] diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 50d92ee35f4..73ed89bddc4 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -30,7 +30,8 @@ pub struct ClientInfo { /// The client's public key, a.k.a. the `NodeId` pub public_key: PublicKey, /// A signature of the server challenge, serves as authentication. - pub signature: Signature, + #[serde(with = "serde_bytes")] + pub signature: [u8; 64], /// Part of the extracted key material, if that's what was signed. pub key_material_suffix: Option<[u8; 16]>, /// Supported versions/protocol features for version negotiation @@ -65,6 +66,44 @@ impl> + Sink + Unpi { } +impl ServerChallenge { + /// TODO(matheus23): docs + pub fn new(mut rng: impl RngCore + CryptoRng) -> Self { + let mut challenge = [0u8; 16]; + rng.fill_bytes(&mut challenge); + Self { challenge } + } + + fn message_to_sign(&self) -> [u8; 32] { + blake3::derive_key( + "iroh-relay handshake v1 challenge signature", + &self.challenge, + ) + } +} + +impl ClientInfo { + /// TODO(matheus23): docs + pub fn new_from_challenge(secret_key: &SecretKey, challenge: &ServerChallenge) -> Self { + Self { + public_key: secret_key.public(), + key_material_suffix: None, + signature: secret_key.sign(&challenge.message_to_sign()).to_bytes(), + versions: vec![PROTOCOL_VERSION.to_vec()], + } + } + + /// TODO(matheus23): docs + pub fn verify_from_challenge(&self, challenge: &ServerChallenge) -> bool { + self.public_key + .verify( + &challenge.message_to_sign(), + &Signature::from_bytes(&self.signature), + ) + .is_ok() + } +} + /// TODO(matheus23) docs pub(crate) async fn clientside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), @@ -80,15 +119,16 @@ pub(crate) async fn clientside( ); if let Some(key_material) = key_material { + let message = blake3::derive_key( + "iroh-relay handshake v1 key material signature", + &key_material[..16], + ); write_frame( io, CLIENT_INFO_TAG, ClientInfo { public_key, - signature: secret_key.sign(&blake3::derive_key( - "iroh-relay handshake v1 key material signature", - &key_material[..16], - )), + signature: secret_key.sign(&message).to_bytes(), key_material_suffix: Some(key_material[16..].try_into().expect("split right")), versions: versions.clone(), }, @@ -110,15 +150,7 @@ pub(crate) async fn clientside( let (tag, frame) = if tag == SERVER_CHALLENGE_TAG { let challenge: ServerChallenge = postcard::from_bytes(&frame)?; - let client_info = ClientInfo { - public_key, - signature: secret_key.sign(&blake3::derive_key( - "iroh-relay handshake v1 challenge signature", - &challenge.challenge, - )), - key_material_suffix: None, - versions, - }; + let client_info = ClientInfo::new_from_challenge(secret_key, &challenge); write_frame(io, CLIENT_INFO_TAG, client_info).await?; read_frame( @@ -148,22 +180,15 @@ pub(crate) async fn clientside( #[cfg(feature = "server")] pub(crate) async fn serverside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), - mut rng: impl RngCore + CryptoRng, + rng: impl RngCore + CryptoRng, ) -> Result { - let mut challenge = [0u8; 16]; - rng.fill_bytes(&mut challenge); - - write_frame(io, SERVER_CHALLENGE_TAG, ServerChallenge { challenge }).await?; + let challenge = ServerChallenge::new(rng); + write_frame(io, SERVER_CHALLENGE_TAG, &challenge).await?; let (_, frame) = read_frame(io, &[CLIENT_INFO_TAG], time::Duration::from_secs(10)).await?; let client_info: ClientInfo = postcard::from_bytes(&frame)?; - let result = client_info.public_key.verify( - &blake3::derive_key("iroh-relay handshake v1 challenge signature", &challenge), - &client_info.signature, - ); - - if result.is_ok() { + if client_info.verify_from_challenge(&challenge) { write_frame(io, SERVER_CONFIRMS_CONNECTED_TAG, ServerConfirmsConnected).await?; } else { write_frame(io, SERVER_DENIES_CONNECTION_TAG, ServerDeniesConnection).await?; @@ -209,7 +234,7 @@ async fn read_frame( Ok((tag, payload)) } -#[cfg(test)] +#[cfg(all(test, feature = "server"))] mod tests { use bytes::BytesMut; use iroh_base::SecretKey; @@ -219,6 +244,8 @@ mod tests { use crate::ExportKeyingMaterial; + use super::{ClientInfo, ServerChallenge}; + struct TestKeyingMaterial { shared_secret: Option, inner: IO, @@ -297,7 +324,6 @@ mod tests { } #[tokio::test] - #[cfg(feature = "server")] async fn simulate_handshake() -> TestResult { use anyhow::Context; @@ -333,4 +359,31 @@ mod tests { Ok(()) } + + #[test] + fn test_client_info_roundtrip() -> TestResult { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + let challenge = ServerChallenge::new(rand::rngs::OsRng); + let client_info = ClientInfo::new_from_challenge(&secret_key, &challenge); + + let bytes = postcard::to_allocvec(&client_info)?; + let decoded: ClientInfo = postcard::from_bytes(&bytes)?; + + assert_eq!(client_info.public_key, decoded.public_key); + assert_eq!(client_info.key_material_suffix, decoded.key_material_suffix); + assert_eq!(client_info.signature, decoded.signature); + assert_eq!(client_info.versions, decoded.versions); + + Ok(()) + } + + #[test] + fn test_challenge_verification() -> TestResult { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + let challenge = ServerChallenge::new(rand::rngs::OsRng); + let client_info = ClientInfo::new_from_challenge(&secret_key, &challenge); + assert!(client_info.verify_from_challenge(&challenge)); + + Ok(()) + } } From 5cb22482222450f2abc13e69d9907057f5065fae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 30 May 2025 16:17:39 +0200 Subject: [PATCH 06/80] Rename handshake stuff, implement key-export based verification on server side --- iroh-relay/src/protos/handshake.rs | 256 +++++++++++++++++++---------- iroh-relay/src/protos/io.rs | 11 ++ 2 files changed, 182 insertions(+), 85 deletions(-) diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 73ed89bddc4..6299a594e7b 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -10,53 +10,63 @@ use rand::{CryptoRng, RngCore}; use crate::ExportKeyingMaterial; /// TODO(matheus23) docs -pub const PROTOCOL_VERSION: &[u8] = b"1"; +pub(crate) const PROTOCOL_VERSION: &[u8] = b"1"; + +/// Message that tells the server the client needs a challenge to authenticate. +#[derive(derive_more::Debug, serde::Serialize)] +#[cfg_attr(feature = "server", derive(serde::Deserialize))] +pub(crate) struct ClientRequestChallenge; + +const TAG_CLIENT_REQUEST_CHALLENGE: VarInt = VarInt::from_u32(5); /// A challenge for the client to sign with their secret key for NodeId authentication. #[derive(derive_more::Debug, serde::Deserialize)] #[cfg_attr(feature = "server", derive(serde::Serialize))] -pub struct ServerChallenge { +pub(crate) struct ServerChallenge { /// The challenge to sign. /// Must be randomly generated with an RNG that is safe to use for crypto. - pub challenge: [u8; 16], + pub(crate) challenge: [u8; 16], } -const SERVER_CHALLENGE_TAG: VarInt = VarInt::from_u32(1); +const TAG_SERVER_CHALLENGE: VarInt = VarInt::from_u32(1); -/// Info about the client. Also serves as authentication. +/// Authentintiation message from the client. +/// +/// Also serves to inform the server about the client's send message version, +/// which will be passed on to other connecting clients. #[derive(derive_more::Debug, serde::Serialize)] #[cfg_attr(feature = "server", derive(serde::Deserialize))] -pub struct ClientInfo { +pub(crate) struct ClientAuth { /// The client's public key, a.k.a. the `NodeId` - pub public_key: PublicKey, + pub(crate) public_key: PublicKey, /// A signature of the server challenge, serves as authentication. #[serde(with = "serde_bytes")] - pub signature: [u8; 64], + pub(crate) signature: [u8; 64], /// Part of the extracted key material, if that's what was signed. - pub key_material_suffix: Option<[u8; 16]>, + pub(crate) key_material_suffix: Option<[u8; 16]>, /// Supported versions/protocol features for version negotiation /// with other connected relay clients - pub versions: Vec>, + pub(crate) versions: Vec>, } -const CLIENT_INFO_TAG: VarInt = VarInt::from_u32(2); +const TAG_CLIENT_AUTH: VarInt = VarInt::from_u32(2); /// Confirmation of successful connection. #[derive(derive_more::Debug, serde::Deserialize)] #[cfg_attr(feature = "server", derive(serde::Serialize))] -pub struct ServerConfirmsConnected; +pub(crate) struct ServerConfirmsAuth; -const SERVER_CONFIRMS_CONNECTED_TAG: VarInt = VarInt::from_u32(3); +const TAG_SERVER_CONFIRMS_AUTH: VarInt = VarInt::from_u32(3); /// Denial of connection. The client couldn't be verified as authentic. #[derive(derive_more::Debug, serde::Deserialize)] #[cfg_attr(feature = "server", derive(serde::Serialize))] -pub struct ServerDeniesConnection; +pub(crate) struct ServerDeniesAuth; -const SERVER_DENIES_CONNECTION_TAG: VarInt = VarInt::from_u32(4); +const TAG_SERVER_DENIES_AUTH: VarInt = VarInt::from_u32(4); /// TODO(matheus23) docs -pub trait BytesStreamSink: +pub(crate) trait BytesStreamSink: Stream> + Sink + Unpin { } @@ -68,7 +78,7 @@ impl> + Sink + Unpi impl ServerChallenge { /// TODO(matheus23): docs - pub fn new(mut rng: impl RngCore + CryptoRng) -> Self { + pub(crate) fn new(mut rng: impl RngCore + CryptoRng) -> Self { let mut challenge = [0u8; 16]; rng.fill_bytes(&mut challenge); Self { challenge } @@ -82,9 +92,9 @@ impl ServerChallenge { } } -impl ClientInfo { +impl ClientAuth { /// TODO(matheus23): docs - pub fn new_from_challenge(secret_key: &SecretKey, challenge: &ServerChallenge) -> Self { + pub(crate) fn new_from_challenge(secret_key: &SecretKey, challenge: &ServerChallenge) -> Self { Self { public_key: secret_key.public(), key_material_suffix: None, @@ -94,7 +104,7 @@ impl ClientInfo { } /// TODO(matheus23): docs - pub fn verify_from_challenge(&self, challenge: &ServerChallenge) -> bool { + pub(crate) fn verify_from_challenge(&self, challenge: &ServerChallenge) -> bool { self.public_key .verify( &challenge.message_to_sign(), @@ -102,60 +112,81 @@ impl ClientInfo { ) .is_ok() } -} -/// TODO(matheus23) docs -pub(crate) async fn clientside( - io: &mut (impl BytesStreamSink + ExportKeyingMaterial), - secret_key: &SecretKey, -) -> Result { - let public_key = secret_key.public(); - let versions = vec![PROTOCOL_VERSION.to_vec()]; - - let key_material = io.export_keying_material( - [0u8; 32], - b"iroh-relay handshake v1", - Some(secret_key.public().as_bytes()), - ); + pub(crate) fn new_from_key_export( + secret_key: &SecretKey, + io: &mut impl ExportKeyingMaterial, + ) -> Option { + let public_key = secret_key.public(); + let key_material = io.export_keying_material( + [0u8; 32], + b"iroh-relay handshake v1", + Some(secret_key.public().as_bytes()), + )?; - if let Some(key_material) = key_material { let message = blake3::derive_key( "iroh-relay handshake v1 key material signature", &key_material[..16], ); - write_frame( - io, - CLIENT_INFO_TAG, - ClientInfo { - public_key, - signature: secret_key.sign(&message).to_bytes(), - key_material_suffix: Some(key_material[16..].try_into().expect("split right")), - versions: versions.clone(), - }, - ) - .await?; + Some(ClientAuth { + public_key, + signature: secret_key.sign(&message).to_bytes(), + key_material_suffix: Some(key_material[16..].try_into().expect("split right")), + versions: vec![PROTOCOL_VERSION.to_vec()], + }) + } + + pub(crate) fn verify_from_key_export(&self, io: &mut impl ExportKeyingMaterial) -> bool { + let Some(key_material) = io.export_keying_material( + [0u8; 32], + b"iroh-relay handshake v1", + Some(self.public_key.as_bytes()), + ) else { + return false; + }; + + let message = blake3::derive_key( + "iroh-relay handshake v1 key material signature", + &key_material[..16], + ); + self.public_key + .verify(&message, &Signature::from_bytes(&self.signature)) + .is_ok() + } +} + +/// TODO(matheus23) docs +pub(crate) async fn clientside( + io: &mut (impl BytesStreamSink + ExportKeyingMaterial), + secret_key: &SecretKey, +) -> Result { + if let Some(client_auth) = ClientAuth::new_from_key_export(secret_key, io) { + write_frame(io, TAG_CLIENT_AUTH, client_auth).await?; + } else { + // we can't use key exporting, so request a challenge. + write_frame(io, TAG_CLIENT_REQUEST_CHALLENGE, ClientRequestChallenge).await?; } let (tag, frame) = read_frame( io, &[ - SERVER_CHALLENGE_TAG, - SERVER_CONFIRMS_CONNECTED_TAG, - SERVER_DENIES_CONNECTION_TAG, + TAG_SERVER_CHALLENGE, + TAG_SERVER_CONFIRMS_AUTH, + TAG_SERVER_DENIES_AUTH, ], time::Duration::from_secs(30), ) .await?; - let (tag, frame) = if tag == SERVER_CHALLENGE_TAG { + let (tag, frame) = if tag == TAG_SERVER_CHALLENGE { let challenge: ServerChallenge = postcard::from_bytes(&frame)?; - let client_info = ClientInfo::new_from_challenge(secret_key, &challenge); - write_frame(io, CLIENT_INFO_TAG, client_info).await?; + let client_info = ClientAuth::new_from_challenge(secret_key, &challenge); + write_frame(io, TAG_CLIENT_AUTH, client_info).await?; read_frame( io, - &[SERVER_CONFIRMS_CONNECTED_TAG, SERVER_DENIES_CONNECTION_TAG], + &[TAG_SERVER_CONFIRMS_AUTH, TAG_SERVER_DENIES_AUTH], time::Duration::from_secs(30), ) .await? @@ -164,12 +195,12 @@ pub(crate) async fn clientside( }; match tag { - SERVER_CONFIRMS_CONNECTED_TAG => { - let confirmation: ServerConfirmsConnected = postcard::from_bytes(&frame)?; + TAG_SERVER_CONFIRMS_AUTH => { + let confirmation: ServerConfirmsAuth = postcard::from_bytes(&frame)?; Ok(confirmation) } - SERVER_DENIES_CONNECTION_TAG => { - let denial: ServerDeniesConnection = postcard::from_bytes(&frame)?; + TAG_SERVER_DENIES_AUTH => { + let denial: ServerDeniesAuth = postcard::from_bytes(&frame)?; anyhow::bail!("server denied connection: {denial:?}"); } _ => unreachable!(), @@ -181,20 +212,38 @@ pub(crate) async fn clientside( pub(crate) async fn serverside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), rng: impl RngCore + CryptoRng, -) -> Result { +) -> Result { + let (tag, frame) = read_frame( + io, + &[TAG_CLIENT_REQUEST_CHALLENGE, TAG_CLIENT_AUTH], + time::Duration::from_secs(10), + ) + .await?; + + // it might be fast-path authentication using TLS exported key material + if tag == TAG_CLIENT_AUTH { + let client_auth: ClientAuth = postcard::from_bytes(&frame)?; + if client_auth.verify_from_key_export(io) { + write_frame(io, TAG_SERVER_CONFIRMS_AUTH, ServerConfirmsAuth).await?; + return Ok(client_auth); + } + } else { + let _frame: ClientRequestChallenge = postcard::from_bytes(&frame)?; + } + let challenge = ServerChallenge::new(rng); - write_frame(io, SERVER_CHALLENGE_TAG, &challenge).await?; + write_frame(io, TAG_SERVER_CHALLENGE, &challenge).await?; - let (_, frame) = read_frame(io, &[CLIENT_INFO_TAG], time::Duration::from_secs(10)).await?; - let client_info: ClientInfo = postcard::from_bytes(&frame)?; + let (_, frame) = read_frame(io, &[TAG_CLIENT_AUTH], time::Duration::from_secs(10)).await?; + let client_auth: ClientAuth = postcard::from_bytes(&frame)?; - if client_info.verify_from_challenge(&challenge) { - write_frame(io, SERVER_CONFIRMS_CONNECTED_TAG, ServerConfirmsConnected).await?; + if client_auth.verify_from_challenge(&challenge) { + write_frame(io, TAG_SERVER_CONFIRMS_AUTH, ServerConfirmsAuth).await?; } else { - write_frame(io, SERVER_DENIES_CONNECTION_TAG, ServerDeniesConnection).await?; + write_frame(io, TAG_SERVER_DENIES_AUTH, ServerDeniesAuth).await?; } - Ok(client_info) + Ok(client_auth) } async fn write_frame( @@ -236,6 +285,7 @@ async fn read_frame( #[cfg(all(test, feature = "server"))] mod tests { + use anyhow::Context; use bytes::BytesMut; use iroh_base::SecretKey; use n0_future::{Sink, SinkExt, Stream, TryStreamExt}; @@ -244,7 +294,7 @@ mod tests { use crate::ExportKeyingMaterial; - use super::{ClientInfo, ServerChallenge}; + use super::{ClientAuth, ServerChallenge}; struct TestKeyingMaterial { shared_secret: Option, @@ -323,25 +373,25 @@ mod tests { } } - #[tokio::test] - async fn simulate_handshake() -> TestResult { - use anyhow::Context; - + async fn simulate_handshake( + secret_key: &SecretKey, + client_shared_secret: Option, + server_shared_secret: Option, + ) -> TestResult { let (client, server) = tokio::io::duplex(1024); - let secret_key = SecretKey::generate(rand::rngs::OsRng); let mut client_io = Framed::new(client, LengthDelimitedCodec::new()) .map_ok(BytesMut::freeze) .map_err(anyhow::Error::from) .sink_err_into() - .with_shared_secret(Some(42)); + .with_shared_secret(client_shared_secret); let mut server_io = Framed::new(server, LengthDelimitedCodec::new()) .map_ok(BytesMut::freeze) .map_err(anyhow::Error::from) .sink_err_into() - .with_shared_secret(Some(42)); + .with_shared_secret(server_shared_secret); - let (_, client_info) = n0_future::future::try_zip( + let (_, client_auth) = n0_future::future::try_zip( async { super::clientside(&mut client_io, &secret_key) .await @@ -355,24 +405,60 @@ mod tests { ) .await?; - println!("{client_info:#?}"); + Ok(client_auth) + } + #[tokio::test] + async fn test_handshake_via_shared_secrets() -> TestResult { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + let auth = simulate_handshake(&secret_key, Some(42), Some(42)).await?; + assert_eq!(auth.public_key, secret_key.public()); + assert!(auth.key_material_suffix.is_some()); // it got verified via shared key material + Ok(()) + } + + #[tokio::test] + async fn test_handshake_via_challenge() -> TestResult { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + let auth = simulate_handshake(&secret_key, None, None).await?; + assert_eq!(auth.public_key, secret_key.public()); + assert!(auth.key_material_suffix.is_none()); + Ok(()) + } + + #[tokio::test] + async fn test_handshake_mismatching_shared_secrets() -> TestResult { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + // mismatching shared secrets *might* happen with HTTPS proxies that don't also middle-man the shared secret + let auth = simulate_handshake(&secret_key, Some(10), Some(99)).await?; + assert_eq!(auth.public_key, secret_key.public()); + assert!(auth.key_material_suffix.is_none()); + Ok(()) + } + + #[tokio::test] + async fn test_handshake_challenge_fallback() -> TestResult { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + // clients might not have access to shared secrets + let auth = simulate_handshake(&secret_key, None, Some(99)).await?; + assert_eq!(auth.public_key, secret_key.public()); + assert!(auth.key_material_suffix.is_none()); Ok(()) } #[test] - fn test_client_info_roundtrip() -> TestResult { + fn test_client_auth_roundtrip() -> TestResult { let secret_key = SecretKey::generate(rand::rngs::OsRng); let challenge = ServerChallenge::new(rand::rngs::OsRng); - let client_info = ClientInfo::new_from_challenge(&secret_key, &challenge); + let client_auth = ClientAuth::new_from_challenge(&secret_key, &challenge); - let bytes = postcard::to_allocvec(&client_info)?; - let decoded: ClientInfo = postcard::from_bytes(&bytes)?; + let bytes = postcard::to_allocvec(&client_auth)?; + let decoded: ClientAuth = postcard::from_bytes(&bytes)?; - assert_eq!(client_info.public_key, decoded.public_key); - assert_eq!(client_info.key_material_suffix, decoded.key_material_suffix); - assert_eq!(client_info.signature, decoded.signature); - assert_eq!(client_info.versions, decoded.versions); + assert_eq!(client_auth.public_key, decoded.public_key); + assert_eq!(client_auth.key_material_suffix, decoded.key_material_suffix); + assert_eq!(client_auth.signature, decoded.signature); + assert_eq!(client_auth.versions, decoded.versions); Ok(()) } @@ -381,8 +467,8 @@ mod tests { fn test_challenge_verification() -> TestResult { let secret_key = SecretKey::generate(rand::rngs::OsRng); let challenge = ServerChallenge::new(rand::rngs::OsRng); - let client_info = ClientInfo::new_from_challenge(&secret_key, &challenge); - assert!(client_info.verify_from_challenge(&challenge)); + let client_auth = ClientAuth::new_from_challenge(&secret_key, &challenge); + assert!(client_auth.verify_from_challenge(&challenge)); Ok(()) } diff --git a/iroh-relay/src/protos/io.rs b/iroh-relay/src/protos/io.rs index 48e70f26d1e..3e645e6f6e6 100644 --- a/iroh-relay/src/protos/io.rs +++ b/iroh-relay/src/protos/io.rs @@ -24,6 +24,17 @@ pub(crate) struct HandshakeIo { impl ExportKeyingMaterial for HandshakeIo { + #[cfg(wasm_browser)] + fn export_keying_material>( + &self, + output: T, + label: &[u8], + context: Option<&[u8]>, + ) -> Option { + None + } + + #[cfg(not(wasm_browser))] fn export_keying_material>( &self, output: T, From 04d34e3964c0ec471e43bd5d436427bf2561b237 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 2 Jun 2025 09:13:11 +0200 Subject: [PATCH 07/80] Introduce `FrameType` enum --- iroh-relay/src/protos/handshake.rs | 117 +++++++++++++++++++++-------- 1 file changed, 84 insertions(+), 33 deletions(-) diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 6299a594e7b..0f4f032de5e 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -12,13 +12,29 @@ use crate::ExportKeyingMaterial; /// TODO(matheus23) docs pub(crate) const PROTOCOL_VERSION: &[u8] = b"1"; +#[repr(u32)] +#[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] +pub(crate) enum FrameType { + ClientRequestChallenge = 0, + ServerChallenge = 1, + ClientAuth = 2, + ServerConfirmsAuth = 3, + ServerDeniesAuth = 4, + #[num_enum(default)] + Unknown, +} + +impl From for VarInt { + fn from(value: FrameType) -> Self { + (value as u32).into() + } +} + /// Message that tells the server the client needs a challenge to authenticate. #[derive(derive_more::Debug, serde::Serialize)] #[cfg_attr(feature = "server", derive(serde::Deserialize))] pub(crate) struct ClientRequestChallenge; -const TAG_CLIENT_REQUEST_CHALLENGE: VarInt = VarInt::from_u32(5); - /// A challenge for the client to sign with their secret key for NodeId authentication. #[derive(derive_more::Debug, serde::Deserialize)] #[cfg_attr(feature = "server", derive(serde::Serialize))] @@ -28,8 +44,6 @@ pub(crate) struct ServerChallenge { pub(crate) challenge: [u8; 16], } -const TAG_SERVER_CHALLENGE: VarInt = VarInt::from_u32(1); - /// Authentintiation message from the client. /// /// Also serves to inform the server about the client's send message version, @@ -49,22 +63,16 @@ pub(crate) struct ClientAuth { pub(crate) versions: Vec>, } -const TAG_CLIENT_AUTH: VarInt = VarInt::from_u32(2); - /// Confirmation of successful connection. #[derive(derive_more::Debug, serde::Deserialize)] #[cfg_attr(feature = "server", derive(serde::Serialize))] pub(crate) struct ServerConfirmsAuth; -const TAG_SERVER_CONFIRMS_AUTH: VarInt = VarInt::from_u32(3); - /// Denial of connection. The client couldn't be verified as authentic. #[derive(derive_more::Debug, serde::Deserialize)] #[cfg_attr(feature = "server", derive(serde::Serialize))] pub(crate) struct ServerDeniesAuth; -const TAG_SERVER_DENIES_AUTH: VarInt = VarInt::from_u32(4); - /// TODO(matheus23) docs pub(crate) trait BytesStreamSink: Stream> + Sink + Unpin @@ -76,6 +84,34 @@ impl> + Sink + Unpi { } +trait Frame { + const TAG: FrameType; +} + +impl Frame for &T { + const TAG: FrameType = T::TAG; +} + +impl Frame for ClientRequestChallenge { + const TAG: FrameType = FrameType::ClientRequestChallenge; +} + +impl Frame for ServerChallenge { + const TAG: FrameType = FrameType::ServerChallenge; +} + +impl Frame for ClientAuth { + const TAG: FrameType = FrameType::ClientAuth; +} + +impl Frame for ServerConfirmsAuth { + const TAG: FrameType = FrameType::ServerConfirmsAuth; +} + +impl Frame for ServerDeniesAuth { + const TAG: FrameType = FrameType::ServerDeniesAuth; +} + impl ServerChallenge { /// TODO(matheus23): docs pub(crate) fn new(mut rng: impl RngCore + CryptoRng) -> Self { @@ -161,32 +197,32 @@ pub(crate) async fn clientside( secret_key: &SecretKey, ) -> Result { if let Some(client_auth) = ClientAuth::new_from_key_export(secret_key, io) { - write_frame(io, TAG_CLIENT_AUTH, client_auth).await?; + write_frame(io, client_auth).await?; } else { // we can't use key exporting, so request a challenge. - write_frame(io, TAG_CLIENT_REQUEST_CHALLENGE, ClientRequestChallenge).await?; + write_frame(io, ClientRequestChallenge).await?; } - let (tag, frame) = read_frame( + let (tag, frame) = read_handshake_frame( io, &[ - TAG_SERVER_CHALLENGE, - TAG_SERVER_CONFIRMS_AUTH, - TAG_SERVER_DENIES_AUTH, + ServerChallenge::TAG, + ServerConfirmsAuth::TAG, + ServerDeniesAuth::TAG, ], time::Duration::from_secs(30), ) .await?; - let (tag, frame) = if tag == TAG_SERVER_CHALLENGE { + let (tag, frame) = if tag == ServerChallenge::TAG { let challenge: ServerChallenge = postcard::from_bytes(&frame)?; let client_info = ClientAuth::new_from_challenge(secret_key, &challenge); - write_frame(io, TAG_CLIENT_AUTH, client_info).await?; + write_frame(io, client_info).await?; - read_frame( + read_handshake_frame( io, - &[TAG_SERVER_CONFIRMS_AUTH, TAG_SERVER_DENIES_AUTH], + &[ServerConfirmsAuth::TAG, ServerDeniesAuth::TAG], time::Duration::from_secs(30), ) .await? @@ -195,11 +231,11 @@ pub(crate) async fn clientside( }; match tag { - TAG_SERVER_CONFIRMS_AUTH => { + FrameType::ServerConfirmsAuth => { let confirmation: ServerConfirmsAuth = postcard::from_bytes(&frame)?; Ok(confirmation) } - TAG_SERVER_DENIES_AUTH => { + FrameType::ServerDeniesAuth => { let denial: ServerDeniesAuth = postcard::from_bytes(&frame)?; anyhow::bail!("server denied connection: {denial:?}"); } @@ -213,18 +249,18 @@ pub(crate) async fn serverside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), rng: impl RngCore + CryptoRng, ) -> Result { - let (tag, frame) = read_frame( + let (tag, frame) = read_handshake_frame( io, - &[TAG_CLIENT_REQUEST_CHALLENGE, TAG_CLIENT_AUTH], + &[ClientRequestChallenge::TAG, ClientAuth::TAG], time::Duration::from_secs(10), ) .await?; // it might be fast-path authentication using TLS exported key material - if tag == TAG_CLIENT_AUTH { + if tag == ClientAuth::TAG { let client_auth: ClientAuth = postcard::from_bytes(&frame)?; if client_auth.verify_from_key_export(io) { - write_frame(io, TAG_SERVER_CONFIRMS_AUTH, ServerConfirmsAuth).await?; + write_frame(io, ServerConfirmsAuth).await?; return Ok(client_auth); } } else { @@ -232,25 +268,26 @@ pub(crate) async fn serverside( } let challenge = ServerChallenge::new(rng); - write_frame(io, TAG_SERVER_CHALLENGE, &challenge).await?; + write_frame(io, &challenge).await?; - let (_, frame) = read_frame(io, &[TAG_CLIENT_AUTH], time::Duration::from_secs(10)).await?; + let (_, frame) = + read_handshake_frame(io, &[ClientAuth::TAG], time::Duration::from_secs(10)).await?; let client_auth: ClientAuth = postcard::from_bytes(&frame)?; if client_auth.verify_from_challenge(&challenge) { - write_frame(io, TAG_SERVER_CONFIRMS_AUTH, ServerConfirmsAuth).await?; + write_frame(io, ServerConfirmsAuth).await?; } else { - write_frame(io, TAG_SERVER_DENIES_AUTH, ServerDeniesAuth).await?; + write_frame(io, ServerDeniesAuth).await?; } Ok(client_auth) } -async fn write_frame( +async fn write_frame( io: &mut impl BytesStreamSink, - tag: VarInt, - frame: impl serde::Serialize, + frame: F, ) -> Result<()> { + let tag: VarInt = F::TAG.into(); let mut bytes = BytesMut::new(); tag.encode(&mut bytes); let bytes = postcard::to_io(&frame, bytes.writer())? @@ -283,6 +320,20 @@ async fn read_frame( Ok((tag, payload)) } +async fn read_handshake_frame( + io: &mut impl BytesStreamSink, + expected_types: &[FrameType], + timeout: time::Duration, +) -> Result<(FrameType, Bytes)> { + let expected_tags = expected_types + .into_iter() + .map(|frame_type| VarInt::from(*frame_type)) + .collect::>(); + let (tag, frame) = read_frame(io, &expected_tags, timeout).await?; + let frame_type = u32::try_from(tag.into_inner()).map_or(FrameType::Unknown, FrameType::from); + Ok((frame_type, frame)) +} + #[cfg(all(test, feature = "server"))] mod tests { use anyhow::Context; From 41f5f3f104b4b6b5d26a2a00b0d02d0914543b88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 2 Jun 2025 11:43:17 +0200 Subject: [PATCH 08/80] Fix typo --- iroh-relay/src/protos/handshake.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 0f4f032de5e..accbc9bb290 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -44,7 +44,7 @@ pub(crate) struct ServerChallenge { pub(crate) challenge: [u8; 16], } -/// Authentintiation message from the client. +/// Authentication message from the client. /// /// Also serves to inform the server about the client's send message version, /// which will be passed on to other connecting clients. From eeebc804afff95cedb8fbb5d88e3e06efe314694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 4 Jun 2025 18:12:35 +0200 Subject: [PATCH 09/80] Fix merge --- Cargo.lock | 33 +++++++ iroh-relay/src/client.rs | 4 +- iroh-relay/src/client/conn.rs | 6 +- iroh-relay/src/protos/handshake.rs | 137 +++++++++++++++++++-------- iroh-relay/src/protos/io.rs | 7 +- iroh-relay/src/protos/relay.rs | 2 +- iroh-relay/src/quic.rs | 2 + iroh-relay/src/server/http_server.rs | 17 +++- 8 files changed, 154 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c357841c0cb..fef802c2d59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -162,6 +162,18 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "asn1-rs" version = "0.6.2" @@ -451,6 +463,19 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +[[package]] +name = "blake3" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -686,6 +711,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "cordyceps" version = "0.3.3" @@ -2481,6 +2512,7 @@ name = "iroh-relay" version = "0.35.0" dependencies = [ "ahash", + "blake3", "bytes", "cfg_aliases", "clap", @@ -2522,6 +2554,7 @@ dependencies = [ "rustls-pki-types", "rustls-webpki", "serde", + "serde_bytes", "serde_json", "sha1", "simdutf8", diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 830b18a3814..46a1aa5f895 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -27,7 +27,7 @@ pub use self::conn::{ReceivedMessage, RecvError, SendError, SendMessage}; use crate::dns::{DnsError, DnsResolver}; use crate::{ http::{Protocol, RELAY_PATH}, - protos::relay::SendError as SendRelayError, + protos::handshake, KeyCache, }; @@ -64,7 +64,7 @@ pub enum ConnectError { source: ws_stream_wasm::WsErr, }, #[snafu(transparent)] - Handshake { source: SendRelayError }, + Handshake { source: handshake::Error }, #[snafu(transparent)] Dial { source: DialError }, #[snafu(display("Unexpected status during upgrade: {code}"))] diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 1b03f9e283c..da7df2e479f 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -127,7 +127,7 @@ impl Conn { conn: ws_stream_wasm::WsStream, key_cache: KeyCache, secret_key: &SecretKey, - ) -> Result { + ) -> Result { let mut io = HandshakeIo { io: conn }; // exchange information with the server @@ -147,7 +147,7 @@ impl Conn { conn: MaybeTlsStreamChained, key_cache: KeyCache, secret_key: &SecretKey, - ) -> Result { + ) -> Result { use n0_future::SinkExt; let conn = Framed::new(conn, RelayCodec::new(key_cache)); @@ -174,7 +174,7 @@ impl Conn { conn: tokio_websockets::WebSocketStream>, key_cache: KeyCache, secret_key: &SecretKey, - ) -> Result { + ) -> Result { let mut io = HandshakeIo { io: conn }; // exchange information with the server diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index accbc9bb290..c81540076f8 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -1,25 +1,41 @@ //! TODO(matheus23) docs -use anyhow::Result; use bytes::{BufMut, Bytes, BytesMut}; use iroh_base::{PublicKey, SecretKey, Signature}; -use n0_future::{time, Sink, SinkExt, Stream, TryStreamExt}; +use n0_future::{ + time::{self, Elapsed}, + Sink, SinkExt, Stream, TryStreamExt, +}; +use nested_enum_utils::common_fields; use quinn_proto::{coding::Codec, VarInt}; use rand::{CryptoRng, RngCore}; +use snafu::{Backtrace, ResultExt, Snafu}; use crate::ExportKeyingMaterial; +use super::relay::SendError; + /// TODO(matheus23) docs pub(crate) const PROTOCOL_VERSION: &[u8] = b"1"; +/// Possible frame types during handshaking #[repr(u32)] #[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] -pub(crate) enum FrameType { +pub enum FrameType { + /// The frame type for the client challenge request ClientRequestChallenge = 0, + /// The server frame type for the challenge response ServerChallenge = 1, + /// The client frame type for the authentication frame ClientAuth = 2, + /// The server frame type for authentication confirmation ServerConfirmsAuth = 3, + /// The server frame type for authentication denial ServerDeniesAuth = 4, + /// The frame type was unknown. + /// + /// This frame is the result of parsing any future frame types that this implementation + /// does not yet understand. #[num_enum(default)] Unknown, } @@ -75,12 +91,12 @@ pub(crate) struct ServerDeniesAuth; /// TODO(matheus23) docs pub(crate) trait BytesStreamSink: - Stream> + Sink + Unpin + Stream> + Sink + Unpin { } -impl> + Sink + Unpin> BytesStreamSink - for T +impl BytesStreamSink for T where + T: Stream> + Sink + Unpin { } @@ -112,6 +128,37 @@ impl Frame for ServerDeniesAuth { const TAG: FrameType = FrameType::ServerDeniesAuth; } +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum Error { + #[snafu(transparent)] + Websocket { source: tokio_websockets::Error }, + #[snafu(transparent)] + Legacy { source: SendError }, + #[snafu(display("Handshake timeout reached"))] + Timeout { source: Elapsed }, + #[snafu(display("Handshake stream ended prematurely"))] + UnexpectedEnd {}, + #[snafu(display("The relay denied our authentication"))] + ServerDeniedAuth {}, + #[snafu(display("Unexpected tag, got {tag}, but expected one of {expected_tags:?}"))] + UnexpectedTag { + tag: VarInt, + expected_tags: Vec, + }, + #[snafu(display("Handshake failed while deserializing {frame_type:?} frame"))] + DeserializationError { + frame_type: FrameType, + source: postcard::Error, + }, +} + impl ServerChallenge { /// TODO(matheus23): docs pub(crate) fn new(mut rng: impl RngCore + CryptoRng) -> Self { @@ -195,7 +242,7 @@ impl ClientAuth { pub(crate) async fn clientside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), secret_key: &SecretKey, -) -> Result { +) -> Result { if let Some(client_auth) = ClientAuth::new_from_key_export(secret_key, io) { write_frame(io, client_auth).await?; } else { @@ -215,7 +262,7 @@ pub(crate) async fn clientside( .await?; let (tag, frame) = if tag == ServerChallenge::TAG { - let challenge: ServerChallenge = postcard::from_bytes(&frame)?; + let challenge: ServerChallenge = deserialize_frame(frame)?; let client_info = ClientAuth::new_from_challenge(secret_key, &challenge); write_frame(io, client_info).await?; @@ -232,12 +279,12 @@ pub(crate) async fn clientside( match tag { FrameType::ServerConfirmsAuth => { - let confirmation: ServerConfirmsAuth = postcard::from_bytes(&frame)?; + let confirmation: ServerConfirmsAuth = deserialize_frame(frame)?; Ok(confirmation) } FrameType::ServerDeniesAuth => { - let denial: ServerDeniesAuth = postcard::from_bytes(&frame)?; - anyhow::bail!("server denied connection: {denial:?}"); + let _denial: ServerDeniesAuth = deserialize_frame(frame)?; + return Err(ServerDeniedAuthSnafu.build()); } _ => unreachable!(), } @@ -248,7 +295,7 @@ pub(crate) async fn clientside( pub(crate) async fn serverside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), rng: impl RngCore + CryptoRng, -) -> Result { +) -> Result { let (tag, frame) = read_handshake_frame( io, &[ClientRequestChallenge::TAG, ClientAuth::TAG], @@ -258,13 +305,13 @@ pub(crate) async fn serverside( // it might be fast-path authentication using TLS exported key material if tag == ClientAuth::TAG { - let client_auth: ClientAuth = postcard::from_bytes(&frame)?; + let client_auth: ClientAuth = deserialize_frame(frame)?; if client_auth.verify_from_key_export(io) { write_frame(io, ServerConfirmsAuth).await?; return Ok(client_auth); } } else { - let _frame: ClientRequestChallenge = postcard::from_bytes(&frame)?; + let _frame: ClientRequestChallenge = deserialize_frame(frame)?; } let challenge = ServerChallenge::new(rng); @@ -272,7 +319,7 @@ pub(crate) async fn serverside( let (_, frame) = read_handshake_frame(io, &[ClientAuth::TAG], time::Duration::from_secs(10)).await?; - let client_auth: ClientAuth = postcard::from_bytes(&frame)?; + let client_auth: ClientAuth = deserialize_frame(frame)?; if client_auth.verify_from_challenge(&challenge) { write_frame(io, ServerConfirmsAuth).await?; @@ -286,11 +333,12 @@ pub(crate) async fn serverside( async fn write_frame( io: &mut impl BytesStreamSink, frame: F, -) -> Result<()> { +) -> Result<(), Error> { let tag: VarInt = F::TAG.into(); let mut bytes = BytesMut::new(); tag.encode(&mut bytes); - let bytes = postcard::to_io(&frame, bytes.writer())? + let bytes = postcard::to_io(&frame, bytes.writer()) + .expect("serialization failed") // buffer can't become "full" without being a critical failure, datastructures shouldn't ever fail serialization .into_inner() .freeze(); io.send(bytes).await?; @@ -302,16 +350,21 @@ async fn read_frame( io: &mut impl BytesStreamSink, expected_tags: &[VarInt], timeout: time::Duration, -) -> Result<(VarInt, Bytes)> { +) -> Result<(VarInt, Bytes), Error> { let recv = time::timeout(timeout, io.try_next()) - .await?? - .ok_or_else(|| anyhow::anyhow!("disconnected"))?; + .await + .context(TimeoutSnafu)?? + .ok_or_else(|| UnexpectedEndSnafu.build())?; let mut cursor = std::io::Cursor::new(recv); - let tag = VarInt::decode(&mut cursor)?; - anyhow::ensure!( + let tag = VarInt::decode(&mut cursor) + .map_err(|quinn_proto::coding::UnexpectedEnd| UnexpectedEndSnafu.build())?; + snafu::ensure!( expected_tags.contains(&tag), - "Unexpected tag {tag}, expected one of {expected_tags:?}" + UnexpectedTagSnafu { + tag, + expected_tags: expected_tags.into_iter().cloned().collect::>() + } ); let start = cursor.position() as usize; @@ -324,7 +377,7 @@ async fn read_handshake_frame( io: &mut impl BytesStreamSink, expected_types: &[FrameType], timeout: time::Duration, -) -> Result<(FrameType, Bytes)> { +) -> Result<(FrameType, Bytes), Error> { let expected_tags = expected_types .into_iter() .map(|frame_type| VarInt::from(*frame_type)) @@ -334,17 +387,19 @@ async fn read_handshake_frame( Ok((frame_type, frame)) } +fn deserialize_frame(frame: Bytes) -> Result { + postcard::from_bytes(&frame).context(DeserializationSnafu { frame_type: F::TAG }) +} + #[cfg(all(test, feature = "server"))] mod tests { - use anyhow::Context; + use crate::ExportKeyingMaterial; use bytes::BytesMut; use iroh_base::SecretKey; use n0_future::{Sink, SinkExt, Stream, TryStreamExt}; - use testresult::TestResult; + use n0_snafu::{Result, ResultExt}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; - use crate::ExportKeyingMaterial; - use super::{ClientAuth, ServerChallenge}; struct TestKeyingMaterial { @@ -428,18 +483,18 @@ mod tests { secret_key: &SecretKey, client_shared_secret: Option, server_shared_secret: Option, - ) -> TestResult { + ) -> Result { let (client, server) = tokio::io::duplex(1024); let mut client_io = Framed::new(client, LengthDelimitedCodec::new()) .map_ok(BytesMut::freeze) - .map_err(anyhow::Error::from) - .sink_err_into() + .map_err(|e| tokio_websockets::Error::Io(e).into()) + .sink_map_err(|e| tokio_websockets::Error::Io(e).into()) .with_shared_secret(client_shared_secret); let mut server_io = Framed::new(server, LengthDelimitedCodec::new()) .map_ok(BytesMut::freeze) - .map_err(anyhow::Error::from) - .sink_err_into() + .map_err(|e| tokio_websockets::Error::Io(e).into()) + .sink_map_err(|e| tokio_websockets::Error::Io(e).into()) .with_shared_secret(server_shared_secret); let (_, client_auth) = n0_future::future::try_zip( @@ -460,7 +515,7 @@ mod tests { } #[tokio::test] - async fn test_handshake_via_shared_secrets() -> TestResult { + async fn test_handshake_via_shared_secrets() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); let auth = simulate_handshake(&secret_key, Some(42), Some(42)).await?; assert_eq!(auth.public_key, secret_key.public()); @@ -469,7 +524,7 @@ mod tests { } #[tokio::test] - async fn test_handshake_via_challenge() -> TestResult { + async fn test_handshake_via_challenge() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); let auth = simulate_handshake(&secret_key, None, None).await?; assert_eq!(auth.public_key, secret_key.public()); @@ -478,7 +533,7 @@ mod tests { } #[tokio::test] - async fn test_handshake_mismatching_shared_secrets() -> TestResult { + async fn test_handshake_mismatching_shared_secrets() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); // mismatching shared secrets *might* happen with HTTPS proxies that don't also middle-man the shared secret let auth = simulate_handshake(&secret_key, Some(10), Some(99)).await?; @@ -488,7 +543,7 @@ mod tests { } #[tokio::test] - async fn test_handshake_challenge_fallback() -> TestResult { + async fn test_handshake_challenge_fallback() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); // clients might not have access to shared secrets let auth = simulate_handshake(&secret_key, None, Some(99)).await?; @@ -498,13 +553,13 @@ mod tests { } #[test] - fn test_client_auth_roundtrip() -> TestResult { + fn test_client_auth_roundtrip() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); let challenge = ServerChallenge::new(rand::rngs::OsRng); let client_auth = ClientAuth::new_from_challenge(&secret_key, &challenge); - let bytes = postcard::to_allocvec(&client_auth)?; - let decoded: ClientAuth = postcard::from_bytes(&bytes)?; + let bytes = postcard::to_allocvec(&client_auth).e()?; + let decoded: ClientAuth = postcard::from_bytes(&bytes).e()?; assert_eq!(client_auth.public_key, decoded.public_key); assert_eq!(client_auth.key_material_suffix, decoded.key_material_suffix); @@ -515,7 +570,7 @@ mod tests { } #[test] - fn test_challenge_verification() -> TestResult { + fn test_challenge_verification() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); let challenge = ServerChallenge::new(rand::rngs::OsRng); let client_auth = ClientAuth::new_from_challenge(&secret_key, &challenge); diff --git a/iroh-relay/src/protos/io.rs b/iroh-relay/src/protos/io.rs index 3e645e6f6e6..68337f37a91 100644 --- a/iroh-relay/src/protos/io.rs +++ b/iroh-relay/src/protos/io.rs @@ -4,13 +4,14 @@ use std::{ task::{Context, Poll}, }; -use anyhow::Result; use bytes::Bytes; use n0_future::{ready, Sink, Stream}; use tokio::io::{AsyncRead, AsyncWrite}; use crate::ExportKeyingMaterial; +use super::handshake::Error; + #[derive(derive_more::Debug)] pub(crate) struct HandshakeIo { #[cfg(not(wasm_browser))] @@ -48,7 +49,7 @@ impl ExportKeyingMate } impl Stream for HandshakeIo { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -79,7 +80,7 @@ impl Stream for HandshakeIo { } impl Sink for HandshakeIo { - type Error = anyhow::Error; + type Error = Error; fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { #[cfg(not(wasm_browser))] diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 46670bc793e..cdd4b09c9fc 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -225,7 +225,7 @@ pub(crate) async fn legacy_send_client_key /// upon it's initial connection. #[cfg(any(test, feature = "server"))] #[deprecated = "switch to proper handshake"] -pub(crate) async fn legacy_recv_client_key> + Unpin>( +pub(crate) async fn legacy_recv_client_key> + Unpin>( stream: S, ) -> Result<(PublicKey, ClientInfo), E> where diff --git a/iroh-relay/src/quic.rs b/iroh-relay/src/quic.rs index 45acdf7a89d..384dbfef531 100644 --- a/iroh-relay/src/quic.rs +++ b/iroh-relay/src/quic.rs @@ -357,6 +357,8 @@ mod tests { #[traced_test] #[cfg(feature = "test-utils")] async fn quic_endpoint_basic() -> Result { + use super::server::{QuicConfig, QuicServer}; + let host: Ipv4Addr = "127.0.0.1".parse().unwrap(); // create a server config with self signed certificates let (_, server_config) = super::super::server::testing::self_signed_tls_certs_and_config(); diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 88295bc20bd..dd55c170fa0 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -159,7 +159,7 @@ pub(super) struct TlsConfig { #[non_exhaustive] pub enum ServeConnectionError { #[snafu(display("TLS[acme] handshake"))] - Handshake { + TlsHandshake { source: std::io::Error, #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, @@ -217,6 +217,13 @@ pub enum AcceptError { #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, + #[snafu(display("Handshake failed"))] + Handshake { + #[allow(clippy::result_large_err)] + source: handshake::Error, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + }, #[snafu(display("Unexpected client version {version}, expected {expected_version}"))] UnexpectedClientVersion { version: usize, @@ -651,7 +658,7 @@ impl Inner { #[allow(deprecated)] let (client_key, info) = legacy_recv_client_key(&mut io) .await - .context("unable to receive client information")?; + .context(RecvClientKeySnafu)?; if info.version != PROTOCOL_VERSION { return Err(UnexpectedClientVersionSnafu { @@ -674,7 +681,9 @@ impl Inner { let mut io = HandshakeIo { io: websocket }; - let client_info = handshake::serverside(&mut io, rand::rngs::OsRng).await?; + let client_info = handshake::serverside(&mut io, rand::rngs::OsRng) + .await + .context(HandshakeSnafu)?; ( client_info.public_key, @@ -808,7 +817,7 @@ impl RelayService { let tls_stream = start_handshake .into_stream(config) .await - .context(HandshakeSnafu)?; + .context(TlsHandshakeSnafu)?; self.serve_connection(MaybeTlsStream::Tls(tls_stream)) .await .context(HttpsSnafu)?; From 140e967eccd2facafbf9373a200fe2e54afd4623 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 5 Jun 2025 16:52:36 +0200 Subject: [PATCH 10/80] (WIP) remove legacy relay protocol path --- iroh-relay/src/client.rs | 7 - iroh-relay/src/client/conn.rs | 77 +-------- iroh-relay/src/client/streams.rs | 125 ++------------ iroh-relay/src/client/tls.rs | 141 +-------------- iroh-relay/src/http.rs | 15 +- iroh-relay/src/protos/handshake.rs | 1 + iroh-relay/src/protos/relay.rs | 246 +-------------------------- iroh-relay/src/server.rs | 132 +------------- iroh-relay/src/server/client.rs | 14 +- iroh-relay/src/server/clients.rs | 16 +- iroh-relay/src/server/http_server.rs | 68 ++------ iroh-relay/src/server/streams.rs | 115 +++++++------ 12 files changed, 133 insertions(+), 824 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 46a1aa5f895..0ff544662e1 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -220,13 +220,6 @@ impl ClientBuilder { let (conn, local_addr) = self.connect_ws().await?; (conn, Some(local_addr)) } - #[cfg(not(wasm_browser))] - Protocol::Relay => { - let (conn, local_addr) = self.connect_relay().await?; - (conn, Some(local_addr)) - } - #[cfg(wasm_browser)] - Protocol::Relay => return Err(RelayProtoNotAvailableSnafu.build()), }; event!( diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index da7df2e479f..6e117dd9421 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -14,23 +14,20 @@ use iroh_base::{NodeId, SecretKey}; use n0_future::{time::Duration, Sink, Stream}; use nested_enum_utils::common_fields; use snafu::{Backtrace, ResultExt, Snafu}; -#[cfg(not(wasm_browser))] -use tokio_util::codec::Framed; use tracing::debug; use super::KeyCache; -use crate::protos::{ - handshake, - relay::{ - ClientInfo, Frame, RecvError as RecvRelayError, SendError as SendRelayError, - MAX_PACKET_SIZE, PROTOCOL_VERSION, - }, -}; #[cfg(not(wasm_browser))] use crate::{ - client::streams::{MaybeTlsStream, MaybeTlsStreamChained, ProxyStream}, + client::streams::{MaybeTlsStream, ProxyStream}, protos::io::HandshakeIo, - protos::relay::RelayCodec, +}; +use crate::{ + protos::{ + handshake, + relay::{Frame, RecvError as RecvRelayError, SendError as SendRelayError}, + }, + MAX_PACKET_SIZE, }; /// Error for sending messages to the relay server. @@ -101,11 +98,6 @@ pub enum RecvError { /// invariants. #[derive(derive_more::Debug)] pub(crate) enum Conn { - #[cfg(not(wasm_browser))] - Relay { - #[debug("Framed")] - conn: Framed, - }, #[cfg(not(wasm_browser))] Ws { #[debug("WebSocketStream>")] @@ -141,34 +133,6 @@ impl Conn { }) } - /// Constructs a new websocket connection, including the initial server handshake. - #[cfg(not(wasm_browser))] - pub(crate) async fn new_relay( - conn: MaybeTlsStreamChained, - key_cache: KeyCache, - secret_key: &SecretKey, - ) -> Result { - use n0_future::SinkExt; - - let conn = Framed::new(conn, RelayCodec::new(key_cache)); - - let mut conn = conn.sink_err_into(); - - // exchange information with the server - debug!("server_handshake: started"); - let client_info = ClientInfo { - version: PROTOCOL_VERSION, - }; - debug!("server_handshake: sending client_key: {:?}", &client_info); - #[allow(deprecated)] - crate::protos::relay::legacy_send_client_key(&mut conn, secret_key, &client_info).await?; - debug!("server_handshake: done"); - - Ok(Self::Relay { - conn: conn.into_inner(), - }) - } - #[cfg(not(wasm_browser))] pub(crate) async fn new_ws( conn: tokio_websockets::WebSocketStream>, @@ -194,15 +158,6 @@ impl Stream for Conn { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { - #[cfg(not(wasm_browser))] - Self::Relay { ref mut conn } => match ready!(Pin::new(conn).poll_next(cx)) { - Some(Ok(frame)) => { - let message = ReceivedMessage::try_from(frame); - Poll::Ready(Some(message)) - } - Some(Err(err)) => Poll::Ready(Some(Err(err.into()))), - None => Poll::Ready(None), - }, #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, @@ -253,8 +208,6 @@ impl Sink for Conn { fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { - #[cfg(not(wasm_browser))] - Self::Relay { ref mut conn } => Pin::new(conn).poll_ready(cx).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_ready(cx).map_err(Into::into), #[cfg(wasm_browser)] @@ -271,8 +224,6 @@ impl Sink for Conn { } } match *self { - #[cfg(not(wasm_browser))] - Self::Relay { ref mut conn } => Pin::new(conn).start_send(frame).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn) .start_send(tokio_websockets::Message::binary({ @@ -290,8 +241,6 @@ impl Sink for Conn { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { - #[cfg(not(wasm_browser))] - Self::Relay { ref mut conn } => Pin::new(conn).poll_flush(cx).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_flush(cx).map_err(Into::into), #[cfg(wasm_browser)] @@ -303,8 +252,6 @@ impl Sink for Conn { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { - #[cfg(not(wasm_browser))] - Self::Relay { ref mut conn } => Pin::new(conn).poll_close(cx).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_flush(cx).map_err(Into::into), #[cfg(wasm_browser)] @@ -320,8 +267,6 @@ impl Sink for Conn { fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { - #[cfg(not(wasm_browser))] - Self::Relay { ref mut conn } => Pin::new(conn).poll_ready(cx).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_ready(cx).map_err(Into::into), #[cfg(wasm_browser)] @@ -338,8 +283,6 @@ impl Sink for Conn { } let frame = Frame::from(item); match *self { - #[cfg(not(wasm_browser))] - Self::Relay { ref mut conn } => Pin::new(conn).start_send(frame).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn) .start_send(tokio_websockets::Message::binary({ @@ -357,8 +300,6 @@ impl Sink for Conn { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { - #[cfg(not(wasm_browser))] - Self::Relay { ref mut conn } => Pin::new(conn).poll_flush(cx).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_flush(cx).map_err(Into::into), #[cfg(wasm_browser)] @@ -370,8 +311,6 @@ impl Sink for Conn { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { - #[cfg(not(wasm_browser))] - Self::Relay { ref mut conn } => Pin::new(conn).poll_close(cx).map_err(Into::into), #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_close(cx).map_err(Into::into), #[cfg(wasm_browser)] diff --git a/iroh-relay/src/client/streams.rs b/iroh-relay/src/client/streams.rs index e1ecdc5afa4..b2695427843 100644 --- a/iroh-relay/src/client/streams.rs +++ b/iroh-relay/src/client/streams.rs @@ -5,8 +5,6 @@ use std::{ }; use bytes::Bytes; -use hyper::upgrade::{Parts, Upgraded}; -use hyper_util::rt::TokioIo; use tokio::{ io::{AsyncRead, AsyncWrite}, net::TcpStream, @@ -16,111 +14,6 @@ use crate::ExportKeyingMaterial; use super::util; -#[allow(clippy::large_enum_variant)] -pub enum MaybeTlsStreamChained { - Raw(util::Chain, ProxyStream>), - Tls(util::Chain, tokio_rustls::client::TlsStream>), - #[cfg(all(test, feature = "server"))] - Mem(tokio::io::DuplexStream), -} - -impl AsyncRead for MaybeTlsStreamChained { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - match &mut *self { - Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf), - Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), - #[cfg(all(test, feature = "server"))] - Self::Mem(stream) => Pin::new(stream).poll_read(cx, buf), - } - } -} - -impl AsyncWrite for MaybeTlsStreamChained { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match &mut *self { - Self::Raw(stream) => Pin::new(stream.get_mut().1).poll_write(cx, buf), - Self::Tls(stream) => Pin::new(stream.get_mut().1).poll_write(cx, buf), - #[cfg(all(test, feature = "server"))] - Self::Mem(stream) => Pin::new(stream).poll_write(cx, buf), - } - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match &mut *self { - Self::Raw(stream) => Pin::new(stream.get_mut().1).poll_flush(cx), - Self::Tls(stream) => Pin::new(stream.get_mut().1).poll_flush(cx), - #[cfg(all(test, feature = "server"))] - Self::Mem(stream) => Pin::new(stream).poll_flush(cx), - } - } - - fn poll_shutdown( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - match &mut *self { - Self::Raw(stream) => Pin::new(stream.get_mut().1).poll_shutdown(cx), - Self::Tls(stream) => Pin::new(stream.get_mut().1).poll_shutdown(cx), - #[cfg(all(test, feature = "server"))] - Self::Mem(stream) => Pin::new(stream).poll_shutdown(cx), - } - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[std::io::IoSlice<'_>], - ) -> Poll> { - match &mut *self { - Self::Raw(stream) => Pin::new(stream.get_mut().1).poll_write_vectored(cx, bufs), - Self::Tls(stream) => Pin::new(stream.get_mut().1).poll_write_vectored(cx, bufs), - #[cfg(all(test, feature = "server"))] - Self::Mem(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), - } - } -} - -pub fn downcast_upgrade(upgraded: Upgraded) -> Option { - match upgraded.downcast::>>() { - Ok(Parts { read_buf, io, .. }) => { - // Prepend data to the reader to avoid data loss - let read_buf = std::io::Cursor::new(read_buf); - match io.into_inner() { - MaybeTlsStream::Raw(conn) => { - Some(MaybeTlsStreamChained::Raw(util::chain(read_buf, conn))) - } - MaybeTlsStream::Tls(conn) => { - Some(MaybeTlsStreamChained::Tls(util::chain(read_buf, conn))) - } - } - } - Err(upgraded) => { - if let Ok(Parts { read_buf, io, .. }) = - upgraded.downcast::>>() - { - let conn = io.into_inner(); - - // Prepend data to the reader to avoid data loss - let conn = util::chain(std::io::Cursor::new(read_buf), conn); - return Some(MaybeTlsStreamChained::Tls(conn)); - } - - None - } - } -} - #[derive(Debug)] #[allow(clippy::large_enum_variant)] pub enum ProxyStream { @@ -205,6 +98,8 @@ impl ProxyStream { pub enum MaybeTlsStream { Raw(IO), Tls(tokio_rustls::client::TlsStream), + #[cfg(test)] + Test(tokio::io::DuplexStream), } impl ExportKeyingMaterial for MaybeTlsStream { @@ -233,6 +128,8 @@ impl AsyncRead for MaybeTlsStream { match &mut *self { Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf), Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), + #[cfg(test)] + Self::Test(stream) => Pin::new(stream).poll_read(cx, buf), } } } @@ -246,6 +143,8 @@ impl AsyncWrite for MaybeTlsStream { match &mut *self { Self::Raw(stream) => Pin::new(stream).poll_write(cx, buf), Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), + #[cfg(test)] + Self::Test(stream) => Pin::new(stream).poll_write(cx, buf), } } @@ -256,6 +155,8 @@ impl AsyncWrite for MaybeTlsStream { match &mut *self { Self::Raw(stream) => Pin::new(stream).poll_flush(cx), Self::Tls(stream) => Pin::new(stream).poll_flush(cx), + #[cfg(test)] + Self::Test(stream) => Pin::new(stream).poll_flush(cx), } } @@ -266,6 +167,8 @@ impl AsyncWrite for MaybeTlsStream { match &mut *self { Self::Raw(stream) => Pin::new(stream).poll_shutdown(cx), Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), + #[cfg(test)] + Self::Test(stream) => Pin::new(stream).poll_shutdown(cx), } } fn poll_write_vectored( @@ -276,6 +179,8 @@ impl AsyncWrite for MaybeTlsStream { match &mut *self { Self::Raw(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), Self::Tls(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), + #[cfg(test)] + Self::Test(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), } } } @@ -285,6 +190,8 @@ impl MaybeTlsStream { match self { Self::Raw(s) => s.local_addr(), Self::Tls(s) => s.get_ref().0.local_addr(), + #[cfg(test)] + Self::Test(_) => Ok(SocketAddr::new(std::net::Ipv4Addr::LOCALHOST.into(), 1337)), } } @@ -292,6 +199,8 @@ impl MaybeTlsStream { match self { Self::Raw(s) => s.peer_addr(), Self::Tls(s) => s.get_ref().0.peer_addr(), + #[cfg(test)] + Self::Test(_) => Ok(SocketAddr::new(std::net::Ipv4Addr::LOCALHOST.into(), 1337)), } } } @@ -301,6 +210,8 @@ impl AsRef for MaybeTlsStream { match self { Self::Raw(s) => s, Self::Tls(s) => s.get_ref().0, + #[cfg(test)] + Self::Test(_) => unimplemented!("can't grab underlying IO in MaybeTlsStream::Test"), } } } diff --git a/iroh-relay/src/client/tls.rs b/iroh-relay/src/client/tls.rs index fc3ba7eded0..447290968d3 100644 --- a/iroh-relay/src/client/tls.rs +++ b/iroh-relay/src/client/tls.rs @@ -17,20 +17,13 @@ use bytes::Bytes; use data_encoding::BASE64URL; use http_body_util::Empty; -use hyper::{ - body::Incoming, - header::{HOST, UPGRADE}, - upgrade::Parts, - Request, -}; +use hyper::{upgrade::Parts, Request}; use n0_future::{task, time}; use rustls::client::Resumption; use snafu::{OptionExt, ResultExt}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tracing::{error, info_span, Instrument}; use super::{ - streams::{downcast_upgrade, MaybeTlsStream, ProxyStream}, + streams::{MaybeTlsStream, ProxyStream}, *, }; use crate::defaults::timeouts::*; @@ -258,7 +251,7 @@ impl MaybeTlsStreamBuilder { .context(ProxyConnectSnafu)?; task::spawn(async move { if let Err(err) = conn.with_upgrades().await { - error!("Proxy connection failed: {:?}", err); + tracing::error!("Proxy connection failed: {:?}", err); } }); @@ -282,47 +275,6 @@ impl MaybeTlsStreamBuilder { } impl ClientBuilder { - /// Connects to configured relay using HTTP(S) with an upgrade header - /// set to [`HTTP_UPGRADE_PROTOCOL`]. - /// - /// [`HTTP_UPGRADE_PROTOCOL`]: crate::http::HTTP_UPGRADE_PROTOCOL - pub(super) async fn connect_relay(&self) -> Result<(Conn, SocketAddr), ConnectError> { - #[allow(unused_mut)] - let mut builder = - MaybeTlsStreamBuilder::new(self.url.clone().into(), self.dns_resolver.clone()) - .prefer_ipv6(self.prefer_ipv6()) - .proxy_url(self.proxy_url.clone()); - - #[cfg(any(test, feature = "test-utils"))] - if self.insecure_skip_cert_verify { - builder = builder.insecure_skip_cert_verify(self.insecure_skip_cert_verify); - } - - let stream = builder.connect().await?; - let local_addr = stream - .as_ref() - .local_addr() - .map_err(|_| NoLocalAddrSnafu.build())?; - let response = self.http_upgrade_relay(stream).await?; - - if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { - UnexpectedUpgradeStatusSnafu { - code: response.status(), - } - .fail()?; - } - - debug!("starting upgrade"); - let upgraded = hyper::upgrade::on(response).await.context(UpgradeSnafu)?; - - debug!("connection upgraded"); - let conn = downcast_upgrade(upgraded).expect("must use TcpStream or client::TlsStream"); - - let conn = Conn::new_relay(conn, self.key_cache.clone(), &self.secret_key).await?; - - Ok((conn, local_addr)) - } - pub(super) async fn connect_ws(&self) -> Result<(Conn, SocketAddr), ConnectError> { let mut dial_url = (*self.url).clone(); dial_url.set_path(RELAY_PATH); @@ -381,46 +333,6 @@ impl ClientBuilder { Ok((conn, local_addr)) } - /// Sends the HTTP upgrade request to the relay server. - async fn http_upgrade_relay(&self, io: T) -> Result, ConnectError> - where - T: AsyncRead + AsyncWrite + Send + Unpin + 'static, - { - use hyper_util::rt::TokioIo; - let host_header_value = - host_header_value(self.url.clone()).context(InvalidRelayUrlSnafu { - url: Url::from(self.url.clone()), - })?; - - let io = TokioIo::new(io); - let (mut request_sender, connection) = hyper::client::conn::http1::Builder::new() - .handshake(io) - .await - .context(UpgradeSnafu)?; - task::spawn( - // This task drives the HTTP exchange, completes once connection is upgraded. - async move { - debug!("HTTP upgrade driver started"); - if let Err(err) = connection.with_upgrades().await { - error!("HTTP upgrade error: {err:#}"); - } - debug!("HTTP upgrade driver finished"); - } - .instrument(info_span!("http-driver")), - ); - debug!("Sending upgrade request"); - let req = Request::builder() - .uri(RELAY_PATH) - .header(UPGRADE, Protocol::Relay.upgrade_header()) - // https://datatracker.ietf.org/doc/html/rfc2616#section-14.23 - // > A client MUST include a Host header field in all HTTP/1.1 request messages. - // This header value helps reverse proxies identify how to forward requests. - .header(HOST, host_header_value) - .body(http_body_util::Empty::::new()) - .expect("fixed config"); - request_sender.send_request(req).await.context(UpgradeSnafu) - } - /// Reports whether IPv4 dials should be slightly /// delayed to give IPv6 a better chance of winning dial races. /// Implementations should only return true if IPv6 is expected @@ -434,23 +346,6 @@ impl ClientBuilder { } } -/// Returns none if no valid url host was found. -fn host_header_value(relay_url: RelayUrl) -> Option { - // grab the host, turns e.g. https://example.com:8080/xyz -> example.com. - let relay_url_host = relay_url.host_str()?; - - // strip the trailing dot, if present: example.com. -> example.com - let relay_url_host = relay_url_host.strip_suffix('.').unwrap_or(relay_url_host); - // build the host header value (reserve up to 6 chars for the ":" and port digits): - let mut host_header_value = String::with_capacity(relay_url_host.len() + 6); - host_header_value += relay_url_host; - if let Some(port) = relay_url.port() { - host_header_value += ":"; - host_header_value += &port.to_string(); - } - Some(host_header_value) -} - fn url_port(url: &Url) -> Option { if let Some(port) = url.port() { return Some(port); @@ -462,33 +357,3 @@ fn url_port(url: &Url) -> Option { _ => None, } } - -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use n0_snafu::Result; - use tracing_test::traced_test; - - use super::*; - - #[test] - #[traced_test] - fn test_host_header_value() -> Result { - let cases = [ - ( - "https://euw1-1.relay.iroh.network.", - "euw1-1.relay.iroh.network", - ), - ("http://localhost:8080", "localhost:8080"), - ]; - - for (url, expected_host) in cases { - let relay_url = RelayUrl::from_str(url)?; - let host = host_header_value(relay_url).unwrap(); - assert_eq!(host, expected_host); - } - - Ok(()) - } -} diff --git a/iroh-relay/src/http.rs b/iroh-relay/src/http.rs index 4f371f5b979..415d22ea3d4 100644 --- a/iroh-relay/src/http.rs +++ b/iroh-relay/src/http.rs @@ -1,6 +1,5 @@ //! HTTP-specific constants for the relay server and client. -pub(crate) const HTTP_UPGRADE_PROTOCOL: &str = "iroh derp http"; pub(crate) const WEBSOCKET_UPGRADE_PROTOCOL: &str = "websocket"; #[cfg(feature = "server")] // only used in the server for now pub(crate) const SUPPORTED_WEBSOCKET_VERSION: &str = "13"; @@ -12,14 +11,10 @@ pub const RELAY_PATH: &str = "/relay"; pub const RELAY_PROBE_PATH: &str = "/ping"; /// The legacy HTTP path under which the relay used to accept relaying connections. /// We keep this for backwards compatibility. -#[cfg(feature = "server")] // legacy paths only used on server-side for backwards compat -pub(crate) const LEGACY_RELAY_PATH: &str = "/derp"; /// The HTTP upgrade protocol used for relaying. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Protocol { - /// Relays over the custom relaying protocol with a custom HTTP upgrade header. - Relay, /// Relays over websockets. /// /// Originally introduced to support browser connections. @@ -28,10 +23,7 @@ pub enum Protocol { impl Default for Protocol { fn default() -> Self { - #[cfg(not(wasm_browser))] - return Self::Relay; - #[cfg(wasm_browser)] - return Self::Websocket; + Self::Websocket } } @@ -39,7 +31,6 @@ impl Protocol { /// The HTTP upgrade header used or expected. pub const fn upgrade_header(&self) -> &'static str { match self { - Protocol::Relay => HTTP_UPGRADE_PROTOCOL, Protocol::Websocket => WEBSOCKET_UPGRADE_PROTOCOL, } } @@ -47,9 +38,7 @@ impl Protocol { /// Tries to match the value of an HTTP upgrade header to figure out which protocol should be initiated. pub fn parse_header(header: &http::HeaderValue) -> Option { let header_bytes = header.as_bytes(); - if header_bytes == Protocol::Relay.upgrade_header().as_bytes() { - Some(Protocol::Relay) - } else if header_bytes == Protocol::Websocket.upgrade_header().as_bytes() { + if header_bytes == Protocol::Websocket.upgrade_header().as_bytes() { Some(Protocol::Websocket) } else { None diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index c81540076f8..702bcb40205 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -161,6 +161,7 @@ pub enum Error { impl ServerChallenge { /// TODO(matheus23): docs + #[cfg(feature = "server")] pub(crate) fn new(mut rng: impl RngCore + CryptoRng) -> Self { let mut challenge = [0u8; 16]; rng.fill_bytes(&mut challenge); diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index cdd4b09c9fc..0bc656659e2 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -13,7 +13,7 @@ //! * server then sends `FrameType::RecvPacket` to recipient use bytes::{BufMut, Bytes}; -use iroh_base::{PublicKey, SecretKey, Signature, SignatureError}; +use iroh_base::{PublicKey, Signature, SignatureError}; #[cfg(feature = "server")] use n0_future::time::Duration; use n0_future::{time, Sink, SinkExt}; @@ -35,7 +35,7 @@ pub const MAX_PACKET_SIZE: usize = 64 * 1024; /// /// This is also the minimum burst size that a rate-limiter has to accept. #[cfg(not(wasm_browser))] -const MAX_FRAME_SIZE: usize = 1024 * 1024; +pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024; /// The Relay magic number, sent in the FrameType::ClientInfo frame upon initial connection. const MAGIC: &str = "RELAY🔑"; @@ -197,97 +197,6 @@ pub(crate) async fn write_frame + Unpin>( Ok(()) } -/// Writes a `FrameType::ClientInfo`, including the client's [`PublicKey`], -/// and the client's [`ClientInfo`], sealed using the server's [`PublicKey`]. -/// -/// Flushes after writing. -#[deprecated = "switch to proper handshake"] -pub(crate) async fn legacy_send_client_key + Unpin>( - mut writer: S, - client_secret_key: &SecretKey, - client_info: &ClientInfo, -) -> Result<(), SendError> { - let msg = postcard::to_stdvec(client_info)?; - let signature = client_secret_key.sign(&msg); - - writer - .send(Frame::ClientInfo { - client_public_key: client_secret_key.public(), - message: msg.into(), - signature, - }) - .await?; - writer.flush().await?; - Ok(()) -} - -/// Reads the `FrameType::ClientInfo` frame from the client (its proof of identity) -/// upon it's initial connection. -#[cfg(any(test, feature = "server"))] -#[deprecated = "switch to proper handshake"] -pub(crate) async fn legacy_recv_client_key> + Unpin>( - stream: S, -) -> Result<(PublicKey, ClientInfo), E> -where - E: From, -{ - // the client is untrusted at this point, limit the input size even smaller than our usual - // maximum frame size, and give a timeout - - // TODO: variable recv size: 256 * 1024 - let buf = tokio::time::timeout( - std::time::Duration::from_secs(10), - recv_frame(FrameType::ClientInfo, stream), - ) - .await - .map_err(RecvError::from)??; - - if let Frame::ClientInfo { - client_public_key, - message, - signature, - } = buf - { - client_public_key - .verify(&message, &signature) - .map_err(RecvError::from)?; - - let info: ClientInfo = postcard::from_bytes(&message).map_err(RecvError::from)?; - Ok((client_public_key, info)) - } else { - Err(UnexpectedFrameSnafu { - got: buf.typ(), - expected: FrameType::ClientInfo, - } - .build() - .into()) - } -} - -/// The protocol for the relay server. -/// -/// This is a framed protocol, using [`tokio_util::codec`] to turn the streams of bytes into -/// [`Frame`]s. -#[cfg(not(wasm_browser))] -#[derive(Debug, Clone)] -pub(crate) struct RelayCodec { - cache: KeyCache, -} - -#[cfg(not(wasm_browser))] -impl RelayCodec { - #[cfg(test)] - pub fn test() -> Self { - Self { - cache: KeyCache::test(), - } - } - - pub(crate) fn new(cache: KeyCache) -> Self { - Self { cache } - } -} - /// The frames in the [`RelayCodec`]. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum Frame { @@ -369,6 +278,7 @@ impl Frame { /// Serialized length with frame header. #[cfg(feature = "server")] pub(crate) fn len_with_header(&self) -> usize { + const HEADER_LEN: usize = 5; // TODO(matheus23): This is used with the rate-limiter. It really shouldn't be. The websocket frames work on a different level! self.len() + HEADER_LEN } @@ -568,89 +478,6 @@ impl Frame { } } -// No need for framing when using websockets, thus this is cfg-ed out for browsers: -#[cfg(not(wasm_browser))] -// rustc doesn't figure out that the trait impls mean it's not unused -#[cfg_attr(not(wasm_browser), allow(unused))] -pub use framing::*; - -#[cfg(not(wasm_browser))] -mod framing { - use bytes::{Buf, BytesMut}; - use tokio_util::codec::{Decoder, Encoder}; - - use super::*; - - pub(super) const HEADER_LEN: usize = 5; - - impl Decoder for RelayCodec { - type Item = Frame; - type Error = RecvError; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - // Need at least 5 bytes - if src.len() < HEADER_LEN { - return Ok(None); - } - - // Can't use the `Buf::get_*` APIs, as that advances the buffer. - let Some(frame_type) = src.first().map(|b| FrameType::from(*b)) else { - return Ok(None); // Not enough bytes - }; - let Some(frame_len) = src - .get(1..5) - .and_then(|s| TryInto::<[u8; 4]>::try_into(s).ok()) - .map(u32::from_be_bytes) - .map(|l| l as usize) - else { - return Ok(None); // Not enough bytes - }; - - if frame_len > MAX_FRAME_SIZE { - return Err(FrameTooLargeSnafu { frame_len }.build()); - } - - if src.len() < HEADER_LEN + frame_len { - // Optimization: prereserve the buffer space - src.reserve(HEADER_LEN + frame_len - src.len()); - - return Ok(None); - } - - // advance the header - src.advance(HEADER_LEN); - - let content = src.split_to(frame_len).freeze(); - let frame = Frame::from_bytes(frame_type, content, &self.cache)?; - - Ok(Some(frame)) - } - } - - impl Encoder for RelayCodec { - type Error = std::io::Error; - - fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { - let frame_len: usize = frame.len(); - if frame_len > MAX_FRAME_SIZE { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Frame of length {} is too large.", frame_len), - )); - } - - let frame_len_u32 = u32::try_from(frame_len).expect("just checked"); - - dst.reserve(HEADER_LEN + frame_len); - dst.put_u8(frame.typ().into()); - dst.put_u32(frame_len_u32); - frame.write_to(dst); - - Ok(()) - } - } -} - /// Receives the next frame and matches the frame type. If the correct type is found returns the content, /// otherwise an error. #[cfg(any(test, feature = "server"))] @@ -685,52 +512,11 @@ where #[cfg(test)] mod tests { use data_encoding::HEXLOWER; + use iroh_base::SecretKey; use n0_snafu::{Result, ResultExt}; - use tokio_util::codec::{FramedRead, FramedWrite}; use super::*; - #[tokio::test] - #[cfg(feature = "server")] - async fn test_basic_read_write() -> Result { - let (reader, writer) = tokio::io::duplex(1024); - let mut reader = FramedRead::new(reader, RelayCodec::test()); - let mut writer = FramedWrite::new(writer, RelayCodec::test()); - - let expect_buf = b"hello world!"; - let expected_frame = Frame::Health { - problem: expect_buf.to_vec().into(), - }; - write_frame(&mut writer, expected_frame.clone(), None).await?; - writer.flush().await.context("flush")?; - println!("{:?}", reader); - let buf = recv_frame(FrameType::Health, &mut reader).await?; - assert_eq!(expect_buf.len(), buf.len()); - assert_eq!(expected_frame, buf); - - Ok(()) - } - - #[tokio::test] - #[allow(deprecated)] - async fn test_send_recv_client_key() -> Result { - let (reader, writer) = tokio::io::duplex(1024); - let mut reader = FramedRead::new(reader, RelayCodec::test()); - let mut writer = - FramedWrite::new(writer, RelayCodec::test()).sink_map_err(ConnSendError::from); - - let client_key = SecretKey::generate(rand::thread_rng()); - let client_info = ClientInfo { - version: PROTOCOL_VERSION, - }; - println!("client_key pub {:?}", client_key.public()); - legacy_send_client_key(&mut writer, &client_key, &client_info).await?; - let (client_pub_key, got_client_info) = legacy_recv_client_key(&mut reader).await?; - assert_eq!(client_key.public(), client_pub_key); - assert_eq!(client_info, got_client_info); - Ok(()) - } - #[test] fn test_frame_snapshot() -> Result { let client_key = SecretKey::from_bytes(&[42u8; 32]); @@ -831,8 +617,8 @@ mod tests { #[cfg(test)] mod proptests { use bytes::BytesMut; + use iroh_base::SecretKey; use proptest::prelude::*; - use tokio_util::codec::{Decoder, Encoder}; use super::*; @@ -922,17 +708,6 @@ mod proptests { } proptest! { - - // Test that we can roundtrip a frame to bytes - #[test] - fn frame_roundtrip(frame in frame()) { - let mut buf = BytesMut::new(); - let mut codec = RelayCodec::test(); - codec.encode(frame.clone(), &mut buf).unwrap(); - let decoded = codec.decode(&mut buf).unwrap().unwrap(); - prop_assert_eq!(frame, decoded); - } - #[test] fn frame_ws_roundtrip(frame in frame()) { let mut encoded = Vec::new(); @@ -940,16 +715,5 @@ mod proptests { let decoded = Frame::decode_from_ws_msg(Bytes::from(encoded), &KeyCache::test()).unwrap(); prop_assert_eq!(frame, decoded); } - - // Test that typical invalid frames will result in an error - #[test] - fn broken_frame_handling(frame in frame()) { - let mut buf = BytesMut::new(); - let mut codec = RelayCodec::test(); - codec.encode(frame.clone(), &mut buf).unwrap(); - inject_error(&mut buf); - let decoded = codec.decode(&mut buf); - prop_assert!(decoded.is_err(), "{:?}", decoded); - } } } diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index 73bd29d712e..cb096856c41 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -877,7 +877,7 @@ mod tests { use std::{net::Ipv4Addr, time::Duration}; use bytes::Bytes; - use http::{header::UPGRADE, StatusCode}; + use http::StatusCode; use iroh_base::{NodeId, RelayUrl, SecretKey}; use n0_future::{FutureExt, SinkExt, StreamExt}; use n0_snafu::{Result, ResultExt}; @@ -892,7 +892,7 @@ mod tests { use crate::{ client::{conn::ReceivedMessage, ClientBuilder, SendMessage}, dns::DnsResolver, - http::{Protocol, HTTP_UPGRADE_PROTOCOL}, + http::Protocol, protos, }; @@ -1010,77 +1010,7 @@ mod tests { #[tokio::test] #[traced_test] - async fn test_relay_client_legacy_route() { - let server = spawn_local_relay().await.unwrap(); - // We're testing the legacy endpoint at `/derp` - let endpoint_url = format!("http://{}/derp", server.http_addr().unwrap()); - - let client = reqwest::Client::new(); - let result = client - .get(endpoint_url) - .header(UPGRADE, HTTP_UPGRADE_PROTOCOL) - .send() - .await - .unwrap(); - - assert_eq!(result.status(), StatusCode::SWITCHING_PROTOCOLS); - } - - #[tokio::test] - #[traced_test] - async fn test_relay_clients_both_relay() -> Result<()> { - let server = spawn_local_relay().await.unwrap(); - let relay_url = format!("http://{}", server.http_addr().unwrap()); - let relay_url: RelayUrl = relay_url.parse().unwrap(); - - // set up client a - let a_secret_key = SecretKey::generate(rand::thread_rng()); - let a_key = a_secret_key.public(); - let resolver = dns_resolver(); - let mut client_a = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver.clone()) - .connect() - .await?; - - // set up client b - let b_secret_key = SecretKey::generate(rand::thread_rng()); - let b_key = b_secret_key.public(); - let mut client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver.clone()) - .connect() - .await?; - - // send message from a to b - let msg = Bytes::from("hello, b"); - let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?; - if let ReceivedMessage::ReceivedPacket { - remote_node_id, - data, - } = res - { - assert_eq!(a_key, remote_node_id); - assert_eq!(msg, data); - } else { - panic!("client_b received unexpected message {res:?}"); - } - - // send message from b to a - let msg = Bytes::from("howdy, a"); - let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?; - if let ReceivedMessage::ReceivedPacket { - remote_node_id, - data, - } = res - { - assert_eq!(b_key, remote_node_id); - assert_eq!(msg, data); - } else { - panic!("client_a received unexpected message {res:?}"); - } - Ok(()) - } - - #[tokio::test] - #[traced_test] - async fn test_relay_clients_both_websockets() -> Result<()> { + async fn test_relay_clients() -> Result<()> { let server = spawn_local_relay().await?; let relay_url = format!("http://{}", server.http_addr().unwrap()); @@ -1140,62 +1070,6 @@ mod tests { Ok(()) } - #[tokio::test] - #[traced_test] - async fn test_relay_clients_websocket_and_relay() -> Result<()> { - let server = spawn_local_relay().await.unwrap(); - - let relay_url = format!("http://{}", server.http_addr().unwrap()); - let relay_url: RelayUrl = relay_url.parse().unwrap(); - - // set up client a - let a_secret_key = SecretKey::generate(rand::thread_rng()); - let a_key = a_secret_key.public(); - let resolver = dns_resolver(); - let mut client_a = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver) - .connect() - .await?; - - // set up client b - let b_secret_key = SecretKey::generate(rand::thread_rng()); - let b_key = b_secret_key.public(); - let resolver = dns_resolver(); - let mut client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver) - .protocol(Protocol::Websocket) // Use websockets - .connect() - .await?; - - // send message from a to b - let msg = Bytes::from("hello, b"); - let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?; - - if let ReceivedMessage::ReceivedPacket { - remote_node_id, - data, - } = res - { - assert_eq!(a_key, remote_node_id); - assert_eq!(msg, data); - } else { - panic!("client_b received unexpected message {res:?}"); - } - - // send message from b to a - let msg = Bytes::from("howdy, a"); - let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?; - if let ReceivedMessage::ReceivedPacket { - remote_node_id, - data, - } = res - { - assert_eq!(b_key, remote_node_id); - assert_eq!(msg, data); - } else { - panic!("client_a received unexpected message {res:?}"); - } - Ok(()) - } - #[tokio::test] #[traced_test] async fn test_stun() { diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 6692b7aa10c..3ebe5b1da13 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -761,7 +761,7 @@ mod tests { use super::*; use crate::{ - protos::relay::{recv_frame, FrameType, RelayCodec}, + protos::relay::{recv_frame, FrameType}, server::streams::MaybeTlsStream, }; @@ -774,9 +774,8 @@ mod tests { let node_id = SecretKey::generate(rand::thread_rng()).public(); let (io, io_rw) = tokio::io::duplex(1024); - let mut io_rw = Framed::new(io_rw, RelayCodec::test()); - let stream = - RelayedStream::Relay(Framed::new(MaybeTlsStream::Test(io), RelayCodec::test())); + let mut io_rw = RelayedStream::test_client(io_rw); + let stream = RelayedStream::test_server(io); let clients = Clients::default(); let metrics = Arc::new(Metrics::default()); @@ -895,11 +894,8 @@ mod tests { // Build the rate limited stream. let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _); - let mut frame_writer = Framed::new(io_write, RelayCodec::test()); - let stream = RelayedStream::Relay(Framed::new( - MaybeTlsStream::Test(io_read), - RelayCodec::test(), - )); + let mut frame_writer = RelayedStream::test_client(io_write); + let stream = RelayedStream::test_server(io_read); let mut stream = RateLimitedRelayedStream::new(stream, limiter, Default::default()); // Prepare a frame to send, assert its size. diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 1122d560527..273f423b70c 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -195,29 +195,25 @@ mod tests { use bytes::Bytes; use iroh_base::SecretKey; use n0_snafu::{Result, ResultExt}; - use tokio::io::DuplexStream; - use tokio_util::codec::{Framed, FramedRead}; use super::*; use crate::{ - protos::relay::{recv_frame, Frame, FrameType, RelayCodec}, + protos::relay::{recv_frame, Frame, FrameType}, server::streams::{MaybeTlsStream, RelayedStream}, + KeyCache, }; - fn test_client_builder(key: NodeId) -> (Config, FramedRead) { - let (test_io, io) = tokio::io::duplex(1024); + fn test_client_builder(key: NodeId) -> (Config, RelayedStream) { + let (server, client) = tokio::io::duplex(1024); ( Config { node_id: key, - stream: RelayedStream::Relay(Framed::new( - MaybeTlsStream::Test(io), - RelayCodec::test(), - )), + stream: RelayedStream::test_client(client), write_timeout: Duration::from_secs(1), channel_capacity: 10, rate_limit: None, }, - FramedRead::new(test_io, RelayCodec::test()), + RelayedStream::test_server(server), ) } diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index dd55c170fa0..3e077dd5b5c 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -18,18 +18,16 @@ use nested_enum_utils::common_fields; use snafu::{Backtrace, ResultExt, Snafu}; use tokio::net::{TcpListener, TcpStream}; use tokio_rustls_acme::AcmeAcceptor; -use tokio_util::{codec::Framed, sync::CancellationToken, task::AbortOnDropHandle}; +use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{debug, debug_span, error, info, info_span, trace, warn, Instrument}; -use super::{clients::Clients, streams::StreamError, AccessConfig, SpawnError}; +use super::{clients::Clients, AccessConfig, SpawnError}; use crate::protos::{handshake, io::HandshakeIo}; #[allow(deprecated)] use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, - http::{Protocol, LEGACY_RELAY_PATH, RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, - protos::relay::{ - legacy_recv_client_key, Frame, RelayCodec, PER_CLIENT_SEND_QUEUE_DEPTH, PROTOCOL_VERSION, - }, + http::{Protocol, RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, + protos::relay::{Frame, PER_CLIENT_SEND_QUEUE_DEPTH}, server::{ client::Config, metrics::Metrics, @@ -210,13 +208,6 @@ pub enum ServeConnectionError { #[derive(Debug, Snafu)] #[non_exhaustive] pub enum AcceptError { - #[snafu(display("Unable to receive client information"))] - RecvClientKey { - #[allow(clippy::result_large_err)] - source: StreamError, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, - }, #[snafu(display("Handshake failed"))] Handshake { #[allow(clippy::result_large_err)] @@ -224,13 +215,6 @@ pub enum AcceptError { #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, - #[snafu(display("Unexpected client version {version}, expected {expected_version}"))] - UnexpectedClientVersion { - version: usize, - expected_version: usize, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, - }, #[snafu(transparent)] Io { source: std::io::Error }, #[snafu(display("Client not authenticated: {key:?}"))] @@ -573,7 +557,7 @@ impl Service> for RelayService { // Create a client if the request hits the relay endpoint. if matches!( (req.method(), req.uri().path()), - (&hyper::Method::GET, LEGACY_RELAY_PATH | RELAY_PATH) + (&hyper::Method::GET, RELAY_PATH) ) { let this = self.clone(); return Box::pin(async move { this.call_client_conn(req).await.map_err(Into::into) }); @@ -649,27 +633,6 @@ impl Inner { trace!(?protocol, "accept: start"); let (client_key, mut io) = match protocol { - Protocol::Relay => { - self.metrics.relay_accepts.inc(); - let mut io = - RelayedStream::Relay(Framed::new(io, RelayCodec::new(self.key_cache.clone()))); - - trace!("accept: recv client key"); - #[allow(deprecated)] - let (client_key, info) = legacy_recv_client_key(&mut io) - .await - .context(RecvClientKeySnafu)?; - - if info.version != PROTOCOL_VERSION { - return Err(UnexpectedClientVersionSnafu { - version: info.version, - expected_version: PROTOCOL_VERSION, - } - .build()); - } - - (client_key, io) - } Protocol::Websocket => { self.metrics.websocket_accepts.inc(); // Since we already did the HTTP upgrade in the previous step, @@ -687,7 +650,10 @@ impl Inner { ( client_info.public_key, - RelayedStream::Ws(io.io, self.key_cache.clone()), + RelayedStream { + inner: io.io, + key_cache: self.key_cache.clone(), + }, ) } }; @@ -895,7 +861,6 @@ mod tests { use crate::{ client::{ conn::{Conn, ReceivedMessage, SendMessage}, - streams::MaybeTlsStreamChained, Client, ClientBuilder, ConnectError, }, dns::DnsResolver, @@ -1111,8 +1076,9 @@ mod tests { } async fn make_test_client(client: tokio::io::DuplexStream, key: &SecretKey) -> Result { - let client = MaybeTlsStreamChained::Mem(client); - let client = Conn::new_relay(client, KeyCache::test(), key).await?; + let client = crate::client::streams::MaybeTlsStream::Test(client); + let client = tokio_websockets::ClientBuilder::new().take_over(client); + let client = Conn::new_ws(client, KeyCache::test(), key).await?; Ok(client) } @@ -1135,7 +1101,7 @@ mod tests { let (client_a, rw_a) = tokio::io::duplex(10); let s = service.clone(); let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Relay, MaybeTlsStream::Test(rw_a)) + s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_a)) .await }); let mut client_a = make_test_client(client_a, &key_a).await?; @@ -1147,7 +1113,7 @@ mod tests { let (client_b, rw_b) = tokio::io::duplex(10); let s = service.clone(); let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Relay, MaybeTlsStream::Test(rw_b)) + s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_b)) .await }); let mut client_b = make_test_client(client_b, &key_b).await?; @@ -1223,7 +1189,7 @@ mod tests { let (client_a, rw_a) = tokio::io::duplex(10); let s = service.clone(); let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Relay, MaybeTlsStream::Test(rw_a)) + s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_a)) .await }); let mut client_a = make_test_client(client_a, &key_a).await?; @@ -1235,7 +1201,7 @@ mod tests { let (client_b, rw_b) = tokio::io::duplex(10); let s = service.clone(); let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Relay, MaybeTlsStream::Test(rw_b)) + s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_b)) .await }); let mut client_b = make_test_client(client_b, &key_b).await?; @@ -1281,7 +1247,7 @@ mod tests { let (new_client_b, new_rw_b) = tokio::io::duplex(10); let s = service.clone(); let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Relay, MaybeTlsStream::Test(new_rw_b)) + s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(new_rw_b)) .await }); let mut new_client_b = make_test_client(new_client_b, &key_b).await?; diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index b566fd60e45..09ff1c4dd83 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -9,11 +9,10 @@ use bytes::BytesMut; use n0_future::{Sink, Stream}; use snafu::Snafu; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; use tokio_websockets::WebSocketStream; use crate::{ - protos::relay::{Frame, RecvError, RelayCodec}, + protos::relay::{Frame, RecvError}, ExportKeyingMaterial, KeyCache, }; @@ -21,9 +20,36 @@ use crate::{ /// /// The stream receives message from the client while the sink sends them to the client. #[derive(Debug)] -pub(crate) enum RelayedStream { - Relay(Framed), - Ws(WebSocketStream, KeyCache), +pub(crate) struct RelayedStream { + pub(crate) inner: WebSocketStream, + pub(crate) key_cache: KeyCache, +} + +#[cfg(test)] +impl RelayedStream { + pub(crate) fn test_client(stream: tokio::io::DuplexStream) -> Self { + Self { + inner: tokio_websockets::ClientBuilder::new() + .limits( + tokio_websockets::Limits::default() + .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE)), + ) + .take_over(MaybeTlsStream::Test(stream)), + key_cache: KeyCache::test(), + } + } + + pub(crate) fn test_server(stream: tokio::io::DuplexStream) -> Self { + Self { + inner: tokio_websockets::ServerBuilder::new() + .limits( + tokio_websockets::Limits::default() + .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE)), + ) + .serve(MaybeTlsStream::Test(stream)), + key_cache: KeyCache::test(), + } + } } fn ws_to_io_err(e: tokio_websockets::Error) -> std::io::Error { @@ -37,37 +63,31 @@ impl Sink for RelayedStream { type Error = std::io::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - Self::Relay(ref mut framed) => Pin::new(framed).poll_ready(cx), - Self::Ws(ref mut ws, _) => Pin::new(ws).poll_ready(cx).map_err(ws_to_io_err), - } + Pin::new(&mut self.inner) + .poll_ready(cx) + .map_err(ws_to_io_err) } fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> { - match *self { - Self::Relay(ref mut framed) => Pin::new(framed).start_send(item), - Self::Ws(ref mut ws, _) => Pin::new(ws) - .start_send(tokio_websockets::Message::binary({ - let mut buf = BytesMut::new(); - item.encode_for_ws_msg(&mut buf); - tokio_websockets::Payload::from(buf.freeze()) - })) - .map_err(ws_to_io_err), - } + Pin::new(&mut self.inner) + .start_send(tokio_websockets::Message::binary({ + let mut buf = BytesMut::new(); + item.encode_for_ws_msg(&mut buf); + tokio_websockets::Payload::from(buf.freeze()) + })) + .map_err(ws_to_io_err) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - Self::Relay(ref mut framed) => Pin::new(framed).poll_flush(cx), - Self::Ws(ref mut ws, _) => Pin::new(ws).poll_flush(cx).map_err(ws_to_io_err), - } + Pin::new(&mut self.inner) + .poll_flush(cx) + .map_err(ws_to_io_err) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - Self::Relay(ref mut framed) => Pin::new(framed).poll_close(cx), - Self::Ws(ref mut ws, _) => Pin::new(ws).poll_close(cx).map_err(ws_to_io_err), - } + Pin::new(&mut self.inner) + .poll_close(cx) + .map_err(ws_to_io_err) } } @@ -85,30 +105,25 @@ impl Stream for RelayedStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - Self::Relay(ref mut framed) => Pin::new(framed).poll_next(cx).map_err(Into::into), - Self::Ws(ref mut ws, ref cache) => match Pin::new(ws).poll_next(cx) { - Poll::Ready(Some(Ok(msg))) => { - if msg.is_close() { - // Indicate the stream is done when we receive a close message. - // Note: We don't have to poll the stream to completion for it to close gracefully. - return Poll::Ready(None); - } - if !msg.is_binary() { - tracing::warn!( - ?msg, - "Got websocket message of unsupported type, skipping." - ); - return Poll::Pending; - } - let frame = Frame::decode_from_ws_msg(msg.into_payload().into(), cache) - .map_err(Into::into); - Poll::Ready(Some(frame)) + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(Some(Ok(msg))) => { + if msg.is_close() { + // Indicate the stream is done when we receive a close message. + // Note: We don't have to poll the stream to completion for it to close gracefully. + return Poll::Ready(None); + } + if !msg.is_binary() { + tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); + return Poll::Pending; } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - }, + Poll::Ready(Some( + Frame::decode_from_ws_msg(msg.into_payload().into(), &self.key_cache) + .map_err(Into::into), + )) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } } From ae67e9ce1c94de616ab6430332ec3e184e0fc79c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 5 Jun 2025 16:53:28 +0200 Subject: [PATCH 11/80] `cargo make format` --- iroh-relay/src/client/streams.rs | 3 +-- iroh-relay/src/protos/handshake.rs | 5 ++--- iroh-relay/src/protos/io.rs | 3 +-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/iroh-relay/src/client/streams.rs b/iroh-relay/src/client/streams.rs index b2695427843..d1565f69b06 100644 --- a/iroh-relay/src/client/streams.rs +++ b/iroh-relay/src/client/streams.rs @@ -10,9 +10,8 @@ use tokio::{ net::TcpStream, }; -use crate::ExportKeyingMaterial; - use super::util; +use crate::ExportKeyingMaterial; #[derive(Debug)] #[allow(clippy::large_enum_variant)] diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 702bcb40205..12a192a203d 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -11,9 +11,8 @@ use quinn_proto::{coding::Codec, VarInt}; use rand::{CryptoRng, RngCore}; use snafu::{Backtrace, ResultExt, Snafu}; -use crate::ExportKeyingMaterial; - use super::relay::SendError; +use crate::ExportKeyingMaterial; /// TODO(matheus23) docs pub(crate) const PROTOCOL_VERSION: &[u8] = b"1"; @@ -394,7 +393,6 @@ fn deserialize_frame(frame: Bytes) -> Re #[cfg(all(test, feature = "server"))] mod tests { - use crate::ExportKeyingMaterial; use bytes::BytesMut; use iroh_base::SecretKey; use n0_future::{Sink, SinkExt, Stream, TryStreamExt}; @@ -402,6 +400,7 @@ mod tests { use tokio_util::codec::{Framed, LengthDelimitedCodec}; use super::{ClientAuth, ServerChallenge}; + use crate::ExportKeyingMaterial; struct TestKeyingMaterial { shared_secret: Option, diff --git a/iroh-relay/src/protos/io.rs b/iroh-relay/src/protos/io.rs index 68337f37a91..c62061fda81 100644 --- a/iroh-relay/src/protos/io.rs +++ b/iroh-relay/src/protos/io.rs @@ -8,9 +8,8 @@ use bytes::Bytes; use n0_future::{ready, Sink, Stream}; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::ExportKeyingMaterial; - use super::handshake::Error; +use crate::ExportKeyingMaterial; #[derive(derive_more::Debug)] pub(crate) struct HandshakeIo { From 0433a33d0e646f55b50270083a0302e5d3929c46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 6 Jun 2025 11:59:35 +0200 Subject: [PATCH 12/80] WIP --- iroh-relay/src/server/client.rs | 222 ++------------------------ iroh-relay/src/server/clients.rs | 4 +- iroh-relay/src/server/http_server.rs | 8 +- iroh-relay/src/server/metrics.rs | 4 +- iroh-relay/src/server/streams.rs | 223 +++++++++++++++++++++++++-- 5 files changed, 230 insertions(+), 231 deletions(-) diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 3ebe5b1da13..daeb9ba7cbf 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -1,13 +1,10 @@ //! The server-side representation of an ongoing client relaying connection. -use std::{ - collections::HashSet, future::Future, num::NonZeroU32, pin::Pin, sync::Arc, task::Poll, - time::Duration, -}; +use std::{collections::HashSet, sync::Arc, time::Duration}; use bytes::Bytes; use iroh_base::NodeId; -use n0_future::{FutureExt, Sink, SinkExt, Stream, StreamExt}; +use n0_future::{SinkExt, StreamExt}; use nested_enum_utils::common_fields; use rand::Rng; use snafu::{Backtrace, GenerateImplicitData, Snafu}; @@ -17,7 +14,7 @@ use tokio::{ time::MissedTickBehavior, }; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; -use tracing::{debug, error, instrument, trace, warn, Instrument}; +use tracing::{debug, trace, warn, Instrument}; use crate::{ protos::{ @@ -28,7 +25,6 @@ use crate::{ clients::Clients, metrics::Metrics, streams::{RelayedStream, StreamError}, - ClientRateLimit, }, PingTracker, }; @@ -49,7 +45,6 @@ pub(super) struct Config { pub(super) stream: RelayedStream, pub(super) write_timeout: Duration, pub(super) channel_capacity: usize, - pub(super) rate_limit: Option, } /// The [`Server`] side representation of a [`Client`]'s connection. @@ -86,24 +81,11 @@ impl Client { ) -> Client { let Config { node_id, - stream: io, + stream, write_timeout, channel_capacity, - rate_limit, } = config; - let stream = match rate_limit { - Some(cfg) => { - let mut quota = governor::Quota::per_second(cfg.bytes_per_second); - if let Some(max_burst) = cfg.max_burst_bytes { - quota = quota.allow_burst(max_burst); - } - let limiter = governor::RateLimiter::direct(quota); - RateLimitedRelayedStream::new(io, limiter, metrics.clone()) - } - None => RateLimitedRelayedStream::unlimited(io, metrics.clone()), - }; - let done = CancellationToken::new(); let (send_queue_s, send_queue_r) = mpsc::channel(channel_capacity); @@ -296,7 +278,7 @@ pub enum RunError { #[derive(Debug)] struct Actor { /// IO Stream to talk to the client - stream: RateLimitedRelayedStream, + stream: RelayedStream, /// Maximum time we wait to complete a write to the client timeout: Duration, /// Packets queued to send to the client @@ -537,187 +519,6 @@ impl ForwardPacketError { } } -/// Rate limiter for reading from a [`RelayedStream`]. -/// -/// The writes to the sink are not rate limited. -/// -/// This potentially buffers one frame if the rate limiter does not allows this frame. -/// While the frame is buffered the undernlying stream is no longer polled. -#[derive(Debug)] -struct RateLimitedRelayedStream { - inner: RelayedStream, - limiter: Option>, - state: State, - /// Keeps track if this stream was ever rate-limited. - limited_once: bool, - metrics: Arc, -} - -#[derive(derive_more::Debug)] -enum State { - #[debug("Blocked")] - Blocked { - /// Future which will complete when the item can be yielded. - delay: Pin + Send + Sync>>, - /// Item to yield when the `delay` future completes. - item: Result, - }, - Ready, -} - -impl RateLimitedRelayedStream { - fn new( - inner: RelayedStream, - limiter: governor::DefaultDirectRateLimiter, - metrics: Arc, - ) -> Self { - Self { - inner, - limiter: Some(Arc::new(limiter)), - state: State::Ready, - limited_once: false, - metrics, - } - } - - fn unlimited(inner: RelayedStream, metrics: Arc) -> Self { - Self { - inner, - limiter: None, - state: State::Ready, - limited_once: false, - metrics, - } - } -} - -impl RateLimitedRelayedStream { - /// Records metrics about being rate-limited. - fn record_rate_limited(&mut self) { - // TODO: add a label for the frame type. - self.metrics.frames_rx_ratelimited_total.inc(); - if !self.limited_once { - self.metrics.conns_rx_ratelimited_total.inc(); - self.limited_once = true; - } - } -} - -impl Stream for RateLimitedRelayedStream { - type Item = Result; - - #[instrument(name = "rate_limited_relayed_stream", skip_all)] - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let Some(ref limiter) = self.limiter else { - // If there is no rate-limiter directly poll the inner. - return Pin::new(&mut self.inner).poll_next(cx); - }; - let limiter = limiter.clone(); - loop { - match &mut self.state { - State::Ready => { - // Poll inner for a new item. - match Pin::new(&mut self.inner).poll_next(cx) { - Poll::Ready(Some(item)) => { - match &item { - Ok(frame) => { - // How many bytes does this frame consume? - let Ok(frame_len) = - TryInto::::try_into(frame.len_with_header()) - .and_then(TryInto::::try_into) - else { - error!("frame len not NonZeroU32, is MAX_FRAME_SIZE too large?"); - // Let this frame through so to not completely break. - return Poll::Ready(Some(item)); - }; - - match limiter.check_n(frame_len) { - Ok(Ok(_)) => return Poll::Ready(Some(item)), - Ok(Err(_)) => { - // Item is rate-limited. - self.record_rate_limited(); - let delay = Box::pin({ - let limiter = limiter.clone(); - async move { - limiter.until_n_ready(frame_len).await.ok(); - } - }); - self.state = State::Blocked { delay, item }; - continue; - } - Err(_insufficient_capacity) => { - error!( - "frame larger than bucket capacity: \ - configuration error: \ - max_burst_bytes < MAX_FRAME_SIZE?" - ); - // Let this frame through so to not completely break. - return Poll::Ready(Some(item)); - } - } - } - Err(_) => { - // Yielding errors is not rate-limited. - return Poll::Ready(Some(item)); - } - } - } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - } - } - State::Blocked { delay, .. } => { - match delay.poll(cx) { - Poll::Ready(_) => { - match std::mem::replace(&mut self.state, State::Ready) { - State::Ready => unreachable!(), - State::Blocked { item, .. } => { - // Yield the item directly, rate-limit has already been - // accounted for by awaiting the future. - return Poll::Ready(Some(item)); - } - } - } - Poll::Pending => return Poll::Pending, - } - } - } - } - } -} - -impl Sink for RateLimitedRelayedStream { - type Error = std::io::Error; - - fn poll_ready( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_ready(cx) - } - - fn start_send(mut self: Pin<&mut Self>, item: Frame) -> std::result::Result<(), Self::Error> { - Pin::new(&mut self.inner).start_send(item) - } - - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_close( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - Pin::new(&mut self.inner).poll_close(cx) - } -} - /// Tracks how many unique nodes have been seen during the last day. #[derive(Debug)] struct ClientCounter { @@ -752,18 +553,16 @@ impl ClientCounter { #[cfg(test)] mod tests { + use std::num::NonZeroU32; + use bytes::Bytes; use iroh_base::SecretKey; use n0_snafu::{Result, ResultExt}; - use tokio_util::codec::Framed; use tracing::info; use tracing_test::traced_test; use super::*; - use crate::{ - protos::relay::{recv_frame, FrameType}, - server::streams::MaybeTlsStream, - }; + use crate::protos::relay::{recv_frame, FrameType}; #[tokio::test] #[traced_test] @@ -780,7 +579,7 @@ mod tests { let clients = Clients::default(); let metrics = Arc::new(Metrics::default()); let actor = Actor { - stream: RateLimitedRelayedStream::unlimited(stream, metrics.clone()), + stream, timeout: Duration::from_secs(1), send_queue: send_queue_r, disco_send_queue: disco_send_queue_r, @@ -895,8 +694,7 @@ mod tests { // Build the rate limited stream. let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _); let mut frame_writer = RelayedStream::test_client(io_write); - let stream = RelayedStream::test_server(io_read); - let mut stream = RateLimitedRelayedStream::new(stream, limiter, Default::default()); + let mut stream = RelayedStream::test_server_limited(io_read, limiter); // Prepare a frame to send, assert its size. let data = Bytes::from_static(b"hello world!!"); diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 273f423b70c..f863d268f51 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -199,8 +199,7 @@ mod tests { use super::*; use crate::{ protos::relay::{recv_frame, Frame, FrameType}, - server::streams::{MaybeTlsStream, RelayedStream}, - KeyCache, + server::streams::RelayedStream, }; fn test_client_builder(key: NodeId) -> (Config, RelayedStream) { @@ -211,7 +210,6 @@ mod tests { stream: RelayedStream::test_client(client), write_timeout: Duration::from_secs(1), channel_capacity: 10, - rate_limit: None, }, RelayedStream::test_server(server), ) diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 3e077dd5b5c..8093045fd96 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -22,7 +22,6 @@ use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{debug, debug_span, error, info, info_span, trace, warn, Instrument}; use super::{clients::Clients, AccessConfig, SpawnError}; -use crate::protos::{handshake, io::HandshakeIo}; #[allow(deprecated)] use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, @@ -36,6 +35,10 @@ use crate::{ }, KeyCache, }; +use crate::{ + protos::{handshake, io::HandshakeIo}, + server::streams::RateLimited, +}; type BytesBody = http_body_util::Full; type HyperError = Box; @@ -631,6 +634,8 @@ impl Inner { async fn accept(&self, protocol: Protocol, io: MaybeTlsStream) -> Result<(), AcceptError> { use snafu::ResultExt; + let io = RateLimited::from_cfg(self.rate_limit, io, self.metrics.clone()); + trace!(?protocol, "accept: start"); let (client_key, mut io) = match protocol { Protocol::Websocket => { @@ -675,7 +680,6 @@ impl Inner { stream: io, write_timeout: self.write_timeout, channel_capacity: PER_CLIENT_SEND_QUEUE_DEPTH, - rate_limit: self.rate_limit, }; trace!("accept: create client"); let node_id = client_conn_builder.node_id; diff --git a/iroh-relay/src/server/metrics.rs b/iroh-relay/src/server/metrics.rs index bf9e52df125..476c042f414 100644 --- a/iroh-relay/src/server/metrics.rs +++ b/iroh-relay/src/server/metrics.rs @@ -56,8 +56,8 @@ pub struct Metrics { #[metrics(help = "Number of unknown frames sent to this server.")] pub unknown_frames: Counter, - /// Number of frames received from client connection which have been rate-limited. - pub frames_rx_ratelimited_total: Counter, + /// Number of bytes received from client connection which have been rate-limited. + pub bytes_rx_ratelimited_total: Counter, /// Number of client connections which have had any frames rate-limited. pub conns_rx_ratelimited_total: Counter, diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 09ff1c4dd83..b6c681a26e8 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -1,55 +1,78 @@ //! Streams used in the server-side implementation of iroh relays. use std::{ + num::NonZeroU32, pin::Pin, + sync::Arc, task::{Context, Poll}, }; use bytes::BytesMut; -use n0_future::{Sink, Stream}; +use governor::clock::Clock; +use n0_future::{ready, time, FutureExt, Sink, Stream}; use snafu::Snafu; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_websockets::WebSocketStream; +use tracing::{error, instrument}; use crate::{ protos::relay::{Frame, RecvError}, ExportKeyingMaterial, KeyCache, }; +use super::{ClientRateLimit, Metrics}; + /// A Stream and Sink for [`Frame`]s connected to a single relay client. /// /// The stream receives message from the client while the sink sends them to the client. #[derive(Debug)] pub(crate) struct RelayedStream { - pub(crate) inner: WebSocketStream, + pub(crate) inner: WebSocketStream>, pub(crate) key_cache: KeyCache, } #[cfg(test)] impl RelayedStream { pub(crate) fn test_client(stream: tokio::io::DuplexStream) -> Self { + let stream = MaybeTlsStream::Test(stream); + let stream = RateLimited::unlimited(stream, Arc::new(Metrics::default())); Self { inner: tokio_websockets::ClientBuilder::new() - .limits( - tokio_websockets::Limits::default() - .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE)), - ) - .take_over(MaybeTlsStream::Test(stream)), + .limits(Self::limits()) + .take_over(stream), key_cache: KeyCache::test(), } } pub(crate) fn test_server(stream: tokio::io::DuplexStream) -> Self { + let stream = MaybeTlsStream::Test(stream); + let stream = RateLimited::unlimited(stream, Arc::new(Metrics::default())); Self { inner: tokio_websockets::ServerBuilder::new() - .limits( - tokio_websockets::Limits::default() - .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE)), - ) - .serve(MaybeTlsStream::Test(stream)), + .limits(Self::limits()) + .serve(stream), key_cache: KeyCache::test(), } } + + pub(crate) fn test_server_limited( + stream: tokio::io::DuplexStream, + limiter: governor::DefaultDirectRateLimiter, + ) -> Self { + let stream = MaybeTlsStream::Test(stream); + let stream = RateLimited::new(stream, limiter, Arc::new(Metrics::default())); + Self { + inner: tokio_websockets::ServerBuilder::new() + .limits(Self::limits()) + .serve(stream), + key_cache: KeyCache::test(), + } + } + + fn limits() -> tokio_websockets::Limits { + tokio_websockets::Limits::default() + .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE)) + } } fn ws_to_io_err(e: tokio_websockets::Error) -> std::io::Error { @@ -227,3 +250,179 @@ impl AsyncWrite for MaybeTlsStream { } } } + +/// Rate limiter for reading from a [`RelayedStream`]. +/// +/// The writes to the sink are not rate limited. +/// +/// This potentially buffers one frame if the rate limiter does not allows this frame. +/// While the frame is buffered the undernlying stream is no longer polled. +#[derive(Debug)] +pub(crate) struct RateLimited { + inner: S, + limiter: Option>, + state: State, + /// Keeps track if this stream was ever rate-limited. + limited_once: bool, + metrics: Arc, +} + +#[derive(derive_more::Debug)] +enum State { + #[debug("Blocked")] + Blocked { + /// Future which will complete when the item can be yielded. + delay: Pin>, + }, + Ready, +} + +impl RateLimited { + pub(crate) fn from_cfg(cfg: Option, io: S, metrics: Arc) -> Self { + match cfg { + Some(cfg) => { + let mut quota = governor::Quota::per_second(cfg.bytes_per_second); + if let Some(max_burst) = cfg.max_burst_bytes { + quota = quota.allow_burst(max_burst); + } + let limiter = governor::RateLimiter::direct(quota); + Self::new(io, limiter, metrics) + } + None => Self::unlimited(io, metrics), + } + } + + pub(crate) fn new( + inner: S, + limiter: governor::DefaultDirectRateLimiter, + metrics: Arc, + ) -> Self { + Self { + inner, + limiter: Some(Arc::new(limiter)), + state: State::Ready, + limited_once: false, + metrics, + } + } + + pub(crate) fn unlimited(inner: S, metrics: Arc) -> Self { + Self { + inner, + limiter: None, + state: State::Ready, + limited_once: false, + metrics, + } + } +} + +impl RateLimited { + /// Records metrics about being rate-limited. + fn record_rate_limited(&mut self, bytes: NonZeroU32) { + // TODO: add a label for the frame type. + self.metrics + .bytes_rx_ratelimited_total + .inc_by(u32::from(bytes) as u64); + if !self.limited_once { + self.metrics.conns_rx_ratelimited_total.inc(); + self.limited_once = true; + } + } +} + +impl AsyncRead for RateLimited { + #[instrument(name = "rate_limited_poll_read", skip_all)] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + let Some(ref limiter) = self.limiter else { + // If there is no rate-limiter directly poll the inner. + return Pin::new(&mut self.inner).poll_read(cx, buf); + }; + let limiter = limiter.clone(); + loop { + match &mut self.state { + State::Ready => { + let bytes_before = buf.remaining(); + + // Poll inner for a new item. + ready!(Pin::new(&mut self.inner).poll_read(cx, buf))?; + + let bytes_read = bytes_before - buf.remaining(); + let Ok(bytes_read) = u32::try_from(bytes_read).and_then(NonZeroU32::try_from) + else { + // 0 bytes read, nothing to rate limit + return Poll::Ready(Ok(())); + }; + + match limiter.check_n(bytes_read) { + Ok(Ok(())) => {} + Ok(Err(not_until)) => { + let delay = not_until.wait_time_from(limiter.clock().now()); + // Item is rate-limited. + self.record_rate_limited(bytes_read); + self.state = State::Blocked { + delay: Box::pin(time::sleep(delay)), + }; + // Continue in `State::Blocked` + continue; + } + Err(capacity_err) => { + error!( + ?capacity_err, + ?bytes_read, + "read burst larger than bucket capacity" + ); + // Continue as normal though + } + } + return Poll::Ready(Ok(())); + } + State::Blocked { delay } => { + ready!(delay.poll(cx)); + self.state = State::Ready; + // Allow polling again, since the delay expired + continue; + } + } + } + } +} + +impl AsyncWrite for RateLimited { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +impl ExportKeyingMaterial for RateLimited { + fn export_keying_material>( + &self, + output: T, + label: &[u8], + context: Option<&[u8]>, + ) -> Option { + self.inner.export_keying_material(output, label, context) + } +} From ab193efc8028d313fbe7f37a5bf0918ac1e4b583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 6 Jun 2025 15:33:42 +0200 Subject: [PATCH 13/80] WIP --- iroh-relay/src/server/streams.rs | 99 +++++++++++++------------------- 1 file changed, 41 insertions(+), 58 deletions(-) diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index b6c681a26e8..2bbaa486b1f 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -261,22 +261,12 @@ impl AsyncWrite for MaybeTlsStream { pub(crate) struct RateLimited { inner: S, limiter: Option>, - state: State, + ready: Option>>, /// Keeps track if this stream was ever rate-limited. limited_once: bool, metrics: Arc, } -#[derive(derive_more::Debug)] -enum State { - #[debug("Blocked")] - Blocked { - /// Future which will complete when the item can be yielded. - delay: Pin>, - }, - Ready, -} - impl RateLimited { pub(crate) fn from_cfg(cfg: Option, io: S, metrics: Arc) -> Self { match cfg { @@ -300,7 +290,7 @@ impl RateLimited { Self { inner, limiter: Some(Arc::new(limiter)), - state: State::Ready, + ready: None, limited_once: false, metrics, } @@ -310,7 +300,7 @@ impl RateLimited { Self { inner, limiter: None, - state: State::Ready, + ready: None, limited_once: false, metrics, } @@ -339,54 +329,47 @@ impl AsyncRead for RateLimited { buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { let Some(ref limiter) = self.limiter else { - // If there is no rate-limiter directly poll the inner. + // If there is no rate-limiter, then directly poll the inner. return Pin::new(&mut self.inner).poll_read(cx, buf); }; let limiter = limiter.clone(); - loop { - match &mut self.state { - State::Ready => { - let bytes_before = buf.remaining(); - - // Poll inner for a new item. - ready!(Pin::new(&mut self.inner).poll_read(cx, buf))?; - - let bytes_read = bytes_before - buf.remaining(); - let Ok(bytes_read) = u32::try_from(bytes_read).and_then(NonZeroU32::try_from) - else { - // 0 bytes read, nothing to rate limit - return Poll::Ready(Ok(())); - }; - - match limiter.check_n(bytes_read) { - Ok(Ok(())) => {} - Ok(Err(not_until)) => { - let delay = not_until.wait_time_from(limiter.clock().now()); - // Item is rate-limited. - self.record_rate_limited(bytes_read); - self.state = State::Blocked { - delay: Box::pin(time::sleep(delay)), - }; - // Continue in `State::Blocked` - continue; - } - Err(capacity_err) => { - error!( - ?capacity_err, - ?bytes_read, - "read burst larger than bucket capacity" - ); - // Continue as normal though - } - } - return Poll::Ready(Ok(())); - } - State::Blocked { delay } => { - ready!(delay.poll(cx)); - self.state = State::Ready; - // Allow polling again, since the delay expired - continue; - } + // If we're currently limited, wait + if let Some(ready) = &mut self.ready { + ready!(ready.poll(cx)); + self.ready = None; + } + + // Poll inner for a new item. + let bytes_before = buf.remaining(); + ready!(Pin::new(&mut self.inner).poll_read(cx, buf))?; + let bytes_read = bytes_before - buf.remaining(); + + let Ok(bytes_read) = u32::try_from(bytes_read).and_then(NonZeroU32::try_from) else { + // 0 bytes read, nothing to rate limit + return Poll::Ready(Ok(())); + }; + + match limiter.check_n(bytes_read) { + Ok(Ok(())) => { + // We're fine + Poll::Ready(Ok(())) + } + Ok(Err(not_until)) => { + let delay = not_until.wait_time_from(limiter.clock().now()); + // Item is rate-limited. + self.record_rate_limited(bytes_read); + self.ready = Some(Box::pin(time::sleep(delay))); + // We already read the bytes into the buffer we were given, though + Poll::Ready(Ok(())) + } + Err(capacity_err) => { + error!( + ?capacity_err, + ?bytes_read, + "read burst larger than bucket capacity" + ); + // Continue as normal though + Poll::Ready(Ok(())) } } } From 6ada3d64ec5716d81f9e81ce819d8e3b1ff6daf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 6 Jun 2025 16:52:15 +0200 Subject: [PATCH 14/80] Our own rate limiter --- iroh-relay/src/server/client.rs | 9 +- iroh-relay/src/server/streams.rs | 217 +++++++++++++++++++++++-------- 2 files changed, 163 insertions(+), 63 deletions(-) diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index daeb9ba7cbf..9615c562c54 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -681,20 +681,17 @@ mod tests { Ok(()) } - #[tokio::test] + #[tokio::test(start_paused = true)] #[traced_test] async fn test_rate_limit() -> Result { const LIMIT: u32 = 50; const MAX_FRAMES: u32 = 100; - // Rate limiter allowing LIMIT bytes/s - let quota = governor::Quota::per_second(NonZeroU32::try_from(LIMIT).unwrap()); - let limiter = governor::RateLimiter::direct(quota); - // Build the rate limited stream. let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _); let mut frame_writer = RelayedStream::test_client(io_write); - let mut stream = RelayedStream::test_server_limited(io_read, limiter); + // Rate limiter allowing LIMIT bytes/s + let mut stream = RelayedStream::test_server_limited(io_read, LIMIT / 10, LIMIT); // Prepare a frame to send, assert its size. let data = Bytes::from_static(b"hello world!!"); diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 2bbaa486b1f..47c516f5d5a 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -1,19 +1,17 @@ //! Streams used in the server-side implementation of iroh relays. use std::{ - num::NonZeroU32, pin::Pin, sync::Arc, task::{Context, Poll}, }; use bytes::BytesMut; -use governor::clock::Clock; use n0_future::{ready, time, FutureExt, Sink, Stream}; use snafu::Snafu; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_websockets::WebSocketStream; -use tracing::{error, instrument}; +use tracing::instrument; use crate::{ protos::relay::{Frame, RecvError}, @@ -57,10 +55,16 @@ impl RelayedStream { pub(crate) fn test_server_limited( stream: tokio::io::DuplexStream, - limiter: governor::DefaultDirectRateLimiter, + max_burst_bytes: u32, + bytes_per_second: u32, ) -> Self { let stream = MaybeTlsStream::Test(stream); - let stream = RateLimited::new(stream, limiter, Arc::new(Metrics::default())); + let stream = RateLimited::new( + stream, + max_burst_bytes, + bytes_per_second, + Arc::new(Metrics::default()), + ); Self { inner: tokio_websockets::ServerBuilder::new() .limits(Self::limits()) @@ -260,23 +264,81 @@ impl AsyncWrite for MaybeTlsStream { #[derive(Debug)] pub(crate) struct RateLimited { inner: S, - limiter: Option>, - ready: Option>>, + bucket: Option, + bucket_refilled: Option>>, /// Keeps track if this stream was ever rate-limited. limited_once: bool, metrics: Arc, } +#[derive(Debug)] +struct Bucket { + // The current bucket fill + fill: i64, + // The maximum bucket fill + max: i64, + // The bucket's last fill time + last_fill: time::Instant, + // Interval length of one refill + refill_period: time::Duration, + // How much we re-fill per refill period + refill: i64, +} + +impl Bucket { + fn new(max: i64, bytes_per_second: i64, refill_period: time::Duration) -> Self { + // TODO(matheus23) convert to errors + debug_assert!(max > 0); + debug_assert!(bytes_per_second > 0); + debug_assert_ne!(refill_period.as_millis(), 0); + // milliseconds is the tokio timer resolution + Self { + fill: max, + max, + last_fill: time::Instant::now(), + refill_period, + refill: bytes_per_second * refill_period.as_millis() as i64 / 1000, + } + } + + fn update_state(&mut self) { + let now = time::Instant::now(); + let refill_periods = now.saturating_duration_since(self.last_fill).as_millis() as u32 + / self.refill_period.as_millis() as u32; + if refill_periods == 0 { + // Nothing to do - we won't refill yet + return; + } + + self.fill += refill_periods as i64 * self.refill; + self.fill = std::cmp::min(self.fill, self.max); + self.last_fill += self.refill_period * refill_periods; + } + + fn consume(&mut self, bytes: i64) -> Result<(), time::Instant> { + self.update_state(); + + self.fill -= bytes; + + if self.fill > 0 { + return Ok(()); + } + + let missing = -self.fill; + + let periods_needed = (missing / self.refill) + 1; + + Err(self.last_fill + periods_needed as u32 * self.refill_period) + } +} + impl RateLimited { pub(crate) fn from_cfg(cfg: Option, io: S, metrics: Arc) -> Self { match cfg { Some(cfg) => { - let mut quota = governor::Quota::per_second(cfg.bytes_per_second); - if let Some(max_burst) = cfg.max_burst_bytes { - quota = quota.allow_burst(max_burst); - } - let limiter = governor::RateLimiter::direct(quota); - Self::new(io, limiter, metrics) + let bytes_per_second = cfg.bytes_per_second.into(); + let max_burst_bytes = cfg.max_burst_bytes.map_or(bytes_per_second / 10, u32::from); + Self::new(io, max_burst_bytes, bytes_per_second, metrics) } None => Self::unlimited(io, metrics), } @@ -284,13 +346,18 @@ impl RateLimited { pub(crate) fn new( inner: S, - limiter: governor::DefaultDirectRateLimiter, + max_burst_bytes: u32, + bytes_per_second: u32, metrics: Arc, ) -> Self { Self { inner, - limiter: Some(Arc::new(limiter)), - ready: None, + bucket: Some(Bucket::new( + max_burst_bytes as i64, + bytes_per_second as i64, + time::Duration::from_millis(100), + )), + bucket_refilled: None, limited_once: false, metrics, } @@ -299,8 +366,8 @@ impl RateLimited { pub(crate) fn unlimited(inner: S, metrics: Arc) -> Self { Self { inner, - limiter: None, - ready: None, + bucket: None, + bucket_refilled: None, limited_once: false, metrics, } @@ -309,11 +376,9 @@ impl RateLimited { impl RateLimited { /// Records metrics about being rate-limited. - fn record_rate_limited(&mut self, bytes: NonZeroU32) { + fn record_rate_limited(&mut self, bytes: usize) { // TODO: add a label for the frame type. - self.metrics - .bytes_rx_ratelimited_total - .inc_by(u32::from(bytes) as u64); + self.metrics.bytes_rx_ratelimited_total.inc_by(bytes as u64); if !self.limited_once { self.metrics.conns_rx_ratelimited_total.inc(); self.limited_once = true; @@ -328,50 +393,32 @@ impl AsyncRead for RateLimited { cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { - let Some(ref limiter) = self.limiter else { + let this = &mut *self; + let Some(bucket) = &mut this.bucket else { // If there is no rate-limiter, then directly poll the inner. - return Pin::new(&mut self.inner).poll_read(cx, buf); + return Pin::new(&mut this.inner).poll_read(cx, buf); }; - let limiter = limiter.clone(); - // If we're currently limited, wait - if let Some(ready) = &mut self.ready { - ready!(ready.poll(cx)); - self.ready = None; + + // If we're currently limited, wait until we've got some bucket space again + if let Some(bucket_refilled) = &mut this.bucket_refilled { + ready!(bucket_refilled.poll(cx)); + this.bucket_refilled = None; } + // We're not currently limited, let's read + // Poll inner for a new item. let bytes_before = buf.remaining(); - ready!(Pin::new(&mut self.inner).poll_read(cx, buf))?; + ready!(Pin::new(&mut this.inner).poll_read(cx, buf))?; let bytes_read = bytes_before - buf.remaining(); - let Ok(bytes_read) = u32::try_from(bytes_read).and_then(NonZeroU32::try_from) else { - // 0 bytes read, nothing to rate limit - return Poll::Ready(Ok(())); - }; - - match limiter.check_n(bytes_read) { - Ok(Ok(())) => { - // We're fine - Poll::Ready(Ok(())) - } - Ok(Err(not_until)) => { - let delay = not_until.wait_time_from(limiter.clock().now()); - // Item is rate-limited. - self.record_rate_limited(bytes_read); - self.ready = Some(Box::pin(time::sleep(delay))); - // We already read the bytes into the buffer we were given, though - Poll::Ready(Ok(())) - } - Err(capacity_err) => { - error!( - ?capacity_err, - ?bytes_read, - "read burst larger than bucket capacity" - ); - // Continue as normal though - Poll::Ready(Ok(())) - } + // Record how much we've read, rate limit accordingly, if need be. + if let Err(refill_time) = bucket.consume(bytes_read as i64) { + this.record_rate_limited(bytes_read); + this.bucket_refilled = Some(Box::pin(time::sleep_until(refill_time))); } + + Poll::Ready(Ok(())) } } @@ -409,3 +456,59 @@ impl ExportKeyingMaterial for RateLimited { self.inner.export_keying_material(output, label, context) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use n0_future::time; + use n0_snafu::{Result, ResultExt}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tracing_test::traced_test; + + use crate::server::{streams::RateLimited, Metrics}; + + #[tokio::test(start_paused = true)] + #[traced_test] + async fn test_ratelimiter() -> Result { + let (read, mut write) = tokio::io::duplex(4096); + + let send_total = 10 * 1024 * 1024; // 10MiB + let send_data = vec![42u8; send_total]; + + let bytes_per_second = 12_345; + + let mut rate_limited = RateLimited::new( + read, + bytes_per_second / 10, + bytes_per_second, + Arc::new(Metrics::default()), + ); + + let before = time::Instant::now(); + n0_future::future::try_zip( + async { + let mut remaining = send_total; + let mut buf = [0u8; 4096]; + while remaining > 0 { + remaining -= rate_limited.read(&mut buf).await?; + } + Ok(()) + }, + async { + write.write_all(&send_data).await?; + write.flush().await + }, + ) + .await + .e()?; + + let duration = time::Instant::now().duration_since(before); + assert_ne!(duration.as_millis(), 0); + + let actual_bytes_per_second = send_total as f64 / duration.as_secs_f64(); + assert_eq!(actual_bytes_per_second.round() as u32, bytes_per_second); + + Ok(()) + } +} From cfe4e4d3f1b264e7288ef0fb4e836b01b2067502 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 9 Jun 2025 09:58:33 +0200 Subject: [PATCH 15/80] Remove unused code --- iroh-relay/src/client/conn.rs | 60 ----------------------------------- 1 file changed, 60 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 6e117dd9421..2e7225ebbc7 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -202,66 +202,6 @@ impl Stream for Conn { } } -// TODO(matheus23): Remove this impl, make `new_relay` work on the `Framed` directly, make the impl not rely on `ConnSendError`. -impl Sink for Conn { - type Error = SendError; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - #[cfg(not(wasm_browser))] - Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_ready(cx).map_err(Into::into), - #[cfg(wasm_browser)] - Self::WsBrowser { ref mut conn, .. } => { - Pin::new(conn).poll_ready(cx).map_err(Into::into) - } - } - } - - fn start_send(mut self: Pin<&mut Self>, frame: Frame) -> Result<(), Self::Error> { - if let Frame::SendPacket { dst_key: _, packet } = &frame { - if packet.len() > MAX_PACKET_SIZE { - return Err(ExceedsMaxPacketSizeSnafu { size: packet.len() }.build()); - } - } - match *self { - #[cfg(not(wasm_browser))] - Self::Ws { ref mut conn, .. } => Pin::new(conn) - .start_send(tokio_websockets::Message::binary({ - let mut buf = BytesMut::new(); - frame.encode_for_ws_msg(&mut buf); - tokio_websockets::Payload::from(buf.freeze()) - })) - .map_err(Into::into), - #[cfg(wasm_browser)] - Self::WsBrowser { ref mut conn, .. } => Pin::new(conn) - .start_send(ws_stream_wasm::WsMessage::Binary(frame.encode_for_ws_msg())) - .map_err(Into::into), - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - #[cfg(not(wasm_browser))] - Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_flush(cx).map_err(Into::into), - #[cfg(wasm_browser)] - Self::WsBrowser { ref mut conn, .. } => { - Pin::new(conn).poll_flush(cx).map_err(Into::into) - } - } - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - #[cfg(not(wasm_browser))] - Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_flush(cx).map_err(Into::into), - #[cfg(wasm_browser)] - Self::WsBrowser { ref mut conn, .. } => { - Pin::new(conn).poll_close(cx).map_err(Into::into) - } - } - } -} - impl Sink for Conn { type Error = SendError; From bb27d3d4150e3662214425ff103892730b6f0dfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 9 Jun 2025 10:40:12 +0200 Subject: [PATCH 16/80] Restructure tests, remove unused code --- iroh-relay/src/protos/handshake.rs | 1 + iroh-relay/src/protos/relay.rs | 100 ++--------------------------- iroh-relay/src/server/client.rs | 44 ++++++++++--- iroh-relay/src/server/clients.rs | 26 +++++++- 4 files changed, 65 insertions(+), 106 deletions(-) diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 12a192a203d..f9424bb5017 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -8,6 +8,7 @@ use n0_future::{ }; use nested_enum_utils::common_fields; use quinn_proto::{coding::Codec, VarInt}; +#[cfg(feature = "server")] use rand::{CryptoRng, RngCore}; use snafu::{Backtrace, ResultExt, Snafu}; diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 0bc656659e2..0b7e5e50c27 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -17,8 +17,6 @@ use iroh_base::{PublicKey, Signature, SignatureError}; #[cfg(feature = "server")] use n0_future::time::Duration; use n0_future::{time, Sink, SinkExt}; -#[cfg(any(test, feature = "server"))] -use n0_future::{Stream, StreamExt}; use nested_enum_utils::common_fields; use postcard::experimental::max_size::MaxSize; use serde::{Deserialize, Serialize}; @@ -251,37 +249,6 @@ impl Frame { } } - /// Serialized length (without the frame header) - #[cfg(not(wasm_browser))] // Not needed with websocket framing - thus not needed in browsers - pub(crate) fn len(&self) -> usize { - match self { - Frame::ClientInfo { - client_public_key: _, - message, - signature: _, - } => MAGIC.len() + PublicKey::LENGTH + message.len() + Signature::BYTE_SIZE, - Frame::SendPacket { dst_key: _, packet } => PublicKey::LENGTH + packet.len(), - Frame::RecvPacket { - src_key: _, - content, - } => PublicKey::LENGTH + content.len(), - Frame::KeepAlive => 0, - Frame::NotePreferred { .. } => 1, - Frame::NodeGone { .. } => PublicKey::LENGTH, - Frame::Ping { .. } => 8, - Frame::Pong { .. } => 8, - Frame::Health { problem } => problem.len(), - Frame::Restarting { .. } => 4 + 4, - } - } - - /// Serialized length with frame header. - #[cfg(feature = "server")] - pub(crate) fn len_with_header(&self) -> usize { - const HEADER_LEN: usize = 5; // TODO(matheus23): This is used with the rate-limiter. It really shouldn't be. The websocket frames work on a different level! - self.len() + HEADER_LEN - } - /// Tries to decode a frame received over websockets. /// /// Specifically, bytes received from a binary websocket message frame. @@ -298,13 +265,13 @@ impl Frame { /// Encodes this frame for sending over websockets. /// /// Specifically meant for being put into a binary websocket message frame. - pub(crate) fn encode_for_ws_msg(self, dst: &mut impl BufMut) { + pub(crate) fn encode_for_ws_msg(self, mut dst: O) -> O { dst.put_u8(self.typ().into()); - self.write_to(dst); + self.write_to(dst) } /// Writes it self to the given buffer. - fn write_to(&self, dst: &mut impl BufMut) { + fn write_to(&self, mut dst: O) -> O { match self { Frame::ClientInfo { client_public_key, @@ -352,6 +319,7 @@ impl Frame { dst.put_u32(*try_for); } } + dst } #[allow(clippy::result_large_err)] @@ -478,37 +446,6 @@ impl Frame { } } -/// Receives the next frame and matches the frame type. If the correct type is found returns the content, -/// otherwise an error. -#[cfg(any(test, feature = "server"))] -pub(crate) async fn recv_frame> + Unpin>( - frame_type: FrameType, - mut stream: S, -) -> Result -where - RecvError: Into, -{ - match stream.next().await { - Some(Ok(frame)) => { - if frame_type != frame.typ() { - return Err(UnexpectedFrameSnafu { - got: frame.typ(), - expected: frame_type, - } - .build() - .into()); - } - Ok(frame) - } - Some(Err(err)) => Err(err), - None => Err(RecvError::from(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "expected frame".to_string(), - )) - .into()), - } -} - #[cfg(test)] mod tests { use data_encoding::HEXLOWER; @@ -616,7 +553,6 @@ mod tests { #[cfg(test)] mod proptests { - use bytes::BytesMut; use iroh_base::SecretKey; use proptest::prelude::*; @@ -679,34 +615,6 @@ mod proptests { ] } - fn inject_error(buf: &mut BytesMut) { - fn is_fixed_size(tpe: FrameType) -> bool { - match tpe { - FrameType::KeepAlive - | FrameType::NotePreferred - | FrameType::Ping - | FrameType::Pong - | FrameType::Restarting - | FrameType::PeerGone => true, - FrameType::ClientInfo - | FrameType::Health - | FrameType::SendPacket - | FrameType::RecvPacket - | FrameType::Unknown => false, - } - } - let tpe: FrameType = buf[0].into(); - let mut len = u32::from_be_bytes(buf[1..5].try_into().unwrap()) as usize; - if is_fixed_size(tpe) { - buf.put_u8(0); - len += 1; - } else { - buf.resize(MAX_FRAME_SIZE + 1, 0); - len = MAX_FRAME_SIZE + 1; - } - buf[1..5].copy_from_slice(&u32::to_be_bytes(len as u32)); - } - proptest! { #[test] fn frame_ws_roundtrip(frame in frame()) { diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 9615c562c54..c913cdfa608 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -553,16 +553,38 @@ impl ClientCounter { #[cfg(test)] mod tests { - use std::num::NonZeroU32; - - use bytes::Bytes; + use bytes::{Bytes, BytesMut}; use iroh_base::SecretKey; + use n0_future::Stream; use n0_snafu::{Result, ResultExt}; use tracing::info; use tracing_test::traced_test; use super::*; - use crate::protos::relay::{recv_frame, FrameType}; + use crate::protos::relay::FrameType; + + async fn recv_frame< + E: snafu::Error + Sync + Send + 'static, + S: Stream> + Unpin, + >( + frame_type: FrameType, + mut stream: S, + ) -> Result { + match stream.next().await { + Some(Ok(frame)) => { + if frame_type != frame.typ() { + snafu::whatever!( + "Unepxected frame, got {}, but expected {}", + frame.typ(), + frame_type + ); + } + Ok(frame) + } + Some(Err(err)) => Err(err).e(), + None => snafu::whatever!("Unexpected EOF, expected frame {frame_type}"), + } + } #[tokio::test] #[traced_test] @@ -607,7 +629,7 @@ mod tests { data: Bytes::from(&data[..]), }; send_queue_s.send(packet.clone()).await.context("send")?; - let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await?; + let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; assert_eq!( frame, Frame::RecvPacket { @@ -622,7 +644,7 @@ mod tests { .send(packet.clone()) .await .context("send")?; - let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await?; + let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; assert_eq!( frame, Frame::RecvPacket { @@ -634,7 +656,7 @@ mod tests { // send peer_gone println!("send peer gone"); peer_gone_s.send(node_id).await.context("send")?; - let frame = recv_frame(FrameType::PeerGone, &mut io_rw).await?; + let frame = recv_frame(FrameType::PeerGone, &mut io_rw).await.e()?; assert_eq!(frame, Frame::NodeGone { node_id }); // Read tests @@ -694,13 +716,17 @@ mod tests { let mut stream = RelayedStream::test_server_limited(io_read, LIMIT / 10, LIMIT); // Prepare a frame to send, assert its size. - let data = Bytes::from_static(b"hello world!!"); + let data = Bytes::from_static(b"hello world!1eins"); let target = SecretKey::generate(rand::thread_rng()).public(); let frame = Frame::SendPacket { dst_key: target, packet: data.clone(), }; - let frame_len = frame.len_with_header(); + let frame_len = frame + .clone() + .encode_for_ws_msg(BytesMut::new()) + .freeze() + .len(); assert_eq!(frame_len, LIMIT as usize); // Send a frame, it should arrive. diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index f863d268f51..3848d021fae 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -194,14 +194,38 @@ mod tests { use bytes::Bytes; use iroh_base::SecretKey; + use n0_future::{Stream, StreamExt}; use n0_snafu::{Result, ResultExt}; use super::*; use crate::{ - protos::relay::{recv_frame, Frame, FrameType}, + protos::relay::{Frame, FrameType}, server::streams::RelayedStream, }; + async fn recv_frame< + E: snafu::Error + Sync + Send + 'static, + S: Stream> + Unpin, + >( + frame_type: FrameType, + mut stream: S, + ) -> Result { + match stream.next().await { + Some(Ok(frame)) => { + if frame_type != frame.typ() { + snafu::whatever!( + "Unepxected frame, got {}, but expected {}", + frame.typ(), + frame_type + ); + } + Ok(frame) + } + Some(Err(err)) => Err(err).e(), + None => snafu::whatever!("Unexpected EOF, expected frame {frame_type}"), + } + } + fn test_client_builder(key: NodeId) -> (Config, RelayedStream) { let (server, client) = tokio::io::duplex(1024); ( From 9a92b0da354e0bee4ccce4ecb07063b11a00d423 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 9 Jun 2025 11:20:04 +0200 Subject: [PATCH 17/80] Remove unused frames --- iroh-relay/src/client/conn.rs | 13 +- iroh-relay/src/protos/handshake.rs | 47 +++++- iroh-relay/src/protos/relay.rs | 213 ++++----------------------- iroh-relay/src/server/client.rs | 6 +- iroh-relay/src/server/http_server.rs | 6 +- iroh-relay/src/server/streams.rs | 4 +- iroh/src/magicsock/relay_actor.rs | 2 +- 7 files changed, 76 insertions(+), 215 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 2e7225ebbc7..d710e9fba4c 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -176,7 +176,7 @@ impl Stream for Conn { ); return Poll::Pending; } - let frame = Frame::decode_from_ws_msg(msg.into_payload().into(), key_cache)?; + let frame = Frame::from_bytes(msg.into_payload().into(), key_cache)?; let message = ReceivedMessage::try_from(frame); Poll::Ready(Some(message)) } @@ -227,7 +227,7 @@ impl Sink for Conn { Self::Ws { ref mut conn, .. } => Pin::new(conn) .start_send(tokio_websockets::Message::binary({ let mut buf = BytesMut::new(); - frame.encode_for_ws_msg(&mut buf); + frame.write_to(&mut buf); tokio_websockets::Payload::from(buf.freeze()) })) .map_err(Into::into), @@ -283,10 +283,6 @@ pub enum ReceivedMessage { /// Reply to a [`ReceivedMessage::Ping`] from a client or server /// with the payload sent previously in the ping. Pong([u8; 8]), - /// A one-way empty message from server to client, just to - /// keep the connection alive. It's like a [`ReceivedMessage::Ping`], but doesn't solicit - /// a reply from the client. - KeepAlive, /// A one-way message from server to client, declaring the connection health state. Health { /// If set, is a description of why the connection is unhealthy. @@ -315,11 +311,6 @@ impl TryFrom for ReceivedMessage { fn try_from(frame: Frame) -> std::result::Result { match frame { - Frame::KeepAlive => { - // A one-way keep-alive message that doesn't require an ack. - // This predated FrameType::Ping/FrameType::Pong. - Ok(ReceivedMessage::KeepAlive) - } Frame::NodeGone { node_id } => Ok(ReceivedMessage::NodeGone(node_id)), Frame::RecvPacket { src_key, content } => { let packet = ReceivedMessage::ReceivedPacket { diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index f9424bb5017..e2554168127 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -23,15 +23,44 @@ pub(crate) const PROTOCOL_VERSION: &[u8] = b"1"; #[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] pub enum FrameType { /// The frame type for the client challenge request - ClientRequestChallenge = 0, + ClientRequestChallenge = 1, /// The server frame type for the challenge response - ServerChallenge = 1, + ServerChallenge = 2, /// The client frame type for the authentication frame - ClientAuth = 2, + ClientAuth = 3, /// The server frame type for authentication confirmation - ServerConfirmsAuth = 3, + ServerConfirmsAuth = 4, /// The server frame type for authentication denial - ServerDeniesAuth = 4, + ServerDeniesAuth = 5, + /// 32B dest pub key + packet bytes + SendPacket = 10, + /// v0/1 packet bytes, v2: 32B src pub key + packet bytes + RecvPacket = 11, + /// no payload, no-op (to be replaced with ping/pong) + KeepAlive = 12, + /// Sent from server to client to signal that a previous sender is no longer connected. + /// + /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` + /// to B so B can forget that a reverse path exists on that connection to get back to A + /// + /// 32B pub key of peer that's gone + PeerGone = 14, + /// Frames 9-11 concern meshing, which we have eliminated from our version of the protocol. + /// Messages with these frames will be ignored. + /// 8 byte ping payload, to be echoed back in FrameType::Pong + Ping = 15, + /// 8 byte payload, the contents of ping being replied to + Pong = 16, + /// Sent from server to client to tell the client if their connection is + /// unhealthy somehow. + Health = 17, + + /// Sent from server to client for the server to declare that it's restarting. + /// Payload is two big endian u32 durations in milliseconds: when to reconnect, + /// and how long to try total. + /// + /// Handled on the `[relay::Client]`, but currently never sent on the `[relay::Server]` + Restarting = 18, /// The frame type was unknown. /// /// This frame is the result of parsing any future frame types that this implementation @@ -40,6 +69,12 @@ pub enum FrameType { Unknown, } +impl std::fmt::Display for FrameType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + impl From for VarInt { fn from(value: FrameType) -> Self { (value as u32).into() @@ -152,7 +187,7 @@ pub enum Error { tag: VarInt, expected_tags: Vec, }, - #[snafu(display("Handshake failed while deserializing {frame_type:?} frame"))] + #[snafu(display("Handshake failed while deserializing {frame_type} frame"))] DeserializationError { frame_type: FrameType, source: postcard::Error, diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 0b7e5e50c27..74b400fb762 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -13,7 +13,7 @@ //! * server then sends `FrameType::RecvPacket` to recipient use bytes::{BufMut, Bytes}; -use iroh_base::{PublicKey, Signature, SignatureError}; +use iroh_base::{PublicKey, SignatureError}; #[cfg(feature = "server")] use n0_future::time::Duration; use n0_future::{time, Sink, SinkExt}; @@ -35,9 +35,6 @@ pub const MAX_PACKET_SIZE: usize = 64 * 1024; #[cfg(not(wasm_browser))] pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024; -/// The Relay magic number, sent in the FrameType::ClientInfo frame upon initial connection. -const MAGIC: &str = "RELAY🔑"; - /// Interval in which we ping the relay server to ensure the connection is alive. /// /// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some @@ -47,26 +44,7 @@ pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(15); /// The number of packets buffered for sending per client #[cfg(feature = "server")] -pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512; //32; - -/// ProtocolVersion is bumped whenever there's a wire-incompatible change. -/// - version 1 (zero on wire): consistent box headers, in use by employee dev nodes a bit -/// - version 2: received packets have src addrs in FrameType::RecvPacket at beginning. -/// -/// NOTE: we are technically running a modified version of the protocol. -/// `FrameType::PeerPresent`, `FrameType::WatchConn`, `FrameType::ClosePeer`, have been removed. -/// The server will error on that connection if a client sends one of these frames. -/// We have split with the DERP protocol significantly starting with our relay protocol 3 -/// `FrameType::PeerPresent`, `FrameType::WatchConn`, `FrameType::ClosePeer`, `FrameType::ServerKey`, and `FrameType::ServerInfo` have been removed. -/// The server will error on that connection if a client sends one of these frames. -/// This materially affects the handshake protocol, and so relay nodes on version 3 will be unable to communicate -/// with nodes running earlier protocol versions. -pub(crate) const PROTOCOL_VERSION: usize = 3; - -/// Indicates this IS the client's home node -const PREFERRED: u8 = 1u8; -/// Indicates this IS NOT the client's home node -const NOT_PREFERRED: u8 = 0u8; +pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512; /// The one byte frame type at the beginning of the frame /// header. The second field is a big-endian u32 describing the @@ -80,10 +58,6 @@ pub enum FrameType { SendPacket = 4, /// v0/1 packet bytes, v2: 32B src pub key + packet bytes RecvPacket = 5, - /// no payload, no-op (to be replaced with ping/pong) - KeepAlive = 6, - /// 1 byte payload: 0x01 or 0x00 for whether this is client's home node - NotePreferred = 7, /// Sent from server to client to signal that a previous sender is no longer connected. /// /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` @@ -198,49 +172,20 @@ pub(crate) async fn write_frame + Unpin>( /// The frames in the [`RelayCodec`]. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum Frame { - ClientInfo { - client_public_key: PublicKey, - message: Bytes, - signature: Signature, - }, - SendPacket { - dst_key: PublicKey, - packet: Bytes, - }, - RecvPacket { - src_key: PublicKey, - content: Bytes, - }, - KeepAlive, - NotePreferred { - preferred: bool, - }, - NodeGone { - node_id: PublicKey, - }, - Ping { - data: [u8; 8], - }, - Pong { - data: [u8; 8], - }, - Health { - problem: Bytes, - }, - Restarting { - reconnect_in: u32, - try_for: u32, - }, + SendPacket { dst_key: PublicKey, packet: Bytes }, + RecvPacket { src_key: PublicKey, content: Bytes }, + NodeGone { node_id: PublicKey }, + Ping { data: [u8; 8] }, + Pong { data: [u8; 8] }, + Health { problem: Bytes }, + Restarting { reconnect_in: u32, try_for: u32 }, } impl Frame { pub(crate) fn typ(&self) -> FrameType { match self { - Frame::ClientInfo { .. } => FrameType::ClientInfo, Frame::SendPacket { .. } => FrameType::SendPacket, Frame::RecvPacket { .. } => FrameType::RecvPacket, - Frame::KeepAlive => FrameType::KeepAlive, - Frame::NotePreferred { .. } => FrameType::NotePreferred, Frame::NodeGone { .. } => FrameType::PeerGone, Frame::Ping { .. } => FrameType::Ping, Frame::Pong { .. } => FrameType::Pong, @@ -249,40 +194,12 @@ impl Frame { } } - /// Tries to decode a frame received over websockets. - /// - /// Specifically, bytes received from a binary websocket message frame. - #[allow(clippy::result_large_err)] - pub(crate) fn decode_from_ws_msg(bytes: Bytes, cache: &KeyCache) -> Result { - if bytes.is_empty() { - return Err(TooSmallSnafu.build()); - } - let typ = FrameType::from(bytes[0]); - let frame = Self::from_bytes(typ, bytes.slice(1..), cache)?; - Ok(frame) - } - /// Encodes this frame for sending over websockets. /// /// Specifically meant for being put into a binary websocket message frame. - pub(crate) fn encode_for_ws_msg(self, mut dst: O) -> O { + pub(crate) fn write_to(&self, mut dst: O) -> O { dst.put_u8(self.typ().into()); - self.write_to(dst) - } - - /// Writes it self to the given buffer. - fn write_to(&self, mut dst: O) -> O { match self { - Frame::ClientInfo { - client_public_key, - message, - signature, - } => { - dst.put(MAGIC.as_bytes()); - dst.put(client_public_key.as_ref()); - dst.put(&signature.to_bytes()[..]); - dst.put(&message[..]); - } Frame::SendPacket { dst_key, packet } => { dst.put(dst_key.as_ref()); dst.put(packet.as_ref()); @@ -291,14 +208,6 @@ impl Frame { dst.put(src_key.as_ref()); dst.put(content.as_ref()); } - Frame::KeepAlive => {} - Frame::NotePreferred { preferred } => { - if *preferred { - dst.put_u8(PREFERRED); - } else { - dst.put_u8(NOT_PREFERRED); - } - } Frame::NodeGone { node_id: peer } => { dst.put(peer.as_ref()); } @@ -322,35 +231,18 @@ impl Frame { dst } + /// Tries to decode a frame received over websockets. + /// + /// Specifically, bytes received from a binary websocket message frame. #[allow(clippy::result_large_err)] - fn from_bytes( - frame_type: FrameType, - content: Bytes, - cache: &KeyCache, - ) -> Result { - let res = match frame_type { - FrameType::ClientInfo => { - if content.len() < PublicKey::LENGTH + Signature::BYTE_SIZE + MAGIC.len() { - return Err(InvalidFrameSnafu.build()); - } - if &content[..MAGIC.len()] != MAGIC.as_bytes() { - return Err(InvalidFrameSnafu.build()); - } + pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { + if bytes.is_empty() { + return Err(TooSmallSnafu.build()); + } + let frame_type = FrameType::from(bytes[0]); + let content = bytes.slice(1..); - let start = MAGIC.len(); - let client_public_key = - cache.key_from_slice(&content[start..start + PublicKey::LENGTH])?; - let start = start + PublicKey::LENGTH; - let signature = - Signature::from_slice(&content[start..start + Signature::BYTE_SIZE])?; - let start = start + Signature::BYTE_SIZE; - let message = content.slice(start..); - Self::ClientInfo { - client_public_key, - message, - signature, - } - } + let res = match frame_type { FrameType::SendPacket => { if content.len() < PublicKey::LENGTH { return Err(InvalidFrameSnafu.build()); @@ -378,23 +270,6 @@ impl Frame { let content = content.slice(PublicKey::LENGTH..); Self::RecvPacket { src_key, content } } - FrameType::KeepAlive => { - if !content.is_empty() { - return Err(InvalidFrameSnafu.build()); - } - Self::KeepAlive - } - FrameType::NotePreferred => { - if content.len() != 1 { - return Err(InvalidFrameSnafu.build()); - } - let preferred = match content[0] { - PREFERRED => true, - NOT_PREFERRED => false, - _ => return Err(InvalidFrameSnafu.build()), - }; - Self::NotePreferred { preferred } - } FrameType::PeerGone => { if content.len() != PublicKey::LENGTH { return Err(InvalidFrameSnafu.build()); @@ -450,34 +325,15 @@ impl Frame { mod tests { use data_encoding::HEXLOWER; use iroh_base::SecretKey; - use n0_snafu::{Result, ResultExt}; + use n0_snafu::Result; use super::*; #[test] fn test_frame_snapshot() -> Result { let client_key = SecretKey::from_bytes(&[42u8; 32]); - let client_info = ClientInfo { - version: PROTOCOL_VERSION, - }; - let message = postcard::to_stdvec(&client_info).context("encode")?; - let signature = client_key.sign(&message); let frames = vec![ - ( - Frame::ClientInfo { - client_public_key: client_key.public(), - message: Bytes::from(message), - signature, - }, - "02 52 45 4c 41 59 f0 9f 94 91 19 7f 6b 23 e1 6c - 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 - 03 34 03 9b fa 8b 3d 36 8d 61 88 e7 7b 22 f2 92 - ab 37 43 5d a8 de 0b c8 cb 84 e2 88 f4 e7 3b 35 - 82 a5 27 31 e9 ff 98 65 46 5c 87 e0 5e 8d 42 7d - f4 22 bb 6e 85 e1 c0 5f 6f 74 98 37 ba a4 a5 c7 - eb a3 23 0d 77 56 99 10 43 0e 03", - ), ( Frame::Health { problem: "Hello? Yes this is dog.".into(), @@ -485,8 +341,6 @@ mod tests { "0e 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 20 69 73 20 64 6f 67 2e", ), - (Frame::KeepAlive, "06"), - (Frame::NotePreferred { preferred: true }, "07 01"), ( Frame::NodeGone { node_id: client_key.public(), @@ -532,7 +386,7 @@ mod tests { for (frame, expected_hex) in frames { let mut bytes = Vec::new(); - frame.encode_for_ws_msg(&mut bytes); + frame.write_to(&mut bytes); let stripped: Vec = expected_hex .chars() .filter_map(|s| { @@ -574,24 +428,10 @@ mod proptests { /// Generates a random valid frame fn frame() -> impl Strategy { - let client_info = (secret_key()).prop_map(|secret_key| { - let info = ClientInfo { - version: PROTOCOL_VERSION, - }; - let msg = postcard::to_stdvec(&info).expect("using default ClientInfo"); - let signature = secret_key.sign(&msg); - Frame::ClientInfo { - client_public_key: secret_key.public(), - message: msg.into(), - signature, - } - }); let send_packet = (key(), data(32)).prop_map(|(dst_key, packet)| Frame::SendPacket { dst_key, packet }); let recv_packet = (key(), data(32)).prop_map(|(src_key, content)| Frame::RecvPacket { src_key, content }); - let keep_alive = Just(Frame::KeepAlive); - let note_preferred = any::().prop_map(|preferred| Frame::NotePreferred { preferred }); let peer_gone = key().prop_map(|peer| Frame::NodeGone { node_id: peer }); let ping = prop::array::uniform8(any::()).prop_map(|data| Frame::Ping { data }); let pong = prop::array::uniform8(any::()).prop_map(|data| Frame::Pong { data }); @@ -602,11 +442,8 @@ mod proptests { try_for, }); prop_oneof![ - client_info, send_packet, recv_packet, - keep_alive, - note_preferred, peer_gone, ping, pong, @@ -617,10 +454,10 @@ mod proptests { proptest! { #[test] - fn frame_ws_roundtrip(frame in frame()) { + fn frame_roundtrip(frame in frame()) { let mut encoded = Vec::new(); - frame.clone().encode_for_ws_msg(&mut encoded); - let decoded = Frame::decode_from_ws_msg(Bytes::from(encoded), &KeyCache::test()).unwrap(); + frame.clone().write_to(&mut encoded); + let decoded = Frame::from_bytes(Bytes::from(encoded), &KeyCache::test()).unwrap(); prop_assert_eq!(frame, decoded); } } diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index c913cdfa608..32519ebfdb1 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -722,11 +722,7 @@ mod tests { dst_key: target, packet: data.clone(), }; - let frame_len = frame - .clone() - .encode_for_ws_msg(BytesMut::new()) - .freeze() - .len(); + let frame_len = frame.clone().write_to(BytesMut::new()).freeze().len(); assert_eq!(frame_len, LIMIT as usize); // Send a frame, it should arrive. diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 8093045fd96..e7a99923075 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -36,7 +36,7 @@ use crate::{ KeyCache, }; use crate::{ - protos::{handshake, io::HandshakeIo}, + protos::{handshake, io::HandshakeIo, relay::MAX_FRAME_SIZE}, server::streams::RateLimited, }; @@ -643,7 +643,9 @@ impl Inner { // Since we already did the HTTP upgrade in the previous step, // we use tokio-websockets to handle this connection // Create a server builder with default config - let builder = tokio_websockets::ServerBuilder::new(); + let builder = tokio_websockets::ServerBuilder::new().limits( + tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE)), + ); // Serve will create a WebSocketStream on an already upgraded connection let websocket = builder.serve(io); diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 47c516f5d5a..f58c234f22d 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -99,7 +99,7 @@ impl Sink for RelayedStream { Pin::new(&mut self.inner) .start_send(tokio_websockets::Message::binary({ let mut buf = BytesMut::new(); - item.encode_for_ws_msg(&mut buf); + item.write_to(&mut buf); tokio_websockets::Payload::from(buf.freeze()) })) .map_err(ws_to_io_err) @@ -144,7 +144,7 @@ impl Stream for RelayedStream { return Poll::Pending; } Poll::Ready(Some( - Frame::decode_from_ws_msg(msg.into_payload().into(), &self.key_cache) + Frame::from_bytes(msg.into_payload().into(), &self.key_cache) .map_err(Into::into), )) } diff --git a/iroh/src/magicsock/relay_actor.rs b/iroh/src/magicsock/relay_actor.rs index 56ee46b50db..8031ff6177a 100644 --- a/iroh/src/magicsock/relay_actor.rs +++ b/iroh/src/magicsock/relay_actor.rs @@ -728,7 +728,7 @@ impl ActiveRelayActor { let problem = problem.as_deref().unwrap_or("unknown"); warn!("Relay server reports problem: {problem}"); } - ReceivedMessage::KeepAlive | ReceivedMessage::ServerRestarting { .. } => { + ReceivedMessage::ServerRestarting { .. } => { trace!("Ignoring {msg:?}") } } From 8a2ebd5dddf524d71319e14bc39ba73f8a0ebbf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 9 Jun 2025 19:40:41 +0200 Subject: [PATCH 18/80] Remove `Frame`, `ReceivedMessage` and `SendMessage` in favor of `ServerToClientMsg` and `ClientToServerMsg` --- iroh-relay/src/client.rs | 27 +- iroh-relay/src/client/conn.rs | 157 ++-------- iroh-relay/src/protos/relay.rs | 420 +++++++++++++++++++-------- iroh-relay/src/server.rs | 27 +- iroh-relay/src/server/client.rs | 78 +++-- iroh-relay/src/server/clients.rs | 25 +- iroh-relay/src/server/http_server.rs | 111 ++++--- iroh-relay/src/server/streams.rs | 33 +-- iroh/src/magicsock/relay_actor.rs | 31 +- 9 files changed, 502 insertions(+), 407 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 0ff544662e1..4bc41af5d9f 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -22,12 +22,15 @@ use tracing::warn; use tracing::{debug, event, trace, Level}; use url::Url; -pub use self::conn::{ReceivedMessage, RecvError, SendError, SendMessage}; +pub use self::conn::{RecvError, SendError}; #[cfg(not(wasm_browser))] use crate::dns::{DnsError, DnsResolver}; use crate::{ http::{Protocol, RELAY_PATH}, - protos::handshake, + protos::{ + handshake, + relay::{ClientToServerMsg, ServerToClientMsg}, + }, KeyCache, }; @@ -283,24 +286,24 @@ impl Client { } impl Stream for Client { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.conn).poll_next(cx) } } -impl Sink for Client { +impl Sink for Client { type Error = SendError; fn poll_ready( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll> { - >::poll_ready(Pin::new(&mut self.conn), cx) + >::poll_ready(Pin::new(&mut self.conn), cx) } - fn start_send(mut self: Pin<&mut Self>, item: SendMessage) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: ClientToServerMsg) -> Result<(), Self::Error> { Pin::new(&mut self.conn).start_send(item) } @@ -308,24 +311,24 @@ impl Sink for Client { mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll> { - >::poll_flush(Pin::new(&mut self.conn), cx) + >::poll_flush(Pin::new(&mut self.conn), cx) } fn poll_close( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll> { - >::poll_close(Pin::new(&mut self.conn), cx) + >::poll_close(Pin::new(&mut self.conn), cx) } } /// The send half of a relay client. #[derive(Debug)] pub struct ClientSink { - sink: SplitSink, + sink: SplitSink, } -impl Sink for ClientSink { +impl Sink for ClientSink { type Error = SendError; fn poll_ready( @@ -335,7 +338,7 @@ impl Sink for ClientSink { Pin::new(&mut self.sink).poll_ready(cx) } - fn start_send(mut self: Pin<&mut Self>, item: SendMessage) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: ClientToServerMsg) -> Result<(), Self::Error> { Pin::new(&mut self.sink).start_send(item) } @@ -369,7 +372,7 @@ impl ClientStream { } impl Stream for ClientStream { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.stream).poll_next(cx) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index d710e9fba4c..2c79b2f1e0c 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -5,15 +5,14 @@ use std::{ io, pin::Pin, - str::Utf8Error, task::{ready, Context, Poll}, }; -use bytes::{Bytes, BytesMut}; -use iroh_base::{NodeId, SecretKey}; -use n0_future::{time::Duration, Sink, Stream}; +use bytes::BytesMut; +use iroh_base::SecretKey; +use n0_future::{Sink, Stream}; use nested_enum_utils::common_fields; -use snafu::{Backtrace, ResultExt, Snafu}; +use snafu::{Backtrace, Snafu}; use tracing::debug; use super::KeyCache; @@ -25,7 +24,10 @@ use crate::{ use crate::{ protos::{ handshake, - relay::{Frame, RecvError as RecvRelayError, SendError as SendRelayError}, + relay::{ + ClientToServerMsg, RecvError as RecvRelayError, SendError as SendRelayError, + ServerToClientMsg, + }, }, MAX_PACKET_SIZE, }; @@ -77,8 +79,6 @@ pub enum RecvError { #[cfg(wasm_browser)] source: ws_stream_wasm::WsErr, }, - #[snafu(display("invalid protocol message encoding"))] - InvalidProtocolMessageEncoding { source: Utf8Error }, #[snafu(display("Unexpected frame received: {frame_type}"))] UnexpectedFrame { frame_type: crate::protos::relay::FrameType, @@ -113,6 +113,18 @@ pub(crate) enum Conn { } impl Conn { + #[cfg(test)] + pub(crate) fn test(io: tokio::io::DuplexStream) -> Self { + use crate::protos::relay::MAX_FRAME_SIZE; + + Self::Ws { + conn: tokio_websockets::ClientBuilder::new() + .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))) + .take_over(MaybeTlsStream::Test(io)), + key_cache: KeyCache::test(), + } + } + /// Constructs a new websocket connection, including the initial server handshake. #[cfg(wasm_browser)] pub(crate) async fn new_ws_browser( @@ -154,7 +166,7 @@ impl Conn { } impl Stream for Conn { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match *self { @@ -176,9 +188,9 @@ impl Stream for Conn { ); return Poll::Pending; } - let frame = Frame::from_bytes(msg.into_payload().into(), key_cache)?; - let message = ReceivedMessage::try_from(frame); - Poll::Ready(Some(message)) + let message = + ServerToClientMsg::from_bytes(msg.into_payload().into(), key_cache); + Poll::Ready(Some(message.map_err(Into::into))) } Some(Err(e)) => Poll::Ready(Some(Err(e.into()))), None => Poll::Ready(None), @@ -202,7 +214,7 @@ impl Stream for Conn { } } -impl Sink for Conn { +impl Sink for Conn { type Error = SendError; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -216,12 +228,12 @@ impl Sink for Conn { } } - fn start_send(mut self: Pin<&mut Self>, item: SendMessage) -> Result<(), Self::Error> { - if let SendMessage::SendPacket(_, bytes) = &item { - let size = bytes.len(); + fn start_send(mut self: Pin<&mut Self>, frame: ClientToServerMsg) -> Result<(), Self::Error> { + // TODO(matheus23): Check this in send message construction instead + if let ClientToServerMsg::SendPacket { packet, .. } = &frame { + let size = packet.len(); snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); } - let frame = Frame::from(item); match *self { #[cfg(not(wasm_browser))] Self::Ws { ref mut conn, .. } => Pin::new(conn) @@ -233,7 +245,9 @@ impl Sink for Conn { .map_err(Into::into), #[cfg(wasm_browser)] Self::WsBrowser { ref mut conn, .. } => Pin::new(conn) - .start_send(ws_stream_wasm::WsMessage::Binary(frame.encode_for_ws_msg())) + .start_send(ws_stream_wasm::WsMessage::Binary( + frame.write_to(Vec::new()), + )) .map_err(Into::into), } } @@ -260,110 +274,3 @@ impl Sink for Conn { } } } - -/// The messages received from a framed relay stream. -/// -/// This is a type-validated version of the `Frame`s on the `RelayCodec`. -#[derive(derive_more::Debug, Clone)] -pub enum ReceivedMessage { - /// Represents an incoming packet. - ReceivedPacket { - /// The [`NodeId`] of the packet sender. - remote_node_id: NodeId, - /// The received packet bytes. - #[debug(skip)] - data: Bytes, // TODO: ref - }, - /// Indicates that the client identified by the underlying public key had previously sent you a - /// packet but has now disconnected from the server. - NodeGone(NodeId), - /// Request from a client or server to reply to the - /// other side with a [`ReceivedMessage::Pong`] with the given payload. - Ping([u8; 8]), - /// Reply to a [`ReceivedMessage::Ping`] from a client or server - /// with the payload sent previously in the ping. - Pong([u8; 8]), - /// A one-way message from server to client, declaring the connection health state. - Health { - /// If set, is a description of why the connection is unhealthy. - /// - /// If `None` means the connection is healthy again. - /// - /// The default condition is healthy, so the server doesn't broadcast a [`ReceivedMessage::Health`] - /// until a problem exists. - problem: Option, - }, - /// A one-way message from server to client, advertising that the server is restarting. - ServerRestarting { - /// An advisory duration that the client should wait before attempting to reconnect. - /// It might be zero. It exists for the server to smear out the reconnects. - reconnect_in: Duration, - /// An advisory duration for how long the client should attempt to reconnect - /// before giving up and proceeding with its normal connection failure logic. The interval - /// between retries is undefined for now. A server should not send a TryFor duration more - /// than a few seconds. - try_for: Duration, - }, -} - -impl TryFrom for ReceivedMessage { - type Error = RecvError; - - fn try_from(frame: Frame) -> std::result::Result { - match frame { - Frame::NodeGone { node_id } => Ok(ReceivedMessage::NodeGone(node_id)), - Frame::RecvPacket { src_key, content } => { - let packet = ReceivedMessage::ReceivedPacket { - remote_node_id: src_key, - data: content, - }; - Ok(packet) - } - Frame::Ping { data } => Ok(ReceivedMessage::Ping(data)), - Frame::Pong { data } => Ok(ReceivedMessage::Pong(data)), - Frame::Health { problem } => { - let problem = std::str::from_utf8(&problem) - .context(InvalidProtocolMessageEncodingSnafu)? - .to_owned(); - let problem = Some(problem); - Ok(ReceivedMessage::Health { problem }) - } - Frame::Restarting { - reconnect_in, - try_for, - } => { - let reconnect_in = Duration::from_millis(reconnect_in as u64); - let try_for = Duration::from_millis(try_for as u64); - Ok(ReceivedMessage::ServerRestarting { - reconnect_in, - try_for, - }) - } - _ => Err(UnexpectedFrameSnafu { - frame_type: frame.typ(), - } - .build()), - } - } -} - -/// Messages we can send to a relay server. -#[derive(Debug)] -pub enum SendMessage { - /// Send a packet of data to the [`NodeId`]. - SendPacket(NodeId, Bytes), - /// Sends a ping message to the connected relay server. - Ping([u8; 8]), - /// Sends a pong message to the connected relay server. - Pong([u8; 8]), -} - -impl From for Frame { - fn from(source: SendMessage) -> Self { - match source { - SendMessage::SendPacket(dst_key, packet) => Frame::SendPacket { dst_key, packet }, - SendMessage::Ping(data) => Frame::Ping { data }, - SendMessage::Pong(data) => Frame::Pong { data }, - } - } -} diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 74b400fb762..1c376b27c32 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -13,14 +13,14 @@ //! * server then sends `FrameType::RecvPacket` to recipient use bytes::{BufMut, Bytes}; -use iroh_base::{PublicKey, SignatureError}; +use iroh_base::{NodeId, SignatureError}; #[cfg(feature = "server")] use n0_future::time::Duration; use n0_future::{time, Sink, SinkExt}; use nested_enum_utils::common_fields; use postcard::experimental::max_size::MaxSize; use serde::{Deserialize, Serialize}; -use snafu::{Backtrace, Snafu}; +use snafu::{Backtrace, ResultExt, Snafu}; use crate::{client::conn::SendError as ConnSendError, KeyCache}; @@ -64,7 +64,7 @@ pub enum FrameType { /// to B so B can forget that a reverse path exists on that connection to get back to A /// /// 32B pub key of peer that's gone - PeerGone = 8, + NodeGone = 8, /// Frames 9-11 concern meshing, which we have eliminated from our version of the protocol. /// Messages with these frames will be ignored. /// 8 byte ping payload, to be echoed back in FrameType::Pong @@ -146,6 +146,8 @@ pub enum RecvError { InvalidFrame {}, #[snafu(display("Invalid frame type: {frame_type}"))] InvalidFrameType { frame_type: FrameType }, + #[snafu(display("invalid protocol message encoding"))] + InvalidProtocolMessageEncoding { source: std::str::Utf8Error }, #[snafu(display("Too few bytes"))] TooSmall {}, } @@ -155,9 +157,9 @@ pub enum RecvError { /// /// Does not flush. #[cfg(feature = "server")] -pub(crate) async fn write_frame + Unpin>( +pub(crate) async fn write_frame + Unpin>( mut writer: S, - frame: Frame, + frame: ServerToClientMsg, timeout: Option, ) -> Result<(), SendError> { if let Some(duration) = timeout { @@ -169,28 +171,80 @@ pub(crate) async fn write_frame + Unpin>( Ok(()) } -/// The frames in the [`RelayCodec`]. +/// TODO(matheus23): Docs +/// The messages received from a framed relay stream. +/// +/// This is a type-validated version of the `Frame`s on the `RelayCodec`. +#[derive(derive_more::Debug, Clone, PartialEq, Eq)] +pub enum ServerToClientMsg { + /// Represents an incoming packet. + ReceivedPacket { + /// The [`NodeId`] of the packet sender. + remote_node_id: NodeId, + /// The received packet bytes. + #[debug(skip)] + data: Bytes, + }, + /// Indicates that the client identified by the underlying public key had previously sent you a + /// packet but has now disconnected from the server. + NodeGone(NodeId), + /// A one-way message from server to client, declaring the connection health state. + Health { + /// If set, is a description of why the connection is unhealthy. + /// + /// If `None` means the connection is healthy again. + /// + /// The default condition is healthy, so the server doesn't broadcast a [`ReceivedMessage::Health`] + /// until a problem exists. + problem: Option, + }, + /// A one-way message from server to client, advertising that the server is restarting. + Restarting { + /// An advisory duration that the client should wait before attempting to reconnect. + /// It might be zero. It exists for the server to smear out the reconnects. + reconnect_in: Duration, + /// An advisory duration for how long the client should attempt to reconnect + /// before giving up and proceeding with its normal connection failure logic. The interval + /// between retries is undefined for now. A server should not send a TryFor duration more + /// than a few seconds. + try_for: Duration, + }, + /// TODO(matheus23) fix docs + /// Request from a client or server to reply to the + /// other side with a [`ReceivedMessage::Pong`] with the given payload. + Ping([u8; 8]), + /// TODO(matheus23) fix docs + /// Reply to a [`ReceivedMessage::Ping`] from a client or server + /// with the payload sent previously in the ping. + Pong([u8; 8]), +} + +/// TODO(matheus23): Docs #[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum Frame { - SendPacket { dst_key: PublicKey, packet: Bytes }, - RecvPacket { src_key: PublicKey, content: Bytes }, - NodeGone { node_id: PublicKey }, - Ping { data: [u8; 8] }, - Pong { data: [u8; 8] }, - Health { problem: Bytes }, - Restarting { reconnect_in: u32, try_for: u32 }, +pub enum ClientToServerMsg { + /// TODO + Ping([u8; 8]), + /// TODO + Pong([u8; 8]), + /// TODO + SendPacket { + /// TODO + dst_key: NodeId, + /// TODO + packet: Bytes, + }, } -impl Frame { - pub(crate) fn typ(&self) -> FrameType { +impl ServerToClientMsg { + /// TODO(matheus23): docs + pub fn typ(&self) -> FrameType { match self { - Frame::SendPacket { .. } => FrameType::SendPacket, - Frame::RecvPacket { .. } => FrameType::RecvPacket, - Frame::NodeGone { .. } => FrameType::PeerGone, - Frame::Ping { .. } => FrameType::Ping, - Frame::Pong { .. } => FrameType::Pong, - Frame::Health { .. } => FrameType::Health, - Frame::Restarting { .. } => FrameType::Restarting, + Self::ReceivedPacket { .. } => FrameType::RecvPacket, + Self::NodeGone { .. } => FrameType::NodeGone, + Self::Ping { .. } => FrameType::Ping, + Self::Pong { .. } => FrameType::Pong, + Self::Health { .. } => FrameType::Health, + Self::Restarting { .. } => FrameType::Restarting, } } @@ -200,32 +254,33 @@ impl Frame { pub(crate) fn write_to(&self, mut dst: O) -> O { dst.put_u8(self.typ().into()); match self { - Frame::SendPacket { dst_key, packet } => { - dst.put(dst_key.as_ref()); - dst.put(packet.as_ref()); - } - Frame::RecvPacket { src_key, content } => { + Self::ReceivedPacket { + remote_node_id: src_key, + data: content, + } => { dst.put(src_key.as_ref()); dst.put(content.as_ref()); } - Frame::NodeGone { node_id: peer } => { - dst.put(peer.as_ref()); + Self::NodeGone(node_id) => { + dst.put(node_id.as_ref()); } - Frame::Ping { data } => { + Self::Ping(data) => { dst.put(&data[..]); } - Frame::Pong { data } => { + Self::Pong(data) => { dst.put(&data[..]); } - Frame::Health { problem } => { - dst.put(problem.as_ref()); + Self::Health { problem } => { + if let Some(problem) = problem { + dst.put(problem.as_ref()); + } } - Frame::Restarting { + Self::Restarting { reconnect_in, try_for, } => { - dst.put_u32(*reconnect_in); - dst.put_u32(*try_for); + dst.put_u32(reconnect_in.as_millis() as u32); + dst.put_u32(try_for.as_millis() as u32); } } dst @@ -243,39 +298,29 @@ impl Frame { let content = bytes.slice(1..); let res = match frame_type { - FrameType::SendPacket => { - if content.len() < PublicKey::LENGTH { - return Err(InvalidFrameSnafu.build()); - } - let frame_len = content.len() - PublicKey::LENGTH; - if frame_len > MAX_PACKET_SIZE { - return Err(FrameTooLargeSnafu { frame_len }.build()); - } - - let dst_key = cache.key_from_slice(&content[..PublicKey::LENGTH])?; - let packet = content.slice(PublicKey::LENGTH..); - Self::SendPacket { dst_key, packet } - } FrameType::RecvPacket => { - if content.len() < PublicKey::LENGTH { + if content.len() < NodeId::LENGTH { return Err(InvalidFrameSnafu.build()); } - let frame_len = content.len() - PublicKey::LENGTH; + let frame_len = content.len() - NodeId::LENGTH; if frame_len > MAX_PACKET_SIZE { return Err(FrameTooLargeSnafu { frame_len }.build()); } - let src_key = cache.key_from_slice(&content[..PublicKey::LENGTH])?; - let content = content.slice(PublicKey::LENGTH..); - Self::RecvPacket { src_key, content } + let src_key = cache.key_from_slice(&content[..NodeId::LENGTH])?; + let content = content.slice(NodeId::LENGTH..); + Self::ReceivedPacket { + remote_node_id: src_key, + data: content, + } } - FrameType::PeerGone => { - if content.len() != PublicKey::LENGTH { + FrameType::NodeGone => { + if content.len() != NodeId::LENGTH { return Err(InvalidFrameSnafu.build()); } - let peer = cache.key_from_slice(&content[..32])?; - Self::NodeGone { node_id: peer } + let node_id = cache.key_from_slice(&content[..32])?; + Self::NodeGone(node_id) } FrameType::Ping => { if content.len() != 8 { @@ -283,7 +328,7 @@ impl Frame { } let mut data = [0u8; 8]; data.copy_from_slice(&content[..8]); - Self::Ping { data } + Self::Ping(data) } FrameType::Pong => { if content.len() != 8 { @@ -291,9 +336,16 @@ impl Frame { } let mut data = [0u8; 8]; data.copy_from_slice(&content[..8]); - Self::Pong { data } + Self::Pong(data) + } + FrameType::Health => { + let problem = std::str::from_utf8(&content) + .context(InvalidProtocolMessageEncodingSnafu)? + .to_owned(); + // TODO(matheus23): Actually encode/decode the option + let problem = Some(problem); + Self::Health { problem } } - FrameType::Health => Self::Health { problem: content }, FrameType::Restarting => { if content.len() != 4 + 4 { return Err(InvalidFrameSnafu.build()); @@ -308,6 +360,8 @@ impl Frame { .try_into() .map_err(|_| InvalidFrameSnafu.build())?, ); + let reconnect_in = Duration::from_millis(reconnect_in as u64); + let try_for = Duration::from_millis(try_for as u64); Self::Restarting { reconnect_in, try_for, @@ -321,6 +375,84 @@ impl Frame { } } +impl ClientToServerMsg { + pub(crate) fn typ(&self) -> FrameType { + match self { + Self::SendPacket { .. } => FrameType::SendPacket, + Self::Ping { .. } => FrameType::Ping, + Self::Pong { .. } => FrameType::Pong, + } + } + + /// Encodes this frame for sending over websockets. + /// + /// Specifically meant for being put into a binary websocket message frame. + pub(crate) fn write_to(&self, mut dst: O) -> O { + dst.put_u8(self.typ().into()); + match self { + Self::SendPacket { dst_key, packet } => { + dst.put(dst_key.as_ref()); + dst.put(packet.as_ref()); + } + Self::Ping(data) => { + dst.put(&data[..]); + } + Self::Pong(data) => { + dst.put(&data[..]); + } + } + dst + } + + /// Tries to decode a frame received over websockets. + /// + /// Specifically, bytes received from a binary websocket message frame. + #[allow(clippy::result_large_err)] + pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { + if bytes.is_empty() { + return Err(TooSmallSnafu.build()); + } + let frame_type = FrameType::from(bytes[0]); + let content = bytes.slice(1..); + + let res = match frame_type { + FrameType::SendPacket => { + if content.len() < NodeId::LENGTH { + return Err(InvalidFrameSnafu.build()); + } + let frame_len = content.len() - NodeId::LENGTH; + if frame_len > MAX_PACKET_SIZE { + return Err(FrameTooLargeSnafu { frame_len }.build()); + } + + let dst_key = cache.key_from_slice(&content[..NodeId::LENGTH])?; + let packet = content.slice(NodeId::LENGTH..); + Self::SendPacket { dst_key, packet } + } + FrameType::Ping => { + if content.len() != 8 { + return Err(InvalidFrameSnafu.build()); + } + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Ping(data) + } + FrameType::Pong => { + if content.len() != 8 { + return Err(InvalidFrameSnafu.build()); + } + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Pong(data) + } + _ => { + return Err(InvalidFrameTypeSnafu { frame_type }.build()); + } + }; + Ok(res) + } +} + #[cfg(test)] mod tests { use data_encoding::HEXLOWER; @@ -329,77 +461,97 @@ mod tests { use super::*; + fn check_expected_bytes(frames: Vec<(Vec, &str)>) { + for (bytes, expected_hex) in frames { + let stripped: Vec = expected_hex + .chars() + .filter_map(|s| { + if s.is_ascii_whitespace() { + None + } else { + Some(s as u8) + } + }) + .collect(); + let expected_bytes = HEXLOWER.decode(&stripped).unwrap(); + assert_eq!(bytes, expected_bytes); + } + } + #[test] - fn test_frame_snapshot() -> Result { + fn test_server_client_frames_snapshot() -> Result { let client_key = SecretKey::from_bytes(&[42u8; 32]); - let frames = vec![ + check_expected_bytes(vec![ ( - Frame::Health { - problem: "Hello? Yes this is dog.".into(), - }, + ServerToClientMsg::Health { + problem: Some("Hello? Yes this is dog.".into()), + } + .write_to(Vec::new()), "0e 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 20 69 73 20 64 6f 67 2e", ), ( - Frame::NodeGone { - node_id: client_key.public(), - }, + ServerToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), "08 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61", ), ( - Frame::Ping { data: [42u8; 8] }, + ServerToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), "0c 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - Frame::Pong { data: [42u8; 8] }, + ServerToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), "0d 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - Frame::RecvPacket { - src_key: client_key.public(), - content: "Hello World!".into(), - }, + ServerToClientMsg::ReceivedPacket { + remote_node_id: client_key.public(), + data: "Hello World!".into(), + } + .write_to(Vec::new()), "05 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", ), ( - Frame::SendPacket { + ServerToClientMsg::Restarting { + reconnect_in: Duration::from_millis(10), + try_for: Duration::from_millis(20), + } + .write_to(Vec::new()), + "0f 00 00 00 0a 00 00 00 14", + ), + ]); + + Ok(()) + } + + #[test] + fn test_client_server_frames_snapshot() -> Result { + let client_key = SecretKey::from_bytes(&[42u8; 32]); + + check_expected_bytes(vec![ + ( + ClientToServerMsg::Ping([42u8; 8]).write_to(Vec::new()), + "0c 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + ClientToServerMsg::Pong([42u8; 8]).write_to(Vec::new()), + "0d 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + ClientToServerMsg::SendPacket { dst_key: client_key.public(), packet: "Goodbye!".into(), - }, + } + .write_to(Vec::new()), "04 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 47 6f 6f 64 62 79 65 21", ), - ( - Frame::Restarting { - reconnect_in: 10, - try_for: 20, - }, - "0f 00 00 00 0a 00 00 00 14", - ), - ]; - - for (frame, expected_hex) in frames { - let mut bytes = Vec::new(); - frame.write_to(&mut bytes); - let stripped: Vec = expected_hex - .chars() - .filter_map(|s| { - if s.is_ascii_whitespace() { - None - } else { - Some(s as u8) - } - }) - .collect(); - let expected_bytes = HEXLOWER.decode(&stripped).unwrap(); - assert_eq!(bytes, expected_bytes); - } + ]); Ok(()) } @@ -407,6 +559,7 @@ mod tests { #[cfg(test)] mod proptests { + use bytes::BytesMut; use iroh_base::SecretKey; use proptest::prelude::*; @@ -416,7 +569,7 @@ mod proptests { prop::array::uniform32(any::()).prop_map(SecretKey::from) } - fn key() -> impl Strategy { + fn key() -> impl Strategy { secret_key().prop_map(|key| key.public()) } @@ -427,37 +580,46 @@ mod proptests { } /// Generates a random valid frame - fn frame() -> impl Strategy { - let send_packet = - (key(), data(32)).prop_map(|(dst_key, packet)| Frame::SendPacket { dst_key, packet }); + fn server_client_frame() -> impl Strategy { let recv_packet = - (key(), data(32)).prop_map(|(src_key, content)| Frame::RecvPacket { src_key, content }); - let peer_gone = key().prop_map(|peer| Frame::NodeGone { node_id: peer }); - let ping = prop::array::uniform8(any::()).prop_map(|data| Frame::Ping { data }); - let pong = prop::array::uniform8(any::()).prop_map(|data| Frame::Pong { data }); - let health = data(0).prop_map(|problem| Frame::Health { problem }); - let restarting = - (any::(), any::()).prop_map(|(reconnect_in, try_for)| Frame::Restarting { - reconnect_in, - try_for, + (key(), data(32)).prop_map(|(src_key, content)| ServerToClientMsg::ReceivedPacket { + remote_node_id: src_key, + data: content, }); - prop_oneof![ - send_packet, - recv_packet, - peer_gone, - ping, - pong, - health, - restarting, - ] + let node_gone = key().prop_map(|node_id| ServerToClientMsg::NodeGone(node_id)); + let ping = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Ping); + let pong = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Pong); + // TODO(matheus23): Actually fix these + let health = data(0).prop_map(|_problem| ServerToClientMsg::Health { problem: None }); + let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { + ServerToClientMsg::Restarting { + reconnect_in: Duration::from_millis(reconnect_in.into()), + try_for: Duration::from_millis(try_for.into()), + } + }); + prop_oneof![recv_packet, node_gone, ping, pong, health, restarting] + } + + fn client_server_frame() -> impl Strategy { + let send_packet = (key(), data(32)) + .prop_map(|(dst_key, packet)| ClientToServerMsg::SendPacket { dst_key, packet }); + let ping = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Ping); + let pong = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Pong); + prop_oneof![send_packet, ping, pong] } proptest! { #[test] - fn frame_roundtrip(frame in frame()) { - let mut encoded = Vec::new(); - frame.clone().write_to(&mut encoded); - let decoded = Frame::from_bytes(Bytes::from(encoded), &KeyCache::test()).unwrap(); + fn server_client_frame_roundtrip(frame in server_client_frame()) { + let encoded = frame.clone().write_to(BytesMut::new()).freeze(); + let decoded = ServerToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); + prop_assert_eq!(frame, decoded); + } + + #[test] + fn client_server_frame_roundtrip(frame in client_server_frame()) { + let encoded = frame.clone().write_to(BytesMut::new()).freeze(); + let decoded = ClientToServerMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); prop_assert_eq!(frame, decoded); } } diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index cb096856c41..afdfa207f2b 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -890,10 +890,13 @@ mod tests { NO_CONTENT_CHALLENGE_HEADER, NO_CONTENT_RESPONSE_HEADER, }; use crate::{ - client::{conn::ReceivedMessage, ClientBuilder, SendMessage}, + client::ClientBuilder, dns::DnsResolver, http::Protocol, - protos, + protos::{ + self, + relay::{ClientToServerMsg, ServerToClientMsg}, + }, }; async fn spawn_local_relay() -> std::result::Result { @@ -918,11 +921,14 @@ mod tests { client_b: &mut crate::client::Client, b_key: NodeId, msg: Bytes, - ) -> Result { + ) -> Result { // try resend 10 times for _ in 0..10 { client_a - .send(SendMessage::SendPacket(b_key, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: b_key, + packet: msg.clone(), + }) .await?; let Ok(res) = tokio::time::timeout(Duration::from_millis(500), client_b.next()).await else { @@ -1040,7 +1046,7 @@ mod tests { // send message from a to b let msg = Bytes::from("hello, b"); let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?; - let ReceivedMessage::ReceivedPacket { + let ServerToClientMsg::ReceivedPacket { remote_node_id, data, } = res @@ -1056,7 +1062,7 @@ mod tests { let msg = Bytes::from("howdy, a"); let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?; - let ReceivedMessage::ReceivedPacket { + let ServerToClientMsg::ReceivedPacket { remote_node_id, data, } = res @@ -1148,7 +1154,7 @@ mod tests { // the next message should be the rejection of the connection tokio::time::timeout(Duration::from_millis(500), async move { match client_a.next().await.unwrap().unwrap() { - ReceivedMessage::Health { problem } => { + ServerToClientMsg::Health { problem } => { assert_eq!(problem, Some("not authenticated".to_string())); } msg => { @@ -1183,7 +1189,7 @@ mod tests { let msg = Bytes::from("hello, c"); let res = try_send_recv(&mut client_b, &mut client_c, c_key, msg.clone()).await?; - if let ReceivedMessage::ReceivedPacket { + if let ServerToClientMsg::ReceivedPacket { remote_node_id, data, } = res @@ -1224,7 +1230,10 @@ mod tests { let msg = Bytes::from("hello, b"); for _i in 0..1000 { client_a - .send(SendMessage::SendPacket(b_key, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: b_key, + packet: msg.clone(), + }) .await?; } Ok(()) diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 32519ebfdb1..414d22e6d7e 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -19,7 +19,10 @@ use tracing::{debug, trace, warn, Instrument}; use crate::{ protos::{ disco, - relay::{write_frame, Frame, SendError as SendRelayError, PING_INTERVAL}, + relay::{ + write_frame, ClientToServerMsg, SendError as SendRelayError, ServerToClientMsg, + PING_INTERVAL, + }, }, server::{ clients::Clients, @@ -187,12 +190,6 @@ pub enum HandleFrameError { }, #[snafu(transparent)] Relay { source: SendRelayError }, - #[snafu(display("Server issue: {problem:?}"))] - Health { - problem: Bytes, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, - }, } /// Run error @@ -364,7 +361,7 @@ impl Actor { node_id = self.node_gone.recv() => { let node_id = node_id.ok_or(NodeGoneDropSnafu.build())?; trace!("node_id gone: {:?}", node_id); - self.write_frame(Frame::NodeGone { node_id }).await.context(NodeGoneWriteFrameSnafu)?; + self.write_frame(ServerToClientMsg::NodeGone(node_id)).await.context(NodeGoneWriteFrameSnafu)?; } _ = self.ping_tracker.timeout() => { trace!("pong timed out"); @@ -375,7 +372,7 @@ impl Actor { // new interval ping_interval.reset_after(next_interval()); let data = self.ping_tracker.new_ping(); - self.write_frame(Frame::Ping { data }).await.context(KeepAliveWriteFrameSnafu)?; + self.write_frame(ServerToClientMsg::Ping(data)).await.context(KeepAliveWriteFrameSnafu)?; } } @@ -390,7 +387,7 @@ impl Actor { /// Writes the given frame to the connection. /// /// Errors if the send does not happen within the `timeout` duration - async fn write_frame(&mut self, frame: Frame) -> Result<(), SendRelayError> { + async fn write_frame(&mut self, frame: ServerToClientMsg) -> Result<(), SendRelayError> { write_frame(&mut self.stream, frame, Some(self.timeout)).await } @@ -405,8 +402,11 @@ impl Actor { if let Ok(len) = content.len().try_into() { self.metrics.bytes_sent.inc_by(len); } - self.write_frame(Frame::RecvPacket { src_key, content }) - .await + self.write_frame(ServerToClientMsg::ReceivedPacket { + remote_node_id: src_key, + data: content, + }) + .await } async fn send_packet(&mut self, packet: Packet) -> Result<(), SendRelayError> { @@ -440,7 +440,7 @@ impl Actor { /// Handles frame read results. async fn handle_frame( &mut self, - maybe_frame: Option>, + maybe_frame: Option>, ) -> Result<(), HandleFrameError> { trace!(?maybe_frame, "handle incoming frame"); let frame = match maybe_frame { @@ -449,7 +449,7 @@ impl Actor { }; match frame { - Frame::SendPacket { dst_key, packet } => { + ClientToServerMsg::SendPacket { dst_key, packet } => { let packet_len = packet.len(); if let Err(err @ ForwardPacketError { .. }) = self.handle_frame_send_packet(dst_key, packet) @@ -458,19 +458,15 @@ impl Actor { } self.metrics.bytes_recv.inc_by(packet_len as u64); } - Frame::Ping { data } => { + ClientToServerMsg::Ping(data) => { self.metrics.got_ping.inc(); // TODO: add rate limiter - self.write_frame(Frame::Pong { data }).await?; + self.write_frame(ServerToClientMsg::Pong(data)).await?; self.metrics.sent_pong.inc(); } - Frame::Pong { data } => { + ClientToServerMsg::Pong(data) => { self.ping_tracker.pong_received(data); } - Frame::Health { problem } => return Err(HealthSnafu { problem }.build()), - _ => { - self.metrics.unknown_frames.inc(); - } } Ok(()) } @@ -561,15 +557,15 @@ mod tests { use tracing_test::traced_test; use super::*; - use crate::protos::relay::FrameType; + use crate::{client::conn::Conn, protos::relay::FrameType}; async fn recv_frame< E: snafu::Error + Sync + Send + 'static, - S: Stream> + Unpin, + S: Stream> + Unpin, >( frame_type: FrameType, mut stream: S, - ) -> Result { + ) -> Result { match stream.next().await { Some(Ok(frame)) => { if frame_type != frame.typ() { @@ -595,8 +591,8 @@ mod tests { let node_id = SecretKey::generate(rand::thread_rng()).public(); let (io, io_rw) = tokio::io::duplex(1024); - let mut io_rw = RelayedStream::test_client(io_rw); - let stream = RelayedStream::test_server(io); + let mut io_rw = Conn::test(io_rw); + let stream = RelayedStream::test(io); let clients = Clients::default(); let metrics = Arc::new(Metrics::default()); @@ -632,9 +628,9 @@ mod tests { let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; assert_eq!( frame, - Frame::RecvPacket { - src_key: node_id, - content: data.to_vec().into() + ServerToClientMsg::ReceivedPacket { + remote_node_id: node_id, + data: data.to_vec().into() } ); @@ -647,29 +643,29 @@ mod tests { let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; assert_eq!( frame, - Frame::RecvPacket { - src_key: node_id, - content: data.to_vec().into() + ServerToClientMsg::ReceivedPacket { + remote_node_id: node_id, + data: data.to_vec().into() } ); // send peer_gone println!("send peer gone"); peer_gone_s.send(node_id).await.context("send")?; - let frame = recv_frame(FrameType::PeerGone, &mut io_rw).await.e()?; - assert_eq!(frame, Frame::NodeGone { node_id }); + let frame = recv_frame(FrameType::NodeGone, &mut io_rw).await.e()?; + assert_eq!(frame, ServerToClientMsg::NodeGone(node_id)); // Read tests println!("--read"); // send ping, expect pong let data = b"pingpong"; - write_frame(&mut io_rw, Frame::Ping { data: *data }, None).await?; + io_rw.send(ClientToServerMsg::Ping(*data)).await?; // recv pong println!(" recv pong"); let frame = recv_frame(FrameType::Pong, &mut io_rw).await?; - assert_eq!(frame, Frame::Pong { data: *data }); + assert_eq!(frame, ServerToClientMsg::Pong(*data)); let target = SecretKey::generate(rand::thread_rng()).public(); @@ -677,7 +673,7 @@ mod tests { println!(" send packet"); let data = b"hello world!"; io_rw - .send(Frame::SendPacket { + .send(ClientToServerMsg::SendPacket { dst_key: target, packet: Bytes::from_static(data), }) @@ -691,7 +687,7 @@ mod tests { disco_data.extend_from_slice(target.as_bytes()); disco_data.extend_from_slice(data); io_rw - .send(Frame::SendPacket { + .send(ClientToServerMsg::SendPacket { dst_key: target, packet: disco_data.clone().into(), }) @@ -711,14 +707,14 @@ mod tests { // Build the rate limited stream. let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _); - let mut frame_writer = RelayedStream::test_client(io_write); + let mut frame_writer = Conn::test(io_write); // Rate limiter allowing LIMIT bytes/s - let mut stream = RelayedStream::test_server_limited(io_read, LIMIT / 10, LIMIT); + let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT); // Prepare a frame to send, assert its size. let data = Bytes::from_static(b"hello world!1eins"); let target = SecretKey::generate(rand::thread_rng()).public(); - let frame = Frame::SendPacket { + let frame = ClientToServerMsg::SendPacket { dst_key: target, packet: data.clone(), }; diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 3848d021fae..7fd087465a9 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -199,17 +199,18 @@ mod tests { use super::*; use crate::{ - protos::relay::{Frame, FrameType}, + client::conn::Conn, + protos::relay::{FrameType, ServerToClientMsg}, server::streams::RelayedStream, }; async fn recv_frame< E: snafu::Error + Sync + Send + 'static, - S: Stream> + Unpin, + S: Stream> + Unpin, >( frame_type: FrameType, mut stream: S, - ) -> Result { + ) -> Result { match stream.next().await { Some(Ok(frame)) => { if frame_type != frame.typ() { @@ -226,16 +227,16 @@ mod tests { } } - fn test_client_builder(key: NodeId) -> (Config, RelayedStream) { + fn test_client_builder(key: NodeId) -> (Config, Conn) { let (server, client) = tokio::io::duplex(1024); ( Config { node_id: key, - stream: RelayedStream::test_client(client), + stream: RelayedStream::test(server), write_timeout: Duration::from_secs(1), channel_capacity: 10, }, - RelayedStream::test_server(server), + Conn::test(client), ) } @@ -256,9 +257,9 @@ mod tests { let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; assert_eq!( frame, - Frame::RecvPacket { - src_key: b_key, - content: data.to_vec().into(), + ServerToClientMsg::ReceivedPacket { + remote_node_id: b_key, + data: data.to_vec().into(), } ); @@ -267,9 +268,9 @@ mod tests { let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; assert_eq!( frame, - Frame::RecvPacket { - src_key: b_key, - content: data.to_vec().into(), + ServerToClientMsg::ReceivedPacket { + remote_node_id: b_key, + data: data.to_vec().into(), } ); diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index e7a99923075..32f73675c0c 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -26,7 +26,7 @@ use super::{clients::Clients, AccessConfig, SpawnError}; use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, http::{Protocol, RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, - protos::relay::{Frame, PER_CLIENT_SEND_QUEUE_DEPTH}, + protos::relay::{ServerToClientMsg, PER_CLIENT_SEND_QUEUE_DEPTH}, server::{ client::Config, metrics::Metrics, @@ -666,9 +666,10 @@ impl Inner { }; trace!("accept: checking access: {:?}", self.access); + // TODO(matheus23): Maybe use new frame? if !self.access.is_allowed(client_key).await { - io.send(Frame::Health { - problem: Bytes::from_static(b"not authenticated"), + io.send(ServerToClientMsg::Health { + problem: Some("not authenticated".into()), }) .await?; io.flush().await?; @@ -865,11 +866,9 @@ mod tests { use super::*; use crate::{ - client::{ - conn::{Conn, ReceivedMessage, SendMessage}, - Client, ClientBuilder, ConnectError, - }, + client::{conn::Conn, Client, ClientBuilder, ConnectError}, dns::DnsResolver, + protos::relay::ClientToServerMsg, }; pub(crate) fn make_tls_config() -> TlsConfig { @@ -927,19 +926,22 @@ mod tests { info!("created client {b_key:?}"); info!("ping a"); - client_a.send(SendMessage::Ping([1u8; 8])).await?; + client_a.send(ClientToServerMsg::Ping([1u8; 8])).await?; let pong = client_a.next().await.expect("eos")?; - assert!(matches!(pong, ReceivedMessage::Pong(_))); + assert!(matches!(pong, ServerToClientMsg::Pong { .. })); info!("ping b"); - client_b.send(SendMessage::Ping([2u8; 8])).await?; + client_b.send(ClientToServerMsg::Ping([2u8; 8])).await?; let pong = client_b.next().await.expect("eos")?; - assert!(matches!(pong, ReceivedMessage::Pong(_))); + assert!(matches!(pong, ServerToClientMsg::Pong { .. })); info!("sending message from a to b"); let msg = Bytes::from_static(b"hi there, client b!"); client_a - .send(SendMessage::SendPacket(b_key, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: b_key, + packet: msg.clone(), + }) .await?; info!("waiting for message from a on b"); let (got_key, got_msg) = @@ -950,7 +952,10 @@ mod tests { info!("sending message from b to a"); let msg = Bytes::from_static(b"right back at ya, client b!"); client_b - .send(SendMessage::SendPacket(a_key, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: a_key, + packet: msg.clone(), + }) .await?; info!("waiting for message b on a"); let (got_key, got_msg) = @@ -979,7 +984,7 @@ mod tests { } fn process_msg( - msg: Option>, + msg: Option>, ) -> Option<(PublicKey, Bytes)> { match msg { Some(Err(e)) => { @@ -988,7 +993,7 @@ mod tests { } Some(Ok(msg)) => { info!("got message on: {msg:?}"); - if let ReceivedMessage::ReceivedPacket { + if let ServerToClientMsg::ReceivedPacket { remote_node_id: source, data, } = msg @@ -1041,19 +1046,22 @@ mod tests { info!("created client {b_key:?}"); info!("ping a"); - client_a.send(SendMessage::Ping([1u8; 8])).await?; + client_a.send(ClientToServerMsg::Ping([1u8; 8])).await?; let pong = client_a.next().await.expect("eos")?; - assert!(matches!(pong, ReceivedMessage::Pong(_))); + assert!(matches!(pong, ServerToClientMsg::Pong { .. })); info!("ping b"); - client_b.send(SendMessage::Ping([2u8; 8])).await?; + client_b.send(ClientToServerMsg::Ping([2u8; 8])).await?; let pong = client_b.next().await.expect("eos")?; - assert!(matches!(pong, ReceivedMessage::Pong(_))); + assert!(matches!(pong, ServerToClientMsg::Pong { .. })); info!("sending message from a to b"); let msg = Bytes::from_static(b"hi there, client b!"); client_a - .send(SendMessage::SendPacket(b_key, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: b_key, + packet: msg.clone(), + }) .await?; info!("waiting for message from a on b"); let (got_key, got_msg) = @@ -1064,7 +1072,10 @@ mod tests { info!("sending message from b to a"); let msg = Bytes::from_static(b"right back at ya, client b!"); client_b - .send(SendMessage::SendPacket(a_key, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: a_key, + packet: msg.clone(), + }) .await?; info!("waiting for message b on a"); let (got_key, got_msg) = @@ -1128,10 +1139,13 @@ mod tests { info!("Send message from A to B."); let msg = Bytes::from_static(b"hello client b!!"); client_a - .send(SendMessage::SendPacket(public_key_b, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: public_key_b, + packet: msg.clone(), + }) .await?; match client_b.next().await.unwrap()? { - ReceivedMessage::ReceivedPacket { + ServerToClientMsg::ReceivedPacket { remote_node_id, data, } => { @@ -1146,10 +1160,13 @@ mod tests { info!("Send message from B to A."); let msg = Bytes::from_static(b"nice to meet you client a!!"); client_b - .send(SendMessage::SendPacket(public_key_a, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: public_key_a, + packet: msg.clone(), + }) .await?; match client_a.next().await.unwrap()? { - ReceivedMessage::ReceivedPacket { + ServerToClientMsg::ReceivedPacket { remote_node_id, data, } => { @@ -1167,10 +1184,10 @@ mod tests { info!("Fail to send message from A to B."); let res = client_a - .send(SendMessage::SendPacket( - public_key_b, - Bytes::from_static(b"try to send"), - )) + .send(ClientToServerMsg::SendPacket { + dst_key: public_key_b, + packet: Bytes::from_static(b"try to send"), + }) .await; assert!(res.is_err()); assert!(client_b.next().await.is_none()); @@ -1216,10 +1233,13 @@ mod tests { info!("Send message from A to B."); let msg = Bytes::from_static(b"hello client b!!"); client_a - .send(SendMessage::SendPacket(public_key_b, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: public_key_b, + packet: msg.clone(), + }) .await?; match client_b.next().await.expect("eos")? { - ReceivedMessage::ReceivedPacket { + ServerToClientMsg::ReceivedPacket { remote_node_id, data, } => { @@ -1234,10 +1254,13 @@ mod tests { info!("Send message from B to A."); let msg = Bytes::from_static(b"nice to meet you client a!!"); client_b - .send(SendMessage::SendPacket(public_key_a, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: public_key_a, + packet: msg.clone(), + }) .await?; match client_a.next().await.expect("eos")? { - ReceivedMessage::ReceivedPacket { + ServerToClientMsg::ReceivedPacket { remote_node_id, data, } => { @@ -1264,10 +1287,13 @@ mod tests { info!("Send message from A to B."); let msg = Bytes::from_static(b"are you still there, b?!"); client_a - .send(SendMessage::SendPacket(public_key_b, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: public_key_b, + packet: msg.clone(), + }) .await?; match new_client_b.next().await.expect("eos")? { - ReceivedMessage::ReceivedPacket { + ServerToClientMsg::ReceivedPacket { remote_node_id, data, } => { @@ -1282,10 +1308,13 @@ mod tests { info!("Send message from B to A."); let msg = Bytes::from_static(b"just had a spot of trouble but I'm back now,a!!"); new_client_b - .send(SendMessage::SendPacket(public_key_a, msg.clone())) + .send(ClientToServerMsg::SendPacket { + dst_key: public_key_a, + packet: msg.clone(), + }) .await?; match client_a.next().await.expect("eos")? { - ReceivedMessage::ReceivedPacket { + ServerToClientMsg::ReceivedPacket { remote_node_id, data, } => { @@ -1302,10 +1331,10 @@ mod tests { info!("Sending message from A to B fails"); let res = client_a - .send(SendMessage::SendPacket( - public_key_b, - Bytes::from_static(b"try to send"), - )) + .send(ClientToServerMsg::SendPacket { + dst_key: public_key_b, + packet: Bytes::from_static(b"try to send"), + }) .await; assert!(res.is_err()); assert!(new_client_b.next().await.is_none()); diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index f58c234f22d..fefe379b0ab 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -14,7 +14,7 @@ use tokio_websockets::WebSocketStream; use tracing::instrument; use crate::{ - protos::relay::{Frame, RecvError}, + protos::relay::{ClientToServerMsg, RecvError, ServerToClientMsg}, ExportKeyingMaterial, KeyCache, }; @@ -31,18 +31,7 @@ pub(crate) struct RelayedStream { #[cfg(test)] impl RelayedStream { - pub(crate) fn test_client(stream: tokio::io::DuplexStream) -> Self { - let stream = MaybeTlsStream::Test(stream); - let stream = RateLimited::unlimited(stream, Arc::new(Metrics::default())); - Self { - inner: tokio_websockets::ClientBuilder::new() - .limits(Self::limits()) - .take_over(stream), - key_cache: KeyCache::test(), - } - } - - pub(crate) fn test_server(stream: tokio::io::DuplexStream) -> Self { + pub(crate) fn test(stream: tokio::io::DuplexStream) -> Self { let stream = MaybeTlsStream::Test(stream); let stream = RateLimited::unlimited(stream, Arc::new(Metrics::default())); Self { @@ -53,7 +42,7 @@ impl RelayedStream { } } - pub(crate) fn test_server_limited( + pub(crate) fn test_limited( stream: tokio::io::DuplexStream, max_burst_bytes: u32, bytes_per_second: u32, @@ -86,7 +75,7 @@ fn ws_to_io_err(e: tokio_websockets::Error) -> std::io::Error { } } -impl Sink for RelayedStream { +impl Sink for RelayedStream { type Error = std::io::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -95,13 +84,11 @@ impl Sink for RelayedStream { .map_err(ws_to_io_err) } - fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: ServerToClientMsg) -> Result<(), Self::Error> { Pin::new(&mut self.inner) - .start_send(tokio_websockets::Message::binary({ - let mut buf = BytesMut::new(); - item.write_to(&mut buf); - tokio_websockets::Payload::from(buf.freeze()) - })) + .start_send(tokio_websockets::Message::binary( + tokio_websockets::Payload::from(item.write_to(BytesMut::new()).freeze()), + )) .map_err(ws_to_io_err) } @@ -129,7 +116,7 @@ pub enum StreamError { } impl Stream for RelayedStream { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match Pin::new(&mut self.inner).poll_next(cx) { @@ -144,7 +131,7 @@ impl Stream for RelayedStream { return Poll::Pending; } Poll::Ready(Some( - Frame::from_bytes(msg.into_payload().into(), &self.key_cache) + ClientToServerMsg::from_bytes(msg.into_payload().into(), &self.key_cache) .map_err(Into::into), )) } diff --git a/iroh/src/magicsock/relay_actor.rs b/iroh/src/magicsock/relay_actor.rs index 8031ff6177a..5bce662a64e 100644 --- a/iroh/src/magicsock/relay_actor.rs +++ b/iroh/src/magicsock/relay_actor.rs @@ -42,7 +42,8 @@ use bytes::{Bytes, BytesMut}; use iroh_base::{NodeId, PublicKey, RelayUrl, SecretKey}; use iroh_relay::{ self as relay, - client::{Client, ConnectError, ReceivedMessage, RecvError, SendError, SendMessage}, + client::{Client, ConnectError, RecvError, SendError}, + protos::relay::{ClientToServerMsg, ServerToClientMsg}, PingTracker, MAX_PACKET_SIZE, }; use n0_future::{ @@ -551,7 +552,7 @@ impl ActiveRelayActor { let res = loop { if let Some(data) = state.pong_pending.take() { - let fut = client_sink.send(SendMessage::Pong(data)); + let fut = client_sink.send(ClientToServerMsg::Pong(data)); self.run_sending(fut, &mut state, &mut client_stream) .await?; } @@ -578,7 +579,7 @@ impl ActiveRelayActor { } _ = ping_interval.tick() => { let data = state.ping_tracker.new_ping(); - let fut = client_sink.send(SendMessage::Ping(data)); + let fut = client_sink.send(ClientToServerMsg::Ping(data)); self.run_sending(fut, &mut state, &mut client_stream).await?; } msg = self.inbox.recv() => { @@ -594,7 +595,7 @@ impl ActiveRelayActor { match client_stream.local_addr() { Some(addr) if local_ips.contains(&addr.ip()) => { let data = state.ping_tracker.new_ping(); - let fut = client_sink.send(SendMessage::Ping(data)); + let fut = client_sink.send(ClientToServerMsg::Ping(data)); self.run_sending(fut, &mut state, &mut client_stream).await?; } Some(_) => break Err(LocalIpInvalidSnafu.build()), @@ -610,7 +611,7 @@ impl ActiveRelayActor { ActiveRelayMessage::PingServer(sender) => { let data = rand::random(); state.test_pong = Some((data, sender)); - let fut = client_sink.send(SendMessage::Ping(data)); + let fut = client_sink.send(ClientToServerMsg::Ping(data)); self.run_sending(fut, &mut state, &mut client_stream).await?; } } @@ -638,11 +639,11 @@ impl ActiveRelayActor { datagrams.datagrams.clone(), ) .map(|p| { - Ok(SendMessage::SendPacket(p.node_id, p.payload)) + Ok(ClientToServerMsg::SendPacket { dst_key: p.node_id, packet: p.payload }) }) }); let mut packet_stream = n0_future::stream::iter(packet_iter).inspect(|m| { - if let Ok(SendMessage::SendPacket(_node_id, payload)) = m { + if let Ok(ClientToServerMsg::SendPacket { dst_key: _node_id, packet: payload }) = m { metrics.send_relay.inc_by(payload.len() as _); } }); @@ -678,9 +679,9 @@ impl ActiveRelayActor { res.map_err(|err| state.map_err(err)) } - fn handle_relay_msg(&mut self, msg: ReceivedMessage, state: &mut ConnectedRelayState) { + fn handle_relay_msg(&mut self, msg: ServerToClientMsg, state: &mut ConnectedRelayState) { match msg { - ReceivedMessage::ReceivedPacket { + ServerToClientMsg::ReceivedPacket { remote_node_id, data, } => { @@ -706,11 +707,11 @@ impl ActiveRelayActor { } } } - ReceivedMessage::NodeGone(node_id) => { + ServerToClientMsg::NodeGone(node_id) => { state.nodes_present.remove(&node_id); } - ReceivedMessage::Ping(data) => state.pong_pending = Some(data), - ReceivedMessage::Pong(data) => { + ServerToClientMsg::Ping(data) => state.pong_pending = Some(data), + ServerToClientMsg::Pong(data) => { #[cfg(test)] { if let Some((expected_data, sender)) = state.test_pong.take() { @@ -724,11 +725,11 @@ impl ActiveRelayActor { state.ping_tracker.pong_received(data); state.established = true; } - ReceivedMessage::Health { problem } => { + ServerToClientMsg::Health { problem } => { let problem = problem.as_deref().unwrap_or("unknown"); warn!("Relay server reports problem: {problem}"); } - ReceivedMessage::ServerRestarting { .. } => { + ServerToClientMsg::Restarting { .. } => { trace!("Ignoring {msg:?}") } } @@ -1272,7 +1273,7 @@ where } } -/// Splits a single [`ReceivedMessage::ReceivedPacket`] frame into datagrams. +/// Splits a single [`ServerToClientMsg::ReceivedPacket`] frame into datagrams. /// /// This splits packets joined by [`PacketizeIter`] back into individual datagrams. See /// that struct for more details. From 87c68a3bd87b81314cebb8985741aa19a132cc85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 10 Jun 2025 07:37:52 +0200 Subject: [PATCH 19/80] Remove `Option` from `ServerToClientMsg::Health` --- iroh-relay/proptest-regressions/protos/relay.txt | 1 + iroh-relay/src/protos/relay.rs | 13 ++++++------- iroh-relay/src/server.rs | 2 +- iroh-relay/src/server/http_server.rs | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/iroh-relay/proptest-regressions/protos/relay.txt b/iroh-relay/proptest-regressions/protos/relay.txt index e718aef9a1e..38aafe254c7 100644 --- a/iroh-relay/proptest-regressions/protos/relay.txt +++ b/iroh-relay/proptest-regressions/protos/relay.txt @@ -5,3 +5,4 @@ # It is recommended to check this file in to source control so that # everyone who runs the test benefits from these saved cases. cc 9295f5287162dfb180e5826e563c2cea08b477b803ef412ff8351eb5c3eb45ef # shrinks to frame = KeepAlive +cc 753aabcf8ae2b4e4a52f451d58339aab85a4b61108afdf4b9600f97b3a33bf42 # shrinks to frame = Health { problem: None } diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 1c376b27c32..aedeb6c4596 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -196,7 +196,7 @@ pub enum ServerToClientMsg { /// /// The default condition is healthy, so the server doesn't broadcast a [`ReceivedMessage::Health`] /// until a problem exists. - problem: Option, + problem: String, }, /// A one-way message from server to client, advertising that the server is restarting. Restarting { @@ -271,9 +271,7 @@ impl ServerToClientMsg { dst.put(&data[..]); } Self::Health { problem } => { - if let Some(problem) = problem { - dst.put(problem.as_ref()); - } + dst.put(problem.as_ref()); } Self::Restarting { reconnect_in, @@ -343,7 +341,6 @@ impl ServerToClientMsg { .context(InvalidProtocolMessageEncodingSnafu)? .to_owned(); // TODO(matheus23): Actually encode/decode the option - let problem = Some(problem); Self::Health { problem } } FrameType::Restarting => { @@ -485,7 +482,7 @@ mod tests { check_expected_bytes(vec![ ( ServerToClientMsg::Health { - problem: Some("Hello? Yes this is dog.".into()), + problem: "Hello? Yes this is dog.".into(), } .write_to(Vec::new()), "0e 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 @@ -590,7 +587,9 @@ mod proptests { let ping = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Pong); // TODO(matheus23): Actually fix these - let health = data(0).prop_map(|_problem| ServerToClientMsg::Health { problem: None }); + let health = data(0).prop_map(|_problem| ServerToClientMsg::Health { + problem: "".to_string(), + }); let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { ServerToClientMsg::Restarting { reconnect_in: Duration::from_millis(reconnect_in.into()), diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index afdfa207f2b..7af4802f97d 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -1155,7 +1155,7 @@ mod tests { tokio::time::timeout(Duration::from_millis(500), async move { match client_a.next().await.unwrap().unwrap() { ServerToClientMsg::Health { problem } => { - assert_eq!(problem, Some("not authenticated".to_string())); + assert_eq!(problem, "not authenticated".to_string()); } msg => { panic!("other msg: {:?}", msg); diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 32f73675c0c..5df31704804 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -669,7 +669,7 @@ impl Inner { // TODO(matheus23): Maybe use new frame? if !self.access.is_allowed(client_key).await { io.send(ServerToClientMsg::Health { - problem: Some("not authenticated".into()), + problem: "not authenticated".into(), }) .await?; io.flush().await?; From ae844ab864d267fa27d72632b1854284fe11e979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 10 Jun 2025 07:40:53 +0200 Subject: [PATCH 20/80] Fix iroh type --- iroh/src/magicsock/relay_actor.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/iroh/src/magicsock/relay_actor.rs b/iroh/src/magicsock/relay_actor.rs index 5bce662a64e..819b896801e 100644 --- a/iroh/src/magicsock/relay_actor.rs +++ b/iroh/src/magicsock/relay_actor.rs @@ -726,7 +726,6 @@ impl ActiveRelayActor { state.established = true; } ServerToClientMsg::Health { problem } => { - let problem = problem.as_deref().unwrap_or("unknown"); warn!("Relay server reports problem: {problem}"); } ServerToClientMsg::Restarting { .. } => { From 521bdfa0cc109ac6a4b2c8cd447145b888aab66e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 10 Jun 2025 14:13:05 +0200 Subject: [PATCH 21/80] Deduplicate `FrameType` --- iroh-relay/src/client/conn.rs | 4 +- iroh-relay/src/protos/handshake.rs | 28 +++++++++-- iroh-relay/src/protos/relay.rs | 78 +++--------------------------- iroh-relay/src/server/client.rs | 2 +- iroh-relay/src/server/clients.rs | 2 +- 5 files changed, 34 insertions(+), 80 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 2c79b2f1e0c..58759d9b629 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -80,9 +80,7 @@ pub enum RecvError { source: ws_stream_wasm::WsErr, }, #[snafu(display("Unexpected frame received: {frame_type}"))] - UnexpectedFrame { - frame_type: crate::protos::relay::FrameType, - }, + UnexpectedFrame { frame_type: handshake::FrameType }, } /// A connection to a relay server. diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index e2554168127..c98db1eb256 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -44,7 +44,7 @@ pub enum FrameType { /// to B so B can forget that a reverse path exists on that connection to get back to A /// /// 32B pub key of peer that's gone - PeerGone = 14, + NodeGone = 14, /// Frames 9-11 concern meshing, which we have eliminated from our version of the protocol. /// Messages with these frames will be ignored. /// 8 byte ping payload, to be echoed back in FrameType::Pong @@ -135,7 +135,9 @@ impl BytesStreamSink for T where { } -trait Frame { +/// TODO(matheus23): Docs +pub trait Frame { + /// ... const TAG: FrameType; } @@ -194,6 +196,24 @@ pub enum Error { }, } +impl FrameType { + pub(crate) fn write_to(&self, mut dst: O) -> O { + VarInt::from(*self).encode(&mut dst); + dst + } + + // TODO(matheus23): Consolidate errors between handshake.rs and relay.rs + // Perhaps a shared error type `FramingError`? + pub(crate) fn from_bytes(bytes: Bytes) -> Option<(Self, Bytes)> { + let mut cursor = std::io::Cursor::new(&bytes); + let tag = VarInt::decode(&mut cursor).ok()?; + let tag_u32 = u32::try_from(u64::from(tag)).ok()?; + let frame_type = FrameType::from(tag_u32); + let content = bytes.slice(cursor.position() as usize..); + Some((frame_type, content)) + } +} + impl ServerChallenge { /// TODO(matheus23): docs #[cfg(feature = "server")] @@ -370,9 +390,8 @@ async fn write_frame( io: &mut impl BytesStreamSink, frame: F, ) -> Result<(), Error> { - let tag: VarInt = F::TAG.into(); let mut bytes = BytesMut::new(); - tag.encode(&mut bytes); + F::TAG.write_to(&mut bytes); let bytes = postcard::to_io(&frame, bytes.writer()) .expect("serialization failed") // buffer can't become "full" without being a critical failure, datastructures shouldn't ever fail serialization .into_inner() @@ -392,6 +411,7 @@ async fn read_frame( .context(TimeoutSnafu)?? .ok_or_else(|| UnexpectedEndSnafu.build())?; + // TODO(matheus23) restructure: use FrameType::from_bytes, perhaps always use `FrameType` instead let mut cursor = std::io::Cursor::new(recv); let tag = VarInt::decode(&mut cursor) .map_err(|quinn_proto::coding::UnexpectedEnd| UnexpectedEndSnafu.build())?; diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index aedeb6c4596..467a278810f 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -18,12 +18,12 @@ use iroh_base::{NodeId, SignatureError}; use n0_future::time::Duration; use n0_future::{time, Sink, SinkExt}; use nested_enum_utils::common_fields; -use postcard::experimental::max_size::MaxSize; -use serde::{Deserialize, Serialize}; -use snafu::{Backtrace, ResultExt, Snafu}; +use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; use crate::{client::conn::SendError as ConnSendError, KeyCache}; +use super::handshake::FrameType; + /// The maximum size of a packet sent over relay. /// (This only includes the data bytes visible to magicsock, not /// including its on-wire framing overhead) @@ -46,60 +46,6 @@ pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(15); #[cfg(feature = "server")] pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512; -/// The one byte frame type at the beginning of the frame -/// header. The second field is a big-endian u32 describing the -/// length of the remaining frame (not including the initial 5 bytes) -#[derive(Debug, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::FromPrimitive, Clone, Copy)] -#[repr(u8)] -pub enum FrameType { - /// magic + 32b pub key + 24B nonce + bytes - ClientInfo = 2, - /// 32B dest pub key + packet bytes - SendPacket = 4, - /// v0/1 packet bytes, v2: 32B src pub key + packet bytes - RecvPacket = 5, - /// Sent from server to client to signal that a previous sender is no longer connected. - /// - /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` - /// to B so B can forget that a reverse path exists on that connection to get back to A - /// - /// 32B pub key of peer that's gone - NodeGone = 8, - /// Frames 9-11 concern meshing, which we have eliminated from our version of the protocol. - /// Messages with these frames will be ignored. - /// 8 byte ping payload, to be echoed back in FrameType::Pong - Ping = 12, - /// 8 byte payload, the contents of ping being replied to - Pong = 13, - /// Sent from server to client to tell the client if their connection is - /// unhealthy somehow. - /// - /// Currently this is used to indicate that the connection was closed because of authentication issues. - Health = 14, - - /// Sent from server to client for the server to declare that it's restarting. - /// Payload is two big endian u32 durations in milliseconds: when to reconnect, - /// and how long to try total. - /// - /// Handled on the `[relay::Client]`, but currently never sent on the `[relay::Server]` - Restarting = 15, - /// Unknown frame type - #[num_enum(default)] - Unknown = 255, -} - -impl std::fmt::Display for FrameType { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -#[derive(Debug, Serialize, Deserialize, MaxSize, PartialEq, Eq)] -pub(crate) struct ClientInfo { - /// The relay protocol version that the client was built with. - pub(crate) version: usize, -} - /// Protocol send errors. #[common_fields({ backtrace: Option, @@ -252,7 +198,7 @@ impl ServerToClientMsg { /// /// Specifically meant for being put into a binary websocket message frame. pub(crate) fn write_to(&self, mut dst: O) -> O { - dst.put_u8(self.typ().into()); + dst = self.typ().write_to(dst); match self { Self::ReceivedPacket { remote_node_id: src_key, @@ -289,12 +235,7 @@ impl ServerToClientMsg { /// Specifically, bytes received from a binary websocket message frame. #[allow(clippy::result_large_err)] pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { - if bytes.is_empty() { - return Err(TooSmallSnafu.build()); - } - let frame_type = FrameType::from(bytes[0]); - let content = bytes.slice(1..); - + let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; let res = match frame_type { FrameType::RecvPacket => { if content.len() < NodeId::LENGTH { @@ -385,7 +326,7 @@ impl ClientToServerMsg { /// /// Specifically meant for being put into a binary websocket message frame. pub(crate) fn write_to(&self, mut dst: O) -> O { - dst.put_u8(self.typ().into()); + dst = self.typ().write_to(dst); match self { Self::SendPacket { dst_key, packet } => { dst.put(dst_key.as_ref()); @@ -406,12 +347,7 @@ impl ClientToServerMsg { /// Specifically, bytes received from a binary websocket message frame. #[allow(clippy::result_large_err)] pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { - if bytes.is_empty() { - return Err(TooSmallSnafu.build()); - } - let frame_type = FrameType::from(bytes[0]); - let content = bytes.slice(1..); - + let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; let res = match frame_type { FrameType::SendPacket => { if content.len() < NodeId::LENGTH { diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 414d22e6d7e..bbc1d2840a3 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -557,7 +557,7 @@ mod tests { use tracing_test::traced_test; use super::*; - use crate::{client::conn::Conn, protos::relay::FrameType}; + use crate::{client::conn::Conn, protos::handshake::FrameType}; async fn recv_frame< E: snafu::Error + Sync + Send + 'static, diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 7fd087465a9..e9bb5c88ec0 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -200,7 +200,7 @@ mod tests { use super::*; use crate::{ client::conn::Conn, - protos::relay::{FrameType, ServerToClientMsg}, + protos::{handshake::FrameType, relay::ServerToClientMsg}, server::streams::RelayedStream, }; From 0dda03ef048901d7d79471e8272e25470a73659e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 10 Jun 2025 21:30:03 +0200 Subject: [PATCH 22/80] Remove `Protocol` enum --- iroh-relay/src/client.rs | 118 +++++++++---- iroh-relay/src/client/conn.rs | 155 +++++++---------- iroh-relay/src/client/tls.rs | 72 -------- iroh-relay/src/http.rs | 36 ---- iroh-relay/src/protos.rs | 2 +- iroh-relay/src/protos/handshake.rs | 17 +- iroh-relay/src/protos/{io.rs => streams.rs} | 33 +++- iroh-relay/src/server.rs | 3 - iroh-relay/src/server/http_server.rs | 183 ++++++++------------ iroh/src/endpoint.rs | 22 +-- iroh/src/lib.rs | 2 +- iroh/src/magicsock.rs | 8 +- iroh/src/magicsock/relay_actor.rs | 8 - 13 files changed, 256 insertions(+), 403 deletions(-) rename iroh-relay/src/protos/{io.rs => streams.rs} (80%) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 4bc41af5d9f..964ba2a39da 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -26,7 +26,7 @@ pub use self::conn::{RecvError, SendError}; #[cfg(not(wasm_browser))] use crate::dns::{DnsError, DnsResolver}; use crate::{ - http::{Protocol, RELAY_PATH}, + http::RELAY_PATH, protos::{ handshake, relay::{ClientToServerMsg, ServerToClientMsg}, @@ -126,8 +126,6 @@ pub struct ClientBuilder { address_family_selector: Option bool + Send + Sync>>, /// Server url. url: RelayUrl, - /// Relay protocol - protocol: Protocol, /// Allow self-signed certificates from relay servers #[cfg(any(test, feature = "test-utils"))] insecure_skip_cert_verify: bool, @@ -153,9 +151,6 @@ impl ClientBuilder { address_family_selector: None, url: url.into(), - // Resolves to websockets in browsers and relay otherwise - protocol: Protocol::default(), - #[cfg(any(test, feature = "test-utils"))] insecure_skip_cert_verify: false, @@ -167,13 +162,6 @@ impl ClientBuilder { } } - /// Sets whether to connect to the relay via websockets or not. - /// Set to use non-websocket, normal relaying by default. - pub fn protocol(mut self, protocol: Protocol) -> Self { - self.protocol = protocol; - self - } - /// Returns if we should prefer ipv6 /// it replaces the relayhttp.AddressFamilySelector we pass /// It provides the hint as to whether in an IPv4-vs-IPv6 race that @@ -210,34 +198,94 @@ impl ClientBuilder { } /// Establishes a new connection to the relay server. + #[cfg(not(wasm_browser))] pub async fn connect(&self) -> Result { - let (conn, local_addr) = match self.protocol { - #[cfg(wasm_browser)] - Protocol::Websocket => { - let conn = self.connect_ws().await?; - let local_addr = None; - (conn, local_addr) - } - #[cfg(not(wasm_browser))] - Protocol::Websocket => { - let (conn, local_addr) = self.connect_ws().await?; - (conn, Some(local_addr)) + use tls::MaybeTlsStreamBuilder; + + let mut dial_url = (*self.url).clone(); + dial_url.set_path(RELAY_PATH); + // The relay URL is exchanged with the http(s) scheme in tickets and similar. + // We need to use the ws:// or wss:// schemes when connecting with websockets, though. + dial_url + .set_scheme(match self.url.scheme() { + "http" => "ws", + "ws" => "ws", + _ => "wss", + }) + .map_err(|_| { + InvalidWebsocketUrlSnafu { + url: dial_url.clone(), + } + .build() + })?; + + debug!(%dial_url, "Dialing relay by websocket"); + + #[allow(unused_mut)] + let mut builder = MaybeTlsStreamBuilder::new(dial_url.clone(), self.dns_resolver.clone()) + .prefer_ipv6(self.prefer_ipv6()) + .proxy_url(self.proxy_url.clone()); + + #[cfg(any(test, feature = "test-utils"))] + if self.insecure_skip_cert_verify { + builder = builder.insecure_skip_cert_verify(self.insecure_skip_cert_verify); + } + + let stream = builder.connect().await?; + let local_addr = stream + .as_ref() + .local_addr() + .map_err(|_| NoLocalAddrSnafu.build())?; + let (conn, response) = tokio_websockets::ClientBuilder::new() + .uri(dial_url.as_str()) + .map_err(|_| { + InvalidRelayUrlSnafu { + url: dial_url.clone(), + } + .build() + })? + .connect_on(stream) + .await?; + + if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { + UnexpectedUpgradeStatusSnafu { + code: response.status(), } - }; + .fail()?; + } + + let conn = Conn::new(conn, self.key_cache.clone(), &self.secret_key).await?; event!( target: "events.net.relay.connected", Level::DEBUG, url = %self.url, - protocol = ?self.protocol, ); trace!("connect done"); - Ok(Client { conn, local_addr }) + + Ok(Client { + conn, + local_addr: Some(local_addr), + }) + } + + /// Reports whether IPv4 dials should be slightly + /// delayed to give IPv6 a better chance of winning dial races. + /// Implementations should only return true if IPv6 is expected + /// to succeed. (otherwise delaying IPv4 will delay the connection + /// overall) + #[cfg(not(wasm_browser))] + fn prefer_ipv6(&self) -> bool { + match self.address_family_selector { + Some(ref selector) => selector(), + None => false, + } } + /// Establishes a new connection to the relay server. #[cfg(wasm_browser)] - async fn connect_ws(&self) -> Result { + async fn connect(&self) -> Result { let mut dial_url = (*self.url).clone(); dial_url.set_path(RELAY_PATH); // The relay URL is exchanged with the http(s) scheme in tickets and similar. @@ -260,7 +308,19 @@ impl ClientBuilder { let (_, ws_stream) = ws_stream_wasm::WsMeta::connect(dial_url.as_str(), None).await?; let conn = Conn::new_ws_browser(ws_stream, self.key_cache.clone(), &self.secret_key).await?; - Ok(conn) + + event!( + target: "events.net.relay.connected", + Level::DEBUG, + url = %self.url, + ); + + trace!("connect done"); + + Ok(Client { + conn, + local_addr: None, + }) } } diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 58759d9b629..338ec1b19eb 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -19,7 +19,7 @@ use super::KeyCache; #[cfg(not(wasm_browser))] use crate::{ client::streams::{MaybeTlsStream, ProxyStream}, - protos::io::HandshakeIo, + protos::streams::WsBytesFramed, }; use crate::{ protos::{ @@ -95,27 +95,21 @@ pub enum RecvError { /// The [`SendMessage`] and [`ReceivedMessage`] are safer wrappers enforcing some protocol /// invariants. #[derive(derive_more::Debug)] -pub(crate) enum Conn { +pub(crate) struct Conn { + #[debug("tokio_websockets::WebSocketStream")] #[cfg(not(wasm_browser))] - Ws { - #[debug("WebSocketStream>")] - conn: tokio_websockets::WebSocketStream>, - key_cache: KeyCache, - }, + pub(crate) conn: tokio_websockets::WebSocketStream>, + #[debug("ws_stream_wasm::WsStream")] #[cfg(wasm_browser)] - WsBrowser { - #[debug("WebSocketStream")] - conn: ws_stream_wasm::WsStream, - key_cache: KeyCache, - }, + pub(crate) conn: ws_stream_wasm::WsStream, + pub(crate) key_cache: KeyCache, } impl Conn { #[cfg(test)] pub(crate) fn test(io: tokio::io::DuplexStream) -> Self { use crate::protos::relay::MAX_FRAME_SIZE; - - Self::Ws { + Self { conn: tokio_websockets::ClientBuilder::new() .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))) .take_over(MaybeTlsStream::Test(io)), @@ -125,38 +119,38 @@ impl Conn { /// Constructs a new websocket connection, including the initial server handshake. #[cfg(wasm_browser)] - pub(crate) async fn new_ws_browser( + pub(crate) async fn new( conn: ws_stream_wasm::WsStream, key_cache: KeyCache, secret_key: &SecretKey, ) -> Result { - let mut io = HandshakeIo { io: conn }; + let mut io = WsBytesFramed { io: conn }; // exchange information with the server debug!("server_handshake: started"); handshake::clientside(&mut io, secret_key).await?; debug!("server_handshake: done"); - Ok(Self::WsBrowser { + Ok(Self { conn: io.io, key_cache, }) } #[cfg(not(wasm_browser))] - pub(crate) async fn new_ws( + pub(crate) async fn new( conn: tokio_websockets::WebSocketStream>, key_cache: KeyCache, secret_key: &SecretKey, ) -> Result { - let mut io = HandshakeIo { io: conn }; + let mut io = WsBytesFramed { io: conn }; // exchange information with the server debug!("server_handshake: started"); handshake::clientside(&mut io, secret_key).await?; debug!("server_handshake: done"); - Ok(Self::Ws { + Ok(Self { conn: io.io, key_cache, }) @@ -167,47 +161,38 @@ impl Stream for Conn { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { + let msg = ready!(Pin::new(&mut self.conn).poll_next(cx)); + match msg { #[cfg(not(wasm_browser))] - Self::Ws { - ref mut conn, - ref key_cache, - } => match ready!(Pin::new(conn).poll_next(cx)) { - Some(Ok(msg)) => { - if msg.is_close() { - // Indicate the stream is done when we receive a close message. - // Note: We don't have to poll the stream to completion for it to close gracefully. - return Poll::Ready(None); - } - if !msg.is_binary() { - tracing::warn!( - ?msg, - "Got websocket message of unsupported type, skipping." - ); - return Poll::Pending; - } - let message = - ServerToClientMsg::from_bytes(msg.into_payload().into(), key_cache); - Poll::Ready(Some(message.map_err(Into::into))) - } - Some(Err(e)) => Poll::Ready(Some(Err(e.into()))), - None => Poll::Ready(None), - }, - #[cfg(wasm_browser)] - Self::WsBrowser { - ref mut conn, - ref key_cache, - } => match ready!(Pin::new(conn).poll_next(cx)) { - Some(ws_stream_wasm::WsMessage::Binary(vec)) => { - let frame = Frame::decode_from_ws_msg(Bytes::from(vec), key_cache)?; - Poll::Ready(Some(ReceivedMessage::try_from(frame))) + Some(Ok(msg)) => { + if msg.is_close() { + // Indicate the stream is done when we receive a close message. + // Note: We don't have to poll the stream to completion for it to close gracefully. + return Poll::Ready(None); } - Some(msg) => { + if !msg.is_binary() { tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); - Poll::Pending + return Poll::Pending; } - None => Poll::Ready(None), - }, + let message = + ServerToClientMsg::from_bytes(msg.into_payload().into(), &self.key_cache); + Poll::Ready(Some(message.map_err(Into::into))) + } + #[cfg(not(wasm_browser))] + Some(Err(e)) => Poll::Ready(Some(Err(e.into()))), + + #[cfg(wasm_browser)] + Some(ws_stream_wasm::WsMessage::Binary(vec)) => { + let frame = Frame::decode_from_ws_msg(Bytes::from(vec), &self.key_cache)?; + Poll::Ready(Some(ReceivedMessage::try_from(frame))) + } + #[cfg(wasm_browser)] + Some(msg) => { + tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); + Poll::Pending + } + + None => Poll::Ready(None), } } } @@ -216,14 +201,7 @@ impl Sink for Conn { type Error = SendError; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - #[cfg(not(wasm_browser))] - Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_ready(cx).map_err(Into::into), - #[cfg(wasm_browser)] - Self::WsBrowser { ref mut conn, .. } => { - Pin::new(conn).poll_ready(cx).map_err(Into::into) - } - } + Pin::new(&mut self.conn).poll_ready(cx).map_err(Into::into) } fn start_send(mut self: Pin<&mut Self>, frame: ClientToServerMsg) -> Result<(), Self::Error> { @@ -232,43 +210,26 @@ impl Sink for Conn { let size = packet.len(); snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); } - match *self { - #[cfg(not(wasm_browser))] - Self::Ws { ref mut conn, .. } => Pin::new(conn) - .start_send(tokio_websockets::Message::binary({ - let mut buf = BytesMut::new(); - frame.write_to(&mut buf); - tokio_websockets::Payload::from(buf.freeze()) - })) - .map_err(Into::into), - #[cfg(wasm_browser)] - Self::WsBrowser { ref mut conn, .. } => Pin::new(conn) - .start_send(ws_stream_wasm::WsMessage::Binary( - frame.write_to(Vec::new()), - )) - .map_err(Into::into), - } + + #[cfg(not(wasm_browser))] + let frame = tokio_websockets::Message::binary({ + let mut buf = BytesMut::new(); + frame.write_to(&mut buf); + tokio_websockets::Payload::from(buf.freeze()) + }); + #[cfg(wasm_browser)] + let frame = ws_stream_wasm::WsMessage::Binary(frame.write_to(Vec::new())); + + Pin::new(&mut self.conn) + .start_send(frame) + .map_err(Into::into) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - #[cfg(not(wasm_browser))] - Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_flush(cx).map_err(Into::into), - #[cfg(wasm_browser)] - Self::WsBrowser { ref mut conn, .. } => { - Pin::new(conn).poll_flush(cx).map_err(Into::into) - } - } + Pin::new(&mut self.conn).poll_flush(cx).map_err(Into::into) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match *self { - #[cfg(not(wasm_browser))] - Self::Ws { ref mut conn, .. } => Pin::new(conn).poll_close(cx).map_err(Into::into), - #[cfg(wasm_browser)] - Self::WsBrowser { ref mut conn, .. } => { - Pin::new(conn).poll_close(cx).map_err(Into::into) - } - } + Pin::new(&mut self.conn).poll_close(cx).map_err(Into::into) } } diff --git a/iroh-relay/src/client/tls.rs b/iroh-relay/src/client/tls.rs index 447290968d3..177ef26b301 100644 --- a/iroh-relay/src/client/tls.rs +++ b/iroh-relay/src/client/tls.rs @@ -274,78 +274,6 @@ impl MaybeTlsStreamBuilder { } } -impl ClientBuilder { - pub(super) async fn connect_ws(&self) -> Result<(Conn, SocketAddr), ConnectError> { - let mut dial_url = (*self.url).clone(); - dial_url.set_path(RELAY_PATH); - // The relay URL is exchanged with the http(s) scheme in tickets and similar. - // We need to use the ws:// or wss:// schemes when connecting with websockets, though. - dial_url - .set_scheme(match self.url.scheme() { - "http" => "ws", - "ws" => "ws", - _ => "wss", - }) - .map_err(|_| { - InvalidWebsocketUrlSnafu { - url: dial_url.clone(), - } - .build() - })?; - - debug!(%dial_url, "Dialing relay by websocket"); - - #[allow(unused_mut)] - let mut builder = MaybeTlsStreamBuilder::new(dial_url.clone(), self.dns_resolver.clone()) - .prefer_ipv6(self.prefer_ipv6()) - .proxy_url(self.proxy_url.clone()); - - #[cfg(any(test, feature = "test-utils"))] - if self.insecure_skip_cert_verify { - builder = builder.insecure_skip_cert_verify(self.insecure_skip_cert_verify); - } - - let stream = builder.connect().await?; - let local_addr = stream - .as_ref() - .local_addr() - .map_err(|_| NoLocalAddrSnafu.build())?; - let (conn, response) = tokio_websockets::ClientBuilder::new() - .uri(dial_url.as_str()) - .map_err(|_| { - InvalidRelayUrlSnafu { - url: dial_url.clone(), - } - .build() - })? - .connect_on(stream) - .await?; - - if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { - UnexpectedUpgradeStatusSnafu { - code: response.status(), - } - .fail()?; - } - - let conn = Conn::new_ws(conn, self.key_cache.clone(), &self.secret_key).await?; - - Ok((conn, local_addr)) - } - - /// Reports whether IPv4 dials should be slightly - /// delayed to give IPv6 a better chance of winning dial races. - /// Implementations should only return true if IPv6 is expected - /// to succeed. (otherwise delaying IPv4 will delay the connection - /// overall) - fn prefer_ipv6(&self) -> bool { - match self.address_family_selector { - Some(ref selector) => selector(), - None => false, - } - } -} - fn url_port(url: &Url) -> Option { if let Some(port) = url.port() { return Some(port); diff --git a/iroh-relay/src/http.rs b/iroh-relay/src/http.rs index 415d22ea3d4..60a8a6d2df2 100644 --- a/iroh-relay/src/http.rs +++ b/iroh-relay/src/http.rs @@ -9,39 +9,3 @@ pub(crate) const SUPPORTED_WEBSOCKET_VERSION: &str = "13"; pub const RELAY_PATH: &str = "/relay"; /// The HTTP path under which the relay allows doing latency queries for testing. pub const RELAY_PROBE_PATH: &str = "/ping"; -/// The legacy HTTP path under which the relay used to accept relaying connections. -/// We keep this for backwards compatibility. - -/// The HTTP upgrade protocol used for relaying. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum Protocol { - /// Relays over websockets. - /// - /// Originally introduced to support browser connections. - Websocket, -} - -impl Default for Protocol { - fn default() -> Self { - Self::Websocket - } -} - -impl Protocol { - /// The HTTP upgrade header used or expected. - pub const fn upgrade_header(&self) -> &'static str { - match self { - Protocol::Websocket => WEBSOCKET_UPGRADE_PROTOCOL, - } - } - - /// Tries to match the value of an HTTP upgrade header to figure out which protocol should be initiated. - pub fn parse_header(header: &http::HeaderValue) -> Option { - let header_bytes = header.as_bytes(); - if header_bytes == Protocol::Websocket.upgrade_header().as_bytes() { - Some(Protocol::Websocket) - } else { - None - } - } -} diff --git a/iroh-relay/src/protos.rs b/iroh-relay/src/protos.rs index 1e1866dc966..92fa7611e20 100644 --- a/iroh-relay/src/protos.rs +++ b/iroh-relay/src/protos.rs @@ -2,6 +2,6 @@ pub mod disco; pub mod handshake; -pub mod io; pub mod relay; +pub mod streams; pub mod stun; diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index c98db1eb256..b3413b40e0b 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -4,7 +4,7 @@ use bytes::{BufMut, Bytes, BytesMut}; use iroh_base::{PublicKey, SecretKey, Signature}; use n0_future::{ time::{self, Elapsed}, - Sink, SinkExt, Stream, TryStreamExt, + SinkExt, TryStreamExt, }; use nested_enum_utils::common_fields; use quinn_proto::{coding::Codec, VarInt}; @@ -12,7 +12,7 @@ use quinn_proto::{coding::Codec, VarInt}; use rand::{CryptoRng, RngCore}; use snafu::{Backtrace, ResultExt, Snafu}; -use super::relay::SendError; +use super::{relay::SendError, streams::BytesStreamSink}; use crate::ExportKeyingMaterial; /// TODO(matheus23) docs @@ -22,7 +22,7 @@ pub(crate) const PROTOCOL_VERSION: &[u8] = b"1"; #[repr(u32)] #[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] pub enum FrameType { - /// The frame type for the client challenge request + /// The client frame type for the client challenge request ClientRequestChallenge = 1, /// The server frame type for the challenge response ServerChallenge = 2, @@ -124,17 +124,6 @@ pub(crate) struct ServerConfirmsAuth; #[cfg_attr(feature = "server", derive(serde::Serialize))] pub(crate) struct ServerDeniesAuth; -/// TODO(matheus23) docs -pub(crate) trait BytesStreamSink: - Stream> + Sink + Unpin -{ -} - -impl BytesStreamSink for T where - T: Stream> + Sink + Unpin -{ -} - /// TODO(matheus23): Docs pub trait Frame { /// ... diff --git a/iroh-relay/src/protos/io.rs b/iroh-relay/src/protos/streams.rs similarity index 80% rename from iroh-relay/src/protos/io.rs rename to iroh-relay/src/protos/streams.rs index c62061fda81..6b040bd78b7 100644 --- a/iroh-relay/src/protos/io.rs +++ b/iroh-relay/src/protos/streams.rs @@ -8,21 +8,38 @@ use bytes::Bytes; use n0_future::{ready, Sink, Stream}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::handshake::Error; use crate::ExportKeyingMaterial; #[derive(derive_more::Debug)] -pub(crate) struct HandshakeIo { +pub(crate) struct WsBytesFramed { #[cfg(not(wasm_browser))] - #[debug("WebSocketStream>")] + #[debug("WebSocketStream")] pub(crate) io: tokio_websockets::WebSocketStream, #[cfg(wasm_browser)] #[debug("WebSocketStream")] pub(crate) io: ws_stream_wasm::WsStream, + #[cfg(wasm_browser)] + _data: PhantomData, +} + +#[cfg(not(wasm_browser))] +type StreamError = tokio_websockets::Error; +#[cfg(wasm_browser)] +type StreamError = ws_stream_wasm::WsErr; + +/// TODO(matheus23) docs +pub(crate) trait BytesStreamSink: + Stream> + Sink + Unpin +{ +} + +impl BytesStreamSink for T where + T: Stream> + Sink + Unpin +{ } impl ExportKeyingMaterial - for HandshakeIo + for WsBytesFramed { #[cfg(wasm_browser)] fn export_keying_material>( @@ -47,8 +64,8 @@ impl ExportKeyingMate } } -impl Stream for HandshakeIo { - type Item = Result; +impl Stream for WsBytesFramed { + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -78,8 +95,8 @@ impl Stream for HandshakeIo { } } -impl Sink for HandshakeIo { - type Error = Error; +impl Sink for WsBytesFramed { + type Error = StreamError; fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { #[cfg(not(wasm_browser))] diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index 7af4802f97d..6a2d4b54f60 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -892,7 +892,6 @@ mod tests { use crate::{ client::ClientBuilder, dns::DnsResolver, - http::Protocol, protos::{ self, relay::{ClientToServerMsg, ServerToClientMsg}, @@ -1028,7 +1027,6 @@ mod tests { let resolver = dns_resolver(); info!("client a build & connect"); let mut client_a = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver.clone()) - .protocol(Protocol::Websocket) .connect() .await?; @@ -1037,7 +1035,6 @@ mod tests { let b_key = b_secret_key.public(); info!("client b build & connect"); let mut client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver.clone()) - .protocol(Protocol::Websocket) // another websocket client .connect() .await?; diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 5df31704804..91d0dafa211 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -25,7 +25,7 @@ use super::{clients::Clients, AccessConfig, SpawnError}; #[allow(deprecated)] use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, - http::{Protocol, RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, + http::{RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, protos::relay::{ServerToClientMsg, PER_CLIENT_SEND_QUEUE_DEPTH}, server::{ client::Config, @@ -36,7 +36,8 @@ use crate::{ KeyCache, }; use crate::{ - protos::{handshake, io::HandshakeIo, relay::MAX_FRAME_SIZE}, + http::WEBSOCKET_UPGRADE_PROTOCOL, + protos::{handshake, relay::MAX_FRAME_SIZE, streams::WsBytesFramed}, server::streams::RateLimited, }; @@ -460,47 +461,42 @@ impl RelayService { async move { { // Send a 400 to any request that doesn't have an `Upgrade` header. - let Some(protocol) = req.headers().get(UPGRADE).and_then(Protocol::parse_header) - else { + if req.headers().get(UPGRADE) + != Some(&HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL)) + { return Ok(builder .status(StatusCode::BAD_REQUEST) .body(body_empty()) .expect("valid body")); }; - let websocket_headers = if protocol == Protocol::Websocket { - let Some(key) = req.headers().get("Sec-WebSocket-Key").cloned() else { - warn!("missing header Sec-WebSocket-Key for websocket relay protocol"); - return Ok(builder - .status(StatusCode::BAD_REQUEST) - .body(body_empty()) - .expect("valid body")); - }; - - let Some(version) = req.headers().get("Sec-WebSocket-Version").cloned() else { - warn!("missing header Sec-WebSocket-Version for websocket relay protocol"); - return Ok(builder - .status(StatusCode::BAD_REQUEST) - .body(body_empty()) - .expect("valid body")); - }; - - if version.as_bytes() != SUPPORTED_WEBSOCKET_VERSION.as_bytes() { - warn!("invalid header Sec-WebSocket-Version: {:?}", version); - return Ok(builder - .status(StatusCode::BAD_REQUEST) - // It's convention to send back the version(s) we *do* support - .header("Sec-WebSocket-Version", SUPPORTED_WEBSOCKET_VERSION) - .body(body_empty()) - .expect("valid body")); - } + let Some(key) = req.headers().get("Sec-WebSocket-Key").cloned() else { + warn!("missing header Sec-WebSocket-Key for websocket relay protocol"); + return Ok(builder + .status(StatusCode::BAD_REQUEST) + .body(body_empty()) + .expect("valid body")); + }; - Some((key, version)) - } else { - None + let Some(version) = req.headers().get("Sec-WebSocket-Version").cloned() else { + warn!("missing header Sec-WebSocket-Version for websocket relay protocol"); + return Ok(builder + .status(StatusCode::BAD_REQUEST) + .body(body_empty()) + .expect("valid body")); }; - debug!(?protocol, "upgrading connection"); + if version.as_bytes() != SUPPORTED_WEBSOCKET_VERSION.as_bytes() { + warn!("invalid header Sec-WebSocket-Version: {:?}", version); + return Ok(builder + .status(StatusCode::BAD_REQUEST) + // It's convention to send back the version(s) we *do* support + .header("Sec-WebSocket-Version", SUPPORTED_WEBSOCKET_VERSION) + .body(body_empty()) + .expect("valid body")); + } + + debug!("upgrading connection"); // Setup a future that will eventually receive the upgraded // connection and talk a new protocol, and spawn the future @@ -513,15 +509,10 @@ impl RelayService { async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { - if let Err(err) = - this.0.relay_connection_handler(protocol, upgraded).await - { - warn!( - ?protocol, - "error accepting upgraded connection: {err:#}", - ); + if let Err(err) = this.0.relay_connection_handler(upgraded).await { + warn!("error accepting upgraded connection: {err:#}",); } else { - debug!(?protocol, "upgraded connection completed"); + debug!("upgraded connection completed"); }; } Err(err) => warn!("upgrade error: {err:#}"), @@ -531,20 +522,17 @@ impl RelayService { ); // Now return a 101 Response saying we agree to the upgrade to the - // HTTP_UPGRADE_PROTOCOL - builder = builder - .status(StatusCode::SWITCHING_PROTOCOLS) - .header(UPGRADE, HeaderValue::from_static(protocol.upgrade_header())); - - if let Some((key, _version)) = websocket_headers { - Ok(builder - .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(&key)) - .header(CONNECTION, "upgrade") - .body(body_full("switching to websocket protocol")) - .expect("valid body")) - } else { - Ok(builder.body(body_empty()).expect("valid body")) - } + // websocket upgrade protocol + builder = builder.status(StatusCode::SWITCHING_PROTOCOLS).header( + UPGRADE, + HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL), + ); + + Ok(builder + .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(&key)) + .header(CONNECTION, "upgrade") + .body(body_full("switching to websocket protocol")) + .expect("valid body")) } } .boxed() @@ -608,16 +596,15 @@ impl Inner { /// having sent off the connection this handler returns. async fn relay_connection_handler( &self, - protocol: Protocol, upgraded: Upgraded, ) -> Result<(), ConnectionHandlerError> { - debug!(?protocol, "relay_connection upgraded"); + debug!("relay_connection upgraded"); let (io, read_buf) = downcast_upgrade(upgraded)?; if !read_buf.is_empty() { return Err(BufferNotEmptySnafu { buf: read_buf }.build()); } - self.accept(protocol, io).await?; + self.accept(io).await?; Ok(()) } @@ -631,38 +618,32 @@ impl Inner { /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`AsyncWrite`]: tokio::io::AsyncWrite - async fn accept(&self, protocol: Protocol, io: MaybeTlsStream) -> Result<(), AcceptError> { + async fn accept(&self, io: MaybeTlsStream) -> Result<(), AcceptError> { use snafu::ResultExt; + trace!("accept: start"); + let io = RateLimited::from_cfg(self.rate_limit, io, self.metrics.clone()); - trace!(?protocol, "accept: start"); - let (client_key, mut io) = match protocol { - Protocol::Websocket => { - self.metrics.websocket_accepts.inc(); - // Since we already did the HTTP upgrade in the previous step, - // we use tokio-websockets to handle this connection - // Create a server builder with default config - let builder = tokio_websockets::ServerBuilder::new().limits( - tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE)), - ); - // Serve will create a WebSocketStream on an already upgraded connection - let websocket = builder.serve(io); + self.metrics.websocket_accepts.inc(); + // Since we already did the HTTP upgrade in the previous step, + // we use tokio-websockets to handle this connection + // Create a server builder with default config + let builder = tokio_websockets::ServerBuilder::new() + .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))); + // Serve will create a WebSocketStream on an already upgraded connection + let websocket = builder.serve(io); - let mut io = HandshakeIo { io: websocket }; + let mut io = WsBytesFramed { io: websocket }; - let client_info = handshake::serverside(&mut io, rand::rngs::OsRng) - .await - .context(HandshakeSnafu)?; - - ( - client_info.public_key, - RelayedStream { - inner: io.io, - key_cache: self.key_cache.clone(), - }, - ) - } + let client_info = handshake::serverside(&mut io, rand::rngs::OsRng) + .await + .context(HandshakeSnafu)?; + + let client_key = client_info.public_key; + let mut io = RelayedStream { + inner: io.io, + key_cache: self.key_cache.clone(), }; trace!("accept: checking access: {:?}", self.access); @@ -1095,7 +1076,7 @@ mod tests { async fn make_test_client(client: tokio::io::DuplexStream, key: &SecretKey) -> Result { let client = crate::client::streams::MaybeTlsStream::Test(client); let client = tokio_websockets::ClientBuilder::new().take_over(client); - let client = Conn::new_ws(client, KeyCache::test(), key).await?; + let client = Conn::new(client, KeyCache::test(), key).await?; Ok(client) } @@ -1117,10 +1098,8 @@ mod tests { let public_key_a = key_a.public(); let (client_a, rw_a) = tokio::io::duplex(10); let s = service.clone(); - let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_a)) - .await - }); + let handler_task = + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a)).await }); let mut client_a = make_test_client(client_a, &key_a).await?; handler_task.await.context("join")??; @@ -1129,10 +1108,8 @@ mod tests { let public_key_b = key_b.public(); let (client_b, rw_b) = tokio::io::duplex(10); let s = service.clone(); - let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_b)) - .await - }); + let handler_task = + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b)).await }); let mut client_b = make_test_client(client_b, &key_b).await?; handler_task.await.context("join")??; @@ -1211,10 +1188,8 @@ mod tests { let public_key_a = key_a.public(); let (client_a, rw_a) = tokio::io::duplex(10); let s = service.clone(); - let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_a)) - .await - }); + let handler_task = + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a)).await }); let mut client_a = make_test_client(client_a, &key_a).await?; handler_task.await.context("join")??; @@ -1223,10 +1198,8 @@ mod tests { let public_key_b = key_b.public(); let (client_b, rw_b) = tokio::io::duplex(10); let s = service.clone(); - let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(rw_b)) - .await - }); + let handler_task = + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b)).await }); let mut client_b = make_test_client(client_b, &key_b).await?; handler_task.await.context("join")??; @@ -1275,10 +1248,8 @@ mod tests { info!("Create client B and connect it to the server"); let (new_client_b, new_rw_b) = tokio::io::duplex(10); let s = service.clone(); - let handler_task = tokio::spawn(async move { - s.0.accept(Protocol::Websocket, MaybeTlsStream::Test(new_rw_b)) - .await - }); + let handler_task = + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(new_rw_b)).await }); let mut new_client_b = make_test_client(new_client_b, &key_b).await?; handler_task.await.context("join")??; diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index 46b5215f0be..2f65d8d1ab9 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -45,7 +45,6 @@ use crate::{ net_report::Report, tls, watcher::{self, Watcher}, - RelayProtocol, }; mod rtt_actor; @@ -127,7 +126,6 @@ pub enum PathSelection { pub struct Builder { secret_key: Option, relay_mode: RelayMode, - relay_protocol: iroh_relay::http::Protocol, alpn_protocols: Vec>, transport_config: quinn::TransportConfig, keylog: bool, @@ -154,7 +152,6 @@ impl Default for Builder { Self { secret_key: Default::default(), relay_mode: default_relay_mode(), - relay_protocol: iroh_relay::http::Protocol::default(), alpn_protocols: Default::default(), transport_config, keylog: Default::default(), @@ -212,7 +209,6 @@ impl Builder { addr_v6: self.addr_v6, secret_key, relay_map, - relay_protocol: self.relay_protocol, node_map: self.node_map, discovery, discovery_user_data: self.discovery_user_data, @@ -299,19 +295,6 @@ impl Builder { self } - /// Sets the protocol to use for relay connections. - /// - /// Options are either [`RelayProtocol::Websocket`] or [`RelayProtocol::Relay`]. - /// - /// `Websocket` is considered unstable between iroh versions at the moment. - /// The protocol can change in compatibility-breaking ways before iroh 1.0. - /// - /// Default is set to `Relay` at the moment, until we've stabilized the websocket protocol. - pub fn relay_conn_protocol(mut self, protocol: RelayProtocol) -> Self { - self.relay_protocol = protocol; - self - } - /// Removes all discovery services from the builder. pub fn clear_discovery(mut self) -> Self { self.discovery.clear(); @@ -2307,7 +2290,6 @@ mod tests { }; use iroh_base::{NodeAddr, NodeId, SecretKey}; - use iroh_relay::http::Protocol; use n0_future::{task::AbortOnDropHandle, StreamExt}; use n0_snafu::{Error, Result, ResultExt}; use quinn::ConnectionError; @@ -2601,16 +2583,14 @@ mod tests { #[tokio::test] #[traced_test] - async fn endpoint_send_relay_websockets() -> Result { + async fn endpoint_send_relay() -> Result { let (relay_map, _relay_url, _guard) = run_relay_server().await?; let client = Endpoint::builder() - .relay_conn_protocol(Protocol::Websocket) .insecure_skip_relay_cert_verify(true) .relay_mode(RelayMode::Custom(relay_map.clone())) .bind() .await?; let server = Endpoint::builder() - .relay_conn_protocol(Protocol::Websocket) .insecure_skip_relay_cert_verify(true) .relay_mode(RelayMode::Custom(relay_map)) .alpns(vec![TEST_ALPN.to_vec()]) diff --git a/iroh/src/lib.rs b/iroh/src/lib.rs index 2f44c3420e6..f190cec2b81 100644 --- a/iroh/src/lib.rs +++ b/iroh/src/lib.rs @@ -275,7 +275,7 @@ pub use endpoint::{Endpoint, RelayMode}; pub use iroh_base::{ KeyParsingError, NodeAddr, NodeId, PublicKey, RelayUrl, RelayUrlParseError, SecretKey, }; -pub use iroh_relay::{http::Protocol as RelayProtocol, node_info, RelayMap, RelayNode}; +pub use iroh_relay::{node_info, RelayMap, RelayNode}; #[cfg(any(test, feature = "test-utils"))] pub mod test_utils; diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index 6bb68259829..4a490d1846e 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -120,9 +120,6 @@ pub(crate) struct Options { /// The [`RelayMap`] to use, leave empty to not use a relay server. pub(crate) relay_map: RelayMap, - /// Whether to use websockets or the nonstandard legacy protocol to connect to relays - pub(crate) relay_protocol: iroh_relay::http::Protocol, - /// An optional [`NodeMap`], to restore information about nodes. pub(crate) node_map: Option>, @@ -1767,7 +1764,6 @@ impl Handle { addr_v6, secret_key, relay_map, - relay_protocol, node_map, discovery, discovery_user_data, @@ -1874,7 +1870,7 @@ impl Handle { let mut actor_tasks = JoinSet::default(); - let relay_actor = RelayActor::new(msock.clone(), relay_datagram_recv_queue, relay_protocol); + let relay_actor = RelayActor::new(msock.clone(), relay_datagram_recv_queue); let relay_actor_cancel_token = relay_actor.cancel_token(); actor_tasks.spawn( async move { @@ -3571,7 +3567,6 @@ mod tests { addr_v6: None, secret_key, relay_map: RelayMap::empty(), - relay_protocol: iroh_relay::http::Protocol::default(), node_map: None, discovery: None, proxy_url: None, @@ -4174,7 +4169,6 @@ mod tests { addr_v6: None, secret_key: secret_key.clone(), relay_map: RelayMap::empty(), - relay_protocol: iroh_relay::http::Protocol::default(), node_map: None, discovery: None, discovery_user_data: None, diff --git a/iroh/src/magicsock/relay_actor.rs b/iroh/src/magicsock/relay_actor.rs index 819b896801e..5d43a39dc37 100644 --- a/iroh/src/magicsock/relay_actor.rs +++ b/iroh/src/magicsock/relay_actor.rs @@ -214,7 +214,6 @@ struct RelayConnectionOptions { prefer_ipv6: Arc, #[cfg(any(test, feature = "test-utils"))] insecure_skip_cert_verify: bool, - protocol: iroh_relay::http::Protocol, } /// Possible reasons for a failed relay connection. @@ -314,7 +313,6 @@ impl ActiveRelayActor { prefer_ipv6, #[cfg(any(test, feature = "test-utils"))] insecure_skip_cert_verify, - protocol, } = opts; let mut builder = relay::client::ClientBuilder::new( @@ -323,7 +321,6 @@ impl ActiveRelayActor { #[cfg(not(wasm_browser))] dns_resolver, ) - .protocol(protocol) .address_family_selector(move || prefer_ipv6.load(Ordering::Relaxed)); if let Some(proxy_url) = proxy_url { builder = builder.proxy_url(proxy_url); @@ -872,14 +869,12 @@ pub(super) struct RelayActor { /// The tasks for the [`ActiveRelayActor`]s in `active_relays` above. active_relay_tasks: JoinSet<()>, cancel_token: CancellationToken, - protocol: iroh_relay::http::Protocol, } impl RelayActor { pub(super) fn new( msock: Arc, relay_datagram_recv_queue: Arc, - protocol: iroh_relay::http::Protocol, ) -> Self { let cancel_token = CancellationToken::new(); Self { @@ -888,7 +883,6 @@ impl RelayActor { active_relays: Default::default(), active_relay_tasks: JoinSet::new(), cancel_token, - protocol, } } @@ -1092,7 +1086,6 @@ impl RelayActor { prefer_ipv6: self.msock.ipv6_reported.clone(), #[cfg(any(test, feature = "test-utils"))] insecure_skip_cert_verify: self.msock.insecure_skip_relay_cert_verify, - protocol: self.protocol, }; // TODO: Replace 64 with PER_CLIENT_SEND_QUEUE_DEPTH once that's unused @@ -1411,7 +1404,6 @@ mod tests { proxy_url: None, prefer_ipv6: Arc::new(AtomicBool::new(true)), insecure_skip_cert_verify: true, - protocol: iroh_relay::http::Protocol::default(), }, stop_token, metrics: Default::default(), From bd8f56b918ad924af77ab53ef9a87bb54632e151 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 10 Jun 2025 21:38:29 +0200 Subject: [PATCH 23/80] Fix snapshot tests --- iroh-relay/src/protos/relay.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 467a278810f..3bd38df1fd4 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -407,7 +407,7 @@ mod tests { }) .collect(); let expected_bytes = HEXLOWER.decode(&stripped).unwrap(); - assert_eq!(bytes, expected_bytes); + assert_eq!(HEXLOWER.encode(&bytes), HEXLOWER.encode(&expected_bytes)); } } @@ -421,22 +421,22 @@ mod tests { problem: "Hello? Yes this is dog.".into(), } .write_to(Vec::new()), - "0e 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 + "11 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 20 69 73 20 64 6f 67 2e", ), ( ServerToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), - "08 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + "0e 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61", ), ( ServerToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), - "0c 2a 2a 2a 2a 2a 2a 2a 2a", + "0f 2a 2a 2a 2a 2a 2a 2a 2a", ), ( ServerToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), - "0d 2a 2a 2a 2a 2a 2a 2a 2a", + "10 2a 2a 2a 2a 2a 2a 2a 2a", ), ( ServerToClientMsg::ReceivedPacket { @@ -444,7 +444,7 @@ mod tests { data: "Hello World!".into(), } .write_to(Vec::new()), - "05 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + "0b 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", ), @@ -454,7 +454,7 @@ mod tests { try_for: Duration::from_millis(20), } .write_to(Vec::new()), - "0f 00 00 00 0a 00 00 00 14", + "12 00 00 00 0a 00 00 00 14", ), ]); @@ -468,11 +468,11 @@ mod tests { check_expected_bytes(vec![ ( ClientToServerMsg::Ping([42u8; 8]).write_to(Vec::new()), - "0c 2a 2a 2a 2a 2a 2a 2a 2a", + "0f 2a 2a 2a 2a 2a 2a 2a 2a", ), ( ClientToServerMsg::Pong([42u8; 8]).write_to(Vec::new()), - "0d 2a 2a 2a 2a 2a 2a 2a 2a", + "10 2a 2a 2a 2a 2a 2a 2a 2a", ), ( ClientToServerMsg::SendPacket { @@ -480,7 +480,7 @@ mod tests { packet: "Goodbye!".into(), } .write_to(Vec::new()), - "04 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + "0a 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 47 6f 6f 64 62 79 65 21", ), From 35df94d1e0b5dc563ce79bdccc5bc4bf10d99021 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 10 Jun 2025 22:21:41 +0200 Subject: [PATCH 24/80] Move handshake & send/recv common code to relay.rs (rename relay.rs -> send_recv.rs) --- iroh-relay/src/client.rs | 2 +- iroh-relay/src/client/conn.rs | 11 +- iroh-relay/src/lib.rs | 2 +- iroh-relay/src/protos.rs | 1 + iroh-relay/src/protos/handshake.rs | 153 ++----- iroh-relay/src/protos/relay.rs | 613 +++------------------------ iroh-relay/src/protos/send_recv.rs | 559 ++++++++++++++++++++++++ iroh-relay/src/server.rs | 2 +- iroh-relay/src/server/client.rs | 4 +- iroh-relay/src/server/clients.rs | 2 +- iroh-relay/src/server/http_server.rs | 6 +- iroh-relay/src/server/streams.rs | 4 +- iroh/src/magicsock/relay_actor.rs | 2 +- 13 files changed, 672 insertions(+), 689 deletions(-) create mode 100644 iroh-relay/src/protos/send_recv.rs diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 964ba2a39da..c14cff5a0ce 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -29,7 +29,7 @@ use crate::{ http::RELAY_PATH, protos::{ handshake, - relay::{ClientToServerMsg, ServerToClientMsg}, + send_recv::{ClientToServerMsg, ServerToClientMsg}, }, KeyCache, }; diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 338ec1b19eb..9f51b9d936c 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -24,7 +24,7 @@ use crate::{ use crate::{ protos::{ handshake, - relay::{ + send_recv::{ ClientToServerMsg, RecvError as RecvRelayError, SendError as SendRelayError, ServerToClientMsg, }, @@ -42,9 +42,6 @@ use crate::{ #[derive(Debug, Snafu)] #[non_exhaustive] pub enum SendError { - #[cfg(not(wasm_browser))] - #[snafu(transparent)] - RelayIo { source: io::Error }, #[snafu(transparent)] WebsocketIo { #[cfg(not(wasm_browser))] @@ -79,8 +76,6 @@ pub enum RecvError { #[cfg(wasm_browser)] source: ws_stream_wasm::WsErr, }, - #[snafu(display("Unexpected frame received: {frame_type}"))] - UnexpectedFrame { frame_type: handshake::FrameType }, } /// A connection to a relay server. @@ -108,7 +103,7 @@ pub(crate) struct Conn { impl Conn { #[cfg(test)] pub(crate) fn test(io: tokio::io::DuplexStream) -> Self { - use crate::protos::relay::MAX_FRAME_SIZE; + use crate::protos::send_recv::MAX_FRAME_SIZE; Self { conn: tokio_websockets::ClientBuilder::new() .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))) @@ -205,7 +200,7 @@ impl Sink for Conn { } fn start_send(mut self: Pin<&mut Self>, frame: ClientToServerMsg) -> Result<(), Self::Error> { - // TODO(matheus23): Check this in send message construction instead + // TODO(matheus23): Check this in send message construction instead (and also check this in RecvPacket construction) if let ClientToServerMsg::SendPacket { packet, .. } = &frame { let size = packet.len(); snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); diff --git a/iroh-relay/src/lib.rs b/iroh-relay/src/lib.rs index 232b794b0e4..06070e648e7 100644 --- a/iroh-relay/src/lib.rs +++ b/iroh-relay/src/lib.rs @@ -48,7 +48,7 @@ pub(crate) use key_cache::KeyCache; pub mod dns; pub mod node_info; -pub use protos::relay::MAX_PACKET_SIZE; +pub use protos::send_recv::MAX_PACKET_SIZE; pub use self::{ ping_tracker::PingTracker, diff --git a/iroh-relay/src/protos.rs b/iroh-relay/src/protos.rs index 92fa7611e20..67110ca9a1a 100644 --- a/iroh-relay/src/protos.rs +++ b/iroh-relay/src/protos.rs @@ -3,5 +3,6 @@ pub mod disco; pub mod handshake; pub mod relay; +pub mod send_recv; pub mod streams; pub mod stun; diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index b3413b40e0b..16873daa09f 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -12,75 +12,9 @@ use quinn_proto::{coding::Codec, VarInt}; use rand::{CryptoRng, RngCore}; use snafu::{Backtrace, ResultExt, Snafu}; -use super::{relay::SendError, streams::BytesStreamSink}; +use super::{relay::FrameType, send_recv::SendError, streams::BytesStreamSink}; use crate::ExportKeyingMaterial; -/// TODO(matheus23) docs -pub(crate) const PROTOCOL_VERSION: &[u8] = b"1"; - -/// Possible frame types during handshaking -#[repr(u32)] -#[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] -pub enum FrameType { - /// The client frame type for the client challenge request - ClientRequestChallenge = 1, - /// The server frame type for the challenge response - ServerChallenge = 2, - /// The client frame type for the authentication frame - ClientAuth = 3, - /// The server frame type for authentication confirmation - ServerConfirmsAuth = 4, - /// The server frame type for authentication denial - ServerDeniesAuth = 5, - /// 32B dest pub key + packet bytes - SendPacket = 10, - /// v0/1 packet bytes, v2: 32B src pub key + packet bytes - RecvPacket = 11, - /// no payload, no-op (to be replaced with ping/pong) - KeepAlive = 12, - /// Sent from server to client to signal that a previous sender is no longer connected. - /// - /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` - /// to B so B can forget that a reverse path exists on that connection to get back to A - /// - /// 32B pub key of peer that's gone - NodeGone = 14, - /// Frames 9-11 concern meshing, which we have eliminated from our version of the protocol. - /// Messages with these frames will be ignored. - /// 8 byte ping payload, to be echoed back in FrameType::Pong - Ping = 15, - /// 8 byte payload, the contents of ping being replied to - Pong = 16, - /// Sent from server to client to tell the client if their connection is - /// unhealthy somehow. - Health = 17, - - /// Sent from server to client for the server to declare that it's restarting. - /// Payload is two big endian u32 durations in milliseconds: when to reconnect, - /// and how long to try total. - /// - /// Handled on the `[relay::Client]`, but currently never sent on the `[relay::Server]` - Restarting = 18, - /// The frame type was unknown. - /// - /// This frame is the result of parsing any future frame types that this implementation - /// does not yet understand. - #[num_enum(default)] - Unknown, -} - -impl std::fmt::Display for FrameType { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -impl From for VarInt { - fn from(value: FrameType) -> Self { - (value as u32).into() - } -} - /// Message that tells the server the client needs a challenge to authenticate. #[derive(derive_more::Debug, serde::Serialize)] #[cfg_attr(feature = "server", derive(serde::Deserialize))] @@ -109,9 +43,6 @@ pub(crate) struct ClientAuth { pub(crate) signature: [u8; 64], /// Part of the extracted key material, if that's what was signed. pub(crate) key_material_suffix: Option<[u8; 16]>, - /// Supported versions/protocol features for version negotiation - /// with other connected relay clients - pub(crate) versions: Vec>, } /// Confirmation of successful connection. @@ -124,9 +55,14 @@ pub(crate) struct ServerConfirmsAuth; #[cfg_attr(feature = "server", derive(serde::Serialize))] pub(crate) struct ServerDeniesAuth; -/// TODO(matheus23): Docs -pub trait Frame { - /// ... +/// Trait for getting the frame type tag for a frame. +/// +/// Used only in the handshake, as the frame we expect next +/// is fairly stateful. +/// Not used in the send/recv protocol, as any frame is +/// allowed to happen at any time there. +trait Frame { + /// The frame type this frame is identified by and prefixed with const TAG: FrameType; } @@ -173,10 +109,10 @@ pub enum Error { UnexpectedEnd {}, #[snafu(display("The relay denied our authentication"))] ServerDeniedAuth {}, - #[snafu(display("Unexpected tag, got {tag}, but expected one of {expected_tags:?}"))] - UnexpectedTag { - tag: VarInt, - expected_tags: Vec, + #[snafu(display("Unexpected tag, got {frame_type}, but expected one of {expected_types:?}"))] + UnexpectedFrameType { + frame_type: FrameType, + expected_types: Vec, }, #[snafu(display("Handshake failed while deserializing {frame_type} frame"))] DeserializationError { @@ -185,24 +121,6 @@ pub enum Error { }, } -impl FrameType { - pub(crate) fn write_to(&self, mut dst: O) -> O { - VarInt::from(*self).encode(&mut dst); - dst - } - - // TODO(matheus23): Consolidate errors between handshake.rs and relay.rs - // Perhaps a shared error type `FramingError`? - pub(crate) fn from_bytes(bytes: Bytes) -> Option<(Self, Bytes)> { - let mut cursor = std::io::Cursor::new(&bytes); - let tag = VarInt::decode(&mut cursor).ok()?; - let tag_u32 = u32::try_from(u64::from(tag)).ok()?; - let frame_type = FrameType::from(tag_u32); - let content = bytes.slice(cursor.position() as usize..); - Some((frame_type, content)) - } -} - impl ServerChallenge { /// TODO(matheus23): docs #[cfg(feature = "server")] @@ -227,7 +145,6 @@ impl ClientAuth { public_key: secret_key.public(), key_material_suffix: None, signature: secret_key.sign(&challenge.message_to_sign()).to_bytes(), - versions: vec![PROTOCOL_VERSION.to_vec()], } } @@ -260,7 +177,6 @@ impl ClientAuth { public_key, signature: secret_key.sign(&message).to_bytes(), key_material_suffix: Some(key_material[16..].try_into().expect("split right")), - versions: vec![PROTOCOL_VERSION.to_vec()], }) } @@ -295,7 +211,7 @@ pub(crate) async fn clientside( write_frame(io, ClientRequestChallenge).await?; } - let (tag, frame) = read_handshake_frame( + let (tag, frame) = read_frame( io, &[ ServerChallenge::TAG, @@ -312,7 +228,7 @@ pub(crate) async fn clientside( let client_info = ClientAuth::new_from_challenge(secret_key, &challenge); write_frame(io, client_info).await?; - read_handshake_frame( + read_frame( io, &[ServerConfirmsAuth::TAG, ServerDeniesAuth::TAG], time::Duration::from_secs(30), @@ -341,7 +257,7 @@ pub(crate) async fn serverside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), rng: impl RngCore + CryptoRng, ) -> Result { - let (tag, frame) = read_handshake_frame( + let (tag, frame) = read_frame( io, &[ClientRequestChallenge::TAG, ClientAuth::TAG], time::Duration::from_secs(10), @@ -362,8 +278,7 @@ pub(crate) async fn serverside( let challenge = ServerChallenge::new(rng); write_frame(io, &challenge).await?; - let (_, frame) = - read_handshake_frame(io, &[ClientAuth::TAG], time::Duration::from_secs(10)).await?; + let (_, frame) = read_frame(io, &[ClientAuth::TAG], time::Duration::from_secs(10)).await?; let client_auth: ClientAuth = deserialize_frame(frame)?; if client_auth.verify_from_challenge(&challenge) { @@ -392,9 +307,9 @@ async fn write_frame( async fn read_frame( io: &mut impl BytesStreamSink, - expected_tags: &[VarInt], + expected_types: &[FrameType], timeout: time::Duration, -) -> Result<(VarInt, Bytes), Error> { +) -> Result<(FrameType, Bytes), Error> { let recv = time::timeout(timeout, io.try_next()) .await .context(TimeoutSnafu)?? @@ -402,34 +317,23 @@ async fn read_frame( // TODO(matheus23) restructure: use FrameType::from_bytes, perhaps always use `FrameType` instead let mut cursor = std::io::Cursor::new(recv); - let tag = VarInt::decode(&mut cursor) + let var_int = VarInt::decode(&mut cursor) .map_err(|quinn_proto::coding::UnexpectedEnd| UnexpectedEndSnafu.build())?; + let frame_type = u32::try_from(var_int.into_inner()) + .ok() + .map_or(FrameType::Unknown, FrameType::from); snafu::ensure!( - expected_tags.contains(&tag), - UnexpectedTagSnafu { - tag, - expected_tags: expected_tags.into_iter().cloned().collect::>() + expected_types.contains(&frame_type), + UnexpectedFrameTypeSnafu { + frame_type, + expected_types: expected_types.into_iter().cloned().collect::>() } ); let start = cursor.position() as usize; let payload = cursor.into_inner().slice(start..); - Ok((tag, payload)) -} - -async fn read_handshake_frame( - io: &mut impl BytesStreamSink, - expected_types: &[FrameType], - timeout: time::Duration, -) -> Result<(FrameType, Bytes), Error> { - let expected_tags = expected_types - .into_iter() - .map(|frame_type| VarInt::from(*frame_type)) - .collect::>(); - let (tag, frame) = read_frame(io, &expected_tags, timeout).await?; - let frame_type = u32::try_from(tag.into_inner()).map_or(FrameType::Unknown, FrameType::from); - Ok((frame_type, frame)) + Ok((frame_type, payload)) } fn deserialize_frame(frame: Bytes) -> Result { @@ -609,7 +513,6 @@ mod tests { assert_eq!(client_auth.public_key, decoded.public_key); assert_eq!(client_auth.key_material_suffix, decoded.key_material_suffix); assert_eq!(client_auth.signature, decoded.signature); - assert_eq!(client_auth.versions, decoded.versions); Ok(()) } diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 3bd38df1fd4..93caacfe2e5 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -1,561 +1,86 @@ -//! This module implements the relaying protocol used by the `server` and `client`. -//! -//! Protocol flow: -//! -//! Login: -//! * client connects -//! * -> client sends `FrameType::ClientInfo` -//! -//! Steady state: -//! * server occasionally sends `FrameType::KeepAlive` (or `FrameType::Ping`) -//! * client responds to any `FrameType::Ping` with a `FrameType::Pong` -//! * clients sends `FrameType::SendPacket` -//! * server then sends `FrameType::RecvPacket` to recipient +//! TODO(matheus23) docs use bytes::{BufMut, Bytes}; -use iroh_base::{NodeId, SignatureError}; -#[cfg(feature = "server")] -use n0_future::time::Duration; -use n0_future::{time, Sink, SinkExt}; -use nested_enum_utils::common_fields; -use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; - -use crate::{client::conn::SendError as ConnSendError, KeyCache}; - -use super::handshake::FrameType; - -/// The maximum size of a packet sent over relay. -/// (This only includes the data bytes visible to magicsock, not -/// including its on-wire framing overhead) -pub const MAX_PACKET_SIZE: usize = 64 * 1024; - -/// The maximum frame size. -/// -/// This is also the minimum burst size that a rate-limiter has to accept. -#[cfg(not(wasm_browser))] -pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024; - -/// Interval in which we ping the relay server to ensure the connection is alive. -/// -/// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some -/// chance of recovering. -#[cfg(feature = "server")] -pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(15); - -/// The number of packets buffered for sending per client -#[cfg(feature = "server")] -pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512; - -/// Protocol send errors. -#[common_fields({ - backtrace: Option, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, -})] -#[allow(missing_docs)] -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum SendError { - #[snafu(transparent)] - Io { source: std::io::Error }, - #[snafu(transparent)] - Timeout { source: time::Elapsed }, - #[snafu(transparent)] - ConnSend { source: ConnSendError }, - #[snafu(transparent)] - SerDe { source: postcard::Error }, -} - -/// Protocol send errors. -#[common_fields({ - backtrace: Option, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, -})] -#[allow(missing_docs)] -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum RecvError { - #[snafu(transparent)] - Io { source: std::io::Error }, - #[snafu(display("unexpected frame: got {got}, expected {expected}"))] - UnexpectedFrame { got: FrameType, expected: FrameType }, - #[snafu(display("Frame is too large, has {frame_len} bytes"))] - FrameTooLarge { frame_len: usize }, - #[snafu(transparent)] - Timeout { source: time::Elapsed }, - #[snafu(transparent)] - SerDe { source: postcard::Error }, - #[snafu(transparent)] - InvalidSignature { source: SignatureError }, - #[snafu(display("Invalid frame encoding"))] - InvalidFrame {}, - #[snafu(display("Invalid frame type: {frame_type}"))] - InvalidFrameType { frame_type: FrameType }, - #[snafu(display("invalid protocol message encoding"))] - InvalidProtocolMessageEncoding { source: std::str::Utf8Error }, - #[snafu(display("Too few bytes"))] - TooSmall {}, -} - -/// Writes complete frame, errors if it is unable to write within the given `timeout`. -/// Ignores the timeout if `None` -/// -/// Does not flush. -#[cfg(feature = "server")] -pub(crate) async fn write_frame + Unpin>( - mut writer: S, - frame: ServerToClientMsg, - timeout: Option, -) -> Result<(), SendError> { - if let Some(duration) = timeout { - tokio::time::timeout(duration, writer.send(frame)).await??; - } else { - writer.send(frame).await?; - } - - Ok(()) -} - -/// TODO(matheus23): Docs -/// The messages received from a framed relay stream. -/// -/// This is a type-validated version of the `Frame`s on the `RelayCodec`. -#[derive(derive_more::Debug, Clone, PartialEq, Eq)] -pub enum ServerToClientMsg { - /// Represents an incoming packet. - ReceivedPacket { - /// The [`NodeId`] of the packet sender. - remote_node_id: NodeId, - /// The received packet bytes. - #[debug(skip)] - data: Bytes, - }, - /// Indicates that the client identified by the underlying public key had previously sent you a - /// packet but has now disconnected from the server. - NodeGone(NodeId), - /// A one-way message from server to client, declaring the connection health state. - Health { - /// If set, is a description of why the connection is unhealthy. - /// - /// If `None` means the connection is healthy again. - /// - /// The default condition is healthy, so the server doesn't broadcast a [`ReceivedMessage::Health`] - /// until a problem exists. - problem: String, - }, - /// A one-way message from server to client, advertising that the server is restarting. - Restarting { - /// An advisory duration that the client should wait before attempting to reconnect. - /// It might be zero. It exists for the server to smear out the reconnects. - reconnect_in: Duration, - /// An advisory duration for how long the client should attempt to reconnect - /// before giving up and proceeding with its normal connection failure logic. The interval - /// between retries is undefined for now. A server should not send a TryFor duration more - /// than a few seconds. - try_for: Duration, - }, - /// TODO(matheus23) fix docs - /// Request from a client or server to reply to the - /// other side with a [`ReceivedMessage::Pong`] with the given payload. - Ping([u8; 8]), - /// TODO(matheus23) fix docs - /// Reply to a [`ReceivedMessage::Ping`] from a client or server - /// with the payload sent previously in the ping. - Pong([u8; 8]), -} - -/// TODO(matheus23): Docs -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ClientToServerMsg { - /// TODO - Ping([u8; 8]), - /// TODO - Pong([u8; 8]), - /// TODO - SendPacket { - /// TODO - dst_key: NodeId, - /// TODO - packet: Bytes, - }, -} - -impl ServerToClientMsg { - /// TODO(matheus23): docs - pub fn typ(&self) -> FrameType { - match self { - Self::ReceivedPacket { .. } => FrameType::RecvPacket, - Self::NodeGone { .. } => FrameType::NodeGone, - Self::Ping { .. } => FrameType::Ping, - Self::Pong { .. } => FrameType::Pong, - Self::Health { .. } => FrameType::Health, - Self::Restarting { .. } => FrameType::Restarting, - } - } - - /// Encodes this frame for sending over websockets. +use quinn_proto::{coding::Codec, VarInt}; + +/// Possible frame types during handshaking +#[repr(u32)] +#[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] +// needs to be pub due to being exposed in error types +pub enum FrameType { + /// The client frame type for the client challenge request + ClientRequestChallenge = 1, + /// The server frame type for the challenge response + ServerChallenge = 2, + /// The client frame type for the authentication frame + ClientAuth = 3, + /// The server frame type for authentication confirmation + ServerConfirmsAuth = 4, + /// The server frame type for authentication denial + ServerDeniesAuth = 5, + /// 32B dest pub key + packet bytes + SendPacket = 10, + /// v0/1 packet bytes, v2: 32B src pub key + packet bytes + RecvPacket = 11, + /// no payload, no-op (to be replaced with ping/pong) + KeepAlive = 12, + /// Sent from server to client to signal that a previous sender is no longer connected. /// - /// Specifically meant for being put into a binary websocket message frame. - pub(crate) fn write_to(&self, mut dst: O) -> O { - dst = self.typ().write_to(dst); - match self { - Self::ReceivedPacket { - remote_node_id: src_key, - data: content, - } => { - dst.put(src_key.as_ref()); - dst.put(content.as_ref()); - } - Self::NodeGone(node_id) => { - dst.put(node_id.as_ref()); - } - Self::Ping(data) => { - dst.put(&data[..]); - } - Self::Pong(data) => { - dst.put(&data[..]); - } - Self::Health { problem } => { - dst.put(problem.as_ref()); - } - Self::Restarting { - reconnect_in, - try_for, - } => { - dst.put_u32(reconnect_in.as_millis() as u32); - dst.put_u32(try_for.as_millis() as u32); - } - } - dst - } - - /// Tries to decode a frame received over websockets. + /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` + /// to B so B can forget that a reverse path exists on that connection to get back to A /// - /// Specifically, bytes received from a binary websocket message frame. - #[allow(clippy::result_large_err)] - pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { - let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; - let res = match frame_type { - FrameType::RecvPacket => { - if content.len() < NodeId::LENGTH { - return Err(InvalidFrameSnafu.build()); - } - - let frame_len = content.len() - NodeId::LENGTH; - if frame_len > MAX_PACKET_SIZE { - return Err(FrameTooLargeSnafu { frame_len }.build()); - } - - let src_key = cache.key_from_slice(&content[..NodeId::LENGTH])?; - let content = content.slice(NodeId::LENGTH..); - Self::ReceivedPacket { - remote_node_id: src_key, - data: content, - } - } - FrameType::NodeGone => { - if content.len() != NodeId::LENGTH { - return Err(InvalidFrameSnafu.build()); - } - let node_id = cache.key_from_slice(&content[..32])?; - Self::NodeGone(node_id) - } - FrameType::Ping => { - if content.len() != 8 { - return Err(InvalidFrameSnafu.build()); - } - let mut data = [0u8; 8]; - data.copy_from_slice(&content[..8]); - Self::Ping(data) - } - FrameType::Pong => { - if content.len() != 8 { - return Err(InvalidFrameSnafu.build()); - } - let mut data = [0u8; 8]; - data.copy_from_slice(&content[..8]); - Self::Pong(data) - } - FrameType::Health => { - let problem = std::str::from_utf8(&content) - .context(InvalidProtocolMessageEncodingSnafu)? - .to_owned(); - // TODO(matheus23): Actually encode/decode the option - Self::Health { problem } - } - FrameType::Restarting => { - if content.len() != 4 + 4 { - return Err(InvalidFrameSnafu.build()); - } - let reconnect_in = u32::from_be_bytes( - content[..4] - .try_into() - .map_err(|_| InvalidFrameSnafu.build())?, - ); - let try_for = u32::from_be_bytes( - content[4..] - .try_into() - .map_err(|_| InvalidFrameSnafu.build())?, - ); - let reconnect_in = Duration::from_millis(reconnect_in as u64); - let try_for = Duration::from_millis(try_for as u64); - Self::Restarting { - reconnect_in, - try_for, - } - } - _ => { - return Err(InvalidFrameTypeSnafu { frame_type }.build()); - } - }; - Ok(res) - } -} - -impl ClientToServerMsg { - pub(crate) fn typ(&self) -> FrameType { - match self { - Self::SendPacket { .. } => FrameType::SendPacket, - Self::Ping { .. } => FrameType::Ping, - Self::Pong { .. } => FrameType::Pong, - } - } - - /// Encodes this frame for sending over websockets. + /// 32B pub key of peer that's gone + NodeGone = 14, + /// Frames 9-11 concern meshing, which we have eliminated from our version of the protocol. + /// Messages with these frames will be ignored. + /// 8 byte ping payload, to be echoed back in FrameType::Pong + Ping = 15, + /// 8 byte payload, the contents of ping being replied to + Pong = 16, + /// Sent from server to client to tell the client if their connection is + /// unhealthy somehow. + Health = 17, + + /// Sent from server to client for the server to declare that it's restarting. + /// Payload is two big endian u32 durations in milliseconds: when to reconnect, + /// and how long to try total. /// - /// Specifically meant for being put into a binary websocket message frame. - pub(crate) fn write_to(&self, mut dst: O) -> O { - dst = self.typ().write_to(dst); - match self { - Self::SendPacket { dst_key, packet } => { - dst.put(dst_key.as_ref()); - dst.put(packet.as_ref()); - } - Self::Ping(data) => { - dst.put(&data[..]); - } - Self::Pong(data) => { - dst.put(&data[..]); - } - } - dst - } - - /// Tries to decode a frame received over websockets. + /// Handled on the `[relay::Client]`, but currently never sent on the `[relay::Server]` + Restarting = 18, + /// The frame type was unknown. /// - /// Specifically, bytes received from a binary websocket message frame. - #[allow(clippy::result_large_err)] - pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { - let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; - let res = match frame_type { - FrameType::SendPacket => { - if content.len() < NodeId::LENGTH { - return Err(InvalidFrameSnafu.build()); - } - let frame_len = content.len() - NodeId::LENGTH; - if frame_len > MAX_PACKET_SIZE { - return Err(FrameTooLargeSnafu { frame_len }.build()); - } - - let dst_key = cache.key_from_slice(&content[..NodeId::LENGTH])?; - let packet = content.slice(NodeId::LENGTH..); - Self::SendPacket { dst_key, packet } - } - FrameType::Ping => { - if content.len() != 8 { - return Err(InvalidFrameSnafu.build()); - } - let mut data = [0u8; 8]; - data.copy_from_slice(&content[..8]); - Self::Ping(data) - } - FrameType::Pong => { - if content.len() != 8 { - return Err(InvalidFrameSnafu.build()); - } - let mut data = [0u8; 8]; - data.copy_from_slice(&content[..8]); - Self::Pong(data) - } - _ => { - return Err(InvalidFrameTypeSnafu { frame_type }.build()); - } - }; - Ok(res) - } + /// This frame is the result of parsing any future frame types that this implementation + /// does not yet understand. + #[num_enum(default)] + Unknown, } -#[cfg(test)] -mod tests { - use data_encoding::HEXLOWER; - use iroh_base::SecretKey; - use n0_snafu::Result; - - use super::*; - - fn check_expected_bytes(frames: Vec<(Vec, &str)>) { - for (bytes, expected_hex) in frames { - let stripped: Vec = expected_hex - .chars() - .filter_map(|s| { - if s.is_ascii_whitespace() { - None - } else { - Some(s as u8) - } - }) - .collect(); - let expected_bytes = HEXLOWER.decode(&stripped).unwrap(); - assert_eq!(HEXLOWER.encode(&bytes), HEXLOWER.encode(&expected_bytes)); - } - } - - #[test] - fn test_server_client_frames_snapshot() -> Result { - let client_key = SecretKey::from_bytes(&[42u8; 32]); - - check_expected_bytes(vec![ - ( - ServerToClientMsg::Health { - problem: "Hello? Yes this is dog.".into(), - } - .write_to(Vec::new()), - "11 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 - 20 69 73 20 64 6f 67 2e", - ), - ( - ServerToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), - "0e 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e - a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d - 61", - ), - ( - ServerToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), - "0f 2a 2a 2a 2a 2a 2a 2a 2a", - ), - ( - ServerToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), - "10 2a 2a 2a 2a 2a 2a 2a 2a", - ), - ( - ServerToClientMsg::ReceivedPacket { - remote_node_id: client_key.public(), - data: "Hello World!".into(), - } - .write_to(Vec::new()), - "0b 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e - a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d - 61 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", - ), - ( - ServerToClientMsg::Restarting { - reconnect_in: Duration::from_millis(10), - try_for: Duration::from_millis(20), - } - .write_to(Vec::new()), - "12 00 00 00 0a 00 00 00 14", - ), - ]); - - Ok(()) - } - - #[test] - fn test_client_server_frames_snapshot() -> Result { - let client_key = SecretKey::from_bytes(&[42u8; 32]); - - check_expected_bytes(vec![ - ( - ClientToServerMsg::Ping([42u8; 8]).write_to(Vec::new()), - "0f 2a 2a 2a 2a 2a 2a 2a 2a", - ), - ( - ClientToServerMsg::Pong([42u8; 8]).write_to(Vec::new()), - "10 2a 2a 2a 2a 2a 2a 2a 2a", - ), - ( - ClientToServerMsg::SendPacket { - dst_key: client_key.public(), - packet: "Goodbye!".into(), - } - .write_to(Vec::new()), - "0a 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e - a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d - 61 47 6f 6f 64 62 79 65 21", - ), - ]); - - Ok(()) +impl std::fmt::Display for FrameType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{self:?}") } } -#[cfg(test)] -mod proptests { - use bytes::BytesMut; - use iroh_base::SecretKey; - use proptest::prelude::*; - - use super::*; - - fn secret_key() -> impl Strategy { - prop::array::uniform32(any::()).prop_map(SecretKey::from) - } - - fn key() -> impl Strategy { - secret_key().prop_map(|key| key.public()) - } - - /// Generates random data, up to the maximum packet size minus the given number of bytes - fn data(consumed: usize) -> impl Strategy { - let len = MAX_PACKET_SIZE - consumed; - prop::collection::vec(any::(), 0..len).prop_map(Bytes::from) - } - - /// Generates a random valid frame - fn server_client_frame() -> impl Strategy { - let recv_packet = - (key(), data(32)).prop_map(|(src_key, content)| ServerToClientMsg::ReceivedPacket { - remote_node_id: src_key, - data: content, - }); - let node_gone = key().prop_map(|node_id| ServerToClientMsg::NodeGone(node_id)); - let ping = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Ping); - let pong = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Pong); - // TODO(matheus23): Actually fix these - let health = data(0).prop_map(|_problem| ServerToClientMsg::Health { - problem: "".to_string(), - }); - let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { - ServerToClientMsg::Restarting { - reconnect_in: Duration::from_millis(reconnect_in.into()), - try_for: Duration::from_millis(try_for.into()), - } - }); - prop_oneof![recv_packet, node_gone, ping, pong, health, restarting] +impl FrameType { + pub(crate) fn write_to(&self, mut dst: O) -> O { + VarInt::from(*self).encode(&mut dst); + dst } - fn client_server_frame() -> impl Strategy { - let send_packet = (key(), data(32)) - .prop_map(|(dst_key, packet)| ClientToServerMsg::SendPacket { dst_key, packet }); - let ping = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Ping); - let pong = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Pong); - prop_oneof![send_packet, ping, pong] + // TODO(matheus23): Consolidate errors between handshake.rs and relay.rs + // Perhaps a shared error type `FramingError`? + pub(crate) fn from_bytes(bytes: Bytes) -> Option<(Self, Bytes)> { + let mut cursor = std::io::Cursor::new(&bytes); + let tag = VarInt::decode(&mut cursor).ok()?; + let tag_u32 = u32::try_from(u64::from(tag)).ok()?; + let frame_type = FrameType::from(tag_u32); + let content = bytes.slice(cursor.position() as usize..); + Some((frame_type, content)) } +} - proptest! { - #[test] - fn server_client_frame_roundtrip(frame in server_client_frame()) { - let encoded = frame.clone().write_to(BytesMut::new()).freeze(); - let decoded = ServerToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); - prop_assert_eq!(frame, decoded); - } - - #[test] - fn client_server_frame_roundtrip(frame in client_server_frame()) { - let encoded = frame.clone().write_to(BytesMut::new()).freeze(); - let decoded = ClientToServerMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); - prop_assert_eq!(frame, decoded); - } +impl From for VarInt { + fn from(value: FrameType) -> Self { + (value as u32).into() } } diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs new file mode 100644 index 00000000000..48402ca0537 --- /dev/null +++ b/iroh-relay/src/protos/send_recv.rs @@ -0,0 +1,559 @@ +//! This module implements the relaying protocol used by the `server` and `client`. +//! +//! Protocol flow: +//! +//! Login: +//! * client connects +//! * -> client sends `FrameType::ClientInfo` +//! +//! Steady state: +//! * server occasionally sends `FrameType::KeepAlive` (or `FrameType::Ping`) +//! * client responds to any `FrameType::Ping` with a `FrameType::Pong` +//! * clients sends `FrameType::SendPacket` +//! * server then sends `FrameType::RecvPacket` to recipient + +use bytes::{BufMut, Bytes}; +use iroh_base::{NodeId, SignatureError}; +#[cfg(feature = "server")] +use n0_future::time::Duration; +use n0_future::{time, Sink, SinkExt}; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; + +use crate::KeyCache; + +use super::relay::FrameType; + +/// The maximum size of a packet sent over relay. +/// (This only includes the data bytes visible to magicsock, not +/// including its on-wire framing overhead) +pub const MAX_PACKET_SIZE: usize = 64 * 1024; + +/// The maximum frame size. +/// +/// This is also the minimum burst size that a rate-limiter has to accept. +#[cfg(not(wasm_browser))] +pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024; + +/// Interval in which we ping the relay server to ensure the connection is alive. +/// +/// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some +/// chance of recovering. +#[cfg(feature = "server")] +pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(15); + +/// The number of packets buffered for sending per client +#[cfg(feature = "server")] +pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512; + +/// Protocol send errors. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum SendError { + #[snafu(transparent)] + Io { source: std::io::Error }, + #[snafu(transparent)] + Timeout { source: time::Elapsed }, + #[snafu(transparent)] + SerDe { source: postcard::Error }, +} + +/// Protocol send errors. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum RecvError { + #[snafu(transparent)] + Io { source: std::io::Error }, + #[snafu(display("unexpected frame: got {got}, expected {expected}"))] + UnexpectedFrame { got: FrameType, expected: FrameType }, + #[snafu(display("Frame is too large, has {frame_len} bytes"))] + FrameTooLarge { frame_len: usize }, + #[snafu(transparent)] + Timeout { source: time::Elapsed }, + #[snafu(transparent)] + SerDe { source: postcard::Error }, + #[snafu(transparent)] + InvalidSignature { source: SignatureError }, + #[snafu(display("Invalid frame encoding"))] + InvalidFrame {}, + #[snafu(display("Invalid frame type: {frame_type}"))] + InvalidFrameType { frame_type: FrameType }, + #[snafu(display("invalid protocol message encoding"))] + InvalidProtocolMessageEncoding { source: std::str::Utf8Error }, + #[snafu(display("Too few bytes"))] + TooSmall {}, +} + +/// Writes complete frame, errors if it is unable to write within the given `timeout`. +/// Ignores the timeout if `None` +/// +/// Does not flush. +#[cfg(feature = "server")] +pub(crate) async fn write_frame + Unpin>( + mut writer: S, + frame: ServerToClientMsg, + timeout: Option, +) -> Result<(), SendError> { + if let Some(duration) = timeout { + tokio::time::timeout(duration, writer.send(frame)).await??; + } else { + writer.send(frame).await?; + } + + Ok(()) +} + +/// TODO(matheus23): Docs +/// The messages received from a framed relay stream. +/// +/// This is a type-validated version of the `Frame`s on the `RelayCodec`. +#[derive(derive_more::Debug, Clone, PartialEq, Eq)] +pub enum ServerToClientMsg { + /// Represents an incoming packet. + ReceivedPacket { + /// The [`NodeId`] of the packet sender. + remote_node_id: NodeId, + /// The received packet bytes. + #[debug(skip)] + data: Bytes, + }, + /// Indicates that the client identified by the underlying public key had previously sent you a + /// packet but has now disconnected from the server. + NodeGone(NodeId), + /// A one-way message from server to client, declaring the connection health state. + Health { + /// If set, is a description of why the connection is unhealthy. + /// + /// If `None` means the connection is healthy again. + /// + /// The default condition is healthy, so the server doesn't broadcast a [`ReceivedMessage::Health`] + /// until a problem exists. + problem: String, + }, + /// A one-way message from server to client, advertising that the server is restarting. + Restarting { + /// An advisory duration that the client should wait before attempting to reconnect. + /// It might be zero. It exists for the server to smear out the reconnects. + reconnect_in: Duration, + /// An advisory duration for how long the client should attempt to reconnect + /// before giving up and proceeding with its normal connection failure logic. The interval + /// between retries is undefined for now. A server should not send a TryFor duration more + /// than a few seconds. + try_for: Duration, + }, + /// TODO(matheus23) fix docs + /// Request from a client or server to reply to the + /// other side with a [`ReceivedMessage::Pong`] with the given payload. + Ping([u8; 8]), + /// TODO(matheus23) fix docs + /// Reply to a [`ReceivedMessage::Ping`] from a client or server + /// with the payload sent previously in the ping. + Pong([u8; 8]), +} + +/// TODO(matheus23): Docs +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ClientToServerMsg { + /// TODO + Ping([u8; 8]), + /// TODO + Pong([u8; 8]), + /// TODO + SendPacket { + /// TODO + dst_key: NodeId, + /// TODO + packet: Bytes, + }, +} + +impl ServerToClientMsg { + /// TODO(matheus23): docs + pub fn typ(&self) -> FrameType { + match self { + Self::ReceivedPacket { .. } => FrameType::RecvPacket, + Self::NodeGone { .. } => FrameType::NodeGone, + Self::Ping { .. } => FrameType::Ping, + Self::Pong { .. } => FrameType::Pong, + Self::Health { .. } => FrameType::Health, + Self::Restarting { .. } => FrameType::Restarting, + } + } + + /// Encodes this frame for sending over websockets. + /// + /// Specifically meant for being put into a binary websocket message frame. + pub(crate) fn write_to(&self, mut dst: O) -> O { + dst = self.typ().write_to(dst); + match self { + Self::ReceivedPacket { + remote_node_id: src_key, + data: content, + } => { + dst.put(src_key.as_ref()); + dst.put(content.as_ref()); + } + Self::NodeGone(node_id) => { + dst.put(node_id.as_ref()); + } + Self::Ping(data) => { + dst.put(&data[..]); + } + Self::Pong(data) => { + dst.put(&data[..]); + } + Self::Health { problem } => { + dst.put(problem.as_ref()); + } + Self::Restarting { + reconnect_in, + try_for, + } => { + dst.put_u32(reconnect_in.as_millis() as u32); + dst.put_u32(try_for.as_millis() as u32); + } + } + dst + } + + /// Tries to decode a frame received over websockets. + /// + /// Specifically, bytes received from a binary websocket message frame. + #[allow(clippy::result_large_err)] + pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { + let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; + let res = match frame_type { + FrameType::RecvPacket => { + if content.len() < NodeId::LENGTH { + return Err(InvalidFrameSnafu.build()); + } + + let frame_len = content.len() - NodeId::LENGTH; + if frame_len > MAX_PACKET_SIZE { + return Err(FrameTooLargeSnafu { frame_len }.build()); + } + + let src_key = cache.key_from_slice(&content[..NodeId::LENGTH])?; + let content = content.slice(NodeId::LENGTH..); + Self::ReceivedPacket { + remote_node_id: src_key, + data: content, + } + } + FrameType::NodeGone => { + if content.len() != NodeId::LENGTH { + return Err(InvalidFrameSnafu.build()); + } + let node_id = cache.key_from_slice(&content[..32])?; + Self::NodeGone(node_id) + } + FrameType::Ping => { + if content.len() != 8 { + return Err(InvalidFrameSnafu.build()); + } + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Ping(data) + } + FrameType::Pong => { + if content.len() != 8 { + return Err(InvalidFrameSnafu.build()); + } + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Pong(data) + } + FrameType::Health => { + let problem = std::str::from_utf8(&content) + .context(InvalidProtocolMessageEncodingSnafu)? + .to_owned(); + // TODO(matheus23): Actually encode/decode the option + Self::Health { problem } + } + FrameType::Restarting => { + if content.len() != 4 + 4 { + return Err(InvalidFrameSnafu.build()); + } + let reconnect_in = u32::from_be_bytes( + content[..4] + .try_into() + .map_err(|_| InvalidFrameSnafu.build())?, + ); + let try_for = u32::from_be_bytes( + content[4..] + .try_into() + .map_err(|_| InvalidFrameSnafu.build())?, + ); + let reconnect_in = Duration::from_millis(reconnect_in as u64); + let try_for = Duration::from_millis(try_for as u64); + Self::Restarting { + reconnect_in, + try_for, + } + } + _ => { + return Err(InvalidFrameTypeSnafu { frame_type }.build()); + } + }; + Ok(res) + } +} + +impl ClientToServerMsg { + pub(crate) fn typ(&self) -> FrameType { + match self { + Self::SendPacket { .. } => FrameType::SendPacket, + Self::Ping { .. } => FrameType::Ping, + Self::Pong { .. } => FrameType::Pong, + } + } + + /// Encodes this frame for sending over websockets. + /// + /// Specifically meant for being put into a binary websocket message frame. + pub(crate) fn write_to(&self, mut dst: O) -> O { + dst = self.typ().write_to(dst); + match self { + Self::SendPacket { dst_key, packet } => { + dst.put(dst_key.as_ref()); + dst.put(packet.as_ref()); + } + Self::Ping(data) => { + dst.put(&data[..]); + } + Self::Pong(data) => { + dst.put(&data[..]); + } + } + dst + } + + /// Tries to decode a frame received over websockets. + /// + /// Specifically, bytes received from a binary websocket message frame. + #[allow(clippy::result_large_err)] + pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { + let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; + let res = match frame_type { + FrameType::SendPacket => { + if content.len() < NodeId::LENGTH { + return Err(InvalidFrameSnafu.build()); + } + let frame_len = content.len() - NodeId::LENGTH; + if frame_len > MAX_PACKET_SIZE { + return Err(FrameTooLargeSnafu { frame_len }.build()); + } + + let dst_key = cache.key_from_slice(&content[..NodeId::LENGTH])?; + let packet = content.slice(NodeId::LENGTH..); + Self::SendPacket { dst_key, packet } + } + FrameType::Ping => { + if content.len() != 8 { + return Err(InvalidFrameSnafu.build()); + } + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Ping(data) + } + FrameType::Pong => { + if content.len() != 8 { + return Err(InvalidFrameSnafu.build()); + } + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Pong(data) + } + _ => { + return Err(InvalidFrameTypeSnafu { frame_type }.build()); + } + }; + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use data_encoding::HEXLOWER; + use iroh_base::SecretKey; + use n0_snafu::Result; + + use super::*; + + fn check_expected_bytes(frames: Vec<(Vec, &str)>) { + for (bytes, expected_hex) in frames { + let stripped: Vec = expected_hex + .chars() + .filter_map(|s| { + if s.is_ascii_whitespace() { + None + } else { + Some(s as u8) + } + }) + .collect(); + let expected_bytes = HEXLOWER.decode(&stripped).unwrap(); + assert_eq!(HEXLOWER.encode(&bytes), HEXLOWER.encode(&expected_bytes)); + } + } + + #[test] + fn test_server_client_frames_snapshot() -> Result { + let client_key = SecretKey::from_bytes(&[42u8; 32]); + + check_expected_bytes(vec![ + ( + ServerToClientMsg::Health { + problem: "Hello? Yes this is dog.".into(), + } + .write_to(Vec::new()), + "11 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 + 20 69 73 20 64 6f 67 2e", + ), + ( + ServerToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), + "0e 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d + 61", + ), + ( + ServerToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), + "0f 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + ServerToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), + "10 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + ServerToClientMsg::ReceivedPacket { + remote_node_id: client_key.public(), + data: "Hello World!".into(), + } + .write_to(Vec::new()), + "0b 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d + 61 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + ), + ( + ServerToClientMsg::Restarting { + reconnect_in: Duration::from_millis(10), + try_for: Duration::from_millis(20), + } + .write_to(Vec::new()), + "12 00 00 00 0a 00 00 00 14", + ), + ]); + + Ok(()) + } + + #[test] + fn test_client_server_frames_snapshot() -> Result { + let client_key = SecretKey::from_bytes(&[42u8; 32]); + + check_expected_bytes(vec![ + ( + ClientToServerMsg::Ping([42u8; 8]).write_to(Vec::new()), + "0f 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + ClientToServerMsg::Pong([42u8; 8]).write_to(Vec::new()), + "10 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + ClientToServerMsg::SendPacket { + dst_key: client_key.public(), + packet: "Goodbye!".into(), + } + .write_to(Vec::new()), + "0a 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d + 61 47 6f 6f 64 62 79 65 21", + ), + ]); + + Ok(()) + } +} + +#[cfg(test)] +mod proptests { + use bytes::BytesMut; + use iroh_base::SecretKey; + use proptest::prelude::*; + + use super::*; + + fn secret_key() -> impl Strategy { + prop::array::uniform32(any::()).prop_map(SecretKey::from) + } + + fn key() -> impl Strategy { + secret_key().prop_map(|key| key.public()) + } + + /// Generates random data, up to the maximum packet size minus the given number of bytes + fn data(consumed: usize) -> impl Strategy { + let len = MAX_PACKET_SIZE - consumed; + prop::collection::vec(any::(), 0..len).prop_map(Bytes::from) + } + + /// Generates a random valid frame + fn server_client_frame() -> impl Strategy { + let recv_packet = + (key(), data(32)).prop_map(|(src_key, content)| ServerToClientMsg::ReceivedPacket { + remote_node_id: src_key, + data: content, + }); + let node_gone = key().prop_map(|node_id| ServerToClientMsg::NodeGone(node_id)); + let ping = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Ping); + let pong = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Pong); + // TODO(matheus23): Actually fix these + let health = data(0).prop_map(|_problem| ServerToClientMsg::Health { + problem: "".to_string(), + }); + let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { + ServerToClientMsg::Restarting { + reconnect_in: Duration::from_millis(reconnect_in.into()), + try_for: Duration::from_millis(try_for.into()), + } + }); + prop_oneof![recv_packet, node_gone, ping, pong, health, restarting] + } + + fn client_server_frame() -> impl Strategy { + let send_packet = (key(), data(32)) + .prop_map(|(dst_key, packet)| ClientToServerMsg::SendPacket { dst_key, packet }); + let ping = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Ping); + let pong = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Pong); + prop_oneof![send_packet, ping, pong] + } + + proptest! { + #[test] + fn server_client_frame_roundtrip(frame in server_client_frame()) { + let encoded = frame.clone().write_to(BytesMut::new()).freeze(); + let decoded = ServerToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); + prop_assert_eq!(frame, decoded); + } + + #[test] + fn client_server_frame_roundtrip(frame in client_server_frame()) { + let encoded = frame.clone().write_to(BytesMut::new()).freeze(); + let decoded = ClientToServerMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); + prop_assert_eq!(frame, decoded); + } + } +} diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index 6a2d4b54f60..e1a4c7eb257 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -894,7 +894,7 @@ mod tests { dns::DnsResolver, protos::{ self, - relay::{ClientToServerMsg, ServerToClientMsg}, + send_recv::{ClientToServerMsg, ServerToClientMsg}, }, }; diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index bbc1d2840a3..514d26a3464 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -19,7 +19,7 @@ use tracing::{debug, trace, warn, Instrument}; use crate::{ protos::{ disco, - relay::{ + send_recv::{ write_frame, ClientToServerMsg, SendError as SendRelayError, ServerToClientMsg, PING_INTERVAL, }, @@ -557,7 +557,7 @@ mod tests { use tracing_test::traced_test; use super::*; - use crate::{client::conn::Conn, protos::handshake::FrameType}; + use crate::{client::conn::Conn, protos::relay::FrameType}; async fn recv_frame< E: snafu::Error + Sync + Send + 'static, diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index e9bb5c88ec0..885f62fc21d 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -200,7 +200,7 @@ mod tests { use super::*; use crate::{ client::conn::Conn, - protos::{handshake::FrameType, relay::ServerToClientMsg}, + protos::{relay::FrameType, send_recv::ServerToClientMsg}, server::streams::RelayedStream, }; diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 91d0dafa211..b78c3675b9a 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -26,7 +26,7 @@ use super::{clients::Clients, AccessConfig, SpawnError}; use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, http::{RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, - protos::relay::{ServerToClientMsg, PER_CLIENT_SEND_QUEUE_DEPTH}, + protos::send_recv::{ServerToClientMsg, PER_CLIENT_SEND_QUEUE_DEPTH}, server::{ client::Config, metrics::Metrics, @@ -37,7 +37,7 @@ use crate::{ }; use crate::{ http::WEBSOCKET_UPGRADE_PROTOCOL, - protos::{handshake, relay::MAX_FRAME_SIZE, streams::WsBytesFramed}, + protos::{handshake, send_recv::MAX_FRAME_SIZE, streams::WsBytesFramed}, server::streams::RateLimited, }; @@ -849,7 +849,7 @@ mod tests { use crate::{ client::{conn::Conn, Client, ClientBuilder, ConnectError}, dns::DnsResolver, - protos::relay::ClientToServerMsg, + protos::send_recv::ClientToServerMsg, }; pub(crate) fn make_tls_config() -> TlsConfig { diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index fefe379b0ab..57ab128bda6 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -14,7 +14,7 @@ use tokio_websockets::WebSocketStream; use tracing::instrument; use crate::{ - protos::relay::{ClientToServerMsg, RecvError, ServerToClientMsg}, + protos::send_recv::{ClientToServerMsg, RecvError, ServerToClientMsg}, ExportKeyingMaterial, KeyCache, }; @@ -64,7 +64,7 @@ impl RelayedStream { fn limits() -> tokio_websockets::Limits { tokio_websockets::Limits::default() - .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE)) + .max_payload_len(Some(crate::protos::send_recv::MAX_FRAME_SIZE)) } } diff --git a/iroh/src/magicsock/relay_actor.rs b/iroh/src/magicsock/relay_actor.rs index 5d43a39dc37..4e2a6103b2a 100644 --- a/iroh/src/magicsock/relay_actor.rs +++ b/iroh/src/magicsock/relay_actor.rs @@ -43,7 +43,7 @@ use iroh_base::{NodeId, PublicKey, RelayUrl, SecretKey}; use iroh_relay::{ self as relay, client::{Client, ConnectError, RecvError, SendError}, - protos::relay::{ClientToServerMsg, ServerToClientMsg}, + protos::send_recv::{ClientToServerMsg, ServerToClientMsg}, PingTracker, MAX_PACKET_SIZE, }; use n0_future::{ From 7a7c09cb0eedf208b0e988bdc56e7dca98ef4133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 30 Jun 2025 12:25:41 +0200 Subject: [PATCH 25/80] Resolve warnings --- iroh-relay/src/client.rs | 3 +++ iroh-relay/src/http.rs | 1 + iroh-relay/src/protos/handshake.rs | 6 +++++- iroh-relay/src/protos/send_recv.rs | 6 ++++-- iroh-relay/src/server.rs | 5 +---- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index c14cff5a0ce..5c4f399820f 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -202,6 +202,8 @@ impl ClientBuilder { pub async fn connect(&self) -> Result { use tls::MaybeTlsStreamBuilder; + use crate::protos::send_recv::MAX_FRAME_SIZE; + let mut dial_url = (*self.url).clone(); dial_url.set_path(RELAY_PATH); // The relay URL is exchanged with the http(s) scheme in tickets and similar. @@ -244,6 +246,7 @@ impl ClientBuilder { } .build() })? + .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))) .connect_on(stream) .await?; diff --git a/iroh-relay/src/http.rs b/iroh-relay/src/http.rs index 60a8a6d2df2..cca8c82aee0 100644 --- a/iroh-relay/src/http.rs +++ b/iroh-relay/src/http.rs @@ -1,5 +1,6 @@ //! HTTP-specific constants for the relay server and client. +#[cfg(feature = "server")] pub(crate) const WEBSOCKET_UPGRADE_PROTOCOL: &str = "websocket"; #[cfg(feature = "server")] // only used in the server for now pub(crate) const SUPPORTED_WEBSOCKET_VERSION: &str = "13"; diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 16873daa09f..4bd5faa192f 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -1,7 +1,9 @@ //! TODO(matheus23) docs use bytes::{BufMut, Bytes, BytesMut}; -use iroh_base::{PublicKey, SecretKey, Signature}; +#[cfg(feature = "server")] +use iroh_base::Signature; +use iroh_base::{PublicKey, SecretKey}; use n0_future::{ time::{self, Elapsed}, SinkExt, TryStreamExt, @@ -149,6 +151,7 @@ impl ClientAuth { } /// TODO(matheus23): docs + #[cfg(feature = "server")] pub(crate) fn verify_from_challenge(&self, challenge: &ServerChallenge) -> bool { self.public_key .verify( @@ -180,6 +183,7 @@ impl ClientAuth { }) } + #[cfg(feature = "server")] pub(crate) fn verify_from_key_export(&self, io: &mut impl ExportKeyingMaterial) -> bool { let Some(key_material) = io.export_keying_material( [0u8; 32], diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs index 48402ca0537..d15e53396b5 100644 --- a/iroh-relay/src/protos/send_recv.rs +++ b/iroh-relay/src/protos/send_recv.rs @@ -14,9 +14,9 @@ use bytes::{BufMut, Bytes}; use iroh_base::{NodeId, SignatureError}; +use n0_future::time::{self, Duration}; #[cfg(feature = "server")] -use n0_future::time::Duration; -use n0_future::{time, Sink, SinkExt}; +use n0_future::{Sink, SinkExt}; use nested_enum_utils::common_fields; use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; @@ -195,6 +195,7 @@ impl ServerToClientMsg { /// Encodes this frame for sending over websockets. /// /// Specifically meant for being put into a binary websocket message frame. + #[cfg(feature = "server")] pub(crate) fn write_to(&self, mut dst: O) -> O { dst = self.typ().write_to(dst); match self { @@ -344,6 +345,7 @@ impl ClientToServerMsg { /// /// Specifically, bytes received from a binary websocket message frame. #[allow(clippy::result_large_err)] + #[cfg(feature = "server")] pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; let res = match frame_type { diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index 83571416a87..81c29faccef 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -765,10 +765,7 @@ mod tests { use crate::{ client::ClientBuilder, dns::DnsResolver, - protos::{ - self, - send_recv::{ClientToServerMsg, ServerToClientMsg}, - }, + protos::send_recv::{ClientToServerMsg, ServerToClientMsg}, }; async fn spawn_local_relay() -> std::result::Result { From fe9641b817ec60d3274f950469268e5aa4951c7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 30 Jun 2025 12:26:46 +0200 Subject: [PATCH 26/80] `cargo make format` --- iroh-relay/src/protos/send_recv.rs | 3 +-- iroh-relay/src/server/streams.rs | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs index d15e53396b5..e5e4a77ad5b 100644 --- a/iroh-relay/src/protos/send_recv.rs +++ b/iroh-relay/src/protos/send_recv.rs @@ -20,9 +20,8 @@ use n0_future::{Sink, SinkExt}; use nested_enum_utils::common_fields; use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; -use crate::KeyCache; - use super::relay::FrameType; +use crate::KeyCache; /// The maximum size of a packet sent over relay. /// (This only includes the data bytes visible to magicsock, not diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 57ab128bda6..0bd59778bf7 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -13,13 +13,12 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_websockets::WebSocketStream; use tracing::instrument; +use super::{ClientRateLimit, Metrics}; use crate::{ protos::send_recv::{ClientToServerMsg, RecvError, ServerToClientMsg}, ExportKeyingMaterial, KeyCache, }; -use super::{ClientRateLimit, Metrics}; - /// A Stream and Sink for [`Frame`]s connected to a single relay client. /// /// The stream receives message from the client while the sink sends them to the client. From 7a5550ac81d7a01bb70a53687be4c69aede2c748 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 30 Jun 2025 15:03:38 +0200 Subject: [PATCH 27/80] Send ECN bits and use stride instead of custom split protocol --- iroh-relay/src/client/conn.rs | 4 +- iroh-relay/src/protos/relay.rs | 4 +- iroh-relay/src/protos/send_recv.rs | 182 ++++++++++----- iroh-relay/src/server.rs | 43 ++-- iroh-relay/src/server/client.rs | 72 +++--- iroh-relay/src/server/clients.rs | 27 +-- iroh-relay/src/server/http_server.rs | 151 +++++++------ iroh/src/magicsock.rs | 4 - iroh/src/magicsock/transports/relay.rs | 78 ++----- iroh/src/magicsock/transports/relay/actor.rs | 225 +++---------------- 10 files changed, 331 insertions(+), 459 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 9f51b9d936c..3511b2fd8a8 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -201,8 +201,8 @@ impl Sink for Conn { fn start_send(mut self: Pin<&mut Self>, frame: ClientToServerMsg) -> Result<(), Self::Error> { // TODO(matheus23): Check this in send message construction instead (and also check this in RecvPacket construction) - if let ClientToServerMsg::SendPacket { packet, .. } = &frame { - let size = packet.len(); + if let ClientToServerMsg::SendDatagrams { datagrams, .. } = &frame { + let size = datagrams.contents.len(); snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); } diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 93caacfe2e5..bcdaf72d33c 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -18,9 +18,9 @@ pub enum FrameType { ServerConfirmsAuth = 4, /// The server frame type for authentication denial ServerDeniesAuth = 5, - /// 32B dest pub key + packet bytes + /// 32B dest pub key + packet bytes TODO(matheus23): Fix docs SendPacket = 10, - /// v0/1 packet bytes, v2: 32B src pub key + packet bytes + /// v0/1 packet bytes, v2: 32B src pub key + packet bytes TODO(matheus23): Fix docs RecvPacket = 11, /// no payload, no-op (to be replaced with ping/pong) KeepAlive = 12, diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs index e5e4a77ad5b..3f548606e94 100644 --- a/iroh-relay/src/protos/send_recv.rs +++ b/iroh-relay/src/protos/send_recv.rs @@ -121,12 +121,11 @@ pub(crate) async fn write_frame, + /// The segment size if this transmission contains multiple datagrams. + /// This is `None` if the transmit only contains a single datagram + pub segment_size: Option, + /// The contents of the datagram(s) + #[debug(skip)] + pub contents: Bytes, +} + +impl> From for Datagrams { + fn from(bytes: T) -> Self { + Self { + ecn: None, + segment_size: None, + contents: Bytes::copy_from_slice(bytes.as_ref()), + } + } +} + +impl Datagrams { + fn write_to(&self, mut dst: O) -> O { + let ecn = self.ecn.map_or(0, |ecn| ecn as u8); + let segment_size = self.segment_size.unwrap_or_default(); + dst.put_u8(ecn); + dst.put_u16(segment_size); + dst.put(self.contents.as_ref()); + dst + } + + fn from_bytes(bytes: Bytes) -> Result { + // 1 bytes ECN, 2 bytes segment size + snafu::ensure!(bytes.len() > 3, InvalidFrameSnafu); + + let ecn_byte = bytes[0]; + let ecn = quinn_proto::EcnCodepoint::from_bits(ecn_byte); + + let segment_size = u16::from_be_bytes(bytes[1..3].try_into().expect("length checked")); + let segment_size = if segment_size == 0 { + None + } else { + Some(segment_size) + }; + + let contents = bytes.slice(3..); + + Ok(Self { + ecn, + segment_size, + contents, + }) + } +} + impl ServerToClientMsg { /// TODO(matheus23): docs pub fn typ(&self) -> FrameType { match self { - Self::ReceivedPacket { .. } => FrameType::RecvPacket, + Self::ReceivedDatagrams { .. } => FrameType::RecvPacket, Self::NodeGone { .. } => FrameType::NodeGone, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, @@ -198,12 +254,12 @@ impl ServerToClientMsg { pub(crate) fn write_to(&self, mut dst: O) -> O { dst = self.typ().write_to(dst); match self { - Self::ReceivedPacket { - remote_node_id: src_key, - data: content, + Self::ReceivedDatagrams { + remote_node_id, + datagrams, } => { - dst.put(src_key.as_ref()); - dst.put(content.as_ref()); + dst.put(remote_node_id.as_ref()); + dst = datagrams.write_to(dst); } Self::NodeGone(node_id) => { dst.put(node_id.as_ref()); @@ -236,41 +292,34 @@ impl ServerToClientMsg { let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; let res = match frame_type { FrameType::RecvPacket => { - if content.len() < NodeId::LENGTH { - return Err(InvalidFrameSnafu.build()); - } + snafu::ensure!(content.len() >= NodeId::LENGTH, InvalidFrameSnafu); let frame_len = content.len() - NodeId::LENGTH; - if frame_len > MAX_PACKET_SIZE { - return Err(FrameTooLargeSnafu { frame_len }.build()); - } + snafu::ensure!( + frame_len <= MAX_PACKET_SIZE, + FrameTooLargeSnafu { frame_len } + ); - let src_key = cache.key_from_slice(&content[..NodeId::LENGTH])?; - let content = content.slice(NodeId::LENGTH..); - Self::ReceivedPacket { - remote_node_id: src_key, - data: content, + let remote_node_id = cache.key_from_slice(&content[..NodeId::LENGTH])?; + let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; + Self::ReceivedDatagrams { + remote_node_id, + datagrams, } } FrameType::NodeGone => { - if content.len() != NodeId::LENGTH { - return Err(InvalidFrameSnafu.build()); - } - let node_id = cache.key_from_slice(&content[..32])?; + snafu::ensure!(content.len() == NodeId::LENGTH, InvalidFrameSnafu); + let node_id = cache.key_from_slice(content.as_ref())?; Self::NodeGone(node_id) } FrameType::Ping => { - if content.len() != 8 { - return Err(InvalidFrameSnafu.build()); - } + snafu::ensure!(content.len() == 8, InvalidFrameSnafu); let mut data = [0u8; 8]; data.copy_from_slice(&content[..8]); Self::Ping(data) } FrameType::Pong => { - if content.len() != 8 { - return Err(InvalidFrameSnafu.build()); - } + snafu::ensure!(content.len() == 8, InvalidFrameSnafu); let mut data = [0u8; 8]; data.copy_from_slice(&content[..8]); Self::Pong(data) @@ -283,9 +332,7 @@ impl ServerToClientMsg { Self::Health { problem } } FrameType::Restarting => { - if content.len() != 4 + 4 { - return Err(InvalidFrameSnafu.build()); - } + snafu::ensure!(content.len() == 4 + 4, InvalidFrameSnafu); let reconnect_in = u32::from_be_bytes( content[..4] .try_into() @@ -314,7 +361,7 @@ impl ServerToClientMsg { impl ClientToServerMsg { pub(crate) fn typ(&self) -> FrameType { match self { - Self::SendPacket { .. } => FrameType::SendPacket, + Self::SendDatagrams { .. } => FrameType::SendPacket, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, } @@ -326,9 +373,12 @@ impl ClientToServerMsg { pub(crate) fn write_to(&self, mut dst: O) -> O { dst = self.typ().write_to(dst); match self { - Self::SendPacket { dst_key, packet } => { - dst.put(dst_key.as_ref()); - dst.put(packet.as_ref()); + Self::SendDatagrams { + dst_node_id, + datagrams, + } => { + dst.put(dst_node_id.as_ref()); + dst = datagrams.write_to(dst); } Self::Ping(data) => { dst.put(&data[..]); @@ -357,9 +407,12 @@ impl ClientToServerMsg { return Err(FrameTooLargeSnafu { frame_len }.build()); } - let dst_key = cache.key_from_slice(&content[..NodeId::LENGTH])?; - let packet = content.slice(NodeId::LENGTH..); - Self::SendPacket { dst_key, packet } + let dst_node_id = cache.key_from_slice(&content[..NodeId::LENGTH])?; + let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; + Self::SendDatagrams { + dst_node_id, + datagrams, + } } FrameType::Ping => { if content.len() != 8 { @@ -438,9 +491,13 @@ mod tests { "10 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id: client_key.public(), - data: "Hello World!".into(), + datagrams: Datagrams { + ecn: Some(quinn::EcnCodepoint::Ce), + segment_size: Some(6), + contents: "Hello World!".into(), + }, } .write_to(Vec::new()), "0b 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e @@ -474,9 +531,13 @@ mod tests { "10 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ClientToServerMsg::SendPacket { - dst_key: client_key.public(), - packet: "Goodbye!".into(), + ClientToServerMsg::SendDatagrams { + dst_node_id: client_key.public(), + datagrams: Datagrams { + ecn: Some(quinn::EcnCodepoint::Ce), + segment_size: Some(6), + contents: "Hello World!".into(), + }, } .write_to(Vec::new()), "0a 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e @@ -511,13 +572,18 @@ mod proptests { prop::collection::vec(any::(), 0..len).prop_map(Bytes::from) } + fn datagrams(data: impl Strategy) -> impl Strategy { + data.prop_map(|_content| todo!()) + } + /// Generates a random valid frame fn server_client_frame() -> impl Strategy { - let recv_packet = - (key(), data(32)).prop_map(|(src_key, content)| ServerToClientMsg::ReceivedPacket { - remote_node_id: src_key, - data: content, - }); + let recv_packet = (key(), datagrams(data(32))).prop_map(|(remote_node_id, datagrams)| { + ServerToClientMsg::ReceivedDatagrams { + remote_node_id, + datagrams, + } + }); let node_gone = key().prop_map(|node_id| ServerToClientMsg::NodeGone(node_id)); let ping = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Pong); @@ -535,8 +601,12 @@ mod proptests { } fn client_server_frame() -> impl Strategy { - let send_packet = (key(), data(32)) - .prop_map(|(dst_key, packet)| ClientToServerMsg::SendPacket { dst_key, packet }); + let send_packet = (key(), datagrams(data(32))).prop_map(|(dst_node_id, datagrams)| { + ClientToServerMsg::SendDatagrams { + dst_node_id, + datagrams, + } + }); let ping = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Pong); prop_oneof![send_packet, ping, pong] diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index 81c29faccef..7b2f050173b 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -750,7 +750,6 @@ impl hyper::service::Service> for CaptivePortalService { mod tests { use std::{net::Ipv4Addr, time::Duration}; - use bytes::Bytes; use http::StatusCode; use iroh_base::{NodeId, RelayUrl, SecretKey}; use n0_future::{FutureExt, SinkExt, StreamExt}; @@ -765,7 +764,7 @@ mod tests { use crate::{ client::ClientBuilder, dns::DnsResolver, - protos::send_recv::{ClientToServerMsg, ServerToClientMsg}, + protos::send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg}, }; async fn spawn_local_relay() -> std::result::Result { @@ -788,14 +787,14 @@ mod tests { client_a: &mut crate::client::Client, client_b: &mut crate::client::Client, b_key: NodeId, - msg: Bytes, + msg: Datagrams, ) -> Result { // try resend 10 times for _ in 0..10 { client_a - .send(ClientToServerMsg::SendPacket { - dst_key: b_key, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: b_key, + datagrams: msg.clone(), }) .await?; let Ok(res) = tokio::time::timeout(Duration::from_millis(500), client_b.next()).await @@ -909,34 +908,34 @@ mod tests { info!("sending a -> b"); // send message from a to b - let msg = Bytes::from("hello, b"); + let msg = Datagrams::from("hello, b"); let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?; - let ServerToClientMsg::ReceivedPacket { + let ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } = res else { panic!("client_b received unexpected message {res:?}"); }; assert_eq!(a_key, remote_node_id); - assert_eq!(msg, data); + assert_eq!(msg, datagrams); info!("sending b -> a"); // send message from b to a - let msg = Bytes::from("howdy, a"); + let msg = Datagrams::from("howdy, a"); let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?; - let ServerToClientMsg::ReceivedPacket { + let ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } = res else { panic!("client_a received unexpected message {res:?}"); }; assert_eq!(b_key, remote_node_id); - assert_eq!(msg, data); + assert_eq!(msg, datagrams); Ok(()) } @@ -1018,16 +1017,16 @@ mod tests { .await?; // send message from b to c - let msg = Bytes::from("hello, c"); + let msg = Datagrams::from("hello, c"); let res = try_send_recv(&mut client_b, &mut client_c, c_key, msg.clone()).await?; - if let ServerToClientMsg::ReceivedPacket { + if let ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } = res { assert_eq!(b_key, remote_node_id); - assert_eq!(msg, data); + assert_eq!(msg, datagrams); } else { panic!("client_c received unexpected message {res:?}"); } @@ -1059,12 +1058,12 @@ mod tests { // send messages from a to b, without b receiving anything. // we should still keep succeeding to send, even if the packet won't be forwarded // by the relay server because the server's send queue for b fills up. - let msg = Bytes::from("hello, b"); + let msg = Datagrams::from("hello, b"); for _i in 0..1000 { client_a - .send(ClientToServerMsg::SendPacket { - dst_key: b_key, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: b_key, + datagrams: msg.clone(), }) .await?; } diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 514d26a3464..9bb0793efc3 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -2,7 +2,6 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; -use bytes::Bytes; use iroh_base::NodeId; use n0_future::{SinkExt, StreamExt}; use nested_enum_utils::common_fields; @@ -20,8 +19,8 @@ use crate::{ protos::{ disco, send_recv::{ - write_frame, ClientToServerMsg, SendError as SendRelayError, ServerToClientMsg, - PING_INTERVAL, + write_frame, ClientToServerMsg, Datagrams, SendError as SendRelayError, + ServerToClientMsg, PING_INTERVAL, }, }, server::{ @@ -38,7 +37,7 @@ pub(super) struct Packet { /// The sender of the packet src: NodeId, /// The data packet bytes. - data: Bytes, + data: Datagrams, } /// Configuration for a [`Client`]. @@ -153,7 +152,7 @@ impl Client { pub(super) fn try_send_packet( &self, src: NodeId, - data: Bytes, + data: Datagrams, ) -> Result<(), TrySendError> { self.send_queue.try_send(Packet { src, data }) } @@ -161,7 +160,7 @@ impl Client { pub(super) fn try_send_disco_packet( &self, src: NodeId, - data: Bytes, + data: Datagrams, ) -> Result<(), TrySendError> { self.disco_send_queue.try_send(Packet { src, data }) } @@ -396,15 +395,15 @@ impl Actor { /// Errors if the send does not happen within the `timeout` duration /// Does not flush. async fn send_raw(&mut self, packet: Packet) -> Result<(), SendRelayError> { - let src_key = packet.src; - let content = packet.data; + let remote_node_id = packet.src; + let datagrams = packet.data; - if let Ok(len) = content.len().try_into() { + if let Ok(len) = datagrams.contents.len().try_into() { self.metrics.bytes_sent.inc_by(len); } - self.write_frame(ServerToClientMsg::ReceivedPacket { - remote_node_id: src_key, - data: content, + self.write_frame(ServerToClientMsg::ReceivedDatagrams { + remote_node_id, + datagrams, }) .await } @@ -449,10 +448,13 @@ impl Actor { }; match frame { - ClientToServerMsg::SendPacket { dst_key, packet } => { - let packet_len = packet.len(); + ClientToServerMsg::SendDatagrams { + dst_node_id: dst_key, + datagrams, + } => { + let packet_len = datagrams.contents.len(); if let Err(err @ ForwardPacketError { .. }) = - self.handle_frame_send_packet(dst_key, packet) + self.handle_frame_send_packet(dst_key, datagrams) { warn!("failed to handle send packet frame: {err:#}"); } @@ -471,8 +473,12 @@ impl Actor { Ok(()) } - fn handle_frame_send_packet(&self, dst: NodeId, data: Bytes) -> Result<(), ForwardPacketError> { - if disco::looks_like_disco_wrapper(&data) { + fn handle_frame_send_packet( + &self, + dst: NodeId, + data: Datagrams, + ) -> Result<(), ForwardPacketError> { + if disco::looks_like_disco_wrapper(&data.contents) { self.metrics.disco_packets_recv.inc(); self.clients .send_disco_packet(dst, data, self.node_id, &self.metrics)?; @@ -549,7 +555,7 @@ impl ClientCounter { #[cfg(test)] mod tests { - use bytes::{Bytes, BytesMut}; + use bytes::BytesMut; use iroh_base::SecretKey; use n0_future::Stream; use n0_snafu::{Result, ResultExt}; @@ -622,15 +628,15 @@ mod tests { println!(" send packet"); let packet = Packet { src: node_id, - data: Bytes::from(&data[..]), + data: Datagrams::from(&data[..]), }; send_queue_s.send(packet.clone()).await.context("send")?; let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; assert_eq!( frame, - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id: node_id, - data: data.to_vec().into() + datagrams: data.to_vec().into() } ); @@ -643,9 +649,9 @@ mod tests { let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; assert_eq!( frame, - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id: node_id, - data: data.to_vec().into() + datagrams: data.to_vec().into() } ); @@ -673,9 +679,9 @@ mod tests { println!(" send packet"); let data = b"hello world!"; io_rw - .send(ClientToServerMsg::SendPacket { - dst_key: target, - packet: Bytes::from_static(data), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: target, + datagrams: Datagrams::from(data), }) .await .context("send")?; @@ -687,9 +693,9 @@ mod tests { disco_data.extend_from_slice(target.as_bytes()); disco_data.extend_from_slice(data); io_rw - .send(ClientToServerMsg::SendPacket { - dst_key: target, - packet: disco_data.clone().into(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: target, + datagrams: disco_data.clone().into(), }) .await .context("send")?; @@ -712,11 +718,11 @@ mod tests { let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT); // Prepare a frame to send, assert its size. - let data = Bytes::from_static(b"hello world!1eins"); + let data = Datagrams::from(b"hello world!1eins"); let target = SecretKey::generate(rand::thread_rng()).public(); - let frame = ClientToServerMsg::SendPacket { - dst_key: target, - packet: data.clone(), + let frame = ClientToServerMsg::SendDatagrams { + dst_node_id: target, + datagrams: data.clone(), }; let frame_len = frame.clone().write_to(BytesMut::new()).freeze().len(); assert_eq!(frame_len, LIMIT as usize); diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 885f62fc21d..3b8b5731805 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -9,16 +9,18 @@ use std::{ }, }; -use bytes::Bytes; use dashmap::DashMap; use iroh_base::NodeId; use tokio::sync::mpsc::error::TrySendError; use tracing::{debug, trace}; use super::client::{Client, Config, ForwardPacketError}; -use crate::server::{ - client::{PacketScope, SendError}, - metrics::Metrics, +use crate::{ + protos::send_recv::Datagrams, + server::{ + client::{PacketScope, SendError}, + metrics::Metrics, + }, }; /// Manages the connections to all currently connected clients. @@ -108,7 +110,7 @@ impl Clients { pub(super) fn send_packet( &self, dst: NodeId, - data: Bytes, + data: Datagrams, src: NodeId, metrics: &Metrics, ) -> Result<(), ForwardPacketError> { @@ -148,7 +150,7 @@ impl Clients { pub(super) fn send_disco_packet( &self, dst: NodeId, - data: Bytes, + data: Datagrams, src: NodeId, metrics: &Metrics, ) -> Result<(), ForwardPacketError> { @@ -192,7 +194,6 @@ impl Clients { mod tests { use std::time::Duration; - use bytes::Bytes; use iroh_base::SecretKey; use n0_future::{Stream, StreamExt}; use n0_snafu::{Result, ResultExt}; @@ -253,24 +254,24 @@ mod tests { // send packet let data = b"hello world!"; - clients.send_packet(a_key, Bytes::from(&data[..]), b_key, &metrics)?; + clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; assert_eq!( frame, - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id: b_key, - data: data.to_vec().into(), + datagrams: data.to_vec().into(), } ); // send disco packet - clients.send_disco_packet(a_key, Bytes::from(&data[..]), b_key, &metrics)?; + clients.send_disco_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; assert_eq!( frame, - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id: b_key, - data: data.to_vec().into(), + datagrams: data.to_vec().into(), } ); diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index a1f590070e1..d52204f4b8a 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -836,7 +836,6 @@ impl std::ops::DerefMut for Handlers { mod tests { use std::sync::Arc; - use bytes::Bytes; use iroh_base::{PublicKey, SecretKey}; use n0_future::{SinkExt, StreamExt}; use n0_snafu::{Result, ResultExt}; @@ -849,7 +848,7 @@ mod tests { use crate::{ client::{conn::Conn, Client, ClientBuilder, ConnectError}, dns::DnsResolver, - protos::send_recv::ClientToServerMsg, + protos::send_recv::{ClientToServerMsg, Datagrams}, }; pub(crate) fn make_tls_config() -> TlsConfig { @@ -917,11 +916,11 @@ mod tests { assert!(matches!(pong, ServerToClientMsg::Pong { .. })); info!("sending message from a to b"); - let msg = Bytes::from_static(b"hi there, client b!"); + let msg = Datagrams::from(b"hi there, client b!"); client_a - .send(ClientToServerMsg::SendPacket { - dst_key: b_key, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: b_key, + datagrams: msg.clone(), }) .await?; info!("waiting for message from a on b"); @@ -931,11 +930,11 @@ mod tests { assert_eq!(msg, got_msg); info!("sending message from b to a"); - let msg = Bytes::from_static(b"right back at ya, client b!"); + let msg = Datagrams::from(b"right back at ya, client b!"); client_b - .send(ClientToServerMsg::SendPacket { - dst_key: a_key, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: a_key, + datagrams: msg.clone(), }) .await?; info!("waiting for message b on a"); @@ -966,7 +965,7 @@ mod tests { fn process_msg( msg: Option>, - ) -> Option<(PublicKey, Bytes)> { + ) -> Option<(PublicKey, Datagrams)> { match msg { Some(Err(e)) => { info!("client `recv` error {e}"); @@ -974,12 +973,12 @@ mod tests { } Some(Ok(msg)) => { info!("got message on: {msg:?}"); - if let ServerToClientMsg::ReceivedPacket { + if let ServerToClientMsg::ReceivedDatagrams { remote_node_id: source, - data, + datagrams, } = msg { - Some((source, data)) + Some((source, datagrams)) } else { None } @@ -1037,11 +1036,11 @@ mod tests { assert!(matches!(pong, ServerToClientMsg::Pong { .. })); info!("sending message from a to b"); - let msg = Bytes::from_static(b"hi there, client b!"); + let msg = Datagrams::from(b"hi there, client b!"); client_a - .send(ClientToServerMsg::SendPacket { - dst_key: b_key, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: b_key, + datagrams: msg.clone(), }) .await?; info!("waiting for message from a on b"); @@ -1051,11 +1050,11 @@ mod tests { assert_eq!(msg, got_msg); info!("sending message from b to a"); - let msg = Bytes::from_static(b"right back at ya, client b!"); + let msg = Datagrams::from(b"right back at ya, client b!"); client_b - .send(ClientToServerMsg::SendPacket { - dst_key: a_key, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: a_key, + datagrams: msg.clone(), }) .await?; info!("waiting for message b on a"); @@ -1114,44 +1113,44 @@ mod tests { handler_task.await.context("join")??; info!("Send message from A to B."); - let msg = Bytes::from_static(b"hello client b!!"); + let msg = Datagrams::from(b"hello client b!!"); client_a - .send(ClientToServerMsg::SendPacket { - dst_key: public_key_b, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: public_key_b, + datagrams: msg.clone(), }) .await?; match client_b.next().await.unwrap()? { - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } => { assert_eq!(public_key_a, remote_node_id); - assert_eq!(&msg[..], data); + assert_eq!(msg, datagrams); } msg => { - whatever!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedDatagrams msg, got {msg:?}"); } } info!("Send message from B to A."); - let msg = Bytes::from_static(b"nice to meet you client a!!"); + let msg = Datagrams::from(b"nice to meet you client a!!"); client_b - .send(ClientToServerMsg::SendPacket { - dst_key: public_key_a, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: public_key_a, + datagrams: msg.clone(), }) .await?; match client_a.next().await.unwrap()? { - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } => { assert_eq!(public_key_b, remote_node_id); - assert_eq!(&msg[..], data); + assert_eq!(msg, datagrams); } msg => { - whatever!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedDatagrams msg, got {msg:?}"); } } @@ -1161,9 +1160,9 @@ mod tests { info!("Fail to send message from A to B."); let res = client_a - .send(ClientToServerMsg::SendPacket { - dst_key: public_key_b, - packet: Bytes::from_static(b"try to send"), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: public_key_b, + datagrams: Datagrams::from(b"try to send"), }) .await; assert!(res.is_err()); @@ -1204,44 +1203,44 @@ mod tests { handler_task.await.context("join")??; info!("Send message from A to B."); - let msg = Bytes::from_static(b"hello client b!!"); + let msg = Datagrams::from(b"hello client b!!"); client_a - .send(ClientToServerMsg::SendPacket { - dst_key: public_key_b, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: public_key_b, + datagrams: msg.clone(), }) .await?; match client_b.next().await.expect("eos")? { - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } => { assert_eq!(public_key_a, remote_node_id); - assert_eq!(&msg[..], data); + assert_eq!(msg, datagrams); } msg => { - whatever!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedDatagrams msg, got {msg:?}"); } } info!("Send message from B to A."); - let msg = Bytes::from_static(b"nice to meet you client a!!"); + let msg = Datagrams::from(b"nice to meet you client a!!"); client_b - .send(ClientToServerMsg::SendPacket { - dst_key: public_key_a, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: public_key_a, + datagrams: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } => { assert_eq!(public_key_b, remote_node_id); - assert_eq!(&msg[..], data); + assert_eq!(msg, datagrams); } msg => { - whatever!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedDatagrams msg, got {msg:?}"); } } @@ -1256,44 +1255,44 @@ mod tests { // assert!(client_b.recv().await.is_err()); info!("Send message from A to B."); - let msg = Bytes::from_static(b"are you still there, b?!"); + let msg = Datagrams::from(b"are you still there, b?!"); client_a - .send(ClientToServerMsg::SendPacket { - dst_key: public_key_b, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: public_key_b, + datagrams: msg.clone(), }) .await?; match new_client_b.next().await.expect("eos")? { - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } => { assert_eq!(public_key_a, remote_node_id); - assert_eq!(&msg[..], data); + assert_eq!(msg, datagrams); } msg => { - whatever!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedDatagrams msg, got {msg:?}"); } } info!("Send message from B to A."); - let msg = Bytes::from_static(b"just had a spot of trouble but I'm back now,a!!"); + let msg = Datagrams::from(b"just had a spot of trouble but I'm back now,a!!"); new_client_b - .send(ClientToServerMsg::SendPacket { - dst_key: public_key_a, - packet: msg.clone(), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: public_key_a, + datagrams: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } => { assert_eq!(public_key_b, remote_node_id); - assert_eq!(&msg[..], data); + assert_eq!(msg, datagrams); } msg => { - whatever!("expected ReceivedPacket msg, got {msg:?}"); + whatever!("expected ReceivedDatagrams msg, got {msg:?}"); } } @@ -1302,9 +1301,9 @@ mod tests { info!("Sending message from A to B fails"); let res = client_a - .send(ClientToServerMsg::SendPacket { - dst_key: public_key_b, - packet: Bytes::from_static(b"try to send"), + .send(ClientToServerMsg::SendDatagrams { + dst_node_id: public_key_b, + datagrams: Datagrams::from(b"try to send"), }) .await; assert!(res.is_err()); diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index 9d68bc0966f..e24d66c01f6 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -148,10 +148,6 @@ pub(crate) struct Options { pub(crate) metrics: EndpointMetrics, } -/// Contents of a relay message. Use a SmallVec to avoid allocations for the very -/// common case of a single packet. -type RelayContents = SmallVec<[Bytes; 1]>; - /// Handle for [`MagicSock`]. /// /// Dereferences to [`MagicSock`], and handles closing. diff --git a/iroh/src/magicsock/transports/relay.rs b/iroh/src/magicsock/transports/relay.rs index 9345bed6af2..86974247548 100644 --- a/iroh/src/magicsock/transports/relay.rs +++ b/iroh/src/magicsock/transports/relay.rs @@ -5,18 +5,17 @@ use std::{ use bytes::Bytes; use iroh_base::{NodeId, RelayUrl}; +use iroh_relay::protos::send_recv::Datagrams; use n0_future::{ ready, task::{self, AbortOnDropHandle}, }; use n0_watcher::{Watchable, Watcher as _}; -use smallvec::SmallVec; use tokio::sync::mpsc; use tokio_util::sync::PollSender; use tracing::{error, info_span, trace, warn, Instrument}; use super::{Addr, Transmit}; -use crate::magicsock::RelayContents; mod actor; @@ -100,9 +99,12 @@ impl RelayTransport { } }; - buf_out[..dm.buf.len()].copy_from_slice(&dm.buf); - meta_out.len = dm.buf.len(); - meta_out.stride = dm.buf.len(); + buf_out[..dm.datagrams.contents.len()].copy_from_slice(&dm.datagrams.contents); + meta_out.len = dm.datagrams.contents.len(); + meta_out.stride = dm + .datagrams + .segment_size + .map_or(dm.datagrams.contents.len(), |s| s as usize); meta_out.ecn = None; meta_out.dst_ip = None; // TODO: insert the relay url for this relay @@ -186,7 +188,7 @@ impl RelaySender { dest_node: NodeId, transmit: &Transmit<'_>, ) -> io::Result<()> { - let contents = split_packets(transmit); + let contents = datagrams_from_transmit(transmit); let item = RelaySendItem { remote_node: dest_node, @@ -228,7 +230,7 @@ impl RelaySender { trace!(node = %dest_node.fmt_short(), relay_url = %dest_url, "send relay: message queued"); - let contents = split_packets(transmit); + let contents = datagrams_from_transmit(transmit); let item = RelaySendItem { remote_node: dest_node, url: dest_url.clone(), @@ -266,7 +268,7 @@ impl RelaySender { dest_node: NodeId, transmit: &Transmit<'_>, ) -> io::Result<()> { - let contents = split_packets(transmit); + let contents = datagrams_from_transmit(transmit); let item = RelaySendItem { remote_node: dest_node, @@ -313,17 +315,16 @@ impl RelaySender { /// size, the contents will be sent as a single packet. // TODO: If quinn stayed on bytes this would probably be much cheaper, probably. Need to // figure out where they allocate the Vec. -fn split_packets(transmit: &Transmit<'_>) -> RelayContents { - let mut res = SmallVec::with_capacity(1); - let contents = transmit.contents; - if let Some(segment_size) = transmit.segment_size { - for chunk in contents.chunks(segment_size) { - res.push(Bytes::from(chunk.to_vec())); - } - } else { - res.push(Bytes::from(contents.to_vec())); +fn datagrams_from_transmit(transmit: &Transmit<'_>) -> Datagrams { + Datagrams { + ecn: transmit.ecn.map(|ecn| match ecn { + quinn_udp::EcnCodepoint::Ect0 => quinn_proto::EcnCodepoint::Ect0, + quinn_udp::EcnCodepoint::Ect1 => quinn_proto::EcnCodepoint::Ect1, + quinn_udp::EcnCodepoint::Ce => quinn_proto::EcnCodepoint::Ce, + }), + segment_size: transmit.segment_size.map(|ss| ss as u16), + contents: Bytes::copy_from_slice(transmit.contents), } - res } #[cfg(test)] @@ -337,43 +338,6 @@ mod tests { use super::*; use crate::defaults::staging; - #[test] - fn test_split_packets() { - fn mk_transmit(contents: &[u8], segment_size: Option) -> Transmit<'_> { - Transmit { - ecn: None, - contents, - segment_size, - } - } - fn mk_expected(parts: impl IntoIterator) -> RelayContents { - parts - .into_iter() - .map(|p| p.as_bytes().to_vec().into()) - .collect() - } - // no split - assert_eq!( - split_packets(&mk_transmit(b"hello", None)), - mk_expected(["hello"]) - ); - // split without rest - assert_eq!( - split_packets(&mk_transmit(b"helloworld", Some(5))), - mk_expected(["hello", "world"]) - ); - // split with rest and second transmit - assert_eq!( - split_packets(&mk_transmit(b"hello world", Some(5))), - mk_expected(["hello", " worl", "d"]) // spellchecker:disable-line - ); - // split that results in 1 packet - assert_eq!( - split_packets(&mk_transmit(b"hello world", Some(1000))), - mk_expected(["hello world"]) - ); - } - #[tokio::test(flavor = "multi_thread")] async fn test_relay_datagram_queue() { let capacity = 16; @@ -387,7 +351,7 @@ mod tests { let mut expected_msgs: BTreeSet = (0..capacity).collect(); while !expected_msgs.is_empty() { let datagram: RelayRecvDatagram = receiver.recv().await.unwrap(); - let msg_num = usize::from_le_bytes(datagram.buf.as_ref().try_into().unwrap()); + let msg_num = usize::from_le_bytes(datagram.datagrams.contents.as_ref().try_into().unwrap()); debug!("Received {msg_num}"); if !expected_msgs.remove(&msg_num) { @@ -407,7 +371,7 @@ mod tests { .try_send(RelayRecvDatagram { url, src: NodeId::from_bytes(&[0u8; 32]).unwrap(), - buf: Bytes::copy_from_slice(&i.to_le_bytes()), + datagrams: Datagrams::from(&i.to_le_bytes()), }) .unwrap(); } diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index 5e5eed10ecb..70a0a1fb818 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -38,12 +38,11 @@ use std::{ }; use backon::{Backoff, BackoffBuilder, ExponentialBuilder}; -use bytes::{Bytes, BytesMut}; use iroh_base::{NodeId, PublicKey, RelayUrl, SecretKey}; use iroh_relay::{ self as relay, client::{Client, ConnectError, RecvError, SendError}, - protos::send_recv::{ClientToServerMsg, ServerToClientMsg}, + protos::send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg}, PingTracker, MAX_PACKET_SIZE, }; use n0_future::{ @@ -62,11 +61,7 @@ use url::Url; #[cfg(not(wasm_browser))] use crate::dns::DnsResolver; -use crate::{ - magicsock::{Metrics as MagicsockMetrics, RelayContents}, - net_report::Report, - util::MaybeFuture, -}; +use crate::{magicsock::Metrics as MagicsockMetrics, net_report::Report, util::MaybeFuture}; /// How long a non-home relay connection needs to be idle (last written to) before we close it. const RELAY_INACTIVE_CLEANUP_TIME: Duration = Duration::from_secs(60); @@ -626,24 +621,18 @@ impl ActiveRelayActor { self.reset_inactive_timeout(); // TODO: This allocation is *very* unfortunate. But so is the // allocation *inside* of PacketizeIter... - let dgrams = std::mem::replace( + let batch = std::mem::replace( &mut send_datagrams_buf, Vec::with_capacity(SEND_DATAGRAM_BATCH_SIZE), ); // TODO(frando): can we avoid the clone here? let metrics = self.metrics.clone(); - let packet_iter = dgrams.into_iter().flat_map(|datagrams| { - PacketizeIter::<_, MAX_PAYLOAD_SIZE>::new( - datagrams.remote_node, - datagrams.datagrams.clone(), - ) - .map(|p| { - Ok(ClientToServerMsg::SendPacket { dst_key: p.node_id, packet: p.payload }) - }) - }); + let packet_iter = batch.into_iter().map(|item| { + Ok(ClientToServerMsg::SendDatagrams { dst_node_id: item.remote_node, datagrams: item.datagrams }) + }); let mut packet_stream = n0_future::stream::iter(packet_iter).inspect(|m| { - if let Ok(ClientToServerMsg::SendPacket { dst_key: _node_id, packet: payload }) = m { - metrics.send_relay.inc_by(payload.len() as _); + if let Ok(ClientToServerMsg::SendDatagrams { dst_node_id: _node_id, datagrams }) = m { + metrics.send_relay.inc_by(datagrams.contents.len() as _); } }); let fut = client_sink.send_all(&mut packet_stream); @@ -680,11 +669,11 @@ impl ActiveRelayActor { fn handle_relay_msg(&mut self, msg: ServerToClientMsg, state: &mut ConnectedRelayState) { match msg { - ServerToClientMsg::ReceivedPacket { + ServerToClientMsg::ReceivedDatagrams { remote_node_id, - data, + datagrams, } => { - trace!(len = %data.len(), "received msg"); + trace!(len = %datagrams.contents.len(), "received msg"); // If this is a new sender, register a route for this peer. if state .last_packet_src @@ -696,14 +685,12 @@ impl ActiveRelayActor { state.last_packet_src = Some(remote_node_id); state.nodes_present.insert(remote_node_id); } - for datagram in PacketSplitIter::new(self.url.clone(), remote_node_id, data) { - let Ok(datagram) = datagram else { - warn!("Invalid packet split"); - break; - }; - if let Err(err) = self.relay_datagrams_recv.try_send(datagram) { - warn!("Dropping received relay packet: {err:#}"); - } + if let Err(err) = self.relay_datagrams_recv.try_send(RelayRecvDatagram { + url: self.url.clone(), + src: remote_node_id, + datagrams, + }) { + warn!("Dropping received relay packet: {err:#}"); } } ServerToClientMsg::NodeGone(node_id) => { @@ -852,7 +839,7 @@ pub(crate) struct RelaySendItem { /// The home relay of the remote node. pub(crate) url: RelayUrl, /// One or more datagrams to send. - pub(crate) datagrams: RelayContents, + pub(crate) datagrams: Datagrams, } pub(super) struct RelayActor { @@ -1232,18 +1219,6 @@ struct ActiveRelayHandle { datagrams_send_queue: mpsc::Sender, } -/// A packet to send over the relay. -/// -/// This is nothing but a newtype, it should be constructed using [`PacketizeIter`]. This -/// is a packet of one or more datagrams, each prefixed with a u16-be length. This is what -/// the `Frame::SendPacket` of the `DerpCodec` transports and is produced by -/// [`PacketizeIter`] and transformed back into datagrams using [`PacketSplitIter`]. -#[derive(Debug, PartialEq, Eq)] -struct RelaySendPacket { - node_id: NodeId, - payload: Bytes, -} - /// A single datagram received from a relay server. /// /// This could be either a QUIC or DISCO packet. @@ -1251,115 +1226,7 @@ struct RelaySendPacket { pub(crate) struct RelayRecvDatagram { pub(crate) url: RelayUrl, pub(crate) src: NodeId, - pub(crate) buf: Bytes, -} - -/// Combines datagrams into a single DISCO frame of at most MAX_PACKET_SIZE. -/// -/// The disco `iroh_relay::protos::Frame::SendPacket` frame can contain more then a single -/// datagram. Each datagram in this frame is prefixed with a little-endian 2-byte length -/// prefix. This occurs when Quinn sends a GSO transmit containing more than one datagram, -/// which are split using `split_packets`. -/// -/// The [`PacketSplitIter`] does the inverse and splits such packets back into individual -/// datagrams. -struct PacketizeIter { - node_id: NodeId, - iter: std::iter::Peekable, - buffer: BytesMut, -} - -impl PacketizeIter { - /// Create a new new PacketizeIter from something that can be turned into an - /// iterator of slices, like a `Vec`. - fn new(node_id: NodeId, iter: impl IntoIterator) -> Self { - Self { - node_id, - iter: iter.into_iter().peekable(), - buffer: BytesMut::with_capacity(N), - } - } -} - -impl Iterator for PacketizeIter -where - I::Item: AsRef<[u8]>, -{ - type Item = RelaySendPacket; - - fn next(&mut self) -> Option { - use bytes::BufMut; - while let Some(next_bytes) = self.iter.peek() { - let next_bytes = next_bytes.as_ref(); - assert!(next_bytes.len() + 2 <= N); - let next_length: u16 = next_bytes.len().try_into().expect("items < 64k size"); - if self.buffer.len() + next_bytes.len() + 2 > N { - break; - } - self.buffer.put_u16_le(next_length); - self.buffer.put_slice(next_bytes); - self.iter.next(); - } - if !self.buffer.is_empty() { - Some(RelaySendPacket { - node_id: self.node_id, - payload: self.buffer.split().freeze(), - }) - } else { - None - } - } -} - -/// Splits a single [`ServerToClientMsg::ReceivedPacket`] frame into datagrams. -/// -/// This splits packets joined by [`PacketizeIter`] back into individual datagrams. See -/// that struct for more details. -#[derive(Debug)] -struct PacketSplitIter { - url: RelayUrl, - src: NodeId, - bytes: Bytes, -} - -impl PacketSplitIter { - /// Create a new PacketSplitIter from a packet. - fn new(url: RelayUrl, src: NodeId, bytes: Bytes) -> Self { - Self { url, src, bytes } - } - - fn fail(&mut self) -> Option> { - self.bytes.clear(); - Some(Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "", - ))) - } -} - -impl Iterator for PacketSplitIter { - type Item = std::io::Result; - - fn next(&mut self) -> Option { - use bytes::Buf; - if self.bytes.has_remaining() { - if self.bytes.remaining() < 2 { - return self.fail(); - } - let len = self.bytes.get_u16_le() as usize; - if self.bytes.remaining() < len { - return self.fail(); - } - let buf = self.bytes.split_to(len); - Some(Ok(RelayRecvDatagram { - url: self.url.clone(), - src: self.src, - buf, - })) - } else { - None - } - } + pub(crate) datagrams: Datagrams, } #[cfg(test)] @@ -1369,11 +1236,9 @@ mod tests { time::Duration, }; - use bytes::Bytes; use iroh_base::{NodeId, RelayUrl, SecretKey}; - use iroh_relay::PingTracker; + use iroh_relay::{protos::send_recv::Datagrams, PingTracker}; use n0_snafu::{Error, Result, ResultExt}; - use smallvec::smallvec; use tokio::sync::{mpsc, oneshot}; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{info, info_span, Instrument}; @@ -1381,42 +1246,11 @@ mod tests { use super::{ ActiveRelayActor, ActiveRelayActorOptions, ActiveRelayMessage, ActiveRelayPrioMessage, - PacketizeIter, RelayConnectionOptions, RelayRecvDatagram, RelaySendItem, MAX_PACKET_SIZE, - RELAY_INACTIVE_CLEANUP_TIME, UNDELIVERABLE_DATAGRAM_TIMEOUT, + RelayConnectionOptions, RelayRecvDatagram, RelaySendItem, RELAY_INACTIVE_CLEANUP_TIME, + UNDELIVERABLE_DATAGRAM_TIMEOUT, }; use crate::{dns::DnsResolver, test_utils}; - #[test] - fn test_packetize_iter() { - let node_id = SecretKey::generate(rand::thread_rng()).public(); - let empty_vec: Vec = Vec::new(); - let mut iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(node_id, empty_vec); - assert_eq!(None, iter.next()); - - let single_vec = vec!["Hello"]; - let iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(node_id, single_vec); - let result = iter.collect::>(); - assert_eq!(1, result.len()); - assert_eq!( - &[5, 0, b'H', b'e', b'l', b'l', b'o'], - &result[0].payload[..] - ); - - let spacer = vec![0u8; MAX_PACKET_SIZE - 10]; - let multiple_vec = vec![&b"Hello"[..], &spacer, &b"World"[..]]; - let iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(node_id, multiple_vec); - let result = iter.collect::>(); - assert_eq!(2, result.len()); - assert_eq!( - &[5, 0, b'H', b'e', b'l', b'l', b'o'], - &result[0].payload[..7] - ); - assert_eq!( - &[5, 0, b'W', b'o', b'r', b'l', b'd'], - &result[1].payload[..] - ); - } - /// Starts a new [`ActiveRelayActor`]. #[allow(clippy::too_many_arguments)] fn start_active_relay_actor( @@ -1477,12 +1311,16 @@ mod tests { loop { let datagram = recv_datagram_rx.recv().await; if let Some(recv) = datagram { - let RelayRecvDatagram { url: _, src, buf } = recv; + let RelayRecvDatagram { + url: _, + src, + datagrams, + } = recv; info!(from = src.fmt_short(), "Received datagram"); let send = RelaySendItem { remote_node: src, url: relay_url.clone(), - datagrams: smallvec![buf], + datagrams, }; send_datagram_tx.send(send).await.ok(); } @@ -1516,7 +1354,6 @@ mod tests { tx: &mpsc::Sender, rx: &mut mpsc::Receiver, ) -> Result<()> { - assert!(item.datagrams.len() == 1); tokio::time::timeout(Duration::from_secs(10), async move { loop { let res = tokio::time::timeout(UNDELIVERABLE_DATAGRAM_TIMEOUT, async { @@ -1524,10 +1361,10 @@ mod tests { let RelayRecvDatagram { url: _, src: _, - buf, + datagrams, } = rx.recv().await.unwrap(); - assert_eq!(buf.as_ref(), item.datagrams[0]); + assert_eq!(datagrams, item.datagrams); Ok::<_, Error>(()) }) @@ -1570,7 +1407,7 @@ mod tests { let hello_send_item = RelaySendItem { remote_node: peer_node, url: relay_url.clone(), - datagrams: smallvec![Bytes::from_static(b"hello")], + datagrams: Datagrams::from(b"hello"), }; send_recv_echo( hello_send_item.clone(), From f254fa46ab7c0c32ecbe2c623494604d07079577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 30 Jun 2025 16:47:49 +0200 Subject: [PATCH 28/80] Fix iroh-relay tests --- .../proptest-regressions/protos/send_recv.txt | 8 +++ iroh-relay/src/protos/send_recv.rs | 69 ++++++++++++++----- iroh-relay/src/server/client.rs | 2 +- iroh/src/magicsock/transports/relay/actor.rs | 7 +- 4 files changed, 62 insertions(+), 24 deletions(-) create mode 100644 iroh-relay/proptest-regressions/protos/send_recv.txt diff --git a/iroh-relay/proptest-regressions/protos/send_recv.txt b/iroh-relay/proptest-regressions/protos/send_recv.txt new file mode 100644 index 00000000000..54a1d9d3588 --- /dev/null +++ b/iroh-relay/proptest-regressions/protos/send_recv.txt @@ -0,0 +1,8 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 8f4f94b7c917bb0f52d31b529c3d580728ea57954ca45d91768cf4ae745e6eb9 # shrinks to frame = ReceivedDatagrams { remote_node_id: PublicKey(3b6a27bcceb6a42d62a3a8d02a6f0d73653215771de243a63ac048a18b59da29), datagrams: Datagrams { ecn: None, segment_size: Some(43846), .. } } +cc e40ca61e22386f1c76f717f2a6dbba367ea05d906317b3f979b46031567edbca # shrinks to frame = SendDatagrams { dst_node_id: PublicKey(3b6a27bcceb6a42d62a3a8d02a6f0d73653215771de243a63ac048a18b59da29), datagrams: Datagrams { ecn: None, segment_size: Some(44811), .. } } diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs index 3f548606e94..037108c28c5 100644 --- a/iroh-relay/src/protos/send_recv.rs +++ b/iroh-relay/src/protos/send_recv.rs @@ -28,6 +28,12 @@ use crate::KeyCache; /// including its on-wire framing overhead) pub const MAX_PACKET_SIZE: usize = 64 * 1024; +/// Maximum size a datagram payload is allowed to be. +/// +/// This is [`MAX_PACKET_SIZE`] minus the length of an encoded public key minus 3 bytes, +/// one for ECN, and two for the segment size. +pub const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - NodeId::LENGTH - 3; + /// The maximum frame size. /// /// This is also the minimum burst size that a rate-limiter has to accept. @@ -500,9 +506,18 @@ mod tests { }, } .write_to(Vec::new()), - "0b 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e - a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d - 61 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + // frame type + // public key first 16 bytes + // public key second 16 bytes + // ECN byte + // segment size + // hello world contents bytes + "0b + 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 + 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 + 03 + 00 06 + 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", ), ( ServerToClientMsg::Restarting { @@ -540,9 +555,18 @@ mod tests { }, } .write_to(Vec::new()), - "0a 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e - a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d - 61 47 6f 6f 64 62 79 65 21", + // frame type + // public key first 16 bytes + // public key second 16 bytes + // ECN byte + // segment size + // hello world contents + "0a + 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 + 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 + 03 + 00 06 + 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", ), ]); @@ -566,19 +590,31 @@ mod proptests { secret_key().prop_map(|key| key.public()) } - /// Generates random data, up to the maximum packet size minus the given number of bytes - fn data(consumed: usize) -> impl Strategy { - let len = MAX_PACKET_SIZE - consumed; - prop::collection::vec(any::(), 0..len).prop_map(Bytes::from) + fn ecn() -> impl Strategy> { + (0..=3).prop_map(|n| match n { + 1 => Some(quinn_proto::EcnCodepoint::Ce), + 2 => Some(quinn_proto::EcnCodepoint::Ect0), + 3 => Some(quinn_proto::EcnCodepoint::Ect1), + _ => None, + }) } - fn datagrams(data: impl Strategy) -> impl Strategy { - data.prop_map(|_content| todo!()) + fn datagrams() -> impl Strategy { + ( + ecn(), + prop::option::of(MAX_PAYLOAD_SIZE / 20..MAX_PAYLOAD_SIZE), + prop::collection::vec(any::(), 0..MAX_PAYLOAD_SIZE), + ) + .prop_map(|(ecn, segment_size, data)| Datagrams { + ecn, + segment_size: segment_size.map(|ss| std::cmp::min(data.len(), ss) as u16), + contents: Bytes::from(data), + }) } /// Generates a random valid frame fn server_client_frame() -> impl Strategy { - let recv_packet = (key(), datagrams(data(32))).prop_map(|(remote_node_id, datagrams)| { + let recv_packet = (key(), datagrams()).prop_map(|(remote_node_id, datagrams)| { ServerToClientMsg::ReceivedDatagrams { remote_node_id, datagrams, @@ -587,10 +623,7 @@ mod proptests { let node_gone = key().prop_map(|node_id| ServerToClientMsg::NodeGone(node_id)); let ping = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Pong); - // TODO(matheus23): Actually fix these - let health = data(0).prop_map(|_problem| ServerToClientMsg::Health { - problem: "".to_string(), - }); + let health = ".{0,65536}".prop_map(|problem| ServerToClientMsg::Health { problem }); let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { ServerToClientMsg::Restarting { reconnect_in: Duration::from_millis(reconnect_in.into()), @@ -601,7 +634,7 @@ mod proptests { } fn client_server_frame() -> impl Strategy { - let send_packet = (key(), datagrams(data(32))).prop_map(|(dst_node_id, datagrams)| { + let send_packet = (key(), datagrams()).prop_map(|(dst_node_id, datagrams)| { ClientToServerMsg::SendDatagrams { dst_node_id, datagrams, diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 9bb0793efc3..44e4d0b3972 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -718,7 +718,7 @@ mod tests { let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT); // Prepare a frame to send, assert its size. - let data = Datagrams::from(b"hello world!1eins"); + let data = Datagrams::from(b"hello world!!1"); let target = SecretKey::generate(rand::thread_rng()).public(); let frame = ClientToServerMsg::SendDatagrams { dst_node_id: target, diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index 70a0a1fb818..d03737100fc 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -38,12 +38,12 @@ use std::{ }; use backon::{Backoff, BackoffBuilder, ExponentialBuilder}; -use iroh_base::{NodeId, PublicKey, RelayUrl, SecretKey}; +use iroh_base::{NodeId, RelayUrl, SecretKey}; use iroh_relay::{ self as relay, client::{Client, ConnectError, RecvError, SendError}, protos::send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg}, - PingTracker, MAX_PACKET_SIZE, + PingTracker, }; use n0_future::{ task::JoinSet, @@ -66,9 +66,6 @@ use crate::{magicsock::Metrics as MagicsockMetrics, net_report::Report, util::Ma /// How long a non-home relay connection needs to be idle (last written to) before we close it. const RELAY_INACTIVE_CLEANUP_TIME: Duration = Duration::from_secs(60); -/// Maximum size a datagram payload is allowed to be. -const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - PublicKey::LENGTH; - /// Interval in which we ping the relay server to ensure the connection is alive. /// /// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some From 1464cc9aee6f07ebebaf5f42884ad2a6514cf5f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 30 Jun 2025 17:07:51 +0200 Subject: [PATCH 29/80] Handle some TODOs --- iroh-relay/src/client/conn.rs | 13 ++++----- iroh-relay/src/protos/handshake.rs | 14 ++-------- iroh-relay/src/protos/relay.rs | 10 +++---- iroh-relay/src/protos/send_recv.rs | 43 +++++++++++++----------------- iroh-relay/src/server/client.rs | 4 +-- iroh-relay/src/server/clients.rs | 4 +-- 6 files changed, 33 insertions(+), 55 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 3511b2fd8a8..bf6e46e77be 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -26,7 +26,7 @@ use crate::{ handshake, send_recv::{ ClientToServerMsg, RecvError as RecvRelayError, SendError as SendRelayError, - ServerToClientMsg, + ServerToClientMsg, MAX_PAYLOAD_SIZE, }, }, MAX_PACKET_SIZE, @@ -200,18 +200,15 @@ impl Sink for Conn { } fn start_send(mut self: Pin<&mut Self>, frame: ClientToServerMsg) -> Result<(), Self::Error> { - // TODO(matheus23): Check this in send message construction instead (and also check this in RecvPacket construction) if let ClientToServerMsg::SendDatagrams { datagrams, .. } = &frame { let size = datagrams.contents.len(); - snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); + snafu::ensure!(size <= MAX_PAYLOAD_SIZE, ExceedsMaxPacketSizeSnafu { size }); } #[cfg(not(wasm_browser))] - let frame = tokio_websockets::Message::binary({ - let mut buf = BytesMut::new(); - frame.write_to(&mut buf); - tokio_websockets::Payload::from(buf.freeze()) - }); + let frame = tokio_websockets::Message::binary(tokio_websockets::Payload::from( + frame.write_to(BytesMut::new()).freeze(), + )); #[cfg(wasm_browser)] let frame = ws_stream_wasm::WsMessage::Binary(frame.write_to(Vec::new())); diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 4bd5faa192f..754eb2dee53 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -9,10 +9,9 @@ use n0_future::{ SinkExt, TryStreamExt, }; use nested_enum_utils::common_fields; -use quinn_proto::{coding::Codec, VarInt}; #[cfg(feature = "server")] use rand::{CryptoRng, RngCore}; -use snafu::{Backtrace, ResultExt, Snafu}; +use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; use super::{relay::FrameType, send_recv::SendError, streams::BytesStreamSink}; use crate::ExportKeyingMaterial; @@ -319,13 +318,7 @@ async fn read_frame( .context(TimeoutSnafu)?? .ok_or_else(|| UnexpectedEndSnafu.build())?; - // TODO(matheus23) restructure: use FrameType::from_bytes, perhaps always use `FrameType` instead - let mut cursor = std::io::Cursor::new(recv); - let var_int = VarInt::decode(&mut cursor) - .map_err(|quinn_proto::coding::UnexpectedEnd| UnexpectedEndSnafu.build())?; - let frame_type = u32::try_from(var_int.into_inner()) - .ok() - .map_or(FrameType::Unknown, FrameType::from); + let (frame_type, payload) = FrameType::from_bytes(recv).context(UnexpectedEndSnafu)?; snafu::ensure!( expected_types.contains(&frame_type), UnexpectedFrameTypeSnafu { @@ -334,9 +327,6 @@ async fn read_frame( } ); - let start = cursor.position() as usize; - let payload = cursor.into_inner().slice(start..); - Ok((frame_type, payload)) } diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index bcdaf72d33c..51b0fb0edc1 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -18,12 +18,10 @@ pub enum FrameType { ServerConfirmsAuth = 4, /// The server frame type for authentication denial ServerDeniesAuth = 5, - /// 32B dest pub key + packet bytes TODO(matheus23): Fix docs - SendPacket = 10, - /// v0/1 packet bytes, v2: 32B src pub key + packet bytes TODO(matheus23): Fix docs - RecvPacket = 11, - /// no payload, no-op (to be replaced with ping/pong) - KeepAlive = 12, + /// 32B dest pub key + ECN byte + segment size u16 + datagrams contents + SendDatagrams = 10, + /// 32B src pub key + ECN byte + segment size u16 + datagrams contents + RecvDatagrams = 11, /// Sent from server to client to signal that a previous sender is no longer connected. /// /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs index 037108c28c5..65b2736a244 100644 --- a/iroh-relay/src/protos/send_recv.rs +++ b/iroh-relay/src/protos/send_recv.rs @@ -244,7 +244,7 @@ impl ServerToClientMsg { /// TODO(matheus23): docs pub fn typ(&self) -> FrameType { match self { - Self::ReceivedDatagrams { .. } => FrameType::RecvPacket, + Self::ReceivedDatagrams { .. } => FrameType::RecvDatagrams, Self::NodeGone { .. } => FrameType::NodeGone, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, @@ -296,16 +296,16 @@ impl ServerToClientMsg { #[allow(clippy::result_large_err)] pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; + let frame_len = content.len(); + snafu::ensure!( + frame_len <= MAX_PACKET_SIZE, + FrameTooLargeSnafu { frame_len } + ); + let res = match frame_type { - FrameType::RecvPacket => { + FrameType::RecvDatagrams => { snafu::ensure!(content.len() >= NodeId::LENGTH, InvalidFrameSnafu); - let frame_len = content.len() - NodeId::LENGTH; - snafu::ensure!( - frame_len <= MAX_PACKET_SIZE, - FrameTooLargeSnafu { frame_len } - ); - let remote_node_id = cache.key_from_slice(&content[..NodeId::LENGTH])?; let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; Self::ReceivedDatagrams { @@ -334,7 +334,6 @@ impl ServerToClientMsg { let problem = std::str::from_utf8(&content) .context(InvalidProtocolMessageEncodingSnafu)? .to_owned(); - // TODO(matheus23): Actually encode/decode the option Self::Health { problem } } FrameType::Restarting => { @@ -367,7 +366,7 @@ impl ServerToClientMsg { impl ClientToServerMsg { pub(crate) fn typ(&self) -> FrameType { match self { - Self::SendDatagrams { .. } => FrameType::SendPacket, + Self::SendDatagrams { .. } => FrameType::SendDatagrams, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, } @@ -403,16 +402,14 @@ impl ClientToServerMsg { #[cfg(feature = "server")] pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; - let res = match frame_type { - FrameType::SendPacket => { - if content.len() < NodeId::LENGTH { - return Err(InvalidFrameSnafu.build()); - } - let frame_len = content.len() - NodeId::LENGTH; - if frame_len > MAX_PACKET_SIZE { - return Err(FrameTooLargeSnafu { frame_len }.build()); - } + let frame_len = content.len(); + snafu::ensure!( + frame_len <= MAX_PACKET_SIZE, + FrameTooLargeSnafu { frame_len } + ); + let res = match frame_type { + FrameType::SendDatagrams => { let dst_node_id = cache.key_from_slice(&content[..NodeId::LENGTH])?; let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; Self::SendDatagrams { @@ -421,17 +418,13 @@ impl ClientToServerMsg { } } FrameType::Ping => { - if content.len() != 8 { - return Err(InvalidFrameSnafu.build()); - } + snafu::ensure!(content.len() == 8, InvalidFrameSnafu); let mut data = [0u8; 8]; data.copy_from_slice(&content[..8]); Self::Ping(data) } FrameType::Pong => { - if content.len() != 8 { - return Err(InvalidFrameSnafu.build()); - } + snafu::ensure!(content.len() == 8, InvalidFrameSnafu); let mut data = [0u8; 8]; data.copy_from_slice(&content[..8]); Self::Pong(data) diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 44e4d0b3972..c7f8745cd72 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -631,7 +631,7 @@ mod tests { data: Datagrams::from(&data[..]), }; send_queue_s.send(packet.clone()).await.context("send")?; - let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; + let frame = recv_frame(FrameType::RecvDatagrams, &mut io_rw).await.e()?; assert_eq!( frame, ServerToClientMsg::ReceivedDatagrams { @@ -646,7 +646,7 @@ mod tests { .send(packet.clone()) .await .context("send")?; - let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; + let frame = recv_frame(FrameType::RecvDatagrams, &mut io_rw).await.e()?; assert_eq!( frame, ServerToClientMsg::ReceivedDatagrams { diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 3b8b5731805..382c7fb5fdf 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -255,7 +255,7 @@ mod tests { // send packet let data = b"hello world!"; clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; + let frame = recv_frame(FrameType::RecvDatagrams, &mut a_rw).await?; assert_eq!( frame, ServerToClientMsg::ReceivedDatagrams { @@ -266,7 +266,7 @@ mod tests { // send disco packet clients.send_disco_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; + let frame = recv_frame(FrameType::RecvDatagrams, &mut a_rw).await?; assert_eq!( frame, ServerToClientMsg::ReceivedDatagrams { From 1589c63d7c5146ecf9739e602bd499efdb17d641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 30 Jun 2025 18:04:31 +0200 Subject: [PATCH 30/80] Fix Wasm --- iroh-relay/src/client.rs | 5 +- iroh-relay/src/client/conn.rs | 17 +++---- iroh-relay/src/protos/handshake.rs | 7 ++- iroh-relay/src/protos/streams.rs | 76 ++++++++++++++++++++++++------ 4 files changed, 78 insertions(+), 27 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 5c4f399820f..0eadaab254d 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -288,7 +288,7 @@ impl ClientBuilder { /// Establishes a new connection to the relay server. #[cfg(wasm_browser)] - async fn connect(&self) -> Result { + pub async fn connect(&self) -> Result { let mut dial_url = (*self.url).clone(); dial_url.set_path(RELAY_PATH); // The relay URL is exchanged with the http(s) scheme in tickets and similar. @@ -309,8 +309,7 @@ impl ClientBuilder { debug!(%dial_url, "Dialing relay by websocket"); let (_, ws_stream) = ws_stream_wasm::WsMeta::connect(dial_url.as_str(), None).await?; - let conn = - Conn::new_ws_browser(ws_stream, self.key_cache.clone(), &self.secret_key).await?; + let conn = Conn::new(ws_stream, self.key_cache.clone(), &self.secret_key).await?; event!( target: "events.net.relay.connected", diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index bf6e46e77be..a197f36adef 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -8,6 +8,7 @@ use std::{ task::{ready, Context, Poll}, }; +#[cfg(not(wasm_browser))] use bytes::BytesMut; use iroh_base::SecretKey; use n0_future::{Sink, Stream}; @@ -17,10 +18,7 @@ use tracing::debug; use super::KeyCache; #[cfg(not(wasm_browser))] -use crate::{ - client::streams::{MaybeTlsStream, ProxyStream}, - protos::streams::WsBytesFramed, -}; +use crate::client::streams::{MaybeTlsStream, ProxyStream}; use crate::{ protos::{ handshake, @@ -28,6 +26,7 @@ use crate::{ ClientToServerMsg, RecvError as RecvRelayError, SendError as SendRelayError, ServerToClientMsg, MAX_PAYLOAD_SIZE, }, + streams::WsBytesFramed, }, MAX_PACKET_SIZE, }; @@ -119,6 +118,8 @@ impl Conn { key_cache: KeyCache, secret_key: &SecretKey, ) -> Result { + // We use a phantom type param of ProxyStream for wrapping, just because it's easier to cfg-out code for wasm. + // It's a little ugly though. let mut io = WsBytesFramed { io: conn }; // exchange information with the server @@ -177,10 +178,10 @@ impl Stream for Conn { Some(Err(e)) => Poll::Ready(Some(Err(e.into()))), #[cfg(wasm_browser)] - Some(ws_stream_wasm::WsMessage::Binary(vec)) => { - let frame = Frame::decode_from_ws_msg(Bytes::from(vec), &self.key_cache)?; - Poll::Ready(Some(ReceivedMessage::try_from(frame))) - } + Some(ws_stream_wasm::WsMessage::Binary(vec)) => Poll::Ready(Some( + ServerToClientMsg::from_bytes(bytes::Bytes::from(vec), &self.key_cache) + .map_err(Into::into), + )), #[cfg(wasm_browser)] Some(msg) => { tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 754eb2dee53..c4299f018b3 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -101,7 +101,12 @@ impl Frame for ServerDeniesAuth { #[non_exhaustive] pub enum Error { #[snafu(transparent)] - Websocket { source: tokio_websockets::Error }, + Websocket { + #[cfg(not(wasm_browser))] + source: tokio_websockets::Error, + #[cfg(wasm_browser)] + source: ws_stream_wasm::WsErr, + }, #[snafu(transparent)] Legacy { source: SendError }, #[snafu(display("Handshake timeout reached"))] diff --git a/iroh-relay/src/protos/streams.rs b/iroh-relay/src/protos/streams.rs index 6b040bd78b7..6536d3ffe4d 100644 --- a/iroh-relay/src/protos/streams.rs +++ b/iroh-relay/src/protos/streams.rs @@ -10,20 +10,23 @@ use tokio::io::{AsyncRead, AsyncWrite}; use crate::ExportKeyingMaterial; +#[cfg(not(wasm_browser))] #[derive(derive_more::Debug)] pub(crate) struct WsBytesFramed { - #[cfg(not(wasm_browser))] #[debug("WebSocketStream")] pub(crate) io: tokio_websockets::WebSocketStream, - #[cfg(wasm_browser)] +} + +#[cfg(wasm_browser)] +#[derive(derive_more::Debug)] +pub(crate) struct WsBytesFramed { #[debug("WebSocketStream")] pub(crate) io: ws_stream_wasm::WsStream, - #[cfg(wasm_browser)] - _data: PhantomData, } #[cfg(not(wasm_browser))] type StreamError = tokio_websockets::Error; + #[cfg(wasm_browser)] type StreamError = ws_stream_wasm::WsErr; @@ -38,32 +41,35 @@ impl BytesStreamSink for T where { } +#[cfg(not(wasm_browser))] impl ExportKeyingMaterial for WsBytesFramed { - #[cfg(wasm_browser)] fn export_keying_material>( &self, output: T, label: &[u8], context: Option<&[u8]>, ) -> Option { - None + self.io + .get_ref() + .export_keying_material(output, label, context) } +} - #[cfg(not(wasm_browser))] +#[cfg(wasm_browser)] +impl ExportKeyingMaterial for WsBytesFramed { fn export_keying_material>( &self, - output: T, - label: &[u8], - context: Option<&[u8]>, + _output: T, + _label: &[u8], + _context: Option<&[u8]>, ) -> Option { - self.io - .get_ref() - .export_keying_material(output, label, context) + None } } +#[cfg(not(wasm_browser))] impl Stream for WsBytesFramed { type Item = Result; @@ -95,13 +101,53 @@ impl Stream for WsBytesFramed { } } +#[cfg(wasm_browser)] +impl Stream for WsBytesFramed { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(Pin::new(&mut self.io).poll_next(cx)) { + None => return Poll::Ready(None), + Some(ws_stream_wasm::WsMessage::Binary(msg)) => { + return Poll::Ready(Some(Ok(msg.into()))) + } + Some(msg) => { + tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); + continue; + } + } + } + } +} + +#[cfg(not(wasm_browser))] impl Sink for WsBytesFramed { type Error = StreamError; fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { - #[cfg(not(wasm_browser))] let msg = tokio_websockets::Message::binary(tokio_websockets::Payload::from(bytes)); - #[cfg(wasm_browser)] + Pin::new(&mut self.io).start_send(msg).map_err(Into::into) + } + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_ready(cx).map_err(Into::into) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_flush(cx).map_err(Into::into) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io).poll_close(cx).map_err(Into::into) + } +} + +#[cfg(wasm_browser)] +impl Sink for WsBytesFramed { + type Error = StreamError; + + fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { let msg = ws_stream_wasm::WsMessage::Binary(bytes.to_vec()); Pin::new(&mut self.io).start_send(msg).map_err(Into::into) } From 6659c90359fbf3841ca784e51133d643b8df24f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 1 Jul 2025 09:47:08 +0200 Subject: [PATCH 31/80] Fix proptest & docfixes & clippy --- .../proptest-regressions/protos/send_recv.txt | 1 + iroh-relay/src/client/conn.rs | 9 ++---- iroh-relay/src/client/tls.rs | 10 +------ iroh-relay/src/protos/handshake.rs | 18 +++++------- iroh-relay/src/protos/send_recv.rs | 29 ++++++++++++------- iroh-relay/src/protos/streams.rs | 4 +-- iroh-relay/src/server/streams.rs | 6 ++-- 7 files changed, 36 insertions(+), 41 deletions(-) diff --git a/iroh-relay/proptest-regressions/protos/send_recv.txt b/iroh-relay/proptest-regressions/protos/send_recv.txt index 54a1d9d3588..7ceca20bbb7 100644 --- a/iroh-relay/proptest-regressions/protos/send_recv.txt +++ b/iroh-relay/proptest-regressions/protos/send_recv.txt @@ -6,3 +6,4 @@ # everyone who runs the test benefits from these saved cases. cc 8f4f94b7c917bb0f52d31b529c3d580728ea57954ca45d91768cf4ae745e6eb9 # shrinks to frame = ReceivedDatagrams { remote_node_id: PublicKey(3b6a27bcceb6a42d62a3a8d02a6f0d73653215771de243a63ac048a18b59da29), datagrams: Datagrams { ecn: None, segment_size: Some(43846), .. } } cc e40ca61e22386f1c76f717f2a6dbba367ea05d906317b3f979b46031567edbca # shrinks to frame = SendDatagrams { dst_node_id: PublicKey(3b6a27bcceb6a42d62a3a8d02a6f0d73653215771de243a63ac048a18b59da29), datagrams: Datagrams { ecn: None, segment_size: Some(44811), .. } } +cc 435ec32fc803db22bf4688a6356878073752b58fcd0b4422876fb3ab2a622684 # shrinks to a huge frame = Health { .. } diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index a197f36adef..d27140a1857 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -81,13 +81,8 @@ pub enum RecvError { /// /// This holds a connection to a relay server. It is: /// -/// - A [`Stream`] for [`ReceivedMessage`] to receive from the server. -/// - A [`Sink`] for [`SendMessage`] to send to the server. -/// - A [`Sink`] for [`Frame`] to send to the server. -/// -/// The [`Frame`] sink is a more internal interface, it allows performing the handshake. -/// The [`SendMessage`] and [`ReceivedMessage`] are safer wrappers enforcing some protocol -/// invariants. +/// - A [`Stream`] for [`ServerToClientMsg`] to receive from the server. +/// - A [`Sink`] for [`ClientToServerMsg`] to send to the server. #[derive(derive_more::Debug)] pub(crate) struct Conn { #[debug("tokio_websockets::WebSocketStream")] diff --git a/iroh-relay/src/client/tls.rs b/iroh-relay/src/client/tls.rs index a880b38ad01..e70e4373770 100644 --- a/iroh-relay/src/client/tls.rs +++ b/iroh-relay/src/client/tls.rs @@ -1,16 +1,8 @@ //! Functionality related to lower-level tls-based connection establishment. //! -//! Primarily to support [`ClientBuilder::connect_relay`]. +//! Primarily to support [`ClientBuilder::connect`]. //! //! This doesn't work in the browser - thus separated into its own file. -//! -//! `connect_relay` uses a custom HTTP upgrade header value (see [`HTTP_UPGRADE_PROTOCOL`]), -//! as opposed to [`WEBSOCKET_UPGRADE_PROTOCOL`]. -//! -//! `connect_ws` however reuses websockets for framing. -//! -//! [`HTTP_UPGRADE_PROTOCOL`]: crate::http::HTTP_UPGRADE_PROTOCOL -//! [`WEBSOCKET_UPGRADE_PROTOCOL`]: crate::http::WEBSOCKET_UPGRADE_PROTOCOL // Based on tailscale/derp/derphttp/derphttp_client.go diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index c4299f018b3..db5ed82f7a8 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -13,7 +13,7 @@ use nested_enum_utils::common_fields; use rand::{CryptoRng, RngCore}; use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; -use super::{relay::FrameType, send_recv::SendError, streams::BytesStreamSink}; +use super::{relay::FrameType, streams::BytesStreamSink}; use crate::ExportKeyingMaterial; /// Message that tells the server the client needs a challenge to authenticate. @@ -107,8 +107,6 @@ pub enum Error { #[cfg(wasm_browser)] source: ws_stream_wasm::WsErr, }, - #[snafu(transparent)] - Legacy { source: SendError }, #[snafu(display("Handshake timeout reached"))] Timeout { source: Elapsed }, #[snafu(display("Handshake stream ended prematurely"))] @@ -253,7 +251,7 @@ pub(crate) async fn clientside( } FrameType::ServerDeniesAuth => { let _denial: ServerDeniesAuth = deserialize_frame(frame)?; - return Err(ServerDeniedAuthSnafu.build()); + Err(ServerDeniedAuthSnafu.build()) } _ => unreachable!(), } @@ -328,7 +326,7 @@ async fn read_frame( expected_types.contains(&frame_type), UnexpectedFrameTypeSnafu { frame_type, - expected_types: expected_types.into_iter().cloned().collect::>() + expected_types: expected_types.to_vec() } ); @@ -436,18 +434,18 @@ mod tests { let mut client_io = Framed::new(client, LengthDelimitedCodec::new()) .map_ok(BytesMut::freeze) - .map_err(|e| tokio_websockets::Error::Io(e).into()) - .sink_map_err(|e| tokio_websockets::Error::Io(e).into()) + .map_err(tokio_websockets::Error::Io) + .sink_map_err(tokio_websockets::Error::Io) .with_shared_secret(client_shared_secret); let mut server_io = Framed::new(server, LengthDelimitedCodec::new()) .map_ok(BytesMut::freeze) - .map_err(|e| tokio_websockets::Error::Io(e).into()) - .sink_map_err(|e| tokio_websockets::Error::Io(e).into()) + .map_err(tokio_websockets::Error::Io) + .sink_map_err(tokio_websockets::Error::Io) .with_shared_secret(server_shared_secret); let (_, client_auth) = n0_future::future::try_zip( async { - super::clientside(&mut client_io, &secret_key) + super::clientside(&mut client_io, secret_key) .await .context("clientside") }, diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs index 65b2736a244..588aeffd622 100644 --- a/iroh-relay/src/protos/send_recv.rs +++ b/iroh-relay/src/protos/send_recv.rs @@ -142,7 +142,7 @@ pub enum ServerToClientMsg { /// /// If `None` means the connection is healthy again. /// - /// The default condition is healthy, so the server doesn't broadcast a [`ReceivedMessage::Health`] + /// The default condition is healthy, so the server doesn't broadcast a [`ServerToClientMsg::Health`] /// until a problem exists. problem: String, }, @@ -157,12 +157,10 @@ pub enum ServerToClientMsg { /// than a few seconds. try_for: Duration, }, - /// TODO(matheus23) fix docs - /// Request from a client or server to reply to the - /// other side with a [`ReceivedMessage::Pong`] with the given payload. + /// Request from the server to reply to the + /// other side with a [`ClientToServerMsg::Pong`] with the given payload. Ping([u8; 8]), - /// TODO(matheus23) fix docs - /// Reply to a [`ReceivedMessage::Ping`] from a client or server + /// Reply to a [`ClientToServerMsg::Ping`] from a client /// with the payload sent previously in the ping. Pong([u8; 8]), } @@ -170,9 +168,11 @@ pub enum ServerToClientMsg { /// TODO(matheus23): Docs #[derive(Debug, Clone, PartialEq, Eq)] pub enum ClientToServerMsg { - /// TODO + /// Request from the client to the server to reply to the + /// other side with a [`ServerToClientMsg::Pong`] with the given payload. Ping([u8; 8]), - /// TODO + /// Reply to a [`ServerToClientMsg::Ping`] from a server + /// with the payload sent previously in the ping. Pong([u8; 8]), /// TODO SendDatagrams { @@ -183,7 +183,10 @@ pub enum ClientToServerMsg { }, } -/// TODO(matheus23): Docs +/// One or multiple datagrams being transferred via the relay. +/// +/// This type is modeled after [`quinn_proto::Transmit`] +/// (or even more similarly `quinn_udp::Transmit`, but we don't depend on that library here). #[derive(derive_more::Debug, Clone, PartialEq, Eq)] pub struct Datagrams { /// Explicit congestion notification bits @@ -613,10 +616,14 @@ mod proptests { datagrams, } }); - let node_gone = key().prop_map(|node_id| ServerToClientMsg::NodeGone(node_id)); + let node_gone = key().prop_map(ServerToClientMsg::NodeGone); let ping = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Pong); - let health = ".{0,65536}".prop_map(|problem| ServerToClientMsg::Health { problem }); + let health = ".{0,65536}" + .prop_filter("exceeds MAX_PAYLOAD_SIZE", |s| { + s.len() < MAX_PAYLOAD_SIZE // a single unicode character can match a regex "." but take up multiple bytes + }) + .prop_map(|problem| ServerToClientMsg::Health { problem }); let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { ServerToClientMsg::Restarting { reconnect_in: Duration::from_millis(reconnect_in.into()), diff --git a/iroh-relay/src/protos/streams.rs b/iroh-relay/src/protos/streams.rs index 6536d3ffe4d..12df4e860c1 100644 --- a/iroh-relay/src/protos/streams.rs +++ b/iroh-relay/src/protos/streams.rs @@ -77,7 +77,7 @@ impl Stream for WsBytesFramed { loop { match ready!(Pin::new(&mut self.io).poll_next(cx)) { None => return Poll::Ready(None), - Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), Some(Ok(msg)) => { if msg.is_close() { // Indicate the stream is done when we receive a close message. @@ -127,7 +127,7 @@ impl Sink for WsBytesFramed { fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { let msg = tokio_websockets::Message::binary(tokio_websockets::Payload::from(bytes)); - Pin::new(&mut self.io).start_send(msg).map_err(Into::into) + Pin::new(&mut self.io).start_send(msg) } fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 0bd59778bf7..ac1a39bcc3d 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -19,9 +19,11 @@ use crate::{ ExportKeyingMaterial, KeyCache, }; -/// A Stream and Sink for [`Frame`]s connected to a single relay client. +/// The relay's connection to a client. /// -/// The stream receives message from the client while the sink sends them to the client. +/// This implements +/// - a [`Stream`] of [`ClientToServerMsg`]s that are received from the client, +/// - a [`Sink`] of [`ServerToClientMsg`]s that can be sent to the client. #[derive(Debug)] pub(crate) struct RelayedStream { pub(crate) inner: WebSocketStream>, From a3282f5fa2134c6fabbc29e3414830e5c523a79a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 1 Jul 2025 14:05:41 +0200 Subject: [PATCH 32/80] Run integraiton test against philipp.iroh.link relay --- iroh/tests/integration.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/iroh/tests/integration.rs b/iroh/tests/integration.rs index e181bdf3745..053e14aa75f 100644 --- a/iroh/tests/integration.rs +++ b/iroh/tests/integration.rs @@ -9,9 +9,11 @@ //! //! In the past we've hit relay rate-limits from all the tests in our CI, but I expect //! we won't hit these with only this integration test. +use std::str::FromStr; + use iroh::{ discovery::{pkarr::PkarrResolver, Discovery}, - Endpoint, + Endpoint, RelayMap, RelayMode, RelayUrl, }; use n0_future::{ task, @@ -37,8 +39,19 @@ async fn simple_node_id_based_connection_transfer() -> Result { std::panic::set_hook(Box::new(console_error_panic_hook::hook)); setup_logging(); - let client = Endpoint::builder().discovery_n0().bind().await?; + // TODO(matheus23): Replace this with actual production relays eventually + let relay_map = RelayMode::Custom(RelayMap::from_iter([RelayUrl::from_str( + "https://philipp.iroh.link.", + ) + .e()?])); + + let client = Endpoint::builder() + .relay_mode(relay_map.clone()) + .discovery_n0() + .bind() + .await?; let server = Endpoint::builder() + .relay_mode(relay_map) .discovery_n0() .alpns(vec![ECHO_ALPN.to_vec()]) .bind() @@ -134,5 +147,7 @@ fn setup_logging() { #[cfg(not(wasm_browser))] fn setup_logging() { - tracing_subscriber::fmt().init(); + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); } From 4f55f72b8b4e2184517130ff0eb6fa23d8e5b843 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 2 Jul 2025 17:05:38 +0200 Subject: [PATCH 33/80] 1-RTT faster handshake by sending client auth with HTTP header --- iroh-relay/src/client.rs | 20 +++++++--- iroh-relay/src/http.rs | 5 +++ iroh-relay/src/protos/handshake.rs | 60 ++++++++++++++++------------ iroh-relay/src/protos/streams.rs | 1 + iroh-relay/src/server/http_server.rs | 31 +++++++++----- iroh/src/net_report/reportgen.rs | 2 +- iroh/tests/integration.rs | 1 + 7 files changed, 78 insertions(+), 42 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 0eadaab254d..d64dc6478b9 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -202,7 +202,10 @@ impl ClientBuilder { pub async fn connect(&self) -> Result { use tls::MaybeTlsStreamBuilder; - use crate::protos::send_recv::MAX_FRAME_SIZE; + use crate::{ + http::CLIENT_AUTH_HEADER, + protos::{handshake::ClientAuth, send_recv::MAX_FRAME_SIZE}, + }; let mut dial_url = (*self.url).clone(); dial_url.set_path(RELAY_PATH); @@ -238,7 +241,7 @@ impl ClientBuilder { .as_ref() .local_addr() .map_err(|_| NoLocalAddrSnafu.build())?; - let (conn, response) = tokio_websockets::ClientBuilder::new() + let mut builder = tokio_websockets::ClientBuilder::new() .uri(dial_url.as_str()) .map_err(|_| { InvalidRelayUrlSnafu { @@ -246,9 +249,16 @@ impl ClientBuilder { } .build() })? - .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))) - .connect_on(stream) - .await?; + .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))); + if let Some(client_auth) = ClientAuth::new_from_key_export(&self.secret_key, &stream) { + debug!("Using TLS key export for relay client authentication"); + builder = builder + .add_header(CLIENT_AUTH_HEADER, client_auth.to_header_value()) + .expect( + "impossible: CLIENT_AUTH_HEADER isn't a disallowed header value for websockets", + ); + } + let (conn, response) = builder.connect_on(stream).await?; if response.status() != hyper::StatusCode::SWITCHING_PROTOCOLS { UnexpectedUpgradeStatusSnafu { diff --git a/iroh-relay/src/http.rs b/iroh-relay/src/http.rs index cca8c82aee0..5957f816808 100644 --- a/iroh-relay/src/http.rs +++ b/iroh-relay/src/http.rs @@ -1,5 +1,7 @@ //! HTTP-specific constants for the relay server and client. +use http::HeaderName; + #[cfg(feature = "server")] pub(crate) const WEBSOCKET_UPGRADE_PROTOCOL: &str = "websocket"; #[cfg(feature = "server")] // only used in the server for now @@ -10,3 +12,6 @@ pub(crate) const SUPPORTED_WEBSOCKET_VERSION: &str = "13"; pub const RELAY_PATH: &str = "/relay"; /// The HTTP path under which the relay allows doing latency queries for testing. pub const RELAY_PROBE_PATH: &str = "/ping"; + +/// The HTTP header name for relay client authentication +pub const CLIENT_AUTH_HEADER: HeaderName = HeaderName::from_static("x-iroh-relay-client-auth"); diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index db5ed82f7a8..9a269421d35 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -1,6 +1,7 @@ //! TODO(matheus23) docs use bytes::{BufMut, Bytes, BytesMut}; +use http::HeaderValue; #[cfg(feature = "server")] use iroh_base::Signature; use iroh_base::{PublicKey, SecretKey}; @@ -123,6 +124,9 @@ pub enum Error { frame_type: FrameType, source: postcard::Error, }, + #[cfg(feature = "server")] + /// Failed to deserialize client auth header + ClientAuthHeaderInvalid { value: HeaderValue }, } impl ServerChallenge { @@ -165,7 +169,7 @@ impl ClientAuth { pub(crate) fn new_from_key_export( secret_key: &SecretKey, - io: &mut impl ExportKeyingMaterial, + io: &impl ExportKeyingMaterial, ) -> Option { let public_key = secret_key.public(); let key_material = io.export_keying_material( @@ -185,6 +189,14 @@ impl ClientAuth { }) } + pub(crate) fn to_header_value(self) -> HeaderValue { + HeaderValue::from_str( + &data_encoding::BASE64URL_NOPAD + .encode(&postcard::to_allocvec(&self).expect("encoding never fails")), + ) + .expect("BASE64URL_NOPAD encoding contained invisible ascii characters") + } + #[cfg(feature = "server")] pub(crate) fn verify_from_key_export(&self, io: &mut impl ExportKeyingMaterial) -> bool { let Some(key_material) = io.export_keying_material( @@ -210,20 +222,9 @@ pub(crate) async fn clientside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), secret_key: &SecretKey, ) -> Result { - if let Some(client_auth) = ClientAuth::new_from_key_export(secret_key, io) { - write_frame(io, client_auth).await?; - } else { - // we can't use key exporting, so request a challenge. - write_frame(io, ClientRequestChallenge).await?; - } - let (tag, frame) = read_frame( io, - &[ - ServerChallenge::TAG, - ServerConfirmsAuth::TAG, - ServerDeniesAuth::TAG, - ], + &[ServerChallenge::TAG, ServerConfirmsAuth::TAG], time::Duration::from_secs(30), ) .await?; @@ -261,24 +262,30 @@ pub(crate) async fn clientside( #[cfg(feature = "server")] pub(crate) async fn serverside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), + client_auth_header: Option, rng: impl RngCore + CryptoRng, ) -> Result { - let (tag, frame) = read_frame( - io, - &[ClientRequestChallenge::TAG, ClientAuth::TAG], - time::Duration::from_secs(10), - ) - .await?; + if let Some(client_auth_header) = client_auth_header { + let client_auth_bytes = data_encoding::BASE64URL_NOPAD + .decode(client_auth_header.as_ref()) + .map_err(|_| { + ClientAuthHeaderInvalidSnafu { + value: client_auth_header.clone(), + } + .build() + })?; + + let client_auth: ClientAuth = postcard::from_bytes(&client_auth_bytes).map_err(|_| { + ClientAuthHeaderInvalidSnafu { + value: client_auth_header.clone(), + } + .build() + })?; - // it might be fast-path authentication using TLS exported key material - if tag == ClientAuth::TAG { - let client_auth: ClientAuth = deserialize_frame(frame)?; if client_auth.verify_from_key_export(io) { write_frame(io, ServerConfirmsAuth).await?; return Ok(client_auth); } - } else { - let _frame: ClientRequestChallenge = deserialize_frame(frame)?; } let challenge = ServerChallenge::new(rng); @@ -443,6 +450,9 @@ mod tests { .sink_map_err(tokio_websockets::Error::Io) .with_shared_secret(server_shared_secret); + let client_auth_header = ClientAuth::new_from_key_export(secret_key, &mut client_io) + .map(ClientAuth::to_header_value); + let (_, client_auth) = n0_future::future::try_zip( async { super::clientside(&mut client_io, secret_key) @@ -450,7 +460,7 @@ mod tests { .context("clientside") }, async { - super::serverside(&mut server_io, rand::rngs::OsRng) + super::serverside(&mut server_io, client_auth_header, rand::rngs::OsRng) .await .context("serverside") }, diff --git a/iroh-relay/src/protos/streams.rs b/iroh-relay/src/protos/streams.rs index 12df4e860c1..371f005c044 100644 --- a/iroh-relay/src/protos/streams.rs +++ b/iroh-relay/src/protos/streams.rs @@ -6,6 +6,7 @@ use std::{ use bytes::Bytes; use n0_future::{ready, Sink, Stream}; +#[cfg(not(wasm_browser))] use tokio::io::{AsyncRead, AsyncWrite}; use crate::ExportKeyingMaterial; diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index d52204f4b8a..eabb9d76565 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -36,7 +36,7 @@ use crate::{ KeyCache, }; use crate::{ - http::WEBSOCKET_UPGRADE_PROTOCOL, + http::{CLIENT_AUTH_HEADER, WEBSOCKET_UPGRADE_PROTOCOL}, protos::{handshake, send_recv::MAX_FRAME_SIZE, streams::WsBytesFramed}, server::streams::RateLimited, }; @@ -496,7 +496,7 @@ impl RelayService { .expect("valid body")); } - debug!("upgrading connection"); + let client_auth_header = req.headers().get(CLIENT_AUTH_HEADER).cloned(); // Setup a future that will eventually receive the upgraded // connection and talk a new protocol, and spawn the future @@ -509,7 +509,11 @@ impl RelayService { async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { - if let Err(err) = this.0.relay_connection_handler(upgraded).await { + if let Err(err) = this + .0 + .relay_connection_handler(upgraded, client_auth_header) + .await + { warn!("error accepting upgraded connection: {err:#}",); } else { debug!("upgraded connection completed"); @@ -597,6 +601,7 @@ impl Inner { async fn relay_connection_handler( &self, upgraded: Upgraded, + client_auth_header: Option, ) -> Result<(), ConnectionHandlerError> { debug!("relay_connection upgraded"); let (io, read_buf) = downcast_upgrade(upgraded)?; @@ -604,7 +609,7 @@ impl Inner { return Err(BufferNotEmptySnafu { buf: read_buf }.build()); } - self.accept(io).await?; + self.accept(io, client_auth_header).await?; Ok(()) } @@ -618,7 +623,11 @@ impl Inner { /// /// [`AsyncRead`]: tokio::io::AsyncRead /// [`AsyncWrite`]: tokio::io::AsyncWrite - async fn accept(&self, io: MaybeTlsStream) -> Result<(), AcceptError> { + async fn accept( + &self, + io: MaybeTlsStream, + client_auth_header: Option, + ) -> Result<(), AcceptError> { use snafu::ResultExt; trace!("accept: start"); @@ -636,7 +645,7 @@ impl Inner { let mut io = WsBytesFramed { io: websocket }; - let client_info = handshake::serverside(&mut io, rand::rngs::OsRng) + let client_info = handshake::serverside(&mut io, client_auth_header, rand::rngs::OsRng) .await .context(HandshakeSnafu)?; @@ -1098,7 +1107,7 @@ mod tests { let (client_a, rw_a) = tokio::io::duplex(10); let s = service.clone(); let handler_task = - tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a)).await }); + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a), None).await }); let mut client_a = make_test_client(client_a, &key_a).await?; handler_task.await.context("join")??; @@ -1108,7 +1117,7 @@ mod tests { let (client_b, rw_b) = tokio::io::duplex(10); let s = service.clone(); let handler_task = - tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b)).await }); + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b), None).await }); let mut client_b = make_test_client(client_b, &key_b).await?; handler_task.await.context("join")??; @@ -1188,7 +1197,7 @@ mod tests { let (client_a, rw_a) = tokio::io::duplex(10); let s = service.clone(); let handler_task = - tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a)).await }); + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_a), None).await }); let mut client_a = make_test_client(client_a, &key_a).await?; handler_task.await.context("join")??; @@ -1198,7 +1207,7 @@ mod tests { let (client_b, rw_b) = tokio::io::duplex(10); let s = service.clone(); let handler_task = - tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b)).await }); + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(rw_b), None).await }); let mut client_b = make_test_client(client_b, &key_b).await?; handler_task.await.context("join")??; @@ -1248,7 +1257,7 @@ mod tests { let (new_client_b, new_rw_b) = tokio::io::duplex(10); let s = service.clone(); let handler_task = - tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(new_rw_b)).await }); + tokio::spawn(async move { s.0.accept(MaybeTlsStream::Test(new_rw_b), None).await }); let mut new_client_b = make_test_client(new_client_b, &key_b).await?; handler_task.await.context("join")??; diff --git a/iroh/src/net_report/reportgen.rs b/iroh/src/net_report/reportgen.rs index 06e46f1e381..9260ab60666 100644 --- a/iroh/src/net_report/reportgen.rs +++ b/iroh/src/net_report/reportgen.rs @@ -379,7 +379,7 @@ impl Actor { let res = match res { Some(Ok(Ok(report))) => Ok(report), Some(Ok(Err(err))) => { - warn!("probe failed: {:#}", err); + warn!("probe failed: {:#?}", err); Err(probes_error::ProbeFailureSnafu {}.into_error(err)) } Some(Err(time::Elapsed { .. })) => { diff --git a/iroh/tests/integration.rs b/iroh/tests/integration.rs index 053e14aa75f..7534aaa32d3 100644 --- a/iroh/tests/integration.rs +++ b/iroh/tests/integration.rs @@ -133,6 +133,7 @@ async fn simple_node_id_based_connection_transfer() -> Result { #[cfg(wasm_browser)] fn setup_logging() { tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_str("trace").expect("hardcoded")) .with_max_level(tracing::level_filters::LevelFilter::DEBUG) .with_writer( // To avoide trace events in the browser from showing their JS backtrace From a3f76af66cf1d6eeab760fb2b5ab2dad9d369031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 2 Jul 2025 18:02:58 +0200 Subject: [PATCH 34/80] Split up `ClientAuth` frame type --- iroh-relay/src/client.rs | 4 +- iroh-relay/src/protos/handshake.rs | 175 ++++++++++++++++----------- iroh-relay/src/protos/relay.rs | 2 - iroh-relay/src/server/http_server.rs | 10 +- 4 files changed, 114 insertions(+), 77 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index d64dc6478b9..a32f569d208 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -204,7 +204,7 @@ impl ClientBuilder { use crate::{ http::CLIENT_AUTH_HEADER, - protos::{handshake::ClientAuth, send_recv::MAX_FRAME_SIZE}, + protos::{handshake::KeyMaterialClientAuth, send_recv::MAX_FRAME_SIZE}, }; let mut dial_url = (*self.url).clone(); @@ -250,7 +250,7 @@ impl ClientBuilder { .build() })? .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))); - if let Some(client_auth) = ClientAuth::new_from_key_export(&self.secret_key, &stream) { + if let Some(client_auth) = KeyMaterialClientAuth::new(&self.secret_key, &stream) { debug!("Using TLS key export for relay client authentication"); builder = builder .add_header(CLIENT_AUTH_HEADER, client_auth.to_header_value()) diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 9a269421d35..f1ea9c33c2b 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -17,10 +17,20 @@ use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; use super::{relay::FrameType, streams::BytesStreamSink}; use crate::ExportKeyingMaterial; -/// Message that tells the server the client needs a challenge to authenticate. +/// Authentication message from the client. #[derive(derive_more::Debug, serde::Serialize)] #[cfg_attr(feature = "server", derive(serde::Deserialize))] -pub(crate) struct ClientRequestChallenge; +pub(crate) struct KeyMaterialClientAuth { + /// The client's public key + pub(crate) public_key: PublicKey, + /// A signature of (a hash of) extracted key material. + #[serde(with = "serde_bytes")] + pub(crate) signature: [u8; 64], + /// Part of the extracted key material. + /// + /// Allows making sure we have the same underlying key material. + pub(crate) key_material_suffix: [u8; 16], +} /// A challenge for the client to sign with their secret key for NodeId authentication. #[derive(derive_more::Debug, serde::Deserialize)] @@ -33,18 +43,17 @@ pub(crate) struct ServerChallenge { /// Authentication message from the client. /// -/// Also serves to inform the server about the client's send message version, -/// which will be passed on to other connecting clients. +/// Used when authentication via [`KeyMaterialClientAuth`] didn't work. #[derive(derive_more::Debug, serde::Serialize)] #[cfg_attr(feature = "server", derive(serde::Deserialize))] pub(crate) struct ClientAuth { /// The client's public key, a.k.a. the `NodeId` pub(crate) public_key: PublicKey, - /// A signature of the server challenge, serves as authentication. + /// A signature of (a hash of) the [`ServerChallenge`]. + /// + /// This is what provides the authentication. #[serde(with = "serde_bytes")] pub(crate) signature: [u8; 64], - /// Part of the extracted key material, if that's what was signed. - pub(crate) key_material_suffix: Option<[u8; 16]>, } /// Confirmation of successful connection. @@ -72,10 +81,6 @@ impl Frame for &T { const TAG: FrameType = T::TAG; } -impl Frame for ClientRequestChallenge { - const TAG: FrameType = FrameType::ClientRequestChallenge; -} - impl Frame for ServerChallenge { const TAG: FrameType = FrameType::ServerChallenge; } @@ -148,17 +153,16 @@ impl ServerChallenge { impl ClientAuth { /// TODO(matheus23): docs - pub(crate) fn new_from_challenge(secret_key: &SecretKey, challenge: &ServerChallenge) -> Self { + pub(crate) fn new(secret_key: &SecretKey, challenge: &ServerChallenge) -> Self { Self { public_key: secret_key.public(), - key_material_suffix: None, signature: secret_key.sign(&challenge.message_to_sign()).to_bytes(), } } /// TODO(matheus23): docs #[cfg(feature = "server")] - pub(crate) fn verify_from_challenge(&self, challenge: &ServerChallenge) -> bool { + pub(crate) fn verify(&self, challenge: &ServerChallenge) -> bool { self.public_key .verify( &challenge.message_to_sign(), @@ -166,26 +170,20 @@ impl ClientAuth { ) .is_ok() } +} - pub(crate) fn new_from_key_export( - secret_key: &SecretKey, - io: &impl ExportKeyingMaterial, - ) -> Option { +impl KeyMaterialClientAuth { + pub(crate) fn new(secret_key: &SecretKey, io: &impl ExportKeyingMaterial) -> Option { let public_key = secret_key.public(); let key_material = io.export_keying_material( [0u8; 32], b"iroh-relay handshake v1", Some(secret_key.public().as_bytes()), )?; - - let message = blake3::derive_key( - "iroh-relay handshake v1 key material signature", - &key_material[..16], - ); - Some(ClientAuth { + Some(Self { public_key, - signature: secret_key.sign(&message).to_bytes(), - key_material_suffix: Some(key_material[16..].try_into().expect("split right")), + signature: secret_key.sign(&key_material[..16]).to_bytes(), + key_material_suffix: key_material[16..].try_into().expect("split right"), }) } @@ -198,7 +196,7 @@ impl ClientAuth { } #[cfg(feature = "server")] - pub(crate) fn verify_from_key_export(&self, io: &mut impl ExportKeyingMaterial) -> bool { + pub(crate) fn verify(&self, io: &impl ExportKeyingMaterial) -> bool { let Some(key_material) = io.export_keying_material( [0u8; 32], b"iroh-relay handshake v1", @@ -207,13 +205,11 @@ impl ClientAuth { return false; }; - let message = blake3::derive_key( - "iroh-relay handshake v1 key material signature", - &key_material[..16], - ); - self.public_key - .verify(&message, &Signature::from_bytes(&self.signature)) - .is_ok() + key_material[16..] == self.key_material_suffix + && self + .public_key + .verify(&key_material[..16], &Signature::from_bytes(&self.signature)) + .is_ok() } } @@ -232,7 +228,7 @@ pub(crate) async fn clientside( let (tag, frame) = if tag == ServerChallenge::TAG { let challenge: ServerChallenge = deserialize_frame(frame)?; - let client_info = ClientAuth::new_from_challenge(secret_key, &challenge); + let client_info = ClientAuth::new(secret_key, &challenge); write_frame(io, client_info).await?; read_frame( @@ -258,13 +254,20 @@ pub(crate) async fn clientside( } } +#[cfg(feature = "server")] +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum Auth { + SignedChallenge, + SignedKeyMaterial, +} + /// TODO(matheus23) docs #[cfg(feature = "server")] pub(crate) async fn serverside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), client_auth_header: Option, rng: impl RngCore + CryptoRng, -) -> Result { +) -> Result<(PublicKey, Auth), Error> { if let Some(client_auth_header) = client_auth_header { let client_auth_bytes = data_encoding::BASE64URL_NOPAD .decode(client_auth_header.as_ref()) @@ -275,16 +278,17 @@ pub(crate) async fn serverside( .build() })?; - let client_auth: ClientAuth = postcard::from_bytes(&client_auth_bytes).map_err(|_| { - ClientAuthHeaderInvalidSnafu { - value: client_auth_header.clone(), - } - .build() - })?; + let client_auth: KeyMaterialClientAuth = + postcard::from_bytes(&client_auth_bytes).map_err(|_| { + ClientAuthHeaderInvalidSnafu { + value: client_auth_header.clone(), + } + .build() + })?; - if client_auth.verify_from_key_export(io) { + if client_auth.verify(io) { write_frame(io, ServerConfirmsAuth).await?; - return Ok(client_auth); + return Ok((client_auth.public_key, Auth::SignedKeyMaterial)); } } @@ -294,13 +298,13 @@ pub(crate) async fn serverside( let (_, frame) = read_frame(io, &[ClientAuth::TAG], time::Duration::from_secs(10)).await?; let client_auth: ClientAuth = deserialize_frame(frame)?; - if client_auth.verify_from_challenge(&challenge) { + if client_auth.verify(&challenge) { write_frame(io, ServerConfirmsAuth).await?; } else { write_frame(io, ServerDeniesAuth).await?; } - Ok(client_auth) + Ok((client_auth.public_key, Auth::SignedChallenge)) } async fn write_frame( @@ -347,12 +351,12 @@ fn deserialize_frame(frame: Bytes) -> Re #[cfg(all(test, feature = "server"))] mod tests { use bytes::BytesMut; - use iroh_base::SecretKey; + use iroh_base::{PublicKey, SecretKey}; use n0_future::{Sink, SinkExt, Stream, TryStreamExt}; use n0_snafu::{Result, ResultExt}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; - use super::{ClientAuth, ServerChallenge}; + use super::{Auth, ClientAuth, KeyMaterialClientAuth, ServerChallenge}; use crate::ExportKeyingMaterial; struct TestKeyingMaterial { @@ -436,7 +440,7 @@ mod tests { secret_key: &SecretKey, client_shared_secret: Option, server_shared_secret: Option, - ) -> Result { + ) -> Result<(PublicKey, Auth)> { let (client, server) = tokio::io::duplex(1024); let mut client_io = Framed::new(client, LengthDelimitedCodec::new()) @@ -450,10 +454,10 @@ mod tests { .sink_map_err(tokio_websockets::Error::Io) .with_shared_secret(server_shared_secret); - let client_auth_header = ClientAuth::new_from_key_export(secret_key, &mut client_io) - .map(ClientAuth::to_header_value); + let client_auth_header = KeyMaterialClientAuth::new(secret_key, &mut client_io) + .map(KeyMaterialClientAuth::to_header_value); - let (_, client_auth) = n0_future::future::try_zip( + let (_, auth) = n0_future::future::try_zip( async { super::clientside(&mut client_io, secret_key) .await @@ -467,24 +471,24 @@ mod tests { ) .await?; - Ok(client_auth) + Ok(auth) } #[tokio::test] async fn test_handshake_via_shared_secrets() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); - let auth = simulate_handshake(&secret_key, Some(42), Some(42)).await?; - assert_eq!(auth.public_key, secret_key.public()); - assert!(auth.key_material_suffix.is_some()); // it got verified via shared key material + let (public_key, auth) = simulate_handshake(&secret_key, Some(42), Some(42)).await?; + assert_eq!(public_key, secret_key.public()); + assert_eq!(auth, Auth::SignedKeyMaterial); // it got verified via shared key material Ok(()) } #[tokio::test] async fn test_handshake_via_challenge() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); - let auth = simulate_handshake(&secret_key, None, None).await?; - assert_eq!(auth.public_key, secret_key.public()); - assert!(auth.key_material_suffix.is_none()); + let (public_key, auth) = simulate_handshake(&secret_key, None, None).await?; + assert_eq!(public_key, secret_key.public()); + assert_eq!(auth, Auth::SignedChallenge); Ok(()) } @@ -492,9 +496,9 @@ mod tests { async fn test_handshake_mismatching_shared_secrets() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); // mismatching shared secrets *might* happen with HTTPS proxies that don't also middle-man the shared secret - let auth = simulate_handshake(&secret_key, Some(10), Some(99)).await?; - assert_eq!(auth.public_key, secret_key.public()); - assert!(auth.key_material_suffix.is_none()); + let (public_key, auth) = simulate_handshake(&secret_key, Some(10), Some(99)).await?; + assert_eq!(public_key, secret_key.public()); + assert_eq!(auth, Auth::SignedChallenge); Ok(()) } @@ -502,9 +506,9 @@ mod tests { async fn test_handshake_challenge_fallback() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); // clients might not have access to shared secrets - let auth = simulate_handshake(&secret_key, None, Some(99)).await?; - assert_eq!(auth.public_key, secret_key.public()); - assert!(auth.key_material_suffix.is_none()); + let (public_key, auth) = simulate_handshake(&secret_key, None, Some(99)).await?; + assert_eq!(public_key, secret_key.public()); + assert_eq!(auth, Auth::SignedChallenge); Ok(()) } @@ -512,13 +516,33 @@ mod tests { fn test_client_auth_roundtrip() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); let challenge = ServerChallenge::new(rand::rngs::OsRng); - let client_auth = ClientAuth::new_from_challenge(&secret_key, &challenge); + let client_auth = ClientAuth::new(&secret_key, &challenge); let bytes = postcard::to_allocvec(&client_auth).e()?; let decoded: ClientAuth = postcard::from_bytes(&bytes).e()?; assert_eq!(client_auth.public_key, decoded.public_key); - assert_eq!(client_auth.key_material_suffix, decoded.key_material_suffix); + assert_eq!(client_auth.signature, decoded.signature); + + Ok(()) + } + + #[test] + fn test_km_client_auth_roundtrip() -> Result { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + let client_auth = KeyMaterialClientAuth::new( + &secret_key, + &TestKeyingMaterial { + inner: (), + shared_secret: Some(42), + }, + ) + .e()?; + + let bytes = postcard::to_allocvec(&client_auth).e()?; + let decoded: KeyMaterialClientAuth = postcard::from_bytes(&bytes).e()?; + + assert_eq!(client_auth.public_key, decoded.public_key); assert_eq!(client_auth.signature, decoded.signature); Ok(()) @@ -528,8 +552,21 @@ mod tests { fn test_challenge_verification() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); let challenge = ServerChallenge::new(rand::rngs::OsRng); - let client_auth = ClientAuth::new_from_challenge(&secret_key, &challenge); - assert!(client_auth.verify_from_challenge(&challenge)); + let client_auth = ClientAuth::new(&secret_key, &challenge); + assert!(client_auth.verify(&challenge)); + + Ok(()) + } + + #[test] + fn test_key_material_verification() -> Result { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + let io = TestKeyingMaterial { + inner: (), + shared_secret: Some(42), + }; + let client_auth = KeyMaterialClientAuth::new(&secret_key, &io).e()?; + assert!(client_auth.verify(&io)); Ok(()) } diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 51b0fb0edc1..28f1a2ac2ef 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -8,8 +8,6 @@ use quinn_proto::{coding::Codec, VarInt}; #[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] // needs to be pub due to being exposed in error types pub enum FrameType { - /// The client frame type for the client challenge request - ClientRequestChallenge = 1, /// The server frame type for the challenge response ServerChallenge = 2, /// The client frame type for the authentication frame diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index eabb9d76565..dd58bc38414 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -645,11 +645,13 @@ impl Inner { let mut io = WsBytesFramed { io: websocket }; - let client_info = handshake::serverside(&mut io, client_auth_header, rand::rngs::OsRng) - .await - .context(HandshakeSnafu)?; + let (client_key, auth_type) = + handshake::serverside(&mut io, client_auth_header, rand::rngs::OsRng) + .await + .context(HandshakeSnafu)?; + + trace!(?auth_type, "accept: verified authentication"); - let client_key = client_info.public_key; let mut io = RelayedStream { inner: io.io, key_cache: self.key_cache.clone(), From 310b9c5293c0e8b9fecbb833f0ba13eaee1a7c3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 2 Jul 2025 18:11:24 +0200 Subject: [PATCH 35/80] Rename to appease clippy (clippy is right tho) --- iroh-relay/src/client.rs | 2 +- iroh-relay/src/protos/handshake.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index a32f569d208..53156705ed5 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -253,7 +253,7 @@ impl ClientBuilder { if let Some(client_auth) = KeyMaterialClientAuth::new(&self.secret_key, &stream) { debug!("Using TLS key export for relay client authentication"); builder = builder - .add_header(CLIENT_AUTH_HEADER, client_auth.to_header_value()) + .add_header(CLIENT_AUTH_HEADER, client_auth.into_header_value()) .expect( "impossible: CLIENT_AUTH_HEADER isn't a disallowed header value for websockets", ); diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index f1ea9c33c2b..15d3ee9bc49 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -187,7 +187,7 @@ impl KeyMaterialClientAuth { }) } - pub(crate) fn to_header_value(self) -> HeaderValue { + pub(crate) fn into_header_value(self) -> HeaderValue { HeaderValue::from_str( &data_encoding::BASE64URL_NOPAD .encode(&postcard::to_allocvec(&self).expect("encoding never fails")), @@ -455,7 +455,7 @@ mod tests { .with_shared_secret(server_shared_secret); let client_auth_header = KeyMaterialClientAuth::new(secret_key, &mut client_io) - .map(KeyMaterialClientAuth::to_header_value); + .map(KeyMaterialClientAuth::into_header_value); let (_, auth) = n0_future::future::try_zip( async { From 6f2f50afeecaed3c4f524d57e413729a772f08ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 3 Jul 2025 10:25:16 +0200 Subject: [PATCH 36/80] Handle terminated streams in relay server actor more gracefully --- iroh-relay/src/protos/handshake.rs | 2 +- iroh-relay/src/server/client.rs | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 15d3ee9bc49..5188647100a 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -454,7 +454,7 @@ mod tests { .sink_map_err(tokio_websockets::Error::Io) .with_shared_secret(server_shared_secret); - let client_auth_header = KeyMaterialClientAuth::new(secret_key, &mut client_io) + let client_auth_header = KeyMaterialClientAuth::new(secret_key, &client_io) .map(KeyMaterialClientAuth::into_header_value); let (_, auth) = n0_future::future::try_zip( diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index c7f8745cd72..8c83e62b6e2 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -111,7 +111,7 @@ impl Client { // start io loop let io_done = done.clone(); let handle = tokio::task::spawn(actor.run(io_done).instrument(tracing::info_span!( - "client connection actor", + "client-connection-actor", remote_node = %node_id.fmt_short(), connection_id = connection_id ))); @@ -305,6 +305,12 @@ impl Actor { self.metrics.unique_client_keys.inc(); } match self.run_inner(done).await { + Err(RunError::HandleFrame { + source: HandleFrameError::StreamTerminated { .. }, + .. + }) => { + debug!("client stream closed, exiting"); + } Err(e) => { warn!("actor errored {e:#?}, exiting"); } From 12230e023bfd8338744755bad4090e1b67b84eed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 3 Jul 2025 10:35:00 +0200 Subject: [PATCH 37/80] Increase timeout waiting for relay actor termination --- iroh/src/magicsock/transports/relay/actor.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index d03737100fc..b729ce87ab2 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -1534,7 +1534,15 @@ mod tests { tokio::time::advance(RELAY_INACTIVE_CLEANUP_TIME).await; tokio::time::resume(); assert!( - tokio::time::timeout(Duration::from_secs(1), task) + // About the 5 second timeout: + // The time advancing above can set off a bunch of async work at once, since suddenly multiple + // timers end up firing when time is resumed. + // This can cause work to build up, and might slow down slower machines (especially in CI). + // With a 1s timeout, I was still seeing proper procedures in the logs ("Inactive for 60s, exiting."), + // But it didn't quite get to the final log line "exiting." yet. Instead, there was a bunch of ping/pong + // logs in between. + // So increasing the timeout instead. + tokio::time::timeout(Duration::from_secs(5), task) .await .is_ok(), "actor task still running" From 05f5e402ed8033757ba61fb76a4015c4b1ed8bf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 3 Jul 2025 10:48:26 +0200 Subject: [PATCH 38/80] Convert `test_active_relay_inactive` into more of a tokio paused time test --- iroh-relay/src/server/client.rs | 8 +------- iroh/src/magicsock/transports/relay/actor.rs | 20 +++++--------------- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 8c83e62b6e2..34a7a6ae01f 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -305,14 +305,8 @@ impl Actor { self.metrics.unique_client_keys.inc(); } match self.run_inner(done).await { - Err(RunError::HandleFrame { - source: HandleFrameError::StreamTerminated { .. }, - .. - }) => { - debug!("client stream closed, exiting"); - } Err(e) => { - warn!("actor errored {e:#?}, exiting"); + warn!("actor errored {e:#}, exiting"); } Ok(()) => { debug!("actor finished, exiting"); diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index b729ce87ab2..285377c3591 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -1498,11 +1498,11 @@ mod tests { ); // Wait until the actor is connected to the relay server. - tokio::time::timeout(Duration::from_secs(5), async { + tokio::time::timeout(Duration::from_millis(200), async { loop { let (tx, rx) = oneshot::channel(); inbox_tx.send(ActiveRelayMessage::PingServer(tx)).await.ok(); - if tokio::time::timeout(Duration::from_millis(200), rx) + if tokio::time::timeout(Duration::from_millis(100), rx) .await .map(|resp| resp.is_ok()) .unwrap_or_default() @@ -1514,12 +1514,12 @@ mod tests { .await .context("timeout")?; + // From now on, we pause time + tokio::time::pause(); // We now have an idling ActiveRelayActor. If we advance time just a little it // should stay alive. info!("Stepping time forwards by RELAY_INACTIVE_CLEANUP_TIME / 2"); - tokio::time::pause(); tokio::time::advance(RELAY_INACTIVE_CLEANUP_TIME / 2).await; - tokio::time::resume(); assert!( tokio::time::timeout(Duration::from_millis(100), &mut task) @@ -1530,19 +1530,9 @@ mod tests { // If we advance time a lot it should finish. info!("Stepping time forwards by RELAY_INACTIVE_CLEANUP_TIME"); - tokio::time::pause(); tokio::time::advance(RELAY_INACTIVE_CLEANUP_TIME).await; - tokio::time::resume(); assert!( - // About the 5 second timeout: - // The time advancing above can set off a bunch of async work at once, since suddenly multiple - // timers end up firing when time is resumed. - // This can cause work to build up, and might slow down slower machines (especially in CI). - // With a 1s timeout, I was still seeing proper procedures in the logs ("Inactive for 60s, exiting."), - // But it didn't quite get to the final log line "exiting." yet. Instead, there was a bunch of ping/pong - // logs in between. - // So increasing the timeout instead. - tokio::time::timeout(Duration::from_secs(5), task) + tokio::time::timeout(Duration::from_millis(100), task) .await .is_ok(), "actor task still running" From ed94a52a20455c15601b94ebc50e20cba34db8bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 10:42:25 +0200 Subject: [PATCH 39/80] Rework error handling and access control --- iroh-relay/src/client.rs | 6 +- iroh-relay/src/client/conn.rs | 99 ++++----------- iroh-relay/src/protos/handshake.rs | 179 ++++++++++++++++++++++----- iroh-relay/src/protos/send_recv.rs | 69 +++-------- iroh-relay/src/protos/streams.rs | 4 +- iroh-relay/src/server.rs | 29 ++--- iroh-relay/src/server/client.rs | 56 +++++---- iroh-relay/src/server/http_server.rs | 67 +++------- iroh-relay/src/server/streams.rs | 86 +++++-------- 9 files changed, 280 insertions(+), 315 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 53156705ed5..b43274b3b7d 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -372,7 +372,7 @@ impl Sink for Client { mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll> { - >::poll_ready(Pin::new(&mut self.conn), cx) + Pin::new(&mut self.conn).poll_ready(cx) } fn start_send(mut self: Pin<&mut Self>, item: ClientToServerMsg) -> Result<(), Self::Error> { @@ -383,14 +383,14 @@ impl Sink for Client { mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll> { - >::poll_flush(Pin::new(&mut self.conn), cx) + Pin::new(&mut self.conn).poll_flush(cx) } fn poll_close( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll> { - >::poll_close(Pin::new(&mut self.conn), cx) + Pin::new(&mut self.conn).poll_close(cx) } } diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index d27140a1857..e282d8fce0f 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -3,7 +3,6 @@ //! based on tailscale/derp/derp_client.go use std::{ - io, pin::Pin, task::{ready, Context, Poll}, }; @@ -23,8 +22,7 @@ use crate::{ protos::{ handshake, send_recv::{ - ClientToServerMsg, RecvError as RecvRelayError, SendError as SendRelayError, - ServerToClientMsg, MAX_PAYLOAD_SIZE, + ClientToServerMsg, Error as RecvRelayError, ServerToClientMsg, MAX_PAYLOAD_SIZE, }, streams::WsBytesFramed, }, @@ -42,7 +40,7 @@ use crate::{ #[non_exhaustive] pub enum SendError { #[snafu(transparent)] - WebsocketIo { + StreamError { #[cfg(not(wasm_browser))] source: tokio_websockets::Error, #[cfg(wasm_browser)] @@ -63,13 +61,9 @@ pub enum SendError { #[non_exhaustive] pub enum RecvError { #[snafu(transparent)] - Io { source: io::Error }, + Protocol { source: RecvRelayError }, #[snafu(transparent)] - ProtocolSend { source: SendRelayError }, - #[snafu(transparent)] - ProtocolRecv { source: RecvRelayError }, - #[snafu(transparent)] - Websocket { + StreamError { #[cfg(not(wasm_browser))] source: tokio_websockets::Error, #[cfg(wasm_browser)] @@ -87,10 +81,10 @@ pub enum RecvError { pub(crate) struct Conn { #[debug("tokio_websockets::WebSocketStream")] #[cfg(not(wasm_browser))] - pub(crate) conn: tokio_websockets::WebSocketStream>, + pub(crate) conn: WsBytesFramed>, #[debug("ws_stream_wasm::WsStream")] #[cfg(wasm_browser)] - pub(crate) conn: ws_stream_wasm::WsStream, + pub(crate) conn: WsBytesFramed, pub(crate) key_cache: KeyCache, } @@ -99,52 +93,34 @@ impl Conn { pub(crate) fn test(io: tokio::io::DuplexStream) -> Self { use crate::protos::send_recv::MAX_FRAME_SIZE; Self { - conn: tokio_websockets::ClientBuilder::new() - .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))) - .take_over(MaybeTlsStream::Test(io)), + conn: WsBytesFramed { + io: tokio_websockets::ClientBuilder::new() + .limits( + tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE)), + ) + .take_over(MaybeTlsStream::Test(io)), + }, key_cache: KeyCache::test(), } } /// Constructs a new websocket connection, including the initial server handshake. - #[cfg(wasm_browser)] pub(crate) async fn new( - conn: ws_stream_wasm::WsStream, + #[cfg(not(wasm_browser))] io: tokio_websockets::WebSocketStream< + MaybeTlsStream, + >, + #[cfg(wasm_browser)] io: ws_stream_wasm::WsStream, key_cache: KeyCache, secret_key: &SecretKey, ) -> Result { - // We use a phantom type param of ProxyStream for wrapping, just because it's easier to cfg-out code for wasm. - // It's a little ugly though. - let mut io = WsBytesFramed { io: conn }; + let mut conn = WsBytesFramed { io }; // exchange information with the server debug!("server_handshake: started"); - handshake::clientside(&mut io, secret_key).await?; + handshake::clientside(&mut conn, secret_key).await?; debug!("server_handshake: done"); - Ok(Self { - conn: io.io, - key_cache, - }) - } - - #[cfg(not(wasm_browser))] - pub(crate) async fn new( - conn: tokio_websockets::WebSocketStream>, - key_cache: KeyCache, - secret_key: &SecretKey, - ) -> Result { - let mut io = WsBytesFramed { io: conn }; - - // exchange information with the server - debug!("server_handshake: started"); - handshake::clientside(&mut io, secret_key).await?; - debug!("server_handshake: done"); - - Ok(Self { - conn: io.io, - key_cache, - }) + Ok(Self { conn, key_cache }) } } @@ -154,35 +130,11 @@ impl Stream for Conn { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let msg = ready!(Pin::new(&mut self.conn).poll_next(cx)); match msg { - #[cfg(not(wasm_browser))] Some(Ok(msg)) => { - if msg.is_close() { - // Indicate the stream is done when we receive a close message. - // Note: We don't have to poll the stream to completion for it to close gracefully. - return Poll::Ready(None); - } - if !msg.is_binary() { - tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); - return Poll::Pending; - } - let message = - ServerToClientMsg::from_bytes(msg.into_payload().into(), &self.key_cache); + let message = ServerToClientMsg::from_bytes(msg, &self.key_cache); Poll::Ready(Some(message.map_err(Into::into))) } - #[cfg(not(wasm_browser))] Some(Err(e)) => Poll::Ready(Some(Err(e.into()))), - - #[cfg(wasm_browser)] - Some(ws_stream_wasm::WsMessage::Binary(vec)) => Poll::Ready(Some( - ServerToClientMsg::from_bytes(bytes::Bytes::from(vec), &self.key_cache) - .map_err(Into::into), - )), - #[cfg(wasm_browser)] - Some(msg) => { - tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); - Poll::Pending - } - None => Poll::Ready(None), } } @@ -201,15 +153,8 @@ impl Sink for Conn { snafu::ensure!(size <= MAX_PAYLOAD_SIZE, ExceedsMaxPacketSizeSnafu { size }); } - #[cfg(not(wasm_browser))] - let frame = tokio_websockets::Message::binary(tokio_websockets::Payload::from( - frame.write_to(BytesMut::new()).freeze(), - )); - #[cfg(wasm_browser)] - let frame = ws_stream_wasm::WsMessage::Binary(frame.write_to(Vec::new())); - Pin::new(&mut self.conn) - .start_send(frame) + .start_send(frame.write_to(BytesMut::new()).freeze()) .map_err(Into::into) } diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 5188647100a..f00612d9a48 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -1,5 +1,4 @@ //! TODO(matheus23) docs - use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderValue; #[cfg(feature = "server")] @@ -62,9 +61,11 @@ pub(crate) struct ClientAuth { pub(crate) struct ServerConfirmsAuth; /// Denial of connection. The client couldn't be verified as authentic. -#[derive(derive_more::Debug, serde::Deserialize)] +#[derive(derive_more::Debug, Clone, serde::Deserialize)] #[cfg_attr(feature = "server", derive(serde::Serialize))] -pub(crate) struct ServerDeniesAuth; +pub(crate) struct ServerDeniesAuth { + reason: String, +} /// Trait for getting the frame type tag for a frame. /// @@ -117,8 +118,8 @@ pub enum Error { Timeout { source: Elapsed }, #[snafu(display("Handshake stream ended prematurely"))] UnexpectedEnd {}, - #[snafu(display("The relay denied our authentication"))] - ServerDeniedAuth {}, + #[snafu(display("The relay denied our authentication ({reason})"))] + ServerDeniedAuth { reason: String }, #[snafu(display("Unexpected tag, got {frame_type}, but expected one of {expected_types:?}"))] UnexpectedFrameType { frame_type: FrameType, @@ -247,16 +248,26 @@ pub(crate) async fn clientside( Ok(confirmation) } FrameType::ServerDeniesAuth => { - let _denial: ServerDeniesAuth = deserialize_frame(frame)?; - Err(ServerDeniedAuthSnafu.build()) + let denial: ServerDeniesAuth = deserialize_frame(frame)?; + Err(ServerDeniedAuthSnafu { + reason: denial.reason, + } + .build()) } _ => unreachable!(), } } #[cfg(feature = "server")] -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum Auth { +#[derive(Debug)] +pub(crate) struct SuccessfulAuthentication { + pub(crate) client_key: PublicKey, + pub(crate) mechanism: Mechanism, +} + +#[cfg(feature = "server")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum Mechanism { SignedChallenge, SignedKeyMaterial, } @@ -267,7 +278,7 @@ pub(crate) async fn serverside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), client_auth_header: Option, rng: impl RngCore + CryptoRng, -) -> Result<(PublicKey, Auth), Error> { +) -> Result { if let Some(client_auth_header) = client_auth_header { let client_auth_bytes = data_encoding::BASE64URL_NOPAD .decode(client_auth_header.as_ref()) @@ -287,8 +298,11 @@ pub(crate) async fn serverside( })?; if client_auth.verify(io) { - write_frame(io, ServerConfirmsAuth).await?; - return Ok((client_auth.public_key, Auth::SignedKeyMaterial)); + tracing::trace!(?client_auth.public_key, "authentication succeeded via keying material"); + return Ok(SuccessfulAuthentication { + client_key: client_auth.public_key, + mechanism: Mechanism::SignedKeyMaterial, + }); } } @@ -299,12 +313,47 @@ pub(crate) async fn serverside( let client_auth: ClientAuth = deserialize_frame(frame)?; if client_auth.verify(&challenge) { - write_frame(io, ServerConfirmsAuth).await?; + tracing::trace!(?client_auth.public_key, "authentication succeeded via challenge"); + Ok(SuccessfulAuthentication { + client_key: client_auth.public_key, + mechanism: Mechanism::SignedChallenge, + }) } else { - write_frame(io, ServerDeniesAuth).await?; + tracing::trace!(?client_auth.public_key, "authentication failed"); + let denial = ServerDeniesAuth { + reason: "signature invalid".into(), + }; + write_frame(io, denial.clone()).await?; + Err(ServerDeniedAuthSnafu { + reason: denial.reason, + } + .build()) } +} - Ok((client_auth.public_key, Auth::SignedChallenge)) +#[cfg(feature = "server")] +impl SuccessfulAuthentication { + pub async fn authorize( + self, + io: &mut (impl BytesStreamSink + ExportKeyingMaterial), + is_authorized: bool, + ) -> Result { + if is_authorized { + tracing::trace!("authorizing client"); + write_frame(io, ServerConfirmsAuth).await?; + Ok(self.client_key) + } else { + tracing::trace!("denying client auth"); + let denial = ServerDeniesAuth { + reason: "not authorized".into(), + }; + write_frame(io, denial.clone()).await?; + Err(ServerDeniedAuthSnafu { + reason: denial.reason, + } + .build()) + } + } } async fn write_frame( @@ -312,6 +361,7 @@ async fn write_frame( frame: F, ) -> Result<(), Error> { let mut bytes = BytesMut::new(); + tracing::trace!(frame_type = %F::TAG, "Writing frame"); F::TAG.write_to(&mut bytes); let bytes = postcard::to_io(&frame, bytes.writer()) .expect("serialization failed") // buffer can't become "full" without being a critical failure, datastructures shouldn't ever fail serialization @@ -333,6 +383,7 @@ async fn read_frame( .ok_or_else(|| UnexpectedEndSnafu.build())?; let (frame_type, payload) = FrameType::from_bytes(recv).context(UnexpectedEndSnafu)?; + tracing::trace!(%frame_type, "Reading frame"); snafu::ensure!( expected_types.contains(&frame_type), UnexpectedFrameTypeSnafu { @@ -355,8 +406,12 @@ mod tests { use n0_future::{Sink, SinkExt, Stream, TryStreamExt}; use n0_snafu::{Result, ResultExt}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; + use tracing::{info_span, Instrument}; + use tracing_test::traced_test; - use super::{Auth, ClientAuth, KeyMaterialClientAuth, ServerChallenge}; + use super::{ + ClientAuth, KeyMaterialClientAuth, Mechanism, ServerChallenge, ServerConfirmsAuth, + }; use crate::ExportKeyingMaterial; struct TestKeyingMaterial { @@ -440,7 +495,8 @@ mod tests { secret_key: &SecretKey, client_shared_secret: Option, server_shared_secret: Option, - ) -> Result<(PublicKey, Auth)> { + restricted_to: Option, + ) -> (Result, Result<(PublicKey, Mechanism)>) { let (client, server) = tokio::io::duplex(1024); let mut client_io = Framed::new(client, LengthDelimitedCodec::new()) @@ -457,58 +513,113 @@ mod tests { let client_auth_header = KeyMaterialClientAuth::new(secret_key, &client_io) .map(KeyMaterialClientAuth::into_header_value); - let (_, auth) = n0_future::future::try_zip( + n0_future::future::zip( async { super::clientside(&mut client_io, secret_key) .await .context("clientside") - }, + } + .instrument(info_span!("clientside")), async { - super::serverside(&mut server_io, client_auth_header, rand::rngs::OsRng) - .await - .context("serverside") - }, + let auth_n = + super::serverside(&mut server_io, client_auth_header, rand::rngs::OsRng) + .await + .context("serverside")?; + let mechanism = auth_n.mechanism; + let is_authorized = restricted_to.map_or(true, |key| key == auth_n.client_key); + let key = auth_n.authorize(&mut server_io, is_authorized).await?; + Ok((key, mechanism)) + } + .instrument(info_span!("serverside")), ) - .await?; - - Ok(auth) + .await } #[tokio::test] + #[traced_test] async fn test_handshake_via_shared_secrets() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); - let (public_key, auth) = simulate_handshake(&secret_key, Some(42), Some(42)).await?; + let (client, server) = simulate_handshake(&secret_key, Some(42), Some(42), None).await; + client?; + let (public_key, auth) = server?; assert_eq!(public_key, secret_key.public()); - assert_eq!(auth, Auth::SignedKeyMaterial); // it got verified via shared key material + assert_eq!(auth, Mechanism::SignedKeyMaterial); // it got verified via shared key material Ok(()) } #[tokio::test] + #[traced_test] async fn test_handshake_via_challenge() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); - let (public_key, auth) = simulate_handshake(&secret_key, None, None).await?; + let (client, server) = simulate_handshake(&secret_key, None, None, None).await; + client?; + let (public_key, auth) = server?; assert_eq!(public_key, secret_key.public()); - assert_eq!(auth, Auth::SignedChallenge); + assert_eq!(auth, Mechanism::SignedChallenge); Ok(()) } #[tokio::test] + #[traced_test] async fn test_handshake_mismatching_shared_secrets() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); // mismatching shared secrets *might* happen with HTTPS proxies that don't also middle-man the shared secret - let (public_key, auth) = simulate_handshake(&secret_key, Some(10), Some(99)).await?; + let (client, server) = simulate_handshake(&secret_key, Some(10), Some(99), None).await; + client?; + let (public_key, auth) = server?; assert_eq!(public_key, secret_key.public()); - assert_eq!(auth, Auth::SignedChallenge); + assert_eq!(auth, Mechanism::SignedChallenge); Ok(()) } #[tokio::test] + #[traced_test] async fn test_handshake_challenge_fallback() -> Result { let secret_key = SecretKey::generate(rand::rngs::OsRng); // clients might not have access to shared secrets - let (public_key, auth) = simulate_handshake(&secret_key, None, Some(99)).await?; + let (client, server) = simulate_handshake(&secret_key, None, Some(99), None).await; + client?; + let (public_key, auth) = server?; assert_eq!(public_key, secret_key.public()); - assert_eq!(auth, Auth::SignedChallenge); + assert_eq!(auth, Mechanism::SignedChallenge); + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_handshake_with_auth_positive() -> Result { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + let public_key = secret_key.public(); + let (client, server) = simulate_handshake(&secret_key, None, None, Some(public_key)).await; + client?; + let (public_key, _) = server?; + assert_eq!(public_key, secret_key.public()); + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_handshake_with_auth_negative() -> Result { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + let public_key = secret_key.public(); + let wrong_secret_key = SecretKey::generate(rand::rngs::OsRng); + let (client, server) = + simulate_handshake(&wrong_secret_key, None, None, Some(public_key)).await; + assert!(client.is_err()); + assert!(server.is_err()); + Ok(()) + } + + #[tokio::test] + #[traced_test] + async fn test_handshake_via_shared_secret_with_auth_negative() -> Result { + let secret_key = SecretKey::generate(rand::rngs::OsRng); + let public_key = secret_key.public(); + let wrong_secret_key = SecretKey::generate(rand::rngs::OsRng); + let (client, server) = + simulate_handshake(&wrong_secret_key, Some(42), Some(42), Some(public_key)).await; + assert!(client.is_err()); + assert!(server.is_err()); Ok(()) } diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs index 588aeffd622..dc4c25bc6ca 100644 --- a/iroh-relay/src/protos/send_recv.rs +++ b/iroh-relay/src/protos/send_recv.rs @@ -15,8 +15,6 @@ use bytes::{BufMut, Bytes}; use iroh_base::{NodeId, SignatureError}; use n0_future::time::{self, Duration}; -#[cfg(feature = "server")] -use n0_future::{Sink, SinkExt}; use nested_enum_utils::common_fields; use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; @@ -60,27 +58,7 @@ pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512; #[allow(missing_docs)] #[derive(Debug, Snafu)] #[non_exhaustive] -pub enum SendError { - #[snafu(transparent)] - Io { source: std::io::Error }, - #[snafu(transparent)] - Timeout { source: time::Elapsed }, - #[snafu(transparent)] - SerDe { source: postcard::Error }, -} - -/// Protocol send errors. -#[common_fields({ - backtrace: Option, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, -})] -#[allow(missing_docs)] -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum RecvError { - #[snafu(transparent)] - Io { source: std::io::Error }, +pub enum Error { #[snafu(display("unexpected frame: got {got}, expected {expected}"))] UnexpectedFrame { got: FrameType, expected: FrameType }, #[snafu(display("Frame is too large, has {frame_len} bytes"))] @@ -89,37 +67,18 @@ pub enum RecvError { Timeout { source: time::Elapsed }, #[snafu(transparent)] SerDe { source: postcard::Error }, - #[snafu(transparent)] - InvalidSignature { source: SignatureError }, + #[snafu(display("Invalid public key"))] + InvalidPublicKey { source: SignatureError }, #[snafu(display("Invalid frame encoding"))] InvalidFrame {}, #[snafu(display("Invalid frame type: {frame_type}"))] InvalidFrameType { frame_type: FrameType }, - #[snafu(display("invalid protocol message encoding"))] + #[snafu(display("Invalid protocol message encoding"))] InvalidProtocolMessageEncoding { source: std::str::Utf8Error }, #[snafu(display("Too few bytes"))] TooSmall {}, } -/// Writes complete frame, errors if it is unable to write within the given `timeout`. -/// Ignores the timeout if `None` -/// -/// Does not flush. -#[cfg(feature = "server")] -pub(crate) async fn write_frame + Unpin>( - mut writer: S, - frame: ServerToClientMsg, - timeout: Option, -) -> Result<(), SendError> { - if let Some(duration) = timeout { - tokio::time::timeout(duration, writer.send(frame)).await??; - } else { - writer.send(frame).await?; - } - - Ok(()) -} - /// TODO(matheus23): Docs /// The messages received from a framed relay stream. /// @@ -219,7 +178,7 @@ impl Datagrams { dst } - fn from_bytes(bytes: Bytes) -> Result { + fn from_bytes(bytes: Bytes) -> Result { // 1 bytes ECN, 2 bytes segment size snafu::ensure!(bytes.len() > 3, InvalidFrameSnafu); @@ -297,7 +256,7 @@ impl ServerToClientMsg { /// /// Specifically, bytes received from a binary websocket message frame. #[allow(clippy::result_large_err)] - pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { + pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; let frame_len = content.len(); snafu::ensure!( @@ -309,7 +268,9 @@ impl ServerToClientMsg { FrameType::RecvDatagrams => { snafu::ensure!(content.len() >= NodeId::LENGTH, InvalidFrameSnafu); - let remote_node_id = cache.key_from_slice(&content[..NodeId::LENGTH])?; + let remote_node_id = cache + .key_from_slice(&content[..NodeId::LENGTH]) + .context(InvalidPublicKeySnafu)?; let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; Self::ReceivedDatagrams { remote_node_id, @@ -318,7 +279,9 @@ impl ServerToClientMsg { } FrameType::NodeGone => { snafu::ensure!(content.len() == NodeId::LENGTH, InvalidFrameSnafu); - let node_id = cache.key_from_slice(content.as_ref())?; + let node_id = cache + .key_from_slice(content.as_ref()) + .context(InvalidPublicKeySnafu)?; Self::NodeGone(node_id) } FrameType::Ping => { @@ -403,7 +366,7 @@ impl ClientToServerMsg { /// Specifically, bytes received from a binary websocket message frame. #[allow(clippy::result_large_err)] #[cfg(feature = "server")] - pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { + pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; let frame_len = content.len(); snafu::ensure!( @@ -413,7 +376,9 @@ impl ClientToServerMsg { let res = match frame_type { FrameType::SendDatagrams => { - let dst_node_id = cache.key_from_slice(&content[..NodeId::LENGTH])?; + let dst_node_id = cache + .key_from_slice(&content[..NodeId::LENGTH]) + .context(InvalidPublicKeySnafu)?; let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; Self::SendDatagrams { dst_node_id, @@ -441,6 +406,7 @@ impl ClientToServerMsg { } #[cfg(test)] +#[cfg(feature = "server")] mod tests { use data_encoding::HEXLOWER; use iroh_base::SecretKey; @@ -571,6 +537,7 @@ mod tests { } #[cfg(test)] +#[cfg(feature = "server")] mod proptests { use bytes::BytesMut; use iroh_base::SecretKey; diff --git a/iroh-relay/src/protos/streams.rs b/iroh-relay/src/protos/streams.rs index 371f005c044..2c326e8f6a2 100644 --- a/iroh-relay/src/protos/streams.rs +++ b/iroh-relay/src/protos/streams.rs @@ -26,10 +26,10 @@ pub(crate) struct WsBytesFramed { } #[cfg(not(wasm_browser))] -type StreamError = tokio_websockets::Error; +pub(crate) type StreamError = tokio_websockets::Error; #[cfg(wasm_browser)] -type StreamError = ws_stream_wasm::WsErr; +pub(crate) type StreamError = ws_stream_wasm::WsErr; /// TODO(matheus23) docs pub(crate) trait BytesStreamSink: diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index 7b2f050173b..bde8707b895 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -753,7 +753,7 @@ mod tests { use http::StatusCode; use iroh_base::{NodeId, RelayUrl, SecretKey}; use n0_future::{FutureExt, SinkExt, StreamExt}; - use n0_snafu::{Result, ResultExt}; + use n0_snafu::Result; use tracing::{info, instrument}; use tracing_test::traced_test; @@ -762,9 +762,12 @@ mod tests { NO_CONTENT_CHALLENGE_HEADER, NO_CONTENT_RESPONSE_HEADER, }; use crate::{ - client::ClientBuilder, + client::{ClientBuilder, ConnectError}, dns::DnsResolver, - protos::send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg}, + protos::{ + handshake, + send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg}, + }, }; async fn spawn_local_relay() -> std::result::Result { @@ -978,23 +981,13 @@ mod tests { // set up client a let resolver = dns_resolver(); - let mut client_a = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver) + let result = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver) .connect() - .await?; + .await; - // the next message should be the rejection of the connection - tokio::time::timeout(Duration::from_millis(500), async move { - match client_a.next().await.unwrap().unwrap() { - ServerToClientMsg::Health { problem } => { - assert_eq!(problem, "not authenticated".to_string()); - } - msg => { - panic!("other msg: {msg:?}"); - } - } - }) - .await - .context("timeout")?; + assert!( + matches!(result, Err(ConnectError::Handshake { source: handshake::Error::ServerDeniedAuth { reason, .. }, .. }) if reason == "not authorized") + ); // test that another client has access diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 34a7a6ae01f..412fb86a766 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -18,15 +18,13 @@ use tracing::{debug, trace, warn, Instrument}; use crate::{ protos::{ disco, - send_recv::{ - write_frame, ClientToServerMsg, Datagrams, SendError as SendRelayError, - ServerToClientMsg, PING_INTERVAL, - }, + send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg, PING_INTERVAL}, + streams::StreamError, }, server::{ clients::Clients, metrics::Metrics, - streams::{RelayedStream, StreamError}, + streams::{RecvError as StreamRecvError, RelayedStream}, }, PingTracker, }; @@ -170,7 +168,7 @@ impl Client { } } -/// Handle frame error +/// Receive frame error #[common_fields({ backtrace: Option, })] @@ -180,15 +178,26 @@ impl Client { pub enum HandleFrameError { #[snafu(transparent)] ForwardPacket { source: ForwardPacketError }, - #[snafu(transparent)] - Streams { source: StreamError }, #[snafu(display("Stream terminated"))] - StreamTerminated { - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, - }, + StreamTerminated {}, + #[snafu(transparent)] + Recv { source: StreamRecvError }, + #[snafu(transparent)] + Send { source: SendFrameError }, +} + +/// Send frame error +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum SendFrameError { #[snafu(transparent)] - Relay { source: SendRelayError }, + Stream { source: StreamError }, + #[snafu(transparent)] + Timeout { source: tokio::time::error::Elapsed }, } /// Run error @@ -215,7 +224,7 @@ pub enum RunError { }, #[snafu(display("Failed to send disco packet"))] DiscoPacketSend { - source: SendRelayError, + source: SendFrameError, #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, @@ -226,7 +235,7 @@ pub enum RunError { }, #[snafu(display("Failed to send packet"))] PacketSend { - source: SendRelayError, + source: SendFrameError, #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, @@ -237,13 +246,13 @@ pub enum RunError { }, #[snafu(display("NodeGone write frame failed"))] NodeGoneWriteFrame { - source: SendRelayError, + source: SendFrameError, #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, #[snafu(display("Keep alive write frame failed"))] KeepAliveWriteFrame { - source: SendRelayError, + source: SendFrameError, #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, @@ -386,15 +395,16 @@ impl Actor { /// Writes the given frame to the connection. /// /// Errors if the send does not happen within the `timeout` duration - async fn write_frame(&mut self, frame: ServerToClientMsg) -> Result<(), SendRelayError> { - write_frame(&mut self.stream, frame, Some(self.timeout)).await + async fn write_frame(&mut self, frame: ServerToClientMsg) -> Result<(), SendFrameError> { + tokio::time::timeout(self.timeout, self.stream.send(frame)).await??; + Ok(()) } /// Writes contents to the client in a `RECV_PACKET` frame. /// /// Errors if the send does not happen within the `timeout` duration /// Does not flush. - async fn send_raw(&mut self, packet: Packet) -> Result<(), SendRelayError> { + async fn send_raw(&mut self, packet: Packet) -> Result<(), SendFrameError> { let remote_node_id = packet.src; let datagrams = packet.data; @@ -408,7 +418,7 @@ impl Actor { .await } - async fn send_packet(&mut self, packet: Packet) -> Result<(), SendRelayError> { + async fn send_packet(&mut self, packet: Packet) -> Result<(), SendFrameError> { trace!("send packet"); match self.send_raw(packet).await { Ok(()) => { @@ -422,7 +432,7 @@ impl Actor { } } - async fn send_disco_packet(&mut self, packet: Packet) -> Result<(), SendRelayError> { + async fn send_disco_packet(&mut self, packet: Packet) -> Result<(), SendFrameError> { trace!("send disco packet"); match self.send_raw(packet).await { Ok(()) => { @@ -439,7 +449,7 @@ impl Actor { /// Handles frame read results. async fn handle_frame( &mut self, - maybe_frame: Option>, + maybe_frame: Option>, ) -> Result<(), HandleFrameError> { trace!(?maybe_frame, "handle incoming frame"); let frame = match maybe_frame { diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index dd58bc38414..47d574e30d1 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -12,8 +12,7 @@ use hyper::{ upgrade::Upgraded, HeaderMap, Method, Request, Response, StatusCode, }; -use iroh_base::PublicKey; -use n0_future::{time::Elapsed, FutureExt, SinkExt}; +use n0_future::{time::Elapsed, FutureExt}; use nested_enum_utils::common_fields; use snafu::{Backtrace, ResultExt, Snafu}; use tokio::net::{TcpListener, TcpStream}; @@ -26,7 +25,7 @@ use super::{clients::Clients, AccessConfig, SpawnError}; use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, http::{RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, - protos::send_recv::{ServerToClientMsg, PER_CLIENT_SEND_QUEUE_DEPTH}, + protos::send_recv::PER_CLIENT_SEND_QUEUE_DEPTH, server::{ client::Config, metrics::Metrics, @@ -204,31 +203,6 @@ pub enum ServeConnectionError { }, } -/// Server accept errors. -#[common_fields({ - backtrace: Option, -})] -#[allow(missing_docs)] -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum AcceptError { - #[snafu(display("Handshake failed"))] - Handshake { - #[allow(clippy::result_large_err)] - source: handshake::Error, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, - }, - #[snafu(transparent)] - Io { source: std::io::Error }, - #[snafu(display("Client not authenticated: {key:?}"))] - ClientNotAuthenticated { - key: PublicKey, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, - }, -} - /// Server connection errors, includes errors that can happen on `accept`. #[common_fields({ backtrace: Option, @@ -238,7 +212,7 @@ pub enum AcceptError { #[non_exhaustive] pub enum ConnectionHandlerError { #[snafu(transparent)] - Accept { source: AcceptError }, + Accept { source: handshake::Error }, #[snafu(display("Could not downcast the upgraded connection to MaybeTlsStream"))] DowncastUpgrade { #[snafu(implicit)] @@ -627,9 +601,7 @@ impl Inner { &self, io: MaybeTlsStream, client_auth_header: Option, - ) -> Result<(), AcceptError> { - use snafu::ResultExt; - + ) -> Result<(), handshake::Error> { trace!("accept: start"); let io = RateLimited::from_cfg(self.rate_limit, io, self.metrics.clone()); @@ -645,29 +617,20 @@ impl Inner { let mut io = WsBytesFramed { io: websocket }; - let (client_key, auth_type) = - handshake::serverside(&mut io, client_auth_header, rand::rngs::OsRng) - .await - .context(HandshakeSnafu)?; + let authentication = + handshake::serverside(&mut io, client_auth_header, rand::rngs::OsRng).await?; - trace!(?auth_type, "accept: verified authentication"); + trace!(?authentication.mechanism, "accept: verified authentication"); - let mut io = RelayedStream { - inner: io.io, - key_cache: self.key_cache.clone(), - }; + let is_authorized = self.access.is_allowed(authentication.client_key).await; + let client_key = authentication.authorize(&mut io, is_authorized).await?; - trace!("accept: checking access: {:?}", self.access); - // TODO(matheus23): Maybe use new frame? - if !self.access.is_allowed(client_key).await { - io.send(ServerToClientMsg::Health { - problem: "not authenticated".into(), - }) - .await?; - io.flush().await?; + trace!("accept: verified authorization"); - return Err(ClientNotAuthenticatedSnafu { key: client_key }.build()); - } + let io = RelayedStream { + inner: io, + key_cache: self.key_cache.clone(), + }; trace!("accept: build client conn"); let client_conn_builder = Config { @@ -859,7 +822,7 @@ mod tests { use crate::{ client::{conn::Conn, Client, ClientBuilder, ConnectError}, dns::DnsResolver, - protos::send_recv::{ClientToServerMsg, Datagrams}, + protos::send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg}, }; pub(crate) fn make_tls_config() -> TlsConfig { diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index ac1a39bcc3d..c77a3b819f5 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -10,12 +10,14 @@ use bytes::BytesMut; use n0_future::{ready, time, FutureExt, Sink, Stream}; use snafu::Snafu; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_websockets::WebSocketStream; use tracing::instrument; use super::{ClientRateLimit, Metrics}; use crate::{ - protos::send_recv::{ClientToServerMsg, RecvError, ServerToClientMsg}, + protos::{ + send_recv::{ClientToServerMsg, Error as ProtoError, ServerToClientMsg}, + streams::{StreamError, WsBytesFramed}, + }, ExportKeyingMaterial, KeyCache, }; @@ -26,7 +28,7 @@ use crate::{ /// - a [`Sink`] of [`ServerToClientMsg`]s that can be sent to the client. #[derive(Debug)] pub(crate) struct RelayedStream { - pub(crate) inner: WebSocketStream>, + pub(crate) inner: WsBytesFramed>, pub(crate) key_cache: KeyCache, } @@ -36,9 +38,11 @@ impl RelayedStream { let stream = MaybeTlsStream::Test(stream); let stream = RateLimited::unlimited(stream, Arc::new(Metrics::default())); Self { - inner: tokio_websockets::ServerBuilder::new() - .limits(Self::limits()) - .serve(stream), + inner: WsBytesFramed { + io: tokio_websockets::ServerBuilder::new() + .limits(Self::limits()) + .serve(stream), + }, key_cache: KeyCache::test(), } } @@ -56,9 +60,11 @@ impl RelayedStream { Arc::new(Metrics::default()), ); Self { - inner: tokio_websockets::ServerBuilder::new() - .limits(Self::limits()) - .serve(stream), + inner: WsBytesFramed { + io: tokio_websockets::ServerBuilder::new() + .limits(Self::limits()) + .serve(stream), + }, key_cache: KeyCache::test(), } } @@ -69,77 +75,47 @@ impl RelayedStream { } } -fn ws_to_io_err(e: tokio_websockets::Error) -> std::io::Error { - match e { - tokio_websockets::Error::Io(io_err) => io_err, - _ => std::io::Error::other(e.to_string()), - } -} - impl Sink for RelayedStream { - type Error = std::io::Error; + type Error = StreamError; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner) - .poll_ready(cx) - .map_err(ws_to_io_err) + Pin::new(&mut self.inner).poll_ready(cx) } fn start_send(mut self: Pin<&mut Self>, item: ServerToClientMsg) -> Result<(), Self::Error> { - Pin::new(&mut self.inner) - .start_send(tokio_websockets::Message::binary( - tokio_websockets::Payload::from(item.write_to(BytesMut::new()).freeze()), - )) - .map_err(ws_to_io_err) + Pin::new(&mut self.inner).start_send(item.write_to(BytesMut::new()).freeze()) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner) - .poll_flush(cx) - .map_err(ws_to_io_err) + Pin::new(&mut self.inner).poll_flush(cx) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner) - .poll_close(cx) - .map_err(ws_to_io_err) + Pin::new(&mut self.inner).poll_close(cx) } } -/// Relay stream errors +/// Relay receive errors #[derive(Debug, Snafu)] #[non_exhaustive] -pub enum StreamError { +pub enum RecvError { #[snafu(transparent)] - Proto { source: RecvError }, + Proto { source: ProtoError }, #[snafu(transparent)] - Ws { source: tokio_websockets::Error }, + StreamError { source: StreamError }, } impl Stream for RelayedStream { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut self.inner).poll_next(cx) { - Poll::Ready(Some(Ok(msg))) => { - if msg.is_close() { - // Indicate the stream is done when we receive a close message. - // Note: We don't have to poll the stream to completion for it to close gracefully. - return Poll::Ready(None); - } - if !msg.is_binary() { - tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); - return Poll::Pending; - } - Poll::Ready(Some( - ClientToServerMsg::from_bytes(msg.into_payload().into(), &self.key_cache) - .map_err(Into::into), - )) + Poll::Ready(match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(msg)) => { + Some(ClientToServerMsg::from_bytes(msg, &self.key_cache).map_err(Into::into)) } - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } + Some(Err(e)) => Some(Err(e.into())), + None => None, + }) } } From 1e7c8c813545d187ad9150ca321f6b27fd13b648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 10:59:48 +0200 Subject: [PATCH 40/80] Fix Wasm & clippy --- iroh-relay/src/client/conn.rs | 1 - iroh-relay/src/lib.rs | 1 + iroh-relay/src/protos/handshake.rs | 4 +++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index e282d8fce0f..8bc5940d1c9 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -7,7 +7,6 @@ use std::{ task::{ready, Context, Poll}, }; -#[cfg(not(wasm_browser))] use bytes::BytesMut; use iroh_base::SecretKey; use n0_future::{Sink, Stream}; diff --git a/iroh-relay/src/lib.rs b/iroh-relay/src/lib.rs index ad556bfeecf..a073fd002ce 100644 --- a/iroh-relay/src/lib.rs +++ b/iroh-relay/src/lib.rs @@ -56,6 +56,7 @@ pub use self::{ }; pub(crate) trait ExportKeyingMaterial { + #[cfg_attr(wasm_browser, allow(unused))] fn export_keying_material>( &self, output: T, diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index f00612d9a48..eeb5b319cd4 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -19,6 +19,7 @@ use crate::ExportKeyingMaterial; /// Authentication message from the client. #[derive(derive_more::Debug, serde::Serialize)] #[cfg_attr(feature = "server", derive(serde::Deserialize))] +#[cfg_attr(wasm_browser, allow(unused))] pub(crate) struct KeyMaterialClientAuth { /// The client's public key pub(crate) public_key: PublicKey, @@ -173,6 +174,7 @@ impl ClientAuth { } } +#[cfg_attr(wasm_browser, allow(unused))] impl KeyMaterialClientAuth { pub(crate) fn new(secret_key: &SecretKey, io: &impl ExportKeyingMaterial) -> Option { let public_key = secret_key.public(); @@ -526,7 +528,7 @@ mod tests { .await .context("serverside")?; let mechanism = auth_n.mechanism; - let is_authorized = restricted_to.map_or(true, |key| key == auth_n.client_key); + let is_authorized = restricted_to.is_none_or(|key| key == auth_n.client_key); let key = auth_n.authorize(&mut server_io, is_authorized).await?; Ok((key, mechanism)) } From 4f9627d7e949ca0aa0dfa1eb9a2c8e45251a44e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 16:36:13 +0200 Subject: [PATCH 41/80] Documentation --- iroh-relay/src/client/conn.rs | 2 +- iroh-relay/src/protos/handshake.rs | 75 ++++++++++++++++++-- iroh-relay/src/protos/relay.rs | 10 ++- iroh-relay/src/protos/send_recv.rs | 59 ++++++++------- iroh-relay/src/protos/streams.rs | 5 +- iroh-relay/src/server.rs | 10 +-- iroh-relay/src/server/client.rs | 16 ++--- iroh-relay/src/server/clients.rs | 4 +- iroh-relay/src/server/http_server.rs | 63 +++++++++------- iroh-relay/src/server/metrics.rs | 2 - iroh-relay/src/server/streams.rs | 62 +++++++++++----- iroh/src/magicsock/transports/relay/actor.rs | 6 +- 12 files changed, 208 insertions(+), 106 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 8bc5940d1c9..5b517fb6157 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -147,7 +147,7 @@ impl Sink for Conn { } fn start_send(mut self: Pin<&mut Self>, frame: ClientToServerMsg) -> Result<(), Self::Error> { - if let ClientToServerMsg::SendDatagrams { datagrams, .. } = &frame { + if let ClientToServerMsg::Datagrams { datagrams, .. } = &frame { let size = datagrams.contents.len(); snafu::ensure!(size <= MAX_PAYLOAD_SIZE, ExceedsMaxPacketSizeSnafu { size }); } diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index eeb5b319cd4..8ba6a87b712 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -1,4 +1,29 @@ -//! TODO(matheus23) docs +//! Implements the handshake protocol that iroh's relays and iroh clients that connect to it go through. +//! +//! The purpose of the handshake is to +//! 1. Inform the relay of the client's NodeId +//! 2. Check that the connecting client owns the secret key for its NodeId ("is authentic"/"authentication") +//! 3. Possibly check that the client has access to this relay, if the relay requires authorization. +//! +//! Additional complexity comes from the fact that there's two ways that clients can authenticate with +//! relays. +//! +//! One way is via an explicitly sent challenge: +//! +//! 1. Once a websocket connection is opened, a client recieves a challenge (the [`ServerChallenge`] frame) +//! 2. The client sends back what is essentially a signature of that challenge with their secret key +//! that matches the NodeId they have, as well as the NodeId (the [`ClientAuth`] frame) +//! +//! The second way is very similar to the [Concealed HTTP Auth RFC], and involves send a header that +//! contains a signature of some shared keying material extracted from TLS ([RFC 5705]). +//! +//! The second way can save a full round trip, because the challenge doesn't have to be sent to the client +//! first, however, it won't always work, as it relies on the keying material extraction feature of TLS, +//! which is not available in browsers (but might be in the future?) and might break when there's an +//! HTTPS proxy that doesn't properly deal with this TLS feature. +//! +//! [Concealed HTTP Auth RFC]: https://datatracker.ietf.org/doc/rfc9729/ +//! [RFC 5705]: https://datatracker.ietf.org/doc/html/rfc5705 use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderValue; #[cfg(feature = "server")] @@ -137,7 +162,7 @@ pub enum Error { } impl ServerChallenge { - /// TODO(matheus23): docs + /// Generates a new challenge. #[cfg(feature = "server")] pub(crate) fn new(mut rng: impl RngCore + CryptoRng) -> Self { let mut challenge = [0u8; 16]; @@ -145,6 +170,7 @@ impl ServerChallenge { Self { challenge } } + /// The actual message bytes to sign (and verify against) for this challenge. fn message_to_sign(&self) -> [u8; 32] { blake3::derive_key( "iroh-relay handshake v1 challenge signature", @@ -154,7 +180,7 @@ impl ServerChallenge { } impl ClientAuth { - /// TODO(matheus23): docs + /// Generates a signature for given challenge from the server. pub(crate) fn new(secret_key: &SecretKey, challenge: &ServerChallenge) -> Self { Self { public_key: secret_key.public(), @@ -162,7 +188,7 @@ impl ClientAuth { } } - /// TODO(matheus23): docs + /// Verifies this client's authentication given the challenge this was sent in response to. #[cfg(feature = "server")] pub(crate) fn verify(&self, challenge: &ServerChallenge) -> bool { self.public_key @@ -176,6 +202,8 @@ impl ClientAuth { #[cfg_attr(wasm_browser, allow(unused))] impl KeyMaterialClientAuth { + /// Generates a client's authentication, similar to [`ClientAuth`], but by using TLS keying material + /// instead of a received challenge. pub(crate) fn new(secret_key: &SecretKey, io: &impl ExportKeyingMaterial) -> Option { let public_key = secret_key.public(); let key_material = io.export_keying_material( @@ -190,6 +218,7 @@ impl KeyMaterialClientAuth { }) } + /// Generate the base64url-nopad-encoded header value. pub(crate) fn into_header_value(self) -> HeaderValue { HeaderValue::from_str( &data_encoding::BASE64URL_NOPAD @@ -198,6 +227,13 @@ impl KeyMaterialClientAuth { .expect("BASE64URL_NOPAD encoding contained invisible ascii characters") } + /// Verifies this client auth on the server side using the same key material. + /// + /// This might return false for a couple of reasons: + /// 1. The exported keying material might not be the same between both ends of the TLS session + /// (e.g. there's an HTTPS proxy in between that doesn't think/care about the TLS keying material exporter). + /// This situation is detected when the key material suffix mismatches. + /// 2. The signature itself doesn't verify. #[cfg(feature = "server")] pub(crate) fn verify(&self, io: &impl ExportKeyingMaterial) -> bool { let Some(key_material) = io.export_keying_material( @@ -216,7 +252,13 @@ impl KeyMaterialClientAuth { } } -/// TODO(matheus23) docs +/// Runs the client side of the handshake protocol. +/// +/// See the module docs for details on the protocol. +/// This is already after having potentially transferred a [`KeyMaterialClientAuth`], +/// but before having received a response for whether that worked or not. +/// +/// This requires access to the client's secret key to sign a challenge. pub(crate) async fn clientside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), secret_key: &SecretKey, @@ -260,21 +302,42 @@ pub(crate) async fn clientside( } } +/// This represents successful authentication for the client with the `client_key` public key +/// via the authentication [`Mechanism`] `mechanism`. +/// +/// You must call [`SuccessfulAuthentication::authorize`] to finish the protocol. #[cfg(feature = "server")] #[derive(Debug)] +#[must_use = "the protocol is not finished unless `authorize` is called"] pub(crate) struct SuccessfulAuthentication { pub(crate) client_key: PublicKey, pub(crate) mechanism: Mechanism, } +/// The mechanism that was used for authentication. #[cfg(feature = "server")] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum Mechanism { + /// Authentication was performed by verifying a signature of a challenge we sent SignedChallenge, + /// Authentication was performed by verifying a signature of shared extracted TLS keying material SignedKeyMaterial, } -/// TODO(matheus23) docs +/// Runs the server side of the handshaking protocol. +/// +/// See the module documentation for an overview of the handshaking protocol. +/// +/// This takes `rng` to generate cryptographic randomness for the authentication challenge. +/// +/// This also takes the `client_auth_header`, if present, to perform authentication without +/// requiring sending a challenge, saving a round-trip, if possible. +/// +/// If this fails, the protocol falls back to doing a normal extra round trip with a challenge. +/// +/// The return value [`SuccessfulAuthentication`] still needs to be resolved by calling +/// [`SuccessfulAuthentication::authorize`] to finish the whole authorization protocol +/// (otherwise the client won't be notified about auth success or failure). #[cfg(feature = "server")] pub(crate) async fn serverside( io: &mut (impl BytesStreamSink + ExportKeyingMaterial), diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 28f1a2ac2ef..31dc01cff79 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -1,4 +1,7 @@ -//! TODO(matheus23) docs +//! Common types between the [`super::handshake`] and [`super::send_recv`] protocols. +//! +//! Hosts the [`FrameType`] enum to make sure we're not accidentally re-using frame type +//! integers for different frames. use bytes::{BufMut, Bytes}; use quinn_proto::{coding::Codec, VarInt}; @@ -58,13 +61,14 @@ impl std::fmt::Display for FrameType { } impl FrameType { + /// Writes the frame type to the buffer (as a QUIC-encoded varint). pub(crate) fn write_to(&self, mut dst: O) -> O { VarInt::from(*self).encode(&mut dst); dst } - // TODO(matheus23): Consolidate errors between handshake.rs and relay.rs - // Perhaps a shared error type `FramingError`? + /// Parses the frame type (as a QUIC-encoded varint) from the first couple of bytes given + /// and returns the frame type and the rest. pub(crate) fn from_bytes(bytes: Bytes) -> Option<(Self, Bytes)> { let mut cursor = std::io::Cursor::new(&bytes); let tag = VarInt::decode(&mut cursor).ok()?; diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs index dc4c25bc6ca..b24e07544c4 100644 --- a/iroh-relay/src/protos/send_recv.rs +++ b/iroh-relay/src/protos/send_recv.rs @@ -79,44 +79,41 @@ pub enum Error { TooSmall {}, } -/// TODO(matheus23): Docs -/// The messages received from a framed relay stream. -/// -/// This is a type-validated version of the `Frame`s on the `RelayCodec`. +/// The messages that a relay sends to clients or the clients receive from the relay. #[derive(derive_more::Debug, Clone, PartialEq, Eq)] pub enum ServerToClientMsg { - /// Represents an incoming packet. - ReceivedDatagrams { - /// The [`NodeId`] of the packet sender. + /// Represents datagrams sent from relays (originally sent to them by another client). + Datagrams { + /// The [`NodeId`] of the original sender. remote_node_id: NodeId, - /// The datagrams and related metadata we received + /// The datagrams and related metadata. datagrams: Datagrams, }, /// Indicates that the client identified by the underlying public key had previously sent you a - /// packet but has now disconnected from the server. + /// packet but has now disconnected from the relay. NodeGone(NodeId), - /// A one-way message from server to client, declaring the connection health state. + /// A one-way message from relay to client, declaring the connection health state. Health { /// If set, is a description of why the connection is unhealthy. /// /// If `None` means the connection is healthy again. /// - /// The default condition is healthy, so the server doesn't broadcast a [`ServerToClientMsg::Health`] + /// The default condition is healthy, so the relay doesn't broadcast a [`ServerToClientMsg::Health`] /// until a problem exists. problem: String, }, - /// A one-way message from server to client, advertising that the server is restarting. + /// A one-way message from relay to client, advertising that the relay is restarting. Restarting { /// An advisory duration that the client should wait before attempting to reconnect. - /// It might be zero. It exists for the server to smear out the reconnects. + /// It might be zero. It exists for the relay to smear out the reconnects. reconnect_in: Duration, /// An advisory duration for how long the client should attempt to reconnect /// before giving up and proceeding with its normal connection failure logic. The interval - /// between retries is undefined for now. A server should not send a TryFor duration more + /// between retries is undefined for now. A relay should not send a `try_for` duration more /// than a few seconds. try_for: Duration, }, - /// Request from the server to reply to the + /// Request from the relay to reply to the /// other side with a [`ClientToServerMsg::Pong`] with the given payload. Ping([u8; 8]), /// Reply to a [`ClientToServerMsg::Ping`] from a client @@ -124,7 +121,7 @@ pub enum ServerToClientMsg { Pong([u8; 8]), } -/// TODO(matheus23): Docs +/// Messages that clients send to relays. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ClientToServerMsg { /// Request from the client to the server to reply to the @@ -133,11 +130,11 @@ pub enum ClientToServerMsg { /// Reply to a [`ServerToClientMsg::Ping`] from a server /// with the payload sent previously in the ping. Pong([u8; 8]), - /// TODO - SendDatagrams { - /// TODO + /// Request from the client to relay datagrams to given remote node. + Datagrams { + /// The remote node to relay to. dst_node_id: NodeId, - /// TODO + /// The datagrams and related metadata to relay. datagrams: Datagrams, }, } @@ -203,10 +200,10 @@ impl Datagrams { } impl ServerToClientMsg { - /// TODO(matheus23): docs + /// Returns this frame's corresponding frame type. pub fn typ(&self) -> FrameType { match self { - Self::ReceivedDatagrams { .. } => FrameType::RecvDatagrams, + Self::Datagrams { .. } => FrameType::RecvDatagrams, Self::NodeGone { .. } => FrameType::NodeGone, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, @@ -222,7 +219,7 @@ impl ServerToClientMsg { pub(crate) fn write_to(&self, mut dst: O) -> O { dst = self.typ().write_to(dst); match self { - Self::ReceivedDatagrams { + Self::Datagrams { remote_node_id, datagrams, } => { @@ -272,7 +269,7 @@ impl ServerToClientMsg { .key_from_slice(&content[..NodeId::LENGTH]) .context(InvalidPublicKeySnafu)?; let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; - Self::ReceivedDatagrams { + Self::Datagrams { remote_node_id, datagrams, } @@ -332,7 +329,7 @@ impl ServerToClientMsg { impl ClientToServerMsg { pub(crate) fn typ(&self) -> FrameType { match self { - Self::SendDatagrams { .. } => FrameType::SendDatagrams, + Self::Datagrams { .. } => FrameType::SendDatagrams, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, } @@ -344,7 +341,7 @@ impl ClientToServerMsg { pub(crate) fn write_to(&self, mut dst: O) -> O { dst = self.typ().write_to(dst); match self { - Self::SendDatagrams { + Self::Datagrams { dst_node_id, datagrams, } => { @@ -380,7 +377,7 @@ impl ClientToServerMsg { .key_from_slice(&content[..NodeId::LENGTH]) .context(InvalidPublicKeySnafu)?; let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; - Self::SendDatagrams { + Self::Datagrams { dst_node_id, datagrams, } @@ -459,7 +456,7 @@ mod tests { "10 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id: client_key.public(), datagrams: Datagrams { ecn: Some(quinn::EcnCodepoint::Ce), @@ -508,7 +505,7 @@ mod tests { "10 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ClientToServerMsg::SendDatagrams { + ClientToServerMsg::Datagrams { dst_node_id: client_key.public(), datagrams: Datagrams { ecn: Some(quinn::EcnCodepoint::Ce), @@ -578,7 +575,7 @@ mod proptests { /// Generates a random valid frame fn server_client_frame() -> impl Strategy { let recv_packet = (key(), datagrams()).prop_map(|(remote_node_id, datagrams)| { - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id, datagrams, } @@ -602,7 +599,7 @@ mod proptests { fn client_server_frame() -> impl Strategy { let send_packet = (key(), datagrams()).prop_map(|(dst_node_id, datagrams)| { - ClientToServerMsg::SendDatagrams { + ClientToServerMsg::Datagrams { dst_node_id, datagrams, } diff --git a/iroh-relay/src/protos/streams.rs b/iroh-relay/src/protos/streams.rs index 2c326e8f6a2..90c90aeaf8f 100644 --- a/iroh-relay/src/protos/streams.rs +++ b/iroh-relay/src/protos/streams.rs @@ -1,4 +1,5 @@ -//! TODO(matheus23) docs +//! Implements logic for abstracing over a websocket stream that allows sending only [`Bytes`]-based +//! messages. use std::{ pin::Pin, task::{Context, Poll}, @@ -31,7 +32,7 @@ pub(crate) type StreamError = tokio_websockets::Error; #[cfg(wasm_browser)] pub(crate) type StreamError = ws_stream_wasm::WsErr; -/// TODO(matheus23) docs +/// Shorthand for a type that implements both a websocket-based stream & sink for [`Bytes`]. pub(crate) trait BytesStreamSink: Stream> + Sink + Unpin { diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index bde8707b895..defa392057f 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -795,7 +795,7 @@ mod tests { // try resend 10 times for _ in 0..10 { client_a - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: b_key, datagrams: msg.clone(), }) @@ -913,7 +913,7 @@ mod tests { // send message from a to b let msg = Datagrams::from("hello, b"); let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?; - let ServerToClientMsg::ReceivedDatagrams { + let ServerToClientMsg::Datagrams { remote_node_id, datagrams, } = res @@ -929,7 +929,7 @@ mod tests { let msg = Datagrams::from("howdy, a"); let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?; - let ServerToClientMsg::ReceivedDatagrams { + let ServerToClientMsg::Datagrams { remote_node_id, datagrams, } = res @@ -1013,7 +1013,7 @@ mod tests { let msg = Datagrams::from("hello, c"); let res = try_send_recv(&mut client_b, &mut client_c, c_key, msg.clone()).await?; - if let ServerToClientMsg::ReceivedDatagrams { + if let ServerToClientMsg::Datagrams { remote_node_id, datagrams, } = res @@ -1054,7 +1054,7 @@ mod tests { let msg = Datagrams::from("hello, b"); for _i in 0..1000 { client_a - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: b_key, datagrams: msg.clone(), }) diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 412fb86a766..5e07dfb4478 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -411,7 +411,7 @@ impl Actor { if let Ok(len) = datagrams.contents.len().try_into() { self.metrics.bytes_sent.inc_by(len); } - self.write_frame(ServerToClientMsg::ReceivedDatagrams { + self.write_frame(ServerToClientMsg::Datagrams { remote_node_id, datagrams, }) @@ -458,7 +458,7 @@ impl Actor { }; match frame { - ClientToServerMsg::SendDatagrams { + ClientToServerMsg::Datagrams { dst_node_id: dst_key, datagrams, } => { @@ -644,7 +644,7 @@ mod tests { let frame = recv_frame(FrameType::RecvDatagrams, &mut io_rw).await.e()?; assert_eq!( frame, - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id: node_id, datagrams: data.to_vec().into() } @@ -659,7 +659,7 @@ mod tests { let frame = recv_frame(FrameType::RecvDatagrams, &mut io_rw).await.e()?; assert_eq!( frame, - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id: node_id, datagrams: data.to_vec().into() } @@ -689,7 +689,7 @@ mod tests { println!(" send packet"); let data = b"hello world!"; io_rw - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: target, datagrams: Datagrams::from(data), }) @@ -703,7 +703,7 @@ mod tests { disco_data.extend_from_slice(target.as_bytes()); disco_data.extend_from_slice(data); io_rw - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: target, datagrams: disco_data.clone().into(), }) @@ -725,12 +725,12 @@ mod tests { let (io_read, io_write) = tokio::io::duplex((LIMIT * MAX_FRAMES) as _); let mut frame_writer = Conn::test(io_write); // Rate limiter allowing LIMIT bytes/s - let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT); + let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT)?; // Prepare a frame to send, assert its size. let data = Datagrams::from(b"hello world!!1"); let target = SecretKey::generate(rand::thread_rng()).public(); - let frame = ClientToServerMsg::SendDatagrams { + let frame = ClientToServerMsg::Datagrams { dst_node_id: target, datagrams: data.clone(), }; diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 382c7fb5fdf..4ec2cd73600 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -258,7 +258,7 @@ mod tests { let frame = recv_frame(FrameType::RecvDatagrams, &mut a_rw).await?; assert_eq!( frame, - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id: b_key, datagrams: data.to_vec().into(), } @@ -269,7 +269,7 @@ mod tests { let frame = recv_frame(FrameType::RecvDatagrams, &mut a_rw).await?; assert_eq!( frame, - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id: b_key, datagrams: data.to_vec().into(), } diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 47d574e30d1..d6be4ca4bd3 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -20,7 +20,7 @@ use tokio_rustls_acme::AcmeAcceptor; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{debug, debug_span, error, info, info_span, trace, warn, Instrument}; -use super::{clients::Clients, AccessConfig, SpawnError}; +use super::{clients::Clients, streams::InvalidBucketConfig, AccessConfig, SpawnError}; #[allow(deprecated)] use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, @@ -203,6 +203,20 @@ pub enum ServeConnectionError { }, } +/// Server accept errors. +#[common_fields({ + backtrace: Option, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum AcceptError { + #[snafu(transparent)] + Handshake { source: handshake::Error }, + #[snafu(display("rate limiting misconfigured"))] + RateLimitingMisconfigured { source: InvalidBucketConfig }, +} + /// Server connection errors, includes errors that can happen on `accept`. #[common_fields({ backtrace: Option, @@ -212,7 +226,7 @@ pub enum ServeConnectionError { #[non_exhaustive] pub enum ConnectionHandlerError { #[snafu(transparent)] - Accept { source: handshake::Error }, + Accept { source: AcceptError }, #[snafu(display("Could not downcast the upgraded connection to MaybeTlsStream"))] DowncastUpgrade { #[snafu(implicit)] @@ -601,12 +615,13 @@ impl Inner { &self, io: MaybeTlsStream, client_auth_header: Option, - ) -> Result<(), handshake::Error> { + ) -> Result<(), AcceptError> { trace!("accept: start"); - let io = RateLimited::from_cfg(self.rate_limit, io, self.metrics.clone()); + let io = RateLimited::from_cfg(self.rate_limit, io, self.metrics.clone()) + .context(RateLimitingMisconfiguredSnafu)?; - self.metrics.websocket_accepts.inc(); + self.metrics.accepts.inc(); // Since we already did the HTTP upgrade in the previous step, // we use tokio-websockets to handle this connection // Create a server builder with default config @@ -892,7 +907,7 @@ mod tests { info!("sending message from a to b"); let msg = Datagrams::from(b"hi there, client b!"); client_a - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: b_key, datagrams: msg.clone(), }) @@ -906,7 +921,7 @@ mod tests { info!("sending message from b to a"); let msg = Datagrams::from(b"right back at ya, client b!"); client_b - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: a_key, datagrams: msg.clone(), }) @@ -947,7 +962,7 @@ mod tests { } Some(Ok(msg)) => { info!("got message on: {msg:?}"); - if let ServerToClientMsg::ReceivedDatagrams { + if let ServerToClientMsg::Datagrams { remote_node_id: source, datagrams, } = msg @@ -1012,7 +1027,7 @@ mod tests { info!("sending message from a to b"); let msg = Datagrams::from(b"hi there, client b!"); client_a - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: b_key, datagrams: msg.clone(), }) @@ -1026,7 +1041,7 @@ mod tests { info!("sending message from b to a"); let msg = Datagrams::from(b"right back at ya, client b!"); client_b - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: a_key, datagrams: msg.clone(), }) @@ -1089,13 +1104,13 @@ mod tests { info!("Send message from A to B."); let msg = Datagrams::from(b"hello client b!!"); client_a - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: public_key_b, datagrams: msg.clone(), }) .await?; match client_b.next().await.unwrap()? { - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1110,13 +1125,13 @@ mod tests { info!("Send message from B to A."); let msg = Datagrams::from(b"nice to meet you client a!!"); client_b - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: public_key_a, datagrams: msg.clone(), }) .await?; match client_a.next().await.unwrap()? { - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1134,7 +1149,7 @@ mod tests { info!("Fail to send message from A to B."); let res = client_a - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: public_key_b, datagrams: Datagrams::from(b"try to send"), }) @@ -1179,13 +1194,13 @@ mod tests { info!("Send message from A to B."); let msg = Datagrams::from(b"hello client b!!"); client_a - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: public_key_b, datagrams: msg.clone(), }) .await?; match client_b.next().await.expect("eos")? { - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1200,13 +1215,13 @@ mod tests { info!("Send message from B to A."); let msg = Datagrams::from(b"nice to meet you client a!!"); client_b - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: public_key_a, datagrams: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1231,13 +1246,13 @@ mod tests { info!("Send message from A to B."); let msg = Datagrams::from(b"are you still there, b?!"); client_a - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: public_key_b, datagrams: msg.clone(), }) .await?; match new_client_b.next().await.expect("eos")? { - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1252,13 +1267,13 @@ mod tests { info!("Send message from B to A."); let msg = Datagrams::from(b"just had a spot of trouble but I'm back now,a!!"); new_client_b - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: public_key_a, datagrams: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1275,7 +1290,7 @@ mod tests { info!("Sending message from A to B fails"); let res = client_a - .send(ClientToServerMsg::SendDatagrams { + .send(ClientToServerMsg::Datagrams { dst_node_id: public_key_b, datagrams: Datagrams::from(b"try to send"), }) diff --git a/iroh-relay/src/server/metrics.rs b/iroh-relay/src/server/metrics.rs index b2b74da674c..9e75f116dea 100644 --- a/iroh-relay/src/server/metrics.rs +++ b/iroh-relay/src/server/metrics.rs @@ -75,8 +75,6 @@ pub struct Metrics { /// Number of accepted websocket connections pub websocket_accepts: Counter, - /// Number of accepted 'iroh derp http' connection upgrades - pub relay_accepts: Counter, // TODO: enable when we can have multiple connections for one node id // pub duplicate_client_keys: Counter, // pub duplicate_client_conns: Counter, diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index c77a3b819f5..09e2c2edd77 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -8,7 +8,7 @@ use std::{ use bytes::BytesMut; use n0_future::{ready, time, FutureExt, Sink, Stream}; -use snafu::Snafu; +use snafu::{Backtrace, Snafu}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::instrument; @@ -51,22 +51,22 @@ impl RelayedStream { stream: tokio::io::DuplexStream, max_burst_bytes: u32, bytes_per_second: u32, - ) -> Self { + ) -> Result { let stream = MaybeTlsStream::Test(stream); let stream = RateLimited::new( stream, max_burst_bytes, bytes_per_second, Arc::new(Metrics::default()), - ); - Self { + )?; + Ok(Self { inner: WsBytesFramed { io: tokio_websockets::ServerBuilder::new() .limits(Self::limits()) .serve(stream), }, key_cache: KeyCache::test(), - } + }) } fn limits() -> tokio_websockets::Limits { @@ -249,20 +249,39 @@ struct Bucket { refill: i64, } +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +pub struct InvalidBucketConfig { + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, + max: i64, + bytes_per_second: i64, + refill_period: time::Duration, +} + impl Bucket { - fn new(max: i64, bytes_per_second: i64, refill_period: time::Duration) -> Self { - // TODO(matheus23) convert to errors - debug_assert!(max > 0); - debug_assert!(bytes_per_second > 0); - debug_assert_ne!(refill_period.as_millis(), 0); + fn new( + max: i64, + bytes_per_second: i64, + refill_period: time::Duration, + ) -> Result { // milliseconds is the tokio timer resolution - Self { + snafu::ensure!( + max > 0 && bytes_per_second > 0 && refill_period.as_millis() != 0, + InvalidBucketConfigSnafu { + max, + bytes_per_second, + refill_period, + }, + ); + Ok(Self { fill: max, max, last_fill: time::Instant::now(), refill_period, refill: bytes_per_second * refill_period.as_millis() as i64 / 1000, - } + }) } fn update_state(&mut self) { @@ -297,14 +316,18 @@ impl Bucket { } impl RateLimited { - pub(crate) fn from_cfg(cfg: Option, io: S, metrics: Arc) -> Self { + pub(crate) fn from_cfg( + cfg: Option, + io: S, + metrics: Arc, + ) -> Result { match cfg { Some(cfg) => { let bytes_per_second = cfg.bytes_per_second.into(); let max_burst_bytes = cfg.max_burst_bytes.map_or(bytes_per_second / 10, u32::from); Self::new(io, max_burst_bytes, bytes_per_second, metrics) } - None => Self::unlimited(io, metrics), + None => Ok(Self::unlimited(io, metrics)), } } @@ -313,18 +336,18 @@ impl RateLimited { max_burst_bytes: u32, bytes_per_second: u32, metrics: Arc, - ) -> Self { - Self { + ) -> Result { + Ok(Self { inner, bucket: Some(Bucket::new( max_burst_bytes as i64, bytes_per_second as i64, time::Duration::from_millis(100), - )), + )?), bucket_refilled: None, limited_once: false, metrics, - } + }) } pub(crate) fn unlimited(inner: S, metrics: Arc) -> Self { @@ -447,7 +470,7 @@ mod tests { bytes_per_second / 10, bytes_per_second, Arc::new(Metrics::default()), - ); + )?; let before = time::Instant::now(); n0_future::future::try_zip( @@ -471,6 +494,7 @@ mod tests { assert_ne!(duration.as_millis(), 0); let actual_bytes_per_second = send_total as f64 / duration.as_secs_f64(); + println!("{actual_bytes_per_second}"); assert_eq!(actual_bytes_per_second.round() as u32, bytes_per_second); Ok(()) diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index 285377c3591..627cf97f119 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -625,10 +625,10 @@ impl ActiveRelayActor { // TODO(frando): can we avoid the clone here? let metrics = self.metrics.clone(); let packet_iter = batch.into_iter().map(|item| { - Ok(ClientToServerMsg::SendDatagrams { dst_node_id: item.remote_node, datagrams: item.datagrams }) + Ok(ClientToServerMsg::Datagrams { dst_node_id: item.remote_node, datagrams: item.datagrams }) }); let mut packet_stream = n0_future::stream::iter(packet_iter).inspect(|m| { - if let Ok(ClientToServerMsg::SendDatagrams { dst_node_id: _node_id, datagrams }) = m { + if let Ok(ClientToServerMsg::Datagrams { dst_node_id: _node_id, datagrams }) = m { metrics.send_relay.inc_by(datagrams.contents.len() as _); } }); @@ -666,7 +666,7 @@ impl ActiveRelayActor { fn handle_relay_msg(&mut self, msg: ServerToClientMsg, state: &mut ConnectedRelayState) { match msg { - ServerToClientMsg::ReceivedDatagrams { + ServerToClientMsg::Datagrams { remote_node_id, datagrams, } => { From f4528b3b99caab383d16025ee82db4be67a526d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 16:38:04 +0200 Subject: [PATCH 42/80] Rename `Server` to `Relay` in general --- iroh-relay/src/client.rs | 16 ++--- iroh-relay/src/client/conn.rs | 16 ++--- iroh-relay/src/protos/send_recv.rs | 67 ++++++++++---------- iroh-relay/src/server.rs | 14 ++-- iroh-relay/src/server/client.rs | 40 ++++++------ iroh-relay/src/server/clients.rs | 10 +-- iroh-relay/src/server/http_server.rs | 58 ++++++++--------- iroh-relay/src/server/streams.rs | 14 ++-- iroh/src/magicsock/transports/relay/actor.rs | 28 ++++---- 9 files changed, 131 insertions(+), 132 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index b43274b3b7d..9c8d2bb5d2a 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -29,7 +29,7 @@ use crate::{ http::RELAY_PATH, protos::{ handshake, - send_recv::{ClientToServerMsg, ServerToClientMsg}, + send_recv::{ClientToRelayMsg, RelayToClientMsg}, }, KeyCache, }; @@ -358,14 +358,14 @@ impl Client { } impl Stream for Client { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.conn).poll_next(cx) } } -impl Sink for Client { +impl Sink for Client { type Error = SendError; fn poll_ready( @@ -375,7 +375,7 @@ impl Sink for Client { Pin::new(&mut self.conn).poll_ready(cx) } - fn start_send(mut self: Pin<&mut Self>, item: ClientToServerMsg) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: ClientToRelayMsg) -> Result<(), Self::Error> { Pin::new(&mut self.conn).start_send(item) } @@ -397,10 +397,10 @@ impl Sink for Client { /// The send half of a relay client. #[derive(Debug)] pub struct ClientSink { - sink: SplitSink, + sink: SplitSink, } -impl Sink for ClientSink { +impl Sink for ClientSink { type Error = SendError; fn poll_ready( @@ -410,7 +410,7 @@ impl Sink for ClientSink { Pin::new(&mut self.sink).poll_ready(cx) } - fn start_send(mut self: Pin<&mut Self>, item: ClientToServerMsg) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: ClientToRelayMsg) -> Result<(), Self::Error> { Pin::new(&mut self.sink).start_send(item) } @@ -444,7 +444,7 @@ impl ClientStream { } impl Stream for ClientStream { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { Pin::new(&mut self.stream).poll_next(cx) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 5b517fb6157..e1a3cd36b01 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -21,7 +21,7 @@ use crate::{ protos::{ handshake, send_recv::{ - ClientToServerMsg, Error as RecvRelayError, ServerToClientMsg, MAX_PAYLOAD_SIZE, + ClientToRelayMsg, Error as RecvRelayError, RelayToClientMsg, MAX_PAYLOAD_SIZE, }, streams::WsBytesFramed, }, @@ -74,8 +74,8 @@ pub enum RecvError { /// /// This holds a connection to a relay server. It is: /// -/// - A [`Stream`] for [`ServerToClientMsg`] to receive from the server. -/// - A [`Sink`] for [`ClientToServerMsg`] to send to the server. +/// - A [`Stream`] for [`RelayToClientMsg`] to receive from the server. +/// - A [`Sink`] for [`ClientToRelayMsg`] to send to the server. #[derive(derive_more::Debug)] pub(crate) struct Conn { #[debug("tokio_websockets::WebSocketStream")] @@ -124,13 +124,13 @@ impl Conn { } impl Stream for Conn { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let msg = ready!(Pin::new(&mut self.conn).poll_next(cx)); match msg { Some(Ok(msg)) => { - let message = ServerToClientMsg::from_bytes(msg, &self.key_cache); + let message = RelayToClientMsg::from_bytes(msg, &self.key_cache); Poll::Ready(Some(message.map_err(Into::into))) } Some(Err(e)) => Poll::Ready(Some(Err(e.into()))), @@ -139,15 +139,15 @@ impl Stream for Conn { } } -impl Sink for Conn { +impl Sink for Conn { type Error = SendError; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.conn).poll_ready(cx).map_err(Into::into) } - fn start_send(mut self: Pin<&mut Self>, frame: ClientToServerMsg) -> Result<(), Self::Error> { - if let ClientToServerMsg::Datagrams { datagrams, .. } = &frame { + fn start_send(mut self: Pin<&mut Self>, frame: ClientToRelayMsg) -> Result<(), Self::Error> { + if let ClientToRelayMsg::Datagrams { datagrams, .. } = &frame { let size = datagrams.contents.len(); snafu::ensure!(size <= MAX_PAYLOAD_SIZE, ExceedsMaxPacketSizeSnafu { size }); } diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs index b24e07544c4..fe7357b4b52 100644 --- a/iroh-relay/src/protos/send_recv.rs +++ b/iroh-relay/src/protos/send_recv.rs @@ -81,7 +81,7 @@ pub enum Error { /// The messages that a relay sends to clients or the clients receive from the relay. #[derive(derive_more::Debug, Clone, PartialEq, Eq)] -pub enum ServerToClientMsg { +pub enum RelayToClientMsg { /// Represents datagrams sent from relays (originally sent to them by another client). Datagrams { /// The [`NodeId`] of the original sender. @@ -98,7 +98,7 @@ pub enum ServerToClientMsg { /// /// If `None` means the connection is healthy again. /// - /// The default condition is healthy, so the relay doesn't broadcast a [`ServerToClientMsg::Health`] + /// The default condition is healthy, so the relay doesn't broadcast a [`RelayToClientMsg::Health`] /// until a problem exists. problem: String, }, @@ -114,20 +114,20 @@ pub enum ServerToClientMsg { try_for: Duration, }, /// Request from the relay to reply to the - /// other side with a [`ClientToServerMsg::Pong`] with the given payload. + /// other side with a [`ClientToRelayMsg::Pong`] with the given payload. Ping([u8; 8]), - /// Reply to a [`ClientToServerMsg::Ping`] from a client + /// Reply to a [`ClientToRelayMsg::Ping`] from a client /// with the payload sent previously in the ping. Pong([u8; 8]), } /// Messages that clients send to relays. #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ClientToServerMsg { +pub enum ClientToRelayMsg { /// Request from the client to the server to reply to the - /// other side with a [`ServerToClientMsg::Pong`] with the given payload. + /// other side with a [`RelayToClientMsg::Pong`] with the given payload. Ping([u8; 8]), - /// Reply to a [`ServerToClientMsg::Ping`] from a server + /// Reply to a [`RelayToClientMsg::Ping`] from a server /// with the payload sent previously in the ping. Pong([u8; 8]), /// Request from the client to relay datagrams to given remote node. @@ -199,7 +199,7 @@ impl Datagrams { } } -impl ServerToClientMsg { +impl RelayToClientMsg { /// Returns this frame's corresponding frame type. pub fn typ(&self) -> FrameType { match self { @@ -326,7 +326,7 @@ impl ServerToClientMsg { } } -impl ClientToServerMsg { +impl ClientToRelayMsg { pub(crate) fn typ(&self) -> FrameType { match self { Self::Datagrams { .. } => FrameType::SendDatagrams, @@ -434,7 +434,7 @@ mod tests { check_expected_bytes(vec![ ( - ServerToClientMsg::Health { + RelayToClientMsg::Health { problem: "Hello? Yes this is dog.".into(), } .write_to(Vec::new()), @@ -442,21 +442,21 @@ mod tests { 20 69 73 20 64 6f 67 2e", ), ( - ServerToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), + RelayToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), "0e 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61", ), ( - ServerToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), + RelayToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), "0f 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ServerToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), + RelayToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), "10 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id: client_key.public(), datagrams: Datagrams { ecn: Some(quinn::EcnCodepoint::Ce), @@ -479,7 +479,7 @@ mod tests { 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", ), ( - ServerToClientMsg::Restarting { + RelayToClientMsg::Restarting { reconnect_in: Duration::from_millis(10), try_for: Duration::from_millis(20), } @@ -497,15 +497,15 @@ mod tests { check_expected_bytes(vec![ ( - ClientToServerMsg::Ping([42u8; 8]).write_to(Vec::new()), + ClientToRelayMsg::Ping([42u8; 8]).write_to(Vec::new()), "0f 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ClientToServerMsg::Pong([42u8; 8]).write_to(Vec::new()), + ClientToRelayMsg::Pong([42u8; 8]).write_to(Vec::new()), "10 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ClientToServerMsg::Datagrams { + ClientToRelayMsg::Datagrams { dst_node_id: client_key.public(), datagrams: Datagrams { ecn: Some(quinn::EcnCodepoint::Ce), @@ -573,23 +573,23 @@ mod proptests { } /// Generates a random valid frame - fn server_client_frame() -> impl Strategy { + fn server_client_frame() -> impl Strategy { let recv_packet = (key(), datagrams()).prop_map(|(remote_node_id, datagrams)| { - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id, datagrams, } }); - let node_gone = key().prop_map(ServerToClientMsg::NodeGone); - let ping = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Ping); - let pong = prop::array::uniform8(any::()).prop_map(ServerToClientMsg::Pong); + let node_gone = key().prop_map(RelayToClientMsg::NodeGone); + let ping = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Ping); + let pong = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Pong); let health = ".{0,65536}" .prop_filter("exceeds MAX_PAYLOAD_SIZE", |s| { s.len() < MAX_PAYLOAD_SIZE // a single unicode character can match a regex "." but take up multiple bytes }) - .prop_map(|problem| ServerToClientMsg::Health { problem }); + .prop_map(|problem| RelayToClientMsg::Health { problem }); let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { - ServerToClientMsg::Restarting { + RelayToClientMsg::Restarting { reconnect_in: Duration::from_millis(reconnect_in.into()), try_for: Duration::from_millis(try_for.into()), } @@ -597,15 +597,14 @@ mod proptests { prop_oneof![recv_packet, node_gone, ping, pong, health, restarting] } - fn client_server_frame() -> impl Strategy { - let send_packet = (key(), datagrams()).prop_map(|(dst_node_id, datagrams)| { - ClientToServerMsg::Datagrams { + fn client_server_frame() -> impl Strategy { + let send_packet = + (key(), datagrams()).prop_map(|(dst_node_id, datagrams)| ClientToRelayMsg::Datagrams { dst_node_id, datagrams, - } - }); - let ping = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Ping); - let pong = prop::array::uniform8(any::()).prop_map(ClientToServerMsg::Pong); + }); + let ping = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Ping); + let pong = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Pong); prop_oneof![send_packet, ping, pong] } @@ -613,14 +612,14 @@ mod proptests { #[test] fn server_client_frame_roundtrip(frame in server_client_frame()) { let encoded = frame.clone().write_to(BytesMut::new()).freeze(); - let decoded = ServerToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); + let decoded = RelayToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); prop_assert_eq!(frame, decoded); } #[test] fn client_server_frame_roundtrip(frame in client_server_frame()) { let encoded = frame.clone().write_to(BytesMut::new()).freeze(); - let decoded = ClientToServerMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); + let decoded = ClientToRelayMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); prop_assert_eq!(frame, decoded); } } diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index defa392057f..94760cc7054 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -766,7 +766,7 @@ mod tests { dns::DnsResolver, protos::{ handshake, - send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg}, + send_recv::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, }, }; @@ -791,11 +791,11 @@ mod tests { client_b: &mut crate::client::Client, b_key: NodeId, msg: Datagrams, - ) -> Result { + ) -> Result { // try resend 10 times for _ in 0..10 { client_a - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: b_key, datagrams: msg.clone(), }) @@ -913,7 +913,7 @@ mod tests { // send message from a to b let msg = Datagrams::from("hello, b"); let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?; - let ServerToClientMsg::Datagrams { + let RelayToClientMsg::Datagrams { remote_node_id, datagrams, } = res @@ -929,7 +929,7 @@ mod tests { let msg = Datagrams::from("howdy, a"); let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?; - let ServerToClientMsg::Datagrams { + let RelayToClientMsg::Datagrams { remote_node_id, datagrams, } = res @@ -1013,7 +1013,7 @@ mod tests { let msg = Datagrams::from("hello, c"); let res = try_send_recv(&mut client_b, &mut client_c, c_key, msg.clone()).await?; - if let ServerToClientMsg::Datagrams { + if let RelayToClientMsg::Datagrams { remote_node_id, datagrams, } = res @@ -1054,7 +1054,7 @@ mod tests { let msg = Datagrams::from("hello, b"); for _i in 0..1000 { client_a - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: b_key, datagrams: msg.clone(), }) diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 5e07dfb4478..7b44e2d85bc 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -18,7 +18,7 @@ use tracing::{debug, trace, warn, Instrument}; use crate::{ protos::{ disco, - send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg, PING_INTERVAL}, + send_recv::{ClientToRelayMsg, Datagrams, RelayToClientMsg, PING_INTERVAL}, streams::StreamError, }, server::{ @@ -369,7 +369,7 @@ impl Actor { node_id = self.node_gone.recv() => { let node_id = node_id.ok_or(NodeGoneDropSnafu.build())?; trace!("node_id gone: {:?}", node_id); - self.write_frame(ServerToClientMsg::NodeGone(node_id)).await.context(NodeGoneWriteFrameSnafu)?; + self.write_frame(RelayToClientMsg::NodeGone(node_id)).await.context(NodeGoneWriteFrameSnafu)?; } _ = self.ping_tracker.timeout() => { trace!("pong timed out"); @@ -380,7 +380,7 @@ impl Actor { // new interval ping_interval.reset_after(next_interval()); let data = self.ping_tracker.new_ping(); - self.write_frame(ServerToClientMsg::Ping(data)).await.context(KeepAliveWriteFrameSnafu)?; + self.write_frame(RelayToClientMsg::Ping(data)).await.context(KeepAliveWriteFrameSnafu)?; } } @@ -395,7 +395,7 @@ impl Actor { /// Writes the given frame to the connection. /// /// Errors if the send does not happen within the `timeout` duration - async fn write_frame(&mut self, frame: ServerToClientMsg) -> Result<(), SendFrameError> { + async fn write_frame(&mut self, frame: RelayToClientMsg) -> Result<(), SendFrameError> { tokio::time::timeout(self.timeout, self.stream.send(frame)).await??; Ok(()) } @@ -411,7 +411,7 @@ impl Actor { if let Ok(len) = datagrams.contents.len().try_into() { self.metrics.bytes_sent.inc_by(len); } - self.write_frame(ServerToClientMsg::Datagrams { + self.write_frame(RelayToClientMsg::Datagrams { remote_node_id, datagrams, }) @@ -449,7 +449,7 @@ impl Actor { /// Handles frame read results. async fn handle_frame( &mut self, - maybe_frame: Option>, + maybe_frame: Option>, ) -> Result<(), HandleFrameError> { trace!(?maybe_frame, "handle incoming frame"); let frame = match maybe_frame { @@ -458,7 +458,7 @@ impl Actor { }; match frame { - ClientToServerMsg::Datagrams { + ClientToRelayMsg::Datagrams { dst_node_id: dst_key, datagrams, } => { @@ -470,13 +470,13 @@ impl Actor { } self.metrics.bytes_recv.inc_by(packet_len as u64); } - ClientToServerMsg::Ping(data) => { + ClientToRelayMsg::Ping(data) => { self.metrics.got_ping.inc(); // TODO: add rate limiter - self.write_frame(ServerToClientMsg::Pong(data)).await?; + self.write_frame(RelayToClientMsg::Pong(data)).await?; self.metrics.sent_pong.inc(); } - ClientToServerMsg::Pong(data) => { + ClientToRelayMsg::Pong(data) => { self.ping_tracker.pong_received(data); } } @@ -577,11 +577,11 @@ mod tests { async fn recv_frame< E: snafu::Error + Sync + Send + 'static, - S: Stream> + Unpin, + S: Stream> + Unpin, >( frame_type: FrameType, mut stream: S, - ) -> Result { + ) -> Result { match stream.next().await { Some(Ok(frame)) => { if frame_type != frame.typ() { @@ -644,7 +644,7 @@ mod tests { let frame = recv_frame(FrameType::RecvDatagrams, &mut io_rw).await.e()?; assert_eq!( frame, - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id: node_id, datagrams: data.to_vec().into() } @@ -659,7 +659,7 @@ mod tests { let frame = recv_frame(FrameType::RecvDatagrams, &mut io_rw).await.e()?; assert_eq!( frame, - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id: node_id, datagrams: data.to_vec().into() } @@ -669,19 +669,19 @@ mod tests { println!("send peer gone"); peer_gone_s.send(node_id).await.context("send")?; let frame = recv_frame(FrameType::NodeGone, &mut io_rw).await.e()?; - assert_eq!(frame, ServerToClientMsg::NodeGone(node_id)); + assert_eq!(frame, RelayToClientMsg::NodeGone(node_id)); // Read tests println!("--read"); // send ping, expect pong let data = b"pingpong"; - io_rw.send(ClientToServerMsg::Ping(*data)).await?; + io_rw.send(ClientToRelayMsg::Ping(*data)).await?; // recv pong println!(" recv pong"); let frame = recv_frame(FrameType::Pong, &mut io_rw).await?; - assert_eq!(frame, ServerToClientMsg::Pong(*data)); + assert_eq!(frame, RelayToClientMsg::Pong(*data)); let target = SecretKey::generate(rand::thread_rng()).public(); @@ -689,7 +689,7 @@ mod tests { println!(" send packet"); let data = b"hello world!"; io_rw - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: target, datagrams: Datagrams::from(data), }) @@ -703,7 +703,7 @@ mod tests { disco_data.extend_from_slice(target.as_bytes()); disco_data.extend_from_slice(data); io_rw - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: target, datagrams: disco_data.clone().into(), }) @@ -730,7 +730,7 @@ mod tests { // Prepare a frame to send, assert its size. let data = Datagrams::from(b"hello world!!1"); let target = SecretKey::generate(rand::thread_rng()).public(); - let frame = ClientToServerMsg::Datagrams { + let frame = ClientToRelayMsg::Datagrams { dst_node_id: target, datagrams: data.clone(), }; diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 4ec2cd73600..e3555e28b05 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -201,17 +201,17 @@ mod tests { use super::*; use crate::{ client::conn::Conn, - protos::{relay::FrameType, send_recv::ServerToClientMsg}, + protos::{relay::FrameType, send_recv::RelayToClientMsg}, server::streams::RelayedStream, }; async fn recv_frame< E: snafu::Error + Sync + Send + 'static, - S: Stream> + Unpin, + S: Stream> + Unpin, >( frame_type: FrameType, mut stream: S, - ) -> Result { + ) -> Result { match stream.next().await { Some(Ok(frame)) => { if frame_type != frame.typ() { @@ -258,7 +258,7 @@ mod tests { let frame = recv_frame(FrameType::RecvDatagrams, &mut a_rw).await?; assert_eq!( frame, - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id: b_key, datagrams: data.to_vec().into(), } @@ -269,7 +269,7 @@ mod tests { let frame = recv_frame(FrameType::RecvDatagrams, &mut a_rw).await?; assert_eq!( frame, - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id: b_key, datagrams: data.to_vec().into(), } diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index d6be4ca4bd3..69e3ab34e63 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -837,7 +837,7 @@ mod tests { use crate::{ client::{conn::Conn, Client, ClientBuilder, ConnectError}, dns::DnsResolver, - protos::send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg}, + protos::send_recv::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, }; pub(crate) fn make_tls_config() -> TlsConfig { @@ -895,19 +895,19 @@ mod tests { info!("created client {b_key:?}"); info!("ping a"); - client_a.send(ClientToServerMsg::Ping([1u8; 8])).await?; + client_a.send(ClientToRelayMsg::Ping([1u8; 8])).await?; let pong = client_a.next().await.expect("eos")?; - assert!(matches!(pong, ServerToClientMsg::Pong { .. })); + assert!(matches!(pong, RelayToClientMsg::Pong { .. })); info!("ping b"); - client_b.send(ClientToServerMsg::Ping([2u8; 8])).await?; + client_b.send(ClientToRelayMsg::Ping([2u8; 8])).await?; let pong = client_b.next().await.expect("eos")?; - assert!(matches!(pong, ServerToClientMsg::Pong { .. })); + assert!(matches!(pong, RelayToClientMsg::Pong { .. })); info!("sending message from a to b"); let msg = Datagrams::from(b"hi there, client b!"); client_a - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: b_key, datagrams: msg.clone(), }) @@ -921,7 +921,7 @@ mod tests { info!("sending message from b to a"); let msg = Datagrams::from(b"right back at ya, client b!"); client_b - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: a_key, datagrams: msg.clone(), }) @@ -953,7 +953,7 @@ mod tests { } fn process_msg( - msg: Option>, + msg: Option>, ) -> Option<(PublicKey, Datagrams)> { match msg { Some(Err(e)) => { @@ -962,7 +962,7 @@ mod tests { } Some(Ok(msg)) => { info!("got message on: {msg:?}"); - if let ServerToClientMsg::Datagrams { + if let RelayToClientMsg::Datagrams { remote_node_id: source, datagrams, } = msg @@ -1015,19 +1015,19 @@ mod tests { info!("created client {b_key:?}"); info!("ping a"); - client_a.send(ClientToServerMsg::Ping([1u8; 8])).await?; + client_a.send(ClientToRelayMsg::Ping([1u8; 8])).await?; let pong = client_a.next().await.expect("eos")?; - assert!(matches!(pong, ServerToClientMsg::Pong { .. })); + assert!(matches!(pong, RelayToClientMsg::Pong { .. })); info!("ping b"); - client_b.send(ClientToServerMsg::Ping([2u8; 8])).await?; + client_b.send(ClientToRelayMsg::Ping([2u8; 8])).await?; let pong = client_b.next().await.expect("eos")?; - assert!(matches!(pong, ServerToClientMsg::Pong { .. })); + assert!(matches!(pong, RelayToClientMsg::Pong { .. })); info!("sending message from a to b"); let msg = Datagrams::from(b"hi there, client b!"); client_a - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: b_key, datagrams: msg.clone(), }) @@ -1041,7 +1041,7 @@ mod tests { info!("sending message from b to a"); let msg = Datagrams::from(b"right back at ya, client b!"); client_b - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: a_key, datagrams: msg.clone(), }) @@ -1104,13 +1104,13 @@ mod tests { info!("Send message from A to B."); let msg = Datagrams::from(b"hello client b!!"); client_a - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: public_key_b, datagrams: msg.clone(), }) .await?; match client_b.next().await.unwrap()? { - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1125,13 +1125,13 @@ mod tests { info!("Send message from B to A."); let msg = Datagrams::from(b"nice to meet you client a!!"); client_b - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: public_key_a, datagrams: msg.clone(), }) .await?; match client_a.next().await.unwrap()? { - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1149,7 +1149,7 @@ mod tests { info!("Fail to send message from A to B."); let res = client_a - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: public_key_b, datagrams: Datagrams::from(b"try to send"), }) @@ -1194,13 +1194,13 @@ mod tests { info!("Send message from A to B."); let msg = Datagrams::from(b"hello client b!!"); client_a - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: public_key_b, datagrams: msg.clone(), }) .await?; match client_b.next().await.expect("eos")? { - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1215,13 +1215,13 @@ mod tests { info!("Send message from B to A."); let msg = Datagrams::from(b"nice to meet you client a!!"); client_b - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: public_key_a, datagrams: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1246,13 +1246,13 @@ mod tests { info!("Send message from A to B."); let msg = Datagrams::from(b"are you still there, b?!"); client_a - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: public_key_b, datagrams: msg.clone(), }) .await?; match new_client_b.next().await.expect("eos")? { - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1267,13 +1267,13 @@ mod tests { info!("Send message from B to A."); let msg = Datagrams::from(b"just had a spot of trouble but I'm back now,a!!"); new_client_b - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: public_key_a, datagrams: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -1290,7 +1290,7 @@ mod tests { info!("Sending message from A to B fails"); let res = client_a - .send(ClientToServerMsg::Datagrams { + .send(ClientToRelayMsg::Datagrams { dst_node_id: public_key_b, datagrams: Datagrams::from(b"try to send"), }) diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 09e2c2edd77..e6dbdcf9318 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -15,7 +15,7 @@ use tracing::instrument; use super::{ClientRateLimit, Metrics}; use crate::{ protos::{ - send_recv::{ClientToServerMsg, Error as ProtoError, ServerToClientMsg}, + send_recv::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg}, streams::{StreamError, WsBytesFramed}, }, ExportKeyingMaterial, KeyCache, @@ -24,8 +24,8 @@ use crate::{ /// The relay's connection to a client. /// /// This implements -/// - a [`Stream`] of [`ClientToServerMsg`]s that are received from the client, -/// - a [`Sink`] of [`ServerToClientMsg`]s that can be sent to the client. +/// - a [`Stream`] of [`ClientToRelayMsg`]s that are received from the client, +/// - a [`Sink`] of [`RelayToClientMsg`]s that can be sent to the client. #[derive(Debug)] pub(crate) struct RelayedStream { pub(crate) inner: WsBytesFramed>, @@ -75,14 +75,14 @@ impl RelayedStream { } } -impl Sink for RelayedStream { +impl Sink for RelayedStream { type Error = StreamError; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_ready(cx) } - fn start_send(mut self: Pin<&mut Self>, item: ServerToClientMsg) -> Result<(), Self::Error> { + fn start_send(mut self: Pin<&mut Self>, item: RelayToClientMsg) -> Result<(), Self::Error> { Pin::new(&mut self.inner).start_send(item.write_to(BytesMut::new()).freeze()) } @@ -106,12 +106,12 @@ pub enum RecvError { } impl Stream for RelayedStream { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Poll::Ready(match ready!(Pin::new(&mut self.inner).poll_next(cx)) { Some(Ok(msg)) => { - Some(ClientToServerMsg::from_bytes(msg, &self.key_cache).map_err(Into::into)) + Some(ClientToRelayMsg::from_bytes(msg, &self.key_cache).map_err(Into::into)) } Some(Err(e)) => Some(Err(e.into())), None => None, diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index 627cf97f119..b51bface9ab 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -42,7 +42,7 @@ use iroh_base::{NodeId, RelayUrl, SecretKey}; use iroh_relay::{ self as relay, client::{Client, ConnectError, RecvError, SendError}, - protos::send_recv::{ClientToServerMsg, Datagrams, ServerToClientMsg}, + protos::send_recv::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, PingTracker, }; use n0_future::{ @@ -543,7 +543,7 @@ impl ActiveRelayActor { let res = loop { if let Some(data) = state.pong_pending.take() { - let fut = client_sink.send(ClientToServerMsg::Pong(data)); + let fut = client_sink.send(ClientToRelayMsg::Pong(data)); self.run_sending(fut, &mut state, &mut client_stream) .await?; } @@ -570,7 +570,7 @@ impl ActiveRelayActor { } _ = ping_interval.tick() => { let data = state.ping_tracker.new_ping(); - let fut = client_sink.send(ClientToServerMsg::Ping(data)); + let fut = client_sink.send(ClientToRelayMsg::Ping(data)); self.run_sending(fut, &mut state, &mut client_stream).await?; } msg = self.inbox.recv() => { @@ -586,7 +586,7 @@ impl ActiveRelayActor { match client_stream.local_addr() { Some(addr) if local_ips.contains(&addr.ip()) => { let data = state.ping_tracker.new_ping(); - let fut = client_sink.send(ClientToServerMsg::Ping(data)); + let fut = client_sink.send(ClientToRelayMsg::Ping(data)); self.run_sending(fut, &mut state, &mut client_stream).await?; } Some(_) => break Err(LocalIpInvalidSnafu.build()), @@ -602,7 +602,7 @@ impl ActiveRelayActor { ActiveRelayMessage::PingServer(sender) => { let data = rand::random(); state.test_pong = Some((data, sender)); - let fut = client_sink.send(ClientToServerMsg::Ping(data)); + let fut = client_sink.send(ClientToRelayMsg::Ping(data)); self.run_sending(fut, &mut state, &mut client_stream).await?; } } @@ -625,10 +625,10 @@ impl ActiveRelayActor { // TODO(frando): can we avoid the clone here? let metrics = self.metrics.clone(); let packet_iter = batch.into_iter().map(|item| { - Ok(ClientToServerMsg::Datagrams { dst_node_id: item.remote_node, datagrams: item.datagrams }) + Ok(ClientToRelayMsg::Datagrams { dst_node_id: item.remote_node, datagrams: item.datagrams }) }); let mut packet_stream = n0_future::stream::iter(packet_iter).inspect(|m| { - if let Ok(ClientToServerMsg::Datagrams { dst_node_id: _node_id, datagrams }) = m { + if let Ok(ClientToRelayMsg::Datagrams { dst_node_id: _node_id, datagrams }) = m { metrics.send_relay.inc_by(datagrams.contents.len() as _); } }); @@ -664,9 +664,9 @@ impl ActiveRelayActor { res.map_err(|err| state.map_err(err)) } - fn handle_relay_msg(&mut self, msg: ServerToClientMsg, state: &mut ConnectedRelayState) { + fn handle_relay_msg(&mut self, msg: RelayToClientMsg, state: &mut ConnectedRelayState) { match msg { - ServerToClientMsg::Datagrams { + RelayToClientMsg::Datagrams { remote_node_id, datagrams, } => { @@ -690,11 +690,11 @@ impl ActiveRelayActor { warn!("Dropping received relay packet: {err:#}"); } } - ServerToClientMsg::NodeGone(node_id) => { + RelayToClientMsg::NodeGone(node_id) => { state.nodes_present.remove(&node_id); } - ServerToClientMsg::Ping(data) => state.pong_pending = Some(data), - ServerToClientMsg::Pong(data) => { + RelayToClientMsg::Ping(data) => state.pong_pending = Some(data), + RelayToClientMsg::Pong(data) => { #[cfg(test)] { if let Some((expected_data, sender)) = state.test_pong.take() { @@ -708,10 +708,10 @@ impl ActiveRelayActor { state.ping_tracker.pong_received(data); state.established = true; } - ServerToClientMsg::Health { problem } => { + RelayToClientMsg::Health { problem } => { warn!("Relay server reports problem: {problem}"); } - ServerToClientMsg::Restarting { .. } => { + RelayToClientMsg::Restarting { .. } => { trace!("Ignoring {msg:?}") } } From 9446e5f566550950c666ceefefa978832940e325 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 16:39:43 +0200 Subject: [PATCH 43/80] Typos --- iroh-relay/src/protos/handshake.rs | 2 +- iroh-relay/src/protos/relay.rs | 2 +- iroh-relay/src/protos/streams.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 8ba6a87b712..0ff582720e3 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -10,7 +10,7 @@ //! //! One way is via an explicitly sent challenge: //! -//! 1. Once a websocket connection is opened, a client recieves a challenge (the [`ServerChallenge`] frame) +//! 1. Once a websocket connection is opened, a client receives a challenge (the [`ServerChallenge`] frame) //! 2. The client sends back what is essentially a signature of that challenge with their secret key //! that matches the NodeId they have, as well as the NodeId (the [`ClientAuth`] frame) //! diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 31dc01cff79..b44901468af 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -1,6 +1,6 @@ //! Common types between the [`super::handshake`] and [`super::send_recv`] protocols. //! -//! Hosts the [`FrameType`] enum to make sure we're not accidentally re-using frame type +//! Hosts the [`FrameType`] enum to make sure we're not accidentally reusing frame type //! integers for different frames. use bytes::{BufMut, Bytes}; diff --git a/iroh-relay/src/protos/streams.rs b/iroh-relay/src/protos/streams.rs index 90c90aeaf8f..31b4d143ec5 100644 --- a/iroh-relay/src/protos/streams.rs +++ b/iroh-relay/src/protos/streams.rs @@ -1,4 +1,4 @@ -//! Implements logic for abstracing over a websocket stream that allows sending only [`Bytes`]-based +//! Implements logic for abstracting over a websocket stream that allows sending only [`Bytes`]-based //! messages. use std::{ pin::Pin, From d5b4317e49318dbabff3037b2d1e39da4d7bea9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 17:28:03 +0200 Subject: [PATCH 44/80] `protos::relay` -> `protos::common` -> `protos::send_recv` -> `protos::relay` --- iroh-relay/src/client.rs | 4 +- iroh-relay/src/client/conn.rs | 4 +- iroh-relay/src/lib.rs | 2 +- iroh-relay/src/protos.rs | 2 +- iroh-relay/src/protos/common.rs | 86 +++ iroh-relay/src/protos/handshake.rs | 2 +- iroh-relay/src/protos/relay.rs | 671 +++++++++++++++++-- iroh-relay/src/protos/send_recv.rs | 626 ----------------- iroh-relay/src/server.rs | 2 +- iroh-relay/src/server/client.rs | 12 +- iroh-relay/src/server/clients.rs | 8 +- iroh-relay/src/server/http_server.rs | 6 +- iroh-relay/src/server/streams.rs | 4 +- iroh/src/magicsock/transports/relay.rs | 2 +- iroh/src/magicsock/transports/relay/actor.rs | 4 +- 15 files changed, 717 insertions(+), 718 deletions(-) create mode 100644 iroh-relay/src/protos/common.rs delete mode 100644 iroh-relay/src/protos/send_recv.rs diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 9c8d2bb5d2a..86b9502393a 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -29,7 +29,7 @@ use crate::{ http::RELAY_PATH, protos::{ handshake, - send_recv::{ClientToRelayMsg, RelayToClientMsg}, + relay::{ClientToRelayMsg, RelayToClientMsg}, }, KeyCache, }; @@ -204,7 +204,7 @@ impl ClientBuilder { use crate::{ http::CLIENT_AUTH_HEADER, - protos::{handshake::KeyMaterialClientAuth, send_recv::MAX_FRAME_SIZE}, + protos::{handshake::KeyMaterialClientAuth, relay::MAX_FRAME_SIZE}, }; let mut dial_url = (*self.url).clone(); diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index e1a3cd36b01..a48739080af 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -20,7 +20,7 @@ use crate::client::streams::{MaybeTlsStream, ProxyStream}; use crate::{ protos::{ handshake, - send_recv::{ + relay::{ ClientToRelayMsg, Error as RecvRelayError, RelayToClientMsg, MAX_PAYLOAD_SIZE, }, streams::WsBytesFramed, @@ -90,7 +90,7 @@ pub(crate) struct Conn { impl Conn { #[cfg(test)] pub(crate) fn test(io: tokio::io::DuplexStream) -> Self { - use crate::protos::send_recv::MAX_FRAME_SIZE; + use crate::protos::relay::MAX_FRAME_SIZE; Self { conn: WsBytesFramed { io: tokio_websockets::ClientBuilder::new() diff --git a/iroh-relay/src/lib.rs b/iroh-relay/src/lib.rs index a073fd002ce..d8440f0baab 100644 --- a/iroh-relay/src/lib.rs +++ b/iroh-relay/src/lib.rs @@ -48,7 +48,7 @@ pub(crate) use key_cache::KeyCache; pub mod dns; pub mod node_info; -pub use protos::send_recv::MAX_PACKET_SIZE; +pub use protos::relay::MAX_PACKET_SIZE; pub use self::{ ping_tracker::PingTracker, diff --git a/iroh-relay/src/protos.rs b/iroh-relay/src/protos.rs index 89d631f47e5..043e9de434f 100644 --- a/iroh-relay/src/protos.rs +++ b/iroh-relay/src/protos.rs @@ -1,7 +1,7 @@ //! Protocols used by the iroh-relay +pub mod common; pub mod disco; pub mod handshake; pub mod relay; -pub mod send_recv; pub mod streams; diff --git a/iroh-relay/src/protos/common.rs b/iroh-relay/src/protos/common.rs new file mode 100644 index 00000000000..4ced6113473 --- /dev/null +++ b/iroh-relay/src/protos/common.rs @@ -0,0 +1,86 @@ +//! Common types between the [`super::handshake`] and [`super::send_recv`] protocols. +//! +//! Hosts the [`FrameType`] enum to make sure we're not accidentally reusing frame type +//! integers for different frames. + +use bytes::{BufMut, Bytes}; +use quinn_proto::{coding::Codec, VarInt}; + +/// Possible frame types during handshaking +#[repr(u32)] +#[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] +// needs to be pub due to being exposed in error types +pub enum FrameType { + /// The server frame type for the challenge response + ServerChallenge = 2, + /// The client frame type for the authentication frame + ClientAuth = 3, + /// The server frame type for authentication confirmation + ServerConfirmsAuth = 4, + /// The server frame type for authentication denial + ServerDeniesAuth = 5, + /// 32B dest pub key + ECN byte + segment size u16 + datagrams contents + ClientToRelayDatagrams = 10, + /// 32B src pub key + ECN byte + segment size u16 + datagrams contents + RelayToClientDatagrams = 11, + /// Sent from server to client to signal that a previous sender is no longer connected. + /// + /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` + /// to B so B can forget that a reverse path exists on that connection to get back to A + /// + /// 32B pub key of peer that's gone + NodeGone = 14, + /// Frames 9-11 concern meshing, which we have eliminated from our version of the protocol. + /// Messages with these frames will be ignored. + /// 8 byte ping payload, to be echoed back in FrameType::Pong + Ping = 15, + /// 8 byte payload, the contents of ping being replied to + Pong = 16, + /// Sent from server to client to tell the client if their connection is + /// unhealthy somehow. + Health = 17, + + /// Sent from server to client for the server to declare that it's restarting. + /// Payload is two big endian u32 durations in milliseconds: when to reconnect, + /// and how long to try total. + /// + /// Handled on the `[relay::Client]`, but currently never sent on the `[relay::Server]` + Restarting = 18, + /// The frame type was unknown. + /// + /// This frame is the result of parsing any future frame types that this implementation + /// does not yet understand. + #[num_enum(default)] + Unknown, +} + +impl std::fmt::Display for FrameType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl FrameType { + /// Writes the frame type to the buffer (as a QUIC-encoded varint). + pub(crate) fn write_to(&self, mut dst: O) -> O { + VarInt::from(*self).encode(&mut dst); + dst + } + + /// Parses the frame type (as a QUIC-encoded varint) from the first couple of bytes given + /// and returns the frame type and the rest. + pub(crate) fn from_bytes(bytes: Bytes) -> Option<(Self, Bytes)> { + let mut cursor = std::io::Cursor::new(&bytes); + let tag = VarInt::decode(&mut cursor).ok()?; + let tag_u32 = u32::try_from(u64::from(tag)).ok()?; + let frame_type = FrameType::from(tag_u32); + let content = bytes.slice(cursor.position() as usize..); + Some((frame_type, content)) + } +} + +impl From for VarInt { + fn from(value: FrameType) -> Self { + (value as u32).into() + } +} diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 0ff582720e3..1c09e5a0a54 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -38,7 +38,7 @@ use nested_enum_utils::common_fields; use rand::{CryptoRng, RngCore}; use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; -use super::{relay::FrameType, streams::BytesStreamSink}; +use super::{common::FrameType, streams::BytesStreamSink}; use crate::ExportKeyingMaterial; /// Authentication message from the client. diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index b44901468af..2ff18691db0 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -1,86 +1,621 @@ -//! Common types between the [`super::handshake`] and [`super::send_recv`] protocols. +//! This module implements the send/recv relaying protocol. //! -//! Hosts the [`FrameType`] enum to make sure we're not accidentally reusing frame type -//! integers for different frames. +//! Protocol flow: +//! * server occasionally sends [`FrameType::Ping`] +//! * client responds to any [`FrameType::Ping`] with a [`FrameType::Pong`] +//! * clients sends [`FrameType::ClientToRelayDatagrams`] +//! * server then sends [`FrameType::RelayToClientDatagrams`] to recipient +//! * server sends [`FrameType::NodeGone`] when the other client disconnects use bytes::{BufMut, Bytes}; -use quinn_proto::{coding::Codec, VarInt}; - -/// Possible frame types during handshaking -#[repr(u32)] -#[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] -// needs to be pub due to being exposed in error types -pub enum FrameType { - /// The server frame type for the challenge response - ServerChallenge = 2, - /// The client frame type for the authentication frame - ClientAuth = 3, - /// The server frame type for authentication confirmation - ServerConfirmsAuth = 4, - /// The server frame type for authentication denial - ServerDeniesAuth = 5, - /// 32B dest pub key + ECN byte + segment size u16 + datagrams contents - SendDatagrams = 10, - /// 32B src pub key + ECN byte + segment size u16 + datagrams contents - RecvDatagrams = 11, - /// Sent from server to client to signal that a previous sender is no longer connected. - /// - /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` - /// to B so B can forget that a reverse path exists on that connection to get back to A - /// - /// 32B pub key of peer that's gone - NodeGone = 14, - /// Frames 9-11 concern meshing, which we have eliminated from our version of the protocol. - /// Messages with these frames will be ignored. - /// 8 byte ping payload, to be echoed back in FrameType::Pong - Ping = 15, - /// 8 byte payload, the contents of ping being replied to - Pong = 16, - /// Sent from server to client to tell the client if their connection is - /// unhealthy somehow. - Health = 17, - - /// Sent from server to client for the server to declare that it's restarting. - /// Payload is two big endian u32 durations in milliseconds: when to reconnect, - /// and how long to try total. +use iroh_base::{NodeId, SignatureError}; +use n0_future::time::{self, Duration}; +use nested_enum_utils::common_fields; +use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; + +use super::common::FrameType; +use crate::KeyCache; + +/// The maximum size of a packet sent over relay. +/// (This only includes the data bytes visible to magicsock, not +/// including its on-wire framing overhead) +pub const MAX_PACKET_SIZE: usize = 64 * 1024; + +/// Maximum size a datagram payload is allowed to be. +/// +/// This is [`MAX_PACKET_SIZE`] minus the length of an encoded public key minus 3 bytes, +/// one for ECN, and two for the segment size. +pub const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - NodeId::LENGTH - 3; + +/// The maximum frame size. +/// +/// This is also the minimum burst size that a rate-limiter has to accept. +#[cfg(not(wasm_browser))] +pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024; + +/// Interval in which we ping the relay server to ensure the connection is alive. +/// +/// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some +/// chance of recovering. +#[cfg(feature = "server")] +pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(15); + +/// The number of packets buffered for sending per client +#[cfg(feature = "server")] +pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512; + +/// Protocol send errors. +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum Error { + #[snafu(display("unexpected frame: got {got}, expected {expected}"))] + UnexpectedFrame { got: FrameType, expected: FrameType }, + #[snafu(display("Frame is too large, has {frame_len} bytes"))] + FrameTooLarge { frame_len: usize }, + #[snafu(transparent)] + Timeout { source: time::Elapsed }, + #[snafu(transparent)] + SerDe { source: postcard::Error }, + #[snafu(display("Invalid public key"))] + InvalidPublicKey { source: SignatureError }, + #[snafu(display("Invalid frame encoding"))] + InvalidFrame {}, + #[snafu(display("Invalid frame type: {frame_type}"))] + InvalidFrameType { frame_type: FrameType }, + #[snafu(display("Invalid protocol message encoding"))] + InvalidProtocolMessageEncoding { source: std::str::Utf8Error }, + #[snafu(display("Too few bytes"))] + TooSmall {}, +} + +/// The messages that a relay sends to clients or the clients receive from the relay. +#[derive(derive_more::Debug, Clone, PartialEq, Eq)] +pub enum RelayToClientMsg { + /// Represents datagrams sent from relays (originally sent to them by another client). + Datagrams { + /// The [`NodeId`] of the original sender. + remote_node_id: NodeId, + /// The datagrams and related metadata. + datagrams: Datagrams, + }, + /// Indicates that the client identified by the underlying public key had previously sent you a + /// packet but has now disconnected from the relay. + NodeGone(NodeId), + /// A one-way message from relay to client, declaring the connection health state. + Health { + /// If set, is a description of why the connection is unhealthy. + /// + /// If `None` means the connection is healthy again. + /// + /// The default condition is healthy, so the relay doesn't broadcast a [`RelayToClientMsg::Health`] + /// until a problem exists. + problem: String, + }, + /// A one-way message from relay to client, advertising that the relay is restarting. + Restarting { + /// An advisory duration that the client should wait before attempting to reconnect. + /// It might be zero. It exists for the relay to smear out the reconnects. + reconnect_in: Duration, + /// An advisory duration for how long the client should attempt to reconnect + /// before giving up and proceeding with its normal connection failure logic. The interval + /// between retries is undefined for now. A relay should not send a `try_for` duration more + /// than a few seconds. + try_for: Duration, + }, + /// Request from the relay to reply to the + /// other side with a [`ClientToRelayMsg::Pong`] with the given payload. + Ping([u8; 8]), + /// Reply to a [`ClientToRelayMsg::Ping`] from a client + /// with the payload sent previously in the ping. + Pong([u8; 8]), +} + +/// Messages that clients send to relays. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ClientToRelayMsg { + /// Request from the client to the server to reply to the + /// other side with a [`RelayToClientMsg::Pong`] with the given payload. + Ping([u8; 8]), + /// Reply to a [`RelayToClientMsg::Ping`] from a server + /// with the payload sent previously in the ping. + Pong([u8; 8]), + /// Request from the client to relay datagrams to given remote node. + Datagrams { + /// The remote node to relay to. + dst_node_id: NodeId, + /// The datagrams and related metadata to relay. + datagrams: Datagrams, + }, +} + +/// One or multiple datagrams being transferred via the relay. +/// +/// This type is modeled after [`quinn_proto::Transmit`] +/// (or even more similarly `quinn_udp::Transmit`, but we don't depend on that library here). +#[derive(derive_more::Debug, Clone, PartialEq, Eq)] +pub struct Datagrams { + /// Explicit congestion notification bits + pub ecn: Option, + /// The segment size if this transmission contains multiple datagrams. + /// This is `None` if the transmit only contains a single datagram + pub segment_size: Option, + /// The contents of the datagram(s) + #[debug(skip)] + pub contents: Bytes, +} + +impl> From for Datagrams { + fn from(bytes: T) -> Self { + Self { + ecn: None, + segment_size: None, + contents: Bytes::copy_from_slice(bytes.as_ref()), + } + } +} + +impl Datagrams { + fn write_to(&self, mut dst: O) -> O { + let ecn = self.ecn.map_or(0, |ecn| ecn as u8); + let segment_size = self.segment_size.unwrap_or_default(); + dst.put_u8(ecn); + dst.put_u16(segment_size); + dst.put(self.contents.as_ref()); + dst + } + + fn from_bytes(bytes: Bytes) -> Result { + // 1 bytes ECN, 2 bytes segment size + snafu::ensure!(bytes.len() > 3, InvalidFrameSnafu); + + let ecn_byte = bytes[0]; + let ecn = quinn_proto::EcnCodepoint::from_bits(ecn_byte); + + let segment_size = u16::from_be_bytes(bytes[1..3].try_into().expect("length checked")); + let segment_size = if segment_size == 0 { + None + } else { + Some(segment_size) + }; + + let contents = bytes.slice(3..); + + Ok(Self { + ecn, + segment_size, + contents, + }) + } +} + +impl RelayToClientMsg { + /// Returns this frame's corresponding frame type. + pub fn typ(&self) -> FrameType { + match self { + Self::Datagrams { .. } => FrameType::RelayToClientDatagrams, + Self::NodeGone { .. } => FrameType::NodeGone, + Self::Ping { .. } => FrameType::Ping, + Self::Pong { .. } => FrameType::Pong, + Self::Health { .. } => FrameType::Health, + Self::Restarting { .. } => FrameType::Restarting, + } + } + + /// Encodes this frame for sending over websockets. /// - /// Handled on the `[relay::Client]`, but currently never sent on the `[relay::Server]` - Restarting = 18, - /// The frame type was unknown. + /// Specifically meant for being put into a binary websocket message frame. + #[cfg(feature = "server")] + pub(crate) fn write_to(&self, mut dst: O) -> O { + dst = self.typ().write_to(dst); + match self { + Self::Datagrams { + remote_node_id, + datagrams, + } => { + dst.put(remote_node_id.as_ref()); + dst = datagrams.write_to(dst); + } + Self::NodeGone(node_id) => { + dst.put(node_id.as_ref()); + } + Self::Ping(data) => { + dst.put(&data[..]); + } + Self::Pong(data) => { + dst.put(&data[..]); + } + Self::Health { problem } => { + dst.put(problem.as_ref()); + } + Self::Restarting { + reconnect_in, + try_for, + } => { + dst.put_u32(reconnect_in.as_millis() as u32); + dst.put_u32(try_for.as_millis() as u32); + } + } + dst + } + + /// Tries to decode a frame received over websockets. /// - /// This frame is the result of parsing any future frame types that this implementation - /// does not yet understand. - #[num_enum(default)] - Unknown, -} + /// Specifically, bytes received from a binary websocket message frame. + #[allow(clippy::result_large_err)] + pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { + let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; + let frame_len = content.len(); + snafu::ensure!( + frame_len <= MAX_PACKET_SIZE, + FrameTooLargeSnafu { frame_len } + ); + + let res = match frame_type { + FrameType::RelayToClientDatagrams => { + snafu::ensure!(content.len() >= NodeId::LENGTH, InvalidFrameSnafu); -impl std::fmt::Display for FrameType { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{self:?}") + let remote_node_id = cache + .key_from_slice(&content[..NodeId::LENGTH]) + .context(InvalidPublicKeySnafu)?; + let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; + Self::Datagrams { + remote_node_id, + datagrams, + } + } + FrameType::NodeGone => { + snafu::ensure!(content.len() == NodeId::LENGTH, InvalidFrameSnafu); + let node_id = cache + .key_from_slice(content.as_ref()) + .context(InvalidPublicKeySnafu)?; + Self::NodeGone(node_id) + } + FrameType::Ping => { + snafu::ensure!(content.len() == 8, InvalidFrameSnafu); + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Ping(data) + } + FrameType::Pong => { + snafu::ensure!(content.len() == 8, InvalidFrameSnafu); + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Pong(data) + } + FrameType::Health => { + let problem = std::str::from_utf8(&content) + .context(InvalidProtocolMessageEncodingSnafu)? + .to_owned(); + Self::Health { problem } + } + FrameType::Restarting => { + snafu::ensure!(content.len() == 4 + 4, InvalidFrameSnafu); + let reconnect_in = u32::from_be_bytes( + content[..4] + .try_into() + .map_err(|_| InvalidFrameSnafu.build())?, + ); + let try_for = u32::from_be_bytes( + content[4..] + .try_into() + .map_err(|_| InvalidFrameSnafu.build())?, + ); + let reconnect_in = Duration::from_millis(reconnect_in as u64); + let try_for = Duration::from_millis(try_for as u64); + Self::Restarting { + reconnect_in, + try_for, + } + } + _ => { + return Err(InvalidFrameTypeSnafu { frame_type }.build()); + } + }; + Ok(res) } } -impl FrameType { - /// Writes the frame type to the buffer (as a QUIC-encoded varint). +impl ClientToRelayMsg { + pub(crate) fn typ(&self) -> FrameType { + match self { + Self::Datagrams { .. } => FrameType::ClientToRelayDatagrams, + Self::Ping { .. } => FrameType::Ping, + Self::Pong { .. } => FrameType::Pong, + } + } + + /// Encodes this frame for sending over websockets. + /// + /// Specifically meant for being put into a binary websocket message frame. pub(crate) fn write_to(&self, mut dst: O) -> O { - VarInt::from(*self).encode(&mut dst); + dst = self.typ().write_to(dst); + match self { + Self::Datagrams { + dst_node_id, + datagrams, + } => { + dst.put(dst_node_id.as_ref()); + dst = datagrams.write_to(dst); + } + Self::Ping(data) => { + dst.put(&data[..]); + } + Self::Pong(data) => { + dst.put(&data[..]); + } + } dst } - /// Parses the frame type (as a QUIC-encoded varint) from the first couple of bytes given - /// and returns the frame type and the rest. - pub(crate) fn from_bytes(bytes: Bytes) -> Option<(Self, Bytes)> { - let mut cursor = std::io::Cursor::new(&bytes); - let tag = VarInt::decode(&mut cursor).ok()?; - let tag_u32 = u32::try_from(u64::from(tag)).ok()?; - let frame_type = FrameType::from(tag_u32); - let content = bytes.slice(cursor.position() as usize..); - Some((frame_type, content)) + /// Tries to decode a frame received over websockets. + /// + /// Specifically, bytes received from a binary websocket message frame. + #[allow(clippy::result_large_err)] + #[cfg(feature = "server")] + pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { + let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; + let frame_len = content.len(); + snafu::ensure!( + frame_len <= MAX_PACKET_SIZE, + FrameTooLargeSnafu { frame_len } + ); + + let res = match frame_type { + FrameType::ClientToRelayDatagrams => { + let dst_node_id = cache + .key_from_slice(&content[..NodeId::LENGTH]) + .context(InvalidPublicKeySnafu)?; + let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; + Self::Datagrams { + dst_node_id, + datagrams, + } + } + FrameType::Ping => { + snafu::ensure!(content.len() == 8, InvalidFrameSnafu); + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Ping(data) + } + FrameType::Pong => { + snafu::ensure!(content.len() == 8, InvalidFrameSnafu); + let mut data = [0u8; 8]; + data.copy_from_slice(&content[..8]); + Self::Pong(data) + } + _ => { + return Err(InvalidFrameTypeSnafu { frame_type }.build()); + } + }; + Ok(res) } } -impl From for VarInt { - fn from(value: FrameType) -> Self { - (value as u32).into() +#[cfg(test)] +#[cfg(feature = "server")] +mod tests { + use data_encoding::HEXLOWER; + use iroh_base::SecretKey; + use n0_snafu::Result; + + use super::*; + + fn check_expected_bytes(frames: Vec<(Vec, &str)>) { + for (bytes, expected_hex) in frames { + let stripped: Vec = expected_hex + .chars() + .filter_map(|s| { + if s.is_ascii_whitespace() { + None + } else { + Some(s as u8) + } + }) + .collect(); + let expected_bytes = HEXLOWER.decode(&stripped).unwrap(); + assert_eq!(HEXLOWER.encode(&bytes), HEXLOWER.encode(&expected_bytes)); + } + } + + #[test] + fn test_server_client_frames_snapshot() -> Result { + let client_key = SecretKey::from_bytes(&[42u8; 32]); + + check_expected_bytes(vec![ + ( + RelayToClientMsg::Health { + problem: "Hello? Yes this is dog.".into(), + } + .write_to(Vec::new()), + "11 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 + 20 69 73 20 64 6f 67 2e", + ), + ( + RelayToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), + "0e 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d + 61", + ), + ( + RelayToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), + "0f 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + RelayToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), + "10 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + RelayToClientMsg::Datagrams { + remote_node_id: client_key.public(), + datagrams: Datagrams { + ecn: Some(quinn::EcnCodepoint::Ce), + segment_size: Some(6), + contents: "Hello World!".into(), + }, + } + .write_to(Vec::new()), + // frame type + // public key first 16 bytes + // public key second 16 bytes + // ECN byte + // segment size + // hello world contents bytes + "0b + 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 + 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 + 03 + 00 06 + 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + ), + ( + RelayToClientMsg::Restarting { + reconnect_in: Duration::from_millis(10), + try_for: Duration::from_millis(20), + } + .write_to(Vec::new()), + "12 00 00 00 0a 00 00 00 14", + ), + ]); + + Ok(()) + } + + #[test] + fn test_client_server_frames_snapshot() -> Result { + let client_key = SecretKey::from_bytes(&[42u8; 32]); + + check_expected_bytes(vec![ + ( + ClientToRelayMsg::Ping([42u8; 8]).write_to(Vec::new()), + "0f 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + ClientToRelayMsg::Pong([42u8; 8]).write_to(Vec::new()), + "10 2a 2a 2a 2a 2a 2a 2a 2a", + ), + ( + ClientToRelayMsg::Datagrams { + dst_node_id: client_key.public(), + datagrams: Datagrams { + ecn: Some(quinn::EcnCodepoint::Ce), + segment_size: Some(6), + contents: "Hello World!".into(), + }, + } + .write_to(Vec::new()), + // frame type + // public key first 16 bytes + // public key second 16 bytes + // ECN byte + // segment size + // hello world contents + "0a + 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 + 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 + 03 + 00 06 + 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + ), + ]); + + Ok(()) + } +} + +#[cfg(test)] +#[cfg(feature = "server")] +mod proptests { + use bytes::BytesMut; + use iroh_base::SecretKey; + use proptest::prelude::*; + + use super::*; + + fn secret_key() -> impl Strategy { + prop::array::uniform32(any::()).prop_map(SecretKey::from) + } + + fn key() -> impl Strategy { + secret_key().prop_map(|key| key.public()) + } + + fn ecn() -> impl Strategy> { + (0..=3).prop_map(|n| match n { + 1 => Some(quinn_proto::EcnCodepoint::Ce), + 2 => Some(quinn_proto::EcnCodepoint::Ect0), + 3 => Some(quinn_proto::EcnCodepoint::Ect1), + _ => None, + }) + } + + fn datagrams() -> impl Strategy { + ( + ecn(), + prop::option::of(MAX_PAYLOAD_SIZE / 20..MAX_PAYLOAD_SIZE), + prop::collection::vec(any::(), 0..MAX_PAYLOAD_SIZE), + ) + .prop_map(|(ecn, segment_size, data)| Datagrams { + ecn, + segment_size: segment_size.map(|ss| std::cmp::min(data.len(), ss) as u16), + contents: Bytes::from(data), + }) + } + + /// Generates a random valid frame + fn server_client_frame() -> impl Strategy { + let recv_packet = (key(), datagrams()).prop_map(|(remote_node_id, datagrams)| { + RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } + }); + let node_gone = key().prop_map(RelayToClientMsg::NodeGone); + let ping = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Ping); + let pong = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Pong); + let health = ".{0,65536}" + .prop_filter("exceeds MAX_PAYLOAD_SIZE", |s| { + s.len() < MAX_PAYLOAD_SIZE // a single unicode character can match a regex "." but take up multiple bytes + }) + .prop_map(|problem| RelayToClientMsg::Health { problem }); + let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { + RelayToClientMsg::Restarting { + reconnect_in: Duration::from_millis(reconnect_in.into()), + try_for: Duration::from_millis(try_for.into()), + } + }); + prop_oneof![recv_packet, node_gone, ping, pong, health, restarting] + } + + fn client_server_frame() -> impl Strategy { + let send_packet = + (key(), datagrams()).prop_map(|(dst_node_id, datagrams)| ClientToRelayMsg::Datagrams { + dst_node_id, + datagrams, + }); + let ping = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Ping); + let pong = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Pong); + prop_oneof![send_packet, ping, pong] + } + + proptest! { + #[test] + fn server_client_frame_roundtrip(frame in server_client_frame()) { + let encoded = frame.clone().write_to(BytesMut::new()).freeze(); + let decoded = RelayToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); + prop_assert_eq!(frame, decoded); + } + + #[test] + fn client_server_frame_roundtrip(frame in client_server_frame()) { + let encoded = frame.clone().write_to(BytesMut::new()).freeze(); + let decoded = ClientToRelayMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); + prop_assert_eq!(frame, decoded); + } } } diff --git a/iroh-relay/src/protos/send_recv.rs b/iroh-relay/src/protos/send_recv.rs deleted file mode 100644 index fe7357b4b52..00000000000 --- a/iroh-relay/src/protos/send_recv.rs +++ /dev/null @@ -1,626 +0,0 @@ -//! This module implements the relaying protocol used by the `server` and `client`. -//! -//! Protocol flow: -//! -//! Login: -//! * client connects -//! * -> client sends `FrameType::ClientInfo` -//! -//! Steady state: -//! * server occasionally sends `FrameType::KeepAlive` (or `FrameType::Ping`) -//! * client responds to any `FrameType::Ping` with a `FrameType::Pong` -//! * clients sends `FrameType::SendPacket` -//! * server then sends `FrameType::RecvPacket` to recipient - -use bytes::{BufMut, Bytes}; -use iroh_base::{NodeId, SignatureError}; -use n0_future::time::{self, Duration}; -use nested_enum_utils::common_fields; -use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; - -use super::relay::FrameType; -use crate::KeyCache; - -/// The maximum size of a packet sent over relay. -/// (This only includes the data bytes visible to magicsock, not -/// including its on-wire framing overhead) -pub const MAX_PACKET_SIZE: usize = 64 * 1024; - -/// Maximum size a datagram payload is allowed to be. -/// -/// This is [`MAX_PACKET_SIZE`] minus the length of an encoded public key minus 3 bytes, -/// one for ECN, and two for the segment size. -pub const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - NodeId::LENGTH - 3; - -/// The maximum frame size. -/// -/// This is also the minimum burst size that a rate-limiter has to accept. -#[cfg(not(wasm_browser))] -pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024; - -/// Interval in which we ping the relay server to ensure the connection is alive. -/// -/// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some -/// chance of recovering. -#[cfg(feature = "server")] -pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(15); - -/// The number of packets buffered for sending per client -#[cfg(feature = "server")] -pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512; - -/// Protocol send errors. -#[common_fields({ - backtrace: Option, - #[snafu(implicit)] - span_trace: n0_snafu::SpanTrace, -})] -#[allow(missing_docs)] -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum Error { - #[snafu(display("unexpected frame: got {got}, expected {expected}"))] - UnexpectedFrame { got: FrameType, expected: FrameType }, - #[snafu(display("Frame is too large, has {frame_len} bytes"))] - FrameTooLarge { frame_len: usize }, - #[snafu(transparent)] - Timeout { source: time::Elapsed }, - #[snafu(transparent)] - SerDe { source: postcard::Error }, - #[snafu(display("Invalid public key"))] - InvalidPublicKey { source: SignatureError }, - #[snafu(display("Invalid frame encoding"))] - InvalidFrame {}, - #[snafu(display("Invalid frame type: {frame_type}"))] - InvalidFrameType { frame_type: FrameType }, - #[snafu(display("Invalid protocol message encoding"))] - InvalidProtocolMessageEncoding { source: std::str::Utf8Error }, - #[snafu(display("Too few bytes"))] - TooSmall {}, -} - -/// The messages that a relay sends to clients or the clients receive from the relay. -#[derive(derive_more::Debug, Clone, PartialEq, Eq)] -pub enum RelayToClientMsg { - /// Represents datagrams sent from relays (originally sent to them by another client). - Datagrams { - /// The [`NodeId`] of the original sender. - remote_node_id: NodeId, - /// The datagrams and related metadata. - datagrams: Datagrams, - }, - /// Indicates that the client identified by the underlying public key had previously sent you a - /// packet but has now disconnected from the relay. - NodeGone(NodeId), - /// A one-way message from relay to client, declaring the connection health state. - Health { - /// If set, is a description of why the connection is unhealthy. - /// - /// If `None` means the connection is healthy again. - /// - /// The default condition is healthy, so the relay doesn't broadcast a [`RelayToClientMsg::Health`] - /// until a problem exists. - problem: String, - }, - /// A one-way message from relay to client, advertising that the relay is restarting. - Restarting { - /// An advisory duration that the client should wait before attempting to reconnect. - /// It might be zero. It exists for the relay to smear out the reconnects. - reconnect_in: Duration, - /// An advisory duration for how long the client should attempt to reconnect - /// before giving up and proceeding with its normal connection failure logic. The interval - /// between retries is undefined for now. A relay should not send a `try_for` duration more - /// than a few seconds. - try_for: Duration, - }, - /// Request from the relay to reply to the - /// other side with a [`ClientToRelayMsg::Pong`] with the given payload. - Ping([u8; 8]), - /// Reply to a [`ClientToRelayMsg::Ping`] from a client - /// with the payload sent previously in the ping. - Pong([u8; 8]), -} - -/// Messages that clients send to relays. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ClientToRelayMsg { - /// Request from the client to the server to reply to the - /// other side with a [`RelayToClientMsg::Pong`] with the given payload. - Ping([u8; 8]), - /// Reply to a [`RelayToClientMsg::Ping`] from a server - /// with the payload sent previously in the ping. - Pong([u8; 8]), - /// Request from the client to relay datagrams to given remote node. - Datagrams { - /// The remote node to relay to. - dst_node_id: NodeId, - /// The datagrams and related metadata to relay. - datagrams: Datagrams, - }, -} - -/// One or multiple datagrams being transferred via the relay. -/// -/// This type is modeled after [`quinn_proto::Transmit`] -/// (or even more similarly `quinn_udp::Transmit`, but we don't depend on that library here). -#[derive(derive_more::Debug, Clone, PartialEq, Eq)] -pub struct Datagrams { - /// Explicit congestion notification bits - pub ecn: Option, - /// The segment size if this transmission contains multiple datagrams. - /// This is `None` if the transmit only contains a single datagram - pub segment_size: Option, - /// The contents of the datagram(s) - #[debug(skip)] - pub contents: Bytes, -} - -impl> From for Datagrams { - fn from(bytes: T) -> Self { - Self { - ecn: None, - segment_size: None, - contents: Bytes::copy_from_slice(bytes.as_ref()), - } - } -} - -impl Datagrams { - fn write_to(&self, mut dst: O) -> O { - let ecn = self.ecn.map_or(0, |ecn| ecn as u8); - let segment_size = self.segment_size.unwrap_or_default(); - dst.put_u8(ecn); - dst.put_u16(segment_size); - dst.put(self.contents.as_ref()); - dst - } - - fn from_bytes(bytes: Bytes) -> Result { - // 1 bytes ECN, 2 bytes segment size - snafu::ensure!(bytes.len() > 3, InvalidFrameSnafu); - - let ecn_byte = bytes[0]; - let ecn = quinn_proto::EcnCodepoint::from_bits(ecn_byte); - - let segment_size = u16::from_be_bytes(bytes[1..3].try_into().expect("length checked")); - let segment_size = if segment_size == 0 { - None - } else { - Some(segment_size) - }; - - let contents = bytes.slice(3..); - - Ok(Self { - ecn, - segment_size, - contents, - }) - } -} - -impl RelayToClientMsg { - /// Returns this frame's corresponding frame type. - pub fn typ(&self) -> FrameType { - match self { - Self::Datagrams { .. } => FrameType::RecvDatagrams, - Self::NodeGone { .. } => FrameType::NodeGone, - Self::Ping { .. } => FrameType::Ping, - Self::Pong { .. } => FrameType::Pong, - Self::Health { .. } => FrameType::Health, - Self::Restarting { .. } => FrameType::Restarting, - } - } - - /// Encodes this frame for sending over websockets. - /// - /// Specifically meant for being put into a binary websocket message frame. - #[cfg(feature = "server")] - pub(crate) fn write_to(&self, mut dst: O) -> O { - dst = self.typ().write_to(dst); - match self { - Self::Datagrams { - remote_node_id, - datagrams, - } => { - dst.put(remote_node_id.as_ref()); - dst = datagrams.write_to(dst); - } - Self::NodeGone(node_id) => { - dst.put(node_id.as_ref()); - } - Self::Ping(data) => { - dst.put(&data[..]); - } - Self::Pong(data) => { - dst.put(&data[..]); - } - Self::Health { problem } => { - dst.put(problem.as_ref()); - } - Self::Restarting { - reconnect_in, - try_for, - } => { - dst.put_u32(reconnect_in.as_millis() as u32); - dst.put_u32(try_for.as_millis() as u32); - } - } - dst - } - - /// Tries to decode a frame received over websockets. - /// - /// Specifically, bytes received from a binary websocket message frame. - #[allow(clippy::result_large_err)] - pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { - let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; - let frame_len = content.len(); - snafu::ensure!( - frame_len <= MAX_PACKET_SIZE, - FrameTooLargeSnafu { frame_len } - ); - - let res = match frame_type { - FrameType::RecvDatagrams => { - snafu::ensure!(content.len() >= NodeId::LENGTH, InvalidFrameSnafu); - - let remote_node_id = cache - .key_from_slice(&content[..NodeId::LENGTH]) - .context(InvalidPublicKeySnafu)?; - let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; - Self::Datagrams { - remote_node_id, - datagrams, - } - } - FrameType::NodeGone => { - snafu::ensure!(content.len() == NodeId::LENGTH, InvalidFrameSnafu); - let node_id = cache - .key_from_slice(content.as_ref()) - .context(InvalidPublicKeySnafu)?; - Self::NodeGone(node_id) - } - FrameType::Ping => { - snafu::ensure!(content.len() == 8, InvalidFrameSnafu); - let mut data = [0u8; 8]; - data.copy_from_slice(&content[..8]); - Self::Ping(data) - } - FrameType::Pong => { - snafu::ensure!(content.len() == 8, InvalidFrameSnafu); - let mut data = [0u8; 8]; - data.copy_from_slice(&content[..8]); - Self::Pong(data) - } - FrameType::Health => { - let problem = std::str::from_utf8(&content) - .context(InvalidProtocolMessageEncodingSnafu)? - .to_owned(); - Self::Health { problem } - } - FrameType::Restarting => { - snafu::ensure!(content.len() == 4 + 4, InvalidFrameSnafu); - let reconnect_in = u32::from_be_bytes( - content[..4] - .try_into() - .map_err(|_| InvalidFrameSnafu.build())?, - ); - let try_for = u32::from_be_bytes( - content[4..] - .try_into() - .map_err(|_| InvalidFrameSnafu.build())?, - ); - let reconnect_in = Duration::from_millis(reconnect_in as u64); - let try_for = Duration::from_millis(try_for as u64); - Self::Restarting { - reconnect_in, - try_for, - } - } - _ => { - return Err(InvalidFrameTypeSnafu { frame_type }.build()); - } - }; - Ok(res) - } -} - -impl ClientToRelayMsg { - pub(crate) fn typ(&self) -> FrameType { - match self { - Self::Datagrams { .. } => FrameType::SendDatagrams, - Self::Ping { .. } => FrameType::Ping, - Self::Pong { .. } => FrameType::Pong, - } - } - - /// Encodes this frame for sending over websockets. - /// - /// Specifically meant for being put into a binary websocket message frame. - pub(crate) fn write_to(&self, mut dst: O) -> O { - dst = self.typ().write_to(dst); - match self { - Self::Datagrams { - dst_node_id, - datagrams, - } => { - dst.put(dst_node_id.as_ref()); - dst = datagrams.write_to(dst); - } - Self::Ping(data) => { - dst.put(&data[..]); - } - Self::Pong(data) => { - dst.put(&data[..]); - } - } - dst - } - - /// Tries to decode a frame received over websockets. - /// - /// Specifically, bytes received from a binary websocket message frame. - #[allow(clippy::result_large_err)] - #[cfg(feature = "server")] - pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { - let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; - let frame_len = content.len(); - snafu::ensure!( - frame_len <= MAX_PACKET_SIZE, - FrameTooLargeSnafu { frame_len } - ); - - let res = match frame_type { - FrameType::SendDatagrams => { - let dst_node_id = cache - .key_from_slice(&content[..NodeId::LENGTH]) - .context(InvalidPublicKeySnafu)?; - let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; - Self::Datagrams { - dst_node_id, - datagrams, - } - } - FrameType::Ping => { - snafu::ensure!(content.len() == 8, InvalidFrameSnafu); - let mut data = [0u8; 8]; - data.copy_from_slice(&content[..8]); - Self::Ping(data) - } - FrameType::Pong => { - snafu::ensure!(content.len() == 8, InvalidFrameSnafu); - let mut data = [0u8; 8]; - data.copy_from_slice(&content[..8]); - Self::Pong(data) - } - _ => { - return Err(InvalidFrameTypeSnafu { frame_type }.build()); - } - }; - Ok(res) - } -} - -#[cfg(test)] -#[cfg(feature = "server")] -mod tests { - use data_encoding::HEXLOWER; - use iroh_base::SecretKey; - use n0_snafu::Result; - - use super::*; - - fn check_expected_bytes(frames: Vec<(Vec, &str)>) { - for (bytes, expected_hex) in frames { - let stripped: Vec = expected_hex - .chars() - .filter_map(|s| { - if s.is_ascii_whitespace() { - None - } else { - Some(s as u8) - } - }) - .collect(); - let expected_bytes = HEXLOWER.decode(&stripped).unwrap(); - assert_eq!(HEXLOWER.encode(&bytes), HEXLOWER.encode(&expected_bytes)); - } - } - - #[test] - fn test_server_client_frames_snapshot() -> Result { - let client_key = SecretKey::from_bytes(&[42u8; 32]); - - check_expected_bytes(vec![ - ( - RelayToClientMsg::Health { - problem: "Hello? Yes this is dog.".into(), - } - .write_to(Vec::new()), - "11 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 - 20 69 73 20 64 6f 67 2e", - ), - ( - RelayToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), - "0e 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e - a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d - 61", - ), - ( - RelayToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), - "0f 2a 2a 2a 2a 2a 2a 2a 2a", - ), - ( - RelayToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), - "10 2a 2a 2a 2a 2a 2a 2a 2a", - ), - ( - RelayToClientMsg::Datagrams { - remote_node_id: client_key.public(), - datagrams: Datagrams { - ecn: Some(quinn::EcnCodepoint::Ce), - segment_size: Some(6), - contents: "Hello World!".into(), - }, - } - .write_to(Vec::new()), - // frame type - // public key first 16 bytes - // public key second 16 bytes - // ECN byte - // segment size - // hello world contents bytes - "0b - 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 - 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 - 03 - 00 06 - 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", - ), - ( - RelayToClientMsg::Restarting { - reconnect_in: Duration::from_millis(10), - try_for: Duration::from_millis(20), - } - .write_to(Vec::new()), - "12 00 00 00 0a 00 00 00 14", - ), - ]); - - Ok(()) - } - - #[test] - fn test_client_server_frames_snapshot() -> Result { - let client_key = SecretKey::from_bytes(&[42u8; 32]); - - check_expected_bytes(vec![ - ( - ClientToRelayMsg::Ping([42u8; 8]).write_to(Vec::new()), - "0f 2a 2a 2a 2a 2a 2a 2a 2a", - ), - ( - ClientToRelayMsg::Pong([42u8; 8]).write_to(Vec::new()), - "10 2a 2a 2a 2a 2a 2a 2a 2a", - ), - ( - ClientToRelayMsg::Datagrams { - dst_node_id: client_key.public(), - datagrams: Datagrams { - ecn: Some(quinn::EcnCodepoint::Ce), - segment_size: Some(6), - contents: "Hello World!".into(), - }, - } - .write_to(Vec::new()), - // frame type - // public key first 16 bytes - // public key second 16 bytes - // ECN byte - // segment size - // hello world contents - "0a - 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 - 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 - 03 - 00 06 - 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", - ), - ]); - - Ok(()) - } -} - -#[cfg(test)] -#[cfg(feature = "server")] -mod proptests { - use bytes::BytesMut; - use iroh_base::SecretKey; - use proptest::prelude::*; - - use super::*; - - fn secret_key() -> impl Strategy { - prop::array::uniform32(any::()).prop_map(SecretKey::from) - } - - fn key() -> impl Strategy { - secret_key().prop_map(|key| key.public()) - } - - fn ecn() -> impl Strategy> { - (0..=3).prop_map(|n| match n { - 1 => Some(quinn_proto::EcnCodepoint::Ce), - 2 => Some(quinn_proto::EcnCodepoint::Ect0), - 3 => Some(quinn_proto::EcnCodepoint::Ect1), - _ => None, - }) - } - - fn datagrams() -> impl Strategy { - ( - ecn(), - prop::option::of(MAX_PAYLOAD_SIZE / 20..MAX_PAYLOAD_SIZE), - prop::collection::vec(any::(), 0..MAX_PAYLOAD_SIZE), - ) - .prop_map(|(ecn, segment_size, data)| Datagrams { - ecn, - segment_size: segment_size.map(|ss| std::cmp::min(data.len(), ss) as u16), - contents: Bytes::from(data), - }) - } - - /// Generates a random valid frame - fn server_client_frame() -> impl Strategy { - let recv_packet = (key(), datagrams()).prop_map(|(remote_node_id, datagrams)| { - RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } - }); - let node_gone = key().prop_map(RelayToClientMsg::NodeGone); - let ping = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Ping); - let pong = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Pong); - let health = ".{0,65536}" - .prop_filter("exceeds MAX_PAYLOAD_SIZE", |s| { - s.len() < MAX_PAYLOAD_SIZE // a single unicode character can match a regex "." but take up multiple bytes - }) - .prop_map(|problem| RelayToClientMsg::Health { problem }); - let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { - RelayToClientMsg::Restarting { - reconnect_in: Duration::from_millis(reconnect_in.into()), - try_for: Duration::from_millis(try_for.into()), - } - }); - prop_oneof![recv_packet, node_gone, ping, pong, health, restarting] - } - - fn client_server_frame() -> impl Strategy { - let send_packet = - (key(), datagrams()).prop_map(|(dst_node_id, datagrams)| ClientToRelayMsg::Datagrams { - dst_node_id, - datagrams, - }); - let ping = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Ping); - let pong = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Pong); - prop_oneof![send_packet, ping, pong] - } - - proptest! { - #[test] - fn server_client_frame_roundtrip(frame in server_client_frame()) { - let encoded = frame.clone().write_to(BytesMut::new()).freeze(); - let decoded = RelayToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); - prop_assert_eq!(frame, decoded); - } - - #[test] - fn client_server_frame_roundtrip(frame in client_server_frame()) { - let encoded = frame.clone().write_to(BytesMut::new()).freeze(); - let decoded = ClientToRelayMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); - prop_assert_eq!(frame, decoded); - } - } -} diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index 94760cc7054..94a3690ed55 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -766,7 +766,7 @@ mod tests { dns::DnsResolver, protos::{ handshake, - send_recv::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, + relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, }, }; diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 7b44e2d85bc..2875817dbab 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -18,7 +18,7 @@ use tracing::{debug, trace, warn, Instrument}; use crate::{ protos::{ disco, - send_recv::{ClientToRelayMsg, Datagrams, RelayToClientMsg, PING_INTERVAL}, + relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg, PING_INTERVAL}, streams::StreamError, }, server::{ @@ -573,7 +573,7 @@ mod tests { use tracing_test::traced_test; use super::*; - use crate::{client::conn::Conn, protos::relay::FrameType}; + use crate::{client::conn::Conn, protos::common::FrameType}; async fn recv_frame< E: snafu::Error + Sync + Send + 'static, @@ -641,7 +641,9 @@ mod tests { data: Datagrams::from(&data[..]), }; send_queue_s.send(packet.clone()).await.context("send")?; - let frame = recv_frame(FrameType::RecvDatagrams, &mut io_rw).await.e()?; + let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut io_rw) + .await + .e()?; assert_eq!( frame, RelayToClientMsg::Datagrams { @@ -656,7 +658,9 @@ mod tests { .send(packet.clone()) .await .context("send")?; - let frame = recv_frame(FrameType::RecvDatagrams, &mut io_rw).await.e()?; + let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut io_rw) + .await + .e()?; assert_eq!( frame, RelayToClientMsg::Datagrams { diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index e3555e28b05..b89d73b82ff 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -16,7 +16,7 @@ use tracing::{debug, trace}; use super::client::{Client, Config, ForwardPacketError}; use crate::{ - protos::send_recv::Datagrams, + protos::relay::Datagrams, server::{ client::{PacketScope, SendError}, metrics::Metrics, @@ -201,7 +201,7 @@ mod tests { use super::*; use crate::{ client::conn::Conn, - protos::{relay::FrameType, send_recv::RelayToClientMsg}, + protos::{common::FrameType, relay::RelayToClientMsg}, server::streams::RelayedStream, }; @@ -255,7 +255,7 @@ mod tests { // send packet let data = b"hello world!"; clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RecvDatagrams, &mut a_rw).await?; + let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut a_rw).await?; assert_eq!( frame, RelayToClientMsg::Datagrams { @@ -266,7 +266,7 @@ mod tests { // send disco packet clients.send_disco_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RecvDatagrams, &mut a_rw).await?; + let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut a_rw).await?; assert_eq!( frame, RelayToClientMsg::Datagrams { diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 69e3ab34e63..ba4a90bb0ff 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -25,7 +25,7 @@ use super::{clients::Clients, streams::InvalidBucketConfig, AccessConfig, SpawnE use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, http::{RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, - protos::send_recv::PER_CLIENT_SEND_QUEUE_DEPTH, + protos::relay::PER_CLIENT_SEND_QUEUE_DEPTH, server::{ client::Config, metrics::Metrics, @@ -36,7 +36,7 @@ use crate::{ }; use crate::{ http::{CLIENT_AUTH_HEADER, WEBSOCKET_UPGRADE_PROTOCOL}, - protos::{handshake, send_recv::MAX_FRAME_SIZE, streams::WsBytesFramed}, + protos::{handshake, relay::MAX_FRAME_SIZE, streams::WsBytesFramed}, server::streams::RateLimited, }; @@ -837,7 +837,7 @@ mod tests { use crate::{ client::{conn::Conn, Client, ClientBuilder, ConnectError}, dns::DnsResolver, - protos::send_recv::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, + protos::relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, }; pub(crate) fn make_tls_config() -> TlsConfig { diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index e6dbdcf9318..dd631a6890e 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -15,7 +15,7 @@ use tracing::instrument; use super::{ClientRateLimit, Metrics}; use crate::{ protos::{ - send_recv::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg}, + relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg}, streams::{StreamError, WsBytesFramed}, }, ExportKeyingMaterial, KeyCache, @@ -71,7 +71,7 @@ impl RelayedStream { fn limits() -> tokio_websockets::Limits { tokio_websockets::Limits::default() - .max_payload_len(Some(crate::protos::send_recv::MAX_FRAME_SIZE)) + .max_payload_len(Some(crate::protos::relay::MAX_FRAME_SIZE)) } } diff --git a/iroh/src/magicsock/transports/relay.rs b/iroh/src/magicsock/transports/relay.rs index 86974247548..e35aab2cedc 100644 --- a/iroh/src/magicsock/transports/relay.rs +++ b/iroh/src/magicsock/transports/relay.rs @@ -5,7 +5,7 @@ use std::{ use bytes::Bytes; use iroh_base::{NodeId, RelayUrl}; -use iroh_relay::protos::send_recv::Datagrams; +use iroh_relay::protos::relay::Datagrams; use n0_future::{ ready, task::{self, AbortOnDropHandle}, diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index b51bface9ab..e29e4f67112 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -42,7 +42,7 @@ use iroh_base::{NodeId, RelayUrl, SecretKey}; use iroh_relay::{ self as relay, client::{Client, ConnectError, RecvError, SendError}, - protos::send_recv::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, + protos::relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, PingTracker, }; use n0_future::{ @@ -1234,7 +1234,7 @@ mod tests { }; use iroh_base::{NodeId, RelayUrl, SecretKey}; - use iroh_relay::{protos::send_recv::Datagrams, PingTracker}; + use iroh_relay::{protos::relay::Datagrams, PingTracker}; use n0_snafu::{Error, Result, ResultExt}; use tokio::sync::{mpsc, oneshot}; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; From 57d89377c1cf957c79329cd051080cf03e2b8ed3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 17:29:22 +0200 Subject: [PATCH 45/80] Fix docs --- iroh-relay/src/protos/common.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iroh-relay/src/protos/common.rs b/iroh-relay/src/protos/common.rs index 4ced6113473..aae464a8c88 100644 --- a/iroh-relay/src/protos/common.rs +++ b/iroh-relay/src/protos/common.rs @@ -1,4 +1,4 @@ -//! Common types between the [`super::handshake`] and [`super::send_recv`] protocols. +//! Common types between the [`super::handshake`] and [`super::relay`] protocols. //! //! Hosts the [`FrameType`] enum to make sure we're not accidentally reusing frame type //! integers for different frames. From 58d7026d0ad0cf205be051efc659520b56bcb3f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 17:44:23 +0200 Subject: [PATCH 46/80] Cleanup packet iter logic --- iroh/src/magicsock/transports/relay/actor.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index e29e4f67112..70e03ed079b 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -625,13 +625,13 @@ impl ActiveRelayActor { // TODO(frando): can we avoid the clone here? let metrics = self.metrics.clone(); let packet_iter = batch.into_iter().map(|item| { - Ok(ClientToRelayMsg::Datagrams { dst_node_id: item.remote_node, datagrams: item.datagrams }) - }); - let mut packet_stream = n0_future::stream::iter(packet_iter).inspect(|m| { - if let Ok(ClientToRelayMsg::Datagrams { dst_node_id: _node_id, datagrams }) = m { - metrics.send_relay.inc_by(datagrams.contents.len() as _); - } + metrics.send_relay.inc_by(item.datagrams.contents.len() as _); + Ok(ClientToRelayMsg::Datagrams { + dst_node_id: item.remote_node, + datagrams: item.datagrams + }) }); + let mut packet_stream = n0_future::stream::iter(packet_iter); let fut = client_sink.send_all(&mut packet_stream); self.run_sending(fut, &mut state, &mut client_stream).await?; } From f2ec6784887bd08078021c5dd6e2dfa6f980bd8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 17:47:15 +0200 Subject: [PATCH 47/80] Small cleanup --- iroh/src/net_report/reportgen.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iroh/src/net_report/reportgen.rs b/iroh/src/net_report/reportgen.rs index 9260ab60666..06e46f1e381 100644 --- a/iroh/src/net_report/reportgen.rs +++ b/iroh/src/net_report/reportgen.rs @@ -379,7 +379,7 @@ impl Actor { let res = match res { Some(Ok(Ok(report))) => Ok(report), Some(Ok(Err(err))) => { - warn!("probe failed: {:#?}", err); + warn!("probe failed: {:#}", err); Err(probes_error::ProbeFailureSnafu {}.into_error(err)) } Some(Err(time::Elapsed { .. })) => { From 48b3fe3d02eae7d21bf096c00d92f7043838f4f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 4 Jul 2025 17:48:33 +0200 Subject: [PATCH 48/80] cargo make format --- iroh-relay/src/client/conn.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index a48739080af..e105930e1cb 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -20,9 +20,7 @@ use crate::client::streams::{MaybeTlsStream, ProxyStream}; use crate::{ protos::{ handshake, - relay::{ - ClientToRelayMsg, Error as RecvRelayError, RelayToClientMsg, MAX_PAYLOAD_SIZE, - }, + relay::{ClientToRelayMsg, Error as RecvRelayError, RelayToClientMsg, MAX_PAYLOAD_SIZE}, streams::WsBytesFramed, }, MAX_PACKET_SIZE, From a3378263b3b9791d0c4cd6147293df9185d84dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Tue, 8 Jul 2025 22:19:17 +0200 Subject: [PATCH 49/80] Use a perf-improved branch of `tokio-websockets` And configure out tokio-websocket's flushing behavior (we do our own). --- Cargo.lock | 5 ++--- Cargo.toml | 4 ++++ iroh-relay/src/client.rs | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2172e048871..5d63d9c76e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4988,8 +4988,7 @@ dependencies = [ [[package]] name = "tokio-websockets" version = "0.11.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fcaf159b4e7a376b05b5bfd77bfd38f3324f5fce751b4213bfc7eaa47affb4e" +source = "git+https://github.com/Gelbpunkt/tokio-websockets.git?branch=frame-queue-buf#0d00440907a33e52ab1fa81469f23c007025605c" dependencies = [ "base64", "bytes", @@ -5584,7 +5583,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 593f1d0ec73..294e7002c35 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,3 +40,7 @@ unexpected_cfgs = { level = "warn", check-cfg = ["cfg(iroh_docsrs)", "cfg(iroh_l [workspace.lints.clippy] unused-async = "warn" + +[patch.crates-io] +# TODO(matheus23): need to wait for a release +tokio-websockets = { git = "https://github.com/Gelbpunkt/tokio-websockets.git", branch = "frame-queue-buf" } diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index 86b9502393a..a912e5b2708 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -249,7 +249,8 @@ impl ClientBuilder { } .build() })? - .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))); + .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))) + .config(tokio_websockets::Config::default().flush_threshold(usize::MAX)); if let Some(client_auth) = KeyMaterialClientAuth::new(&self.secret_key, &stream) { debug!("Using TLS key export for relay client authentication"); builder = builder From e1cf2ba8fb7abf9ec69c73c9e87517e65ffe4704 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 9 Jul 2025 12:09:59 +0200 Subject: [PATCH 50/80] Vector your writes like noone's watching --- iroh-relay/src/client/streams.rs | 28 +++++++++++++--------------- iroh-relay/src/server/streams.rs | 21 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/iroh-relay/src/client/streams.rs b/iroh-relay/src/client/streams.rs index d1565f69b06..b08c6917e25 100644 --- a/iroh-relay/src/client/streams.rs +++ b/iroh-relay/src/client/streams.rs @@ -64,6 +64,7 @@ impl AsyncWrite for ProxyStream { Self::Proxied(stream) => Pin::new(stream.get_mut().1).poll_shutdown(cx), } } + fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -74,6 +75,13 @@ impl AsyncWrite for ProxyStream { Self::Proxied(stream) => Pin::new(stream.get_mut().1).poll_write_vectored(cx, bufs), } } + + fn is_write_vectored(&self) -> bool { + match self { + ProxyStream::Raw(stream) => stream.is_write_vectored(), + ProxyStream::Proxied(stream) => stream.get_ref().1.is_write_vectored(), + } + } } impl ProxyStream { @@ -170,6 +178,7 @@ impl AsyncWrite for MaybeTlsStream { Self::Test(stream) => Pin::new(stream).poll_shutdown(cx), } } + fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -180,26 +189,15 @@ impl AsyncWrite for MaybeTlsStream { Self::Tls(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), #[cfg(test)] Self::Test(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), - } } } -impl MaybeTlsStream { - pub fn local_addr(&self) -> std::io::Result { - match self { - Self::Raw(s) => s.local_addr(), - Self::Tls(s) => s.get_ref().0.local_addr(), - #[cfg(test)] - Self::Test(_) => Ok(SocketAddr::new(std::net::Ipv4Addr::LOCALHOST.into(), 1337)), - } - } - - pub fn peer_addr(&self) -> std::io::Result { + fn is_write_vectored(&self) -> bool { match self { - Self::Raw(s) => s.peer_addr(), - Self::Tls(s) => s.get_ref().0.peer_addr(), + Self::Raw(stream) => stream.is_write_vectored(), + Self::Tls(stream) => stream.is_write_vectored(), #[cfg(test)] - Self::Test(_) => Ok(SocketAddr::new(std::net::Ipv4Addr::LOCALHOST.into(), 1337)), + Self::Test(stream) => stream.is_write_vectored(), } } } diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index dd631a6890e..44515f0832b 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -217,6 +217,15 @@ impl AsyncWrite for MaybeTlsStream { MaybeTlsStream::Test(ref mut s) => Pin::new(s).poll_write_vectored(cx, bufs), } } + + fn is_write_vectored(&self) -> bool { + match self { + MaybeTlsStream::Plain(s) => s.is_write_vectored(), + MaybeTlsStream::Tls(s) => s.is_write_vectored(), + #[cfg(test)] + MaybeTlsStream::Test(s) => s.is_write_vectored(), + } + } } /// Rate limiter for reading from a [`RelayedStream`]. @@ -431,6 +440,18 @@ impl AsyncWrite for RateLimited { ) -> Poll> { Pin::new(&mut self.inner).poll_shutdown(cx) } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } } impl ExportKeyingMaterial for RateLimited { From 3339d907e38bb0534f0015de65726d98795ed4cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 9 Jul 2025 14:38:30 +0200 Subject: [PATCH 51/80] Cleanup --- iroh-relay/src/client/streams.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/iroh-relay/src/client/streams.rs b/iroh-relay/src/client/streams.rs index b08c6917e25..fce08b830bb 100644 --- a/iroh-relay/src/client/streams.rs +++ b/iroh-relay/src/client/streams.rs @@ -88,14 +88,14 @@ impl ProxyStream { pub fn local_addr(&self) -> std::io::Result { match self { Self::Raw(s) => s.local_addr(), - Self::Proxied(s) => s.get_ref().1.local_addr(), + Self::Proxied(s) => s.get_ref().1.as_ref().local_addr(), } } pub fn peer_addr(&self) -> std::io::Result { match self { Self::Raw(s) => s.peer_addr(), - Self::Proxied(s) => s.get_ref().1.peer_addr(), + Self::Proxied(s) => s.get_ref().1.as_ref().peer_addr(), } } } @@ -189,8 +189,8 @@ impl AsyncWrite for MaybeTlsStream { Self::Tls(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), #[cfg(test)] Self::Test(stream) => Pin::new(stream).poll_write_vectored(cx, bufs), + } } -} fn is_write_vectored(&self) -> bool { match self { From 739e8f9626eda812acbe1fe8386b699e2761668c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 9 Jul 2025 15:06:58 +0200 Subject: [PATCH 52/80] Use staging relays for the integration test --- iroh/tests/integration.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/iroh/tests/integration.rs b/iroh/tests/integration.rs index 7534aaa32d3..68a5caa7135 100644 --- a/iroh/tests/integration.rs +++ b/iroh/tests/integration.rs @@ -9,11 +9,9 @@ //! //! In the past we've hit relay rate-limits from all the tests in our CI, but I expect //! we won't hit these with only this integration test. -use std::str::FromStr; - use iroh::{ discovery::{pkarr::PkarrResolver, Discovery}, - Endpoint, RelayMap, RelayMode, RelayUrl, + Endpoint, RelayMode, }; use n0_future::{ task, @@ -39,19 +37,13 @@ async fn simple_node_id_based_connection_transfer() -> Result { std::panic::set_hook(Box::new(console_error_panic_hook::hook)); setup_logging(); - // TODO(matheus23): Replace this with actual production relays eventually - let relay_map = RelayMode::Custom(RelayMap::from_iter([RelayUrl::from_str( - "https://philipp.iroh.link.", - ) - .e()?])); - let client = Endpoint::builder() - .relay_mode(relay_map.clone()) + .relay_mode(RelayMode::Staging) .discovery_n0() .bind() .await?; let server = Endpoint::builder() - .relay_mode(relay_map) + .relay_mode(RelayMode::Staging) .discovery_n0() .alpns(vec![ECHO_ALPN.to_vec()]) .bind() From bd3f43def2a8e0697a0f1981eb43eccdc321cf3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 9 Jul 2025 15:11:27 +0200 Subject: [PATCH 53/80] Fix wasm import --- iroh/tests/integration.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/iroh/tests/integration.rs b/iroh/tests/integration.rs index 68a5caa7135..e48909603ac 100644 --- a/iroh/tests/integration.rs +++ b/iroh/tests/integration.rs @@ -124,6 +124,8 @@ async fn simple_node_id_based_connection_transfer() -> Result { #[cfg(wasm_browser)] fn setup_logging() { + use std::str::FromStr; + tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_str("trace").expect("hardcoded")) .with_max_level(tracing::level_filters::LevelFilter::DEBUG) From 7c9ba5d5a4562b2ba1893dd0c0b50d72e2808140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 9 Jul 2025 15:11:45 +0200 Subject: [PATCH 54/80] Use debug level --- iroh/tests/integration.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iroh/tests/integration.rs b/iroh/tests/integration.rs index e48909603ac..17ab91f9eb7 100644 --- a/iroh/tests/integration.rs +++ b/iroh/tests/integration.rs @@ -127,7 +127,7 @@ fn setup_logging() { use std::str::FromStr; tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_str("trace").expect("hardcoded")) + .with_env_filter(tracing_subscriber::EnvFilter::from_str("debug").expect("hardcoded")) .with_max_level(tracing::level_filters::LevelFilter::DEBUG) .with_writer( // To avoide trace events in the browser from showing their JS backtrace From 716984a78e45885872bfb6418684d948f03fa11f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Wed, 9 Jul 2025 15:20:07 +0200 Subject: [PATCH 55/80] Update staging URLs --- iroh/src/defaults.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/iroh/src/defaults.rs b/iroh/src/defaults.rs index cb6327f9f46..6d3fa70d62a 100644 --- a/iroh/src/defaults.rs +++ b/iroh/src/defaults.rs @@ -86,11 +86,9 @@ pub mod staging { use super::*; /// Hostname of the default NA relay. - // TODO(ramfox): for `0.91` release, make sure we have canary staging relays - pub const NA_RELAY_HOSTNAME: &str = "use1-1.relay.n0.iroh-canary.iroh.link."; + pub const NA_RELAY_HOSTNAME: &str = "staging-use1-1.relay.iroh.network."; /// Hostname of the default EU relay. - // TODO(ramfox): for `0.91` release, make sure we have canary staging relays - pub const EU_RELAY_HOSTNAME: &str = "euc1-1.relay.n0.iroh-canary.iroh.link."; + pub const EU_RELAY_HOSTNAME: &str = "staging-euw1-1.relay.iroh.network."; /// Get the default [`RelayMap`]. pub fn default_relay_map() -> RelayMap { From 76893685747e67160070c38c1a53b0885f72aac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 10 Jul 2025 10:11:16 +0200 Subject: [PATCH 56/80] Implement server-side ws subprotocol versioning --- iroh-relay/src/client/tls.rs | 3 +- iroh-relay/src/http.rs | 7 +- iroh-relay/src/server/http_server.rs | 232 ++++++++++++++++----------- 3 files changed, 147 insertions(+), 95 deletions(-) diff --git a/iroh-relay/src/client/tls.rs b/iroh-relay/src/client/tls.rs index e70e4373770..c480f8560b0 100644 --- a/iroh-relay/src/client/tls.rs +++ b/iroh-relay/src/client/tls.rs @@ -13,6 +13,7 @@ use hyper::{upgrade::Parts, Request}; use n0_future::{task, time}; use rustls::client::Resumption; use snafu::{OptionExt, ResultExt}; +use tracing::error; use super::{ streams::{MaybeTlsStream, ProxyStream}, @@ -243,7 +244,7 @@ impl MaybeTlsStreamBuilder { .context(ProxyConnectSnafu)?; task::spawn(async move { if let Err(err) = conn.with_upgrades().await { - tracing::error!("Proxy connection failed: {:?}", err); + error!("Proxy connection failed: {:?}", err); } }); diff --git a/iroh-relay/src/http.rs b/iroh-relay/src/http.rs index 5957f816808..c3fe7d40aa3 100644 --- a/iroh-relay/src/http.rs +++ b/iroh-relay/src/http.rs @@ -13,5 +13,10 @@ pub const RELAY_PATH: &str = "/relay"; /// The HTTP path under which the relay allows doing latency queries for testing. pub const RELAY_PROBE_PATH: &str = "/ping"; +/// The websocket sub-protocol version that we currently support. +/// +/// This is sent as the websocket sub-protocol header `Sec-Websocket-Protocol` from +/// the client and answered from the server. +pub const RELAY_PROTOCOL_VERSION: &str = "iroh-relay-v1"; /// The HTTP header name for relay client authentication -pub const CLIENT_AUTH_HEADER: HeaderName = HeaderName::from_static("x-iroh-relay-client-auth"); +pub const CLIENT_AUTH_HEADER: HeaderName = HeaderName::from_static("x-iroh-relay-client-auth-v1"); diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index ba4a90bb0ff..1074dedac7a 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -4,7 +4,10 @@ use std::{ use bytes::Bytes; use derive_more::Debug; -use http::{header::CONNECTION, response::Builder as ResponseBuilder}; +use http::{ + header::{CONNECTION, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION}, + response::Builder as ResponseBuilder, +}; use hyper::{ body::Incoming, header::{HeaderValue, SEC_WEBSOCKET_ACCEPT, UPGRADE}, @@ -12,9 +15,9 @@ use hyper::{ upgrade::Upgraded, HeaderMap, Method, Request, Response, StatusCode, }; -use n0_future::{time::Elapsed, FutureExt}; +use n0_future::time::Elapsed; use nested_enum_utils::common_fields; -use snafu::{Backtrace, ResultExt, Snafu}; +use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; use tokio::net::{TcpListener, TcpStream}; use tokio_rustls_acme::AcmeAcceptor; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; @@ -35,7 +38,7 @@ use crate::{ KeyCache, }; use crate::{ - http::{CLIENT_AUTH_HEADER, WEBSOCKET_UPGRADE_PROTOCOL}, + http::{CLIENT_AUTH_HEADER, RELAY_PROTOCOL_VERSION, WEBSOCKET_UPGRADE_PROTOCOL}, protos::{handshake, relay::MAX_FRAME_SIZE, streams::WsBytesFramed}, server::streams::RateLimited, }; @@ -64,11 +67,6 @@ fn derive_accept_key(client_key: &HeaderValue) -> String { data_encoding::BASE64.encode(&sha1.finalize()) } -/// Creates a new [`BytesBody`] with no content. -fn body_empty() -> BytesBody { - http_body_util::Full::new(hyper::body::Bytes::new()) -} - /// Creates a new [`BytesBody`] with given content. fn body_full(content: impl Into) -> BytesBody { http_body_util::Full::new(content.into()) @@ -433,101 +431,134 @@ struct Inner { metrics: Arc, } +#[derive(Debug, Snafu)] +enum RelayUpgradeReqError { + #[snafu(display("missing header: {header}"))] + MissingHeader { header: http::HeaderName }, + #[snafu(display("invalid header value for {header}: {details}"))] + InvalidHeader { + header: http::HeaderName, + details: String, + }, + #[snafu(display( + "invalid header value for {SEC_WEBSOCKET_VERSION}: unsupported websocket version, only supporting {SUPPORTED_WEBSOCKET_VERSION}" + ))] + UnsupportedWebsocketVersion, + #[snafu(display( + "invalid header value for {SEC_WEBSOCKET_PROTOCOL}: unsupported relay version: we support {we_support} but you only provide {you_support}" + ))] + UnsupportedRelayVersion { + we_support: &'static str, + you_support: String, + }, +} + impl RelayService { + fn build_response(&self) -> http::response::Builder { + let mut res = Response::builder(); + for (key, value) in self.0.headers.iter() { + res = res.header(key, value); + } + res + } + /// Upgrades the HTTP connection to the relay protocol, runs relay client. - fn call_client_conn( + async fn handle_relay_ws_upgrade( &self, mut req: Request, - ) -> Pin, hyper::Error>> + Send>> { - // TODO: soooo much cloning. See if there is an alternative - let this = self.clone(); - let mut builder = Response::builder(); - for (key, value) in self.0.headers.iter() { - builder = builder.header(key, value); + ) -> Result, RelayUpgradeReqError> { + fn expect_header( + req: &Request, + header: http::HeaderName, + ) -> Result<&HeaderValue, RelayUpgradeReqError> { + req.headers() + .get(&header) + .context(MissingHeaderSnafu { header }) } - async move { - { - // Send a 400 to any request that doesn't have an `Upgrade` header. - if req.headers().get(UPGRADE) - != Some(&HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL)) - { - return Ok(builder - .status(StatusCode::BAD_REQUEST) - .body(body_empty()) - .expect("valid body")); - }; + // Send a 400 to any request that doesn't have an `Upgrade` header. + let upgrade_header = expect_header(&req, UPGRADE)?; + snafu::ensure!( + upgrade_header == &HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL), + InvalidHeaderSnafu { + header: UPGRADE, + details: format!("value must be {WEBSOCKET_UPGRADE_PROTOCOL}"), + } + ); - let Some(key) = req.headers().get("Sec-WebSocket-Key").cloned() else { - warn!("missing header Sec-WebSocket-Key for websocket relay protocol"); - return Ok(builder - .status(StatusCode::BAD_REQUEST) - .body(body_empty()) - .expect("valid body")); - }; + let key = expect_header(&req, SEC_WEBSOCKET_KEY)?.clone(); + let version = expect_header(&req, SEC_WEBSOCKET_VERSION)?.clone(); - let Some(version) = req.headers().get("Sec-WebSocket-Version").cloned() else { - warn!("missing header Sec-WebSocket-Version for websocket relay protocol"); - return Ok(builder - .status(StatusCode::BAD_REQUEST) - .body(body_empty()) - .expect("valid body")); - }; + snafu::ensure!( + version.as_bytes() == SUPPORTED_WEBSOCKET_VERSION.as_bytes(), + UnsupportedWebsocketVersionSnafu + ); - if version.as_bytes() != SUPPORTED_WEBSOCKET_VERSION.as_bytes() { - warn!("invalid header Sec-WebSocket-Version: {:?}", version); - return Ok(builder - .status(StatusCode::BAD_REQUEST) - // It's convention to send back the version(s) we *do* support - .header("Sec-WebSocket-Version", SUPPORTED_WEBSOCKET_VERSION) - .body(body_empty()) - .expect("valid body")); - } + let subprotocols = expect_header(&req, SEC_WEBSOCKET_PROTOCOL)? + .to_str() + .ok() + .context(InvalidHeaderSnafu { + header: SEC_WEBSOCKET_PROTOCOL, + details: format!("header value is not ascii"), + })?; + let supports_our_version = subprotocols + .split_whitespace() + .any(|p| p == RELAY_PROTOCOL_VERSION); + snafu::ensure!( + supports_our_version, + UnsupportedRelayVersionSnafu { + we_support: RELAY_PROTOCOL_VERSION, + you_support: subprotocols.to_string(), + } + ); - let client_auth_header = req.headers().get(CLIENT_AUTH_HEADER).cloned(); - - // Setup a future that will eventually receive the upgraded - // connection and talk a new protocol, and spawn the future - // into the runtime. - // - // Note: This can't possibly be fulfilled until the 101 response - // is returned below, so it's better to spawn this future instead - // waiting for it to complete to then return a response. - tokio::task::spawn( - async move { - match hyper::upgrade::on(&mut req).await { - Ok(upgraded) => { - if let Err(err) = this - .0 - .relay_connection_handler(upgraded, client_auth_header) - .await - { - warn!("error accepting upgraded connection: {err:#}",); - } else { - debug!("upgraded connection completed"); - }; - } - Err(err) => warn!("upgrade error: {err:#}"), - } + let client_auth_header = req.headers().get(CLIENT_AUTH_HEADER).cloned(); + + // Setup a future that will eventually receive the upgraded + // connection and talk a new protocol, and spawn the future + // into the runtime. + // + // Note: This can't possibly be fulfilled until the 101 response + // is returned below, so it's better to spawn this future instead + // waiting for it to complete to then return a response. + tokio::task::spawn({ + let this = self.clone(); + async move { + match hyper::upgrade::on(&mut req).await { + Ok(upgraded) => { + if let Err(err) = this + .0 + .relay_connection_handler(upgraded, client_auth_header) + .await + { + warn!("error accepting upgraded connection: {err:#}",); + } else { + debug!("upgraded connection completed"); + }; } - .instrument(debug_span!("handler")), - ); - - // Now return a 101 Response saying we agree to the upgrade to the - // websocket upgrade protocol - builder = builder.status(StatusCode::SWITCHING_PROTOCOLS).header( - UPGRADE, - HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL), - ); - - Ok(builder - .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(&key)) - .header(CONNECTION, "upgrade") - .body(body_full("switching to websocket protocol")) - .expect("valid body")) + Err(err) => warn!("upgrade error: {err:#}"), + } } - } - .boxed() + .instrument(debug_span!("handler")) + }); + + // Now return a 101 Response saying we agree to the upgrade to the + // websocket upgrade protocol + Ok(self + .build_response() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header( + UPGRADE, + HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL), + ) + .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(&key)) + .header( + SEC_WEBSOCKET_PROTOCOL, + HeaderValue::from_static(RELAY_PROTOCOL_VERSION), + ) + .header(CONNECTION, "upgrade") + .body(body_full("switching to websocket protocol")) + .expect("valid body")) } } @@ -543,7 +574,22 @@ impl Service> for RelayService { (&hyper::Method::GET, RELAY_PATH) ) { let this = self.clone(); - return Box::pin(async move { this.call_client_conn(req).await.map_err(Into::into) }); + return Box::pin(async move { + match this.handle_relay_ws_upgrade(req).await { + Ok(response) => Ok(response), + // It's convention to send back the version(s) we *do* support + Err(e @ RelayUpgradeReqError::UnsupportedWebsocketVersion) => this + .build_response() + .status(StatusCode::BAD_REQUEST) + .header(SEC_WEBSOCKET_VERSION, SUPPORTED_WEBSOCKET_VERSION) + .body(body_full(e.to_string())), + Err(e) => this + .build_response() + .status(StatusCode::BAD_REQUEST) + .body(body_full(e.to_string())), + } + .map_err(Into::into) + }); } // Otherwise handle the relay connection as normal. From 9a731fa34425c7bc0c4562d909216f3f21fb3afa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 10 Jul 2025 10:14:49 +0200 Subject: [PATCH 57/80] Implement relay protocol versioning client-side --- iroh-relay/src/client.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/iroh-relay/src/client.rs b/iroh-relay/src/client.rs index a912e5b2708..9befb89c5ad 100644 --- a/iroh-relay/src/client.rs +++ b/iroh-relay/src/client.rs @@ -200,10 +200,11 @@ impl ClientBuilder { /// Establishes a new connection to the relay server. #[cfg(not(wasm_browser))] pub async fn connect(&self) -> Result { + use http::header::SEC_WEBSOCKET_PROTOCOL; use tls::MaybeTlsStreamBuilder; use crate::{ - http::CLIENT_AUTH_HEADER, + http::{CLIENT_AUTH_HEADER, RELAY_PROTOCOL_VERSION}, protos::{handshake::KeyMaterialClientAuth, relay::MAX_FRAME_SIZE}, }; @@ -249,6 +250,11 @@ impl ClientBuilder { } .build() })? + .add_header( + SEC_WEBSOCKET_PROTOCOL, + http::HeaderValue::from_static(RELAY_PROTOCOL_VERSION), + ) + .expect("valid header name and value") .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))) .config(tokio_websockets::Config::default().flush_threshold(usize::MAX)); if let Some(client_auth) = KeyMaterialClientAuth::new(&self.secret_key, &stream) { @@ -300,6 +306,8 @@ impl ClientBuilder { /// Establishes a new connection to the relay server. #[cfg(wasm_browser)] pub async fn connect(&self) -> Result { + use crate::http::RELAY_PROTOCOL_VERSION; + let mut dial_url = (*self.url).clone(); dial_url.set_path(RELAY_PATH); // The relay URL is exchanged with the http(s) scheme in tickets and similar. @@ -319,7 +327,9 @@ impl ClientBuilder { debug!(%dial_url, "Dialing relay by websocket"); - let (_, ws_stream) = ws_stream_wasm::WsMeta::connect(dial_url.as_str(), None).await?; + let (_, ws_stream) = + ws_stream_wasm::WsMeta::connect(dial_url.as_str(), Some(vec![RELAY_PROTOCOL_VERSION])) + .await?; let conn = Conn::new(ws_stream, self.key_cache.clone(), &self.secret_key).await?; event!( From 7efad8a65dc81fcf15720b41ce840569a2e1c7c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 10 Jul 2025 10:54:56 +0200 Subject: [PATCH 58/80] Internal docs & prefer `Vec::from(Bytes)` over `to_vec()` --- iroh-relay/src/lib.rs | 18 ++++++++++++++++++ iroh-relay/src/protos/streams.rs | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/iroh-relay/src/lib.rs b/iroh-relay/src/lib.rs index d8440f0baab..76565b9cf5f 100644 --- a/iroh-relay/src/lib.rs +++ b/iroh-relay/src/lib.rs @@ -55,7 +55,25 @@ pub use self::{ relay_map::{RelayMap, RelayNode, RelayQuicConfig}, }; +/// This trait allows anything that ends up potentially +/// wrapping a TLS stream use the underlying [`export_keying_material`] +/// function. +/// +/// [`export_keying_material`]: rustls::ConnectionCommon::export_keying_material pub(crate) trait ExportKeyingMaterial { + /// If this type ends up wrapping a TLS stream, then this tries + /// to export keying material by calling the underlying [`export_keying_material`] + /// function. + /// + /// However unlike that function, this returns `Option`, in case the + /// underlying stream might not be wrapping TLS, e.g. as in the case of + /// [`MaybeTlsStream`]. + /// + /// For more information on what this function does, see the + /// [`export_keying_material`] documentation. + /// + /// [`export_keying_material`]: rustls::ConnectionCommon::export_keying_material + /// [`MaybeTlsStream`]: crate::client::streams::MaybeTlsStream #[cfg_attr(wasm_browser, allow(unused))] fn export_keying_material>( &self, diff --git a/iroh-relay/src/protos/streams.rs b/iroh-relay/src/protos/streams.rs index 31b4d143ec5..3b146599b86 100644 --- a/iroh-relay/src/protos/streams.rs +++ b/iroh-relay/src/protos/streams.rs @@ -150,7 +150,7 @@ impl Sink for WsBytesFramed { type Error = StreamError; fn start_send(mut self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { - let msg = ws_stream_wasm::WsMessage::Binary(bytes.to_vec()); + let msg = ws_stream_wasm::WsMessage::Binary(Vec::from(bytes)); Pin::new(&mut self.io).start_send(msg).map_err(Into::into) } From beef2d6667b1ecab290b3ef1074c1807d1d7ecf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 10 Jul 2025 10:58:17 +0200 Subject: [PATCH 59/80] Make clippy happy --- iroh-relay/src/server/http_server.rs | 38 +++++++++++++--------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 1074dedac7a..d1fe878af55 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -463,7 +463,7 @@ impl RelayService { } /// Upgrades the HTTP connection to the relay protocol, runs relay client. - async fn handle_relay_ws_upgrade( + fn handle_relay_ws_upgrade( &self, mut req: Request, ) -> Result, RelayUpgradeReqError> { @@ -479,7 +479,7 @@ impl RelayService { // Send a 400 to any request that doesn't have an `Upgrade` header. let upgrade_header = expect_header(&req, UPGRADE)?; snafu::ensure!( - upgrade_header == &HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL), + upgrade_header == HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL), InvalidHeaderSnafu { header: UPGRADE, details: format!("value must be {WEBSOCKET_UPGRADE_PROTOCOL}"), @@ -499,7 +499,7 @@ impl RelayService { .ok() .context(InvalidHeaderSnafu { header: SEC_WEBSOCKET_PROTOCOL, - details: format!("header value is not ascii"), + details: "header value is not ascii".to_string(), })?; let supports_our_version = subprotocols .split_whitespace() @@ -573,23 +573,21 @@ impl Service> for RelayService { (req.method(), req.uri().path()), (&hyper::Method::GET, RELAY_PATH) ) { - let this = self.clone(); - return Box::pin(async move { - match this.handle_relay_ws_upgrade(req).await { - Ok(response) => Ok(response), - // It's convention to send back the version(s) we *do* support - Err(e @ RelayUpgradeReqError::UnsupportedWebsocketVersion) => this - .build_response() - .status(StatusCode::BAD_REQUEST) - .header(SEC_WEBSOCKET_VERSION, SUPPORTED_WEBSOCKET_VERSION) - .body(body_full(e.to_string())), - Err(e) => this - .build_response() - .status(StatusCode::BAD_REQUEST) - .body(body_full(e.to_string())), - } - .map_err(Into::into) - }); + let res = match self.handle_relay_ws_upgrade(req) { + Ok(response) => Ok(response), + // It's convention to send back the version(s) we *do* support + Err(e @ RelayUpgradeReqError::UnsupportedWebsocketVersion) => self + .build_response() + .status(StatusCode::BAD_REQUEST) + .header(SEC_WEBSOCKET_VERSION, SUPPORTED_WEBSOCKET_VERSION) + .body(body_full(e.to_string())), + Err(e) => self + .build_response() + .status(StatusCode::BAD_REQUEST) + .body(body_full(e.to_string())), + } + .map_err(Into::into); + return Box::pin(async move { res }); } // Otherwise handle the relay connection as normal. From 9b85f32dc18a976951382061c79a8ba620a3a6c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 10 Jul 2025 15:04:14 +0200 Subject: [PATCH 60/80] Adjust frame type parsing and numbers and comments --- iroh-relay/src/protos/common.rs | 74 +++++++++++++++++------------- iroh-relay/src/protos/handshake.rs | 28 ++++++----- iroh-relay/src/protos/relay.rs | 14 +++--- 3 files changed, 68 insertions(+), 48 deletions(-) diff --git a/iroh-relay/src/protos/common.rs b/iroh-relay/src/protos/common.rs index aae464a8c88..b6c29661a5c 100644 --- a/iroh-relay/src/protos/common.rs +++ b/iroh-relay/src/protos/common.rs @@ -3,55 +3,65 @@ //! Hosts the [`FrameType`] enum to make sure we're not accidentally reusing frame type //! integers for different frames. -use bytes::{BufMut, Bytes}; +use bytes::{Buf, BufMut}; +use nested_enum_utils::common_fields; use quinn_proto::{coding::Codec, VarInt}; +use snafu::{Backtrace, OptionExt, Snafu}; /// Possible frame types during handshaking #[repr(u32)] -#[derive(Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::FromPrimitive)] +#[derive( + Copy, Clone, PartialEq, Eq, Debug, num_enum::IntoPrimitive, num_enum::TryFromPrimitive, +)] // needs to be pub due to being exposed in error types pub enum FrameType { /// The server frame type for the challenge response - ServerChallenge = 2, + ServerChallenge = 1, /// The client frame type for the authentication frame - ClientAuth = 3, + ClientAuth = 2, /// The server frame type for authentication confirmation - ServerConfirmsAuth = 4, + ServerConfirmsAuth = 3, /// The server frame type for authentication denial - ServerDeniesAuth = 5, + ServerDeniesAuth = 4, /// 32B dest pub key + ECN byte + segment size u16 + datagrams contents - ClientToRelayDatagrams = 10, + ClientToRelayDatagrams = 5, /// 32B src pub key + ECN byte + segment size u16 + datagrams contents - RelayToClientDatagrams = 11, + RelayToClientDatagrams = 6, /// Sent from server to client to signal that a previous sender is no longer connected. /// /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` /// to B so B can forget that a reverse path exists on that connection to get back to A /// /// 32B pub key of peer that's gone - NodeGone = 14, - /// Frames 9-11 concern meshing, which we have eliminated from our version of the protocol. + NodeGone = 7, /// Messages with these frames will be ignored. /// 8 byte ping payload, to be echoed back in FrameType::Pong - Ping = 15, + Ping = 8, /// 8 byte payload, the contents of ping being replied to - Pong = 16, - /// Sent from server to client to tell the client if their connection is - /// unhealthy somehow. - Health = 17, + Pong = 9, + /// Sent from server to client to tell the client if their connection is unhealthy somehow. + /// Contains only UTF-8 bytes. + Health = 10, /// Sent from server to client for the server to declare that it's restarting. /// Payload is two big endian u32 durations in milliseconds: when to reconnect, /// and how long to try total. - /// - /// Handled on the `[relay::Client]`, but currently never sent on the `[relay::Server]` - Restarting = 18, - /// The frame type was unknown. - /// - /// This frame is the result of parsing any future frame types that this implementation - /// does not yet understand. - #[num_enum(default)] - Unknown, + Restarting = 11, +} + +#[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: n0_snafu::SpanTrace, +})] +#[allow(missing_docs)] +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum FrameTypeError { + #[snafu(display("not enough bytes to parse frame type"))] + UnexpectedEnd {}, + #[snafu(display("frame type unknown"))] + UnknownFrameType { tag: VarInt }, } impl std::fmt::Display for FrameType { @@ -69,13 +79,15 @@ impl FrameType { /// Parses the frame type (as a QUIC-encoded varint) from the first couple of bytes given /// and returns the frame type and the rest. - pub(crate) fn from_bytes(bytes: Bytes) -> Option<(Self, Bytes)> { - let mut cursor = std::io::Cursor::new(&bytes); - let tag = VarInt::decode(&mut cursor).ok()?; - let tag_u32 = u32::try_from(u64::from(tag)).ok()?; - let frame_type = FrameType::from(tag_u32); - let content = bytes.slice(cursor.position() as usize..); - Some((frame_type, content)) + pub(crate) fn from_bytes(buf: &mut impl Buf) -> Result { + let tag = VarInt::decode(buf).ok().context(UnexpectedEndSnafu)?; + let tag_u32 = u32::try_from(u64::from(tag)) + .ok() + .context(UnknownFrameTypeSnafu { tag })?; + let frame_type = FrameType::try_from(tag_u32) + .ok() + .context(UnknownFrameTypeSnafu { tag })?; + Ok(frame_type) } } diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 1c09e5a0a54..e1b888c571d 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -36,9 +36,13 @@ use n0_future::{ use nested_enum_utils::common_fields; #[cfg(feature = "server")] use rand::{CryptoRng, RngCore}; -use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; +use snafu::{Backtrace, ResultExt, Snafu}; +use tracing::trace; -use super::{common::FrameType, streams::BytesStreamSink}; +use super::{ + common::{FrameType, FrameTypeError}, + streams::BytesStreamSink, +}; use crate::ExportKeyingMaterial; /// Authentication message from the client. @@ -144,6 +148,8 @@ pub enum Error { Timeout { source: Elapsed }, #[snafu(display("Handshake stream ended prematurely"))] UnexpectedEnd {}, + #[snafu(transparent)] + FrameTypeError { source: FrameTypeError }, #[snafu(display("The relay denied our authentication ({reason})"))] ServerDeniedAuth { reason: String }, #[snafu(display("Unexpected tag, got {frame_type}, but expected one of {expected_types:?}"))] @@ -363,7 +369,7 @@ pub(crate) async fn serverside( })?; if client_auth.verify(io) { - tracing::trace!(?client_auth.public_key, "authentication succeeded via keying material"); + trace!(?client_auth.public_key, "authentication succeeded via keying material"); return Ok(SuccessfulAuthentication { client_key: client_auth.public_key, mechanism: Mechanism::SignedKeyMaterial, @@ -378,13 +384,13 @@ pub(crate) async fn serverside( let client_auth: ClientAuth = deserialize_frame(frame)?; if client_auth.verify(&challenge) { - tracing::trace!(?client_auth.public_key, "authentication succeeded via challenge"); + trace!(?client_auth.public_key, "authentication succeeded via challenge"); Ok(SuccessfulAuthentication { client_key: client_auth.public_key, mechanism: Mechanism::SignedChallenge, }) } else { - tracing::trace!(?client_auth.public_key, "authentication failed"); + trace!(?client_auth.public_key, "authentication failed"); let denial = ServerDeniesAuth { reason: "signature invalid".into(), }; @@ -404,11 +410,11 @@ impl SuccessfulAuthentication { is_authorized: bool, ) -> Result { if is_authorized { - tracing::trace!("authorizing client"); + trace!("authorizing client"); write_frame(io, ServerConfirmsAuth).await?; Ok(self.client_key) } else { - tracing::trace!("denying client auth"); + trace!("denying client auth"); let denial = ServerDeniesAuth { reason: "not authorized".into(), }; @@ -426,7 +432,7 @@ async fn write_frame( frame: F, ) -> Result<(), Error> { let mut bytes = BytesMut::new(); - tracing::trace!(frame_type = %F::TAG, "Writing frame"); + trace!(frame_type = %F::TAG, "Writing frame"); F::TAG.write_to(&mut bytes); let bytes = postcard::to_io(&frame, bytes.writer()) .expect("serialization failed") // buffer can't become "full" without being a critical failure, datastructures shouldn't ever fail serialization @@ -442,13 +448,13 @@ async fn read_frame( expected_types: &[FrameType], timeout: time::Duration, ) -> Result<(FrameType, Bytes), Error> { - let recv = time::timeout(timeout, io.try_next()) + let mut payload = time::timeout(timeout, io.try_next()) .await .context(TimeoutSnafu)?? .ok_or_else(|| UnexpectedEndSnafu.build())?; - let (frame_type, payload) = FrameType::from_bytes(recv).context(UnexpectedEndSnafu)?; - tracing::trace!(%frame_type, "Reading frame"); + let frame_type = FrameType::from_bytes(&mut payload)?; + trace!(%frame_type, "Reading frame"); snafu::ensure!( expected_types.contains(&frame_type), UnexpectedFrameTypeSnafu { diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 2ff18691db0..e809e719192 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -11,9 +11,9 @@ use bytes::{BufMut, Bytes}; use iroh_base::{NodeId, SignatureError}; use n0_future::time::{self, Duration}; use nested_enum_utils::common_fields; -use snafu::{Backtrace, OptionExt, ResultExt, Snafu}; +use snafu::{Backtrace, ResultExt, Snafu}; -use super::common::FrameType; +use super::common::{FrameType, FrameTypeError}; use crate::KeyCache; /// The maximum size of a packet sent over relay. @@ -62,6 +62,8 @@ pub enum Error { Timeout { source: time::Elapsed }, #[snafu(transparent)] SerDe { source: postcard::Error }, + #[snafu(transparent)] + FrameTypeError { source: FrameTypeError }, #[snafu(display("Invalid public key"))] InvalidPublicKey { source: SignatureError }, #[snafu(display("Invalid frame encoding"))] @@ -248,8 +250,8 @@ impl RelayToClientMsg { /// /// Specifically, bytes received from a binary websocket message frame. #[allow(clippy::result_large_err)] - pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { - let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; + pub(crate) fn from_bytes(mut content: Bytes, cache: &KeyCache) -> Result { + let frame_type = FrameType::from_bytes(&mut content)?; let frame_len = content.len(); snafu::ensure!( frame_len <= MAX_PACKET_SIZE, @@ -358,8 +360,8 @@ impl ClientToRelayMsg { /// Specifically, bytes received from a binary websocket message frame. #[allow(clippy::result_large_err)] #[cfg(feature = "server")] - pub(crate) fn from_bytes(bytes: Bytes, cache: &KeyCache) -> Result { - let (frame_type, content) = FrameType::from_bytes(bytes).context(InvalidFrameSnafu)?; + pub(crate) fn from_bytes(mut content: Bytes, cache: &KeyCache) -> Result { + let frame_type = FrameType::from_bytes(&mut content)?; let frame_len = content.len(); snafu::ensure!( frame_len <= MAX_PACKET_SIZE, From cdea65e7790957edb8c578b5cfcadffd75eedf1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Thu, 10 Jul 2025 15:19:54 +0200 Subject: [PATCH 61/80] Choose correct buffer size --- iroh-relay/src/client/conn.rs | 3 +- iroh-relay/src/protos/relay.rs | 64 ++++++++++++++++++++++++++++++-- iroh-relay/src/server/client.rs | 3 +- iroh-relay/src/server/streams.rs | 3 +- 4 files changed, 63 insertions(+), 10 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index e105930e1cb..6d92ea5d240 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -7,7 +7,6 @@ use std::{ task::{ready, Context, Poll}, }; -use bytes::BytesMut; use iroh_base::SecretKey; use n0_future::{Sink, Stream}; use nested_enum_utils::common_fields; @@ -151,7 +150,7 @@ impl Sink for Conn { } Pin::new(&mut self.conn) - .start_send(frame.write_to(BytesMut::new()).freeze()) + .start_send(frame.to_bytes().freeze()) .map_err(Into::into) } diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index e809e719192..e5adf737fcf 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -7,7 +7,7 @@ //! * server then sends [`FrameType::RelayToClientDatagrams`] to recipient //! * server sends [`FrameType::NodeGone`] when the other client disconnects -use bytes::{BufMut, Bytes}; +use bytes::{BufMut, Bytes, BytesMut}; use iroh_base::{NodeId, SignatureError}; use n0_future::time::{self, Duration}; use nested_enum_utils::common_fields; @@ -172,6 +172,12 @@ impl Datagrams { dst } + fn encoded_len(&self) -> usize { + 1 // ECN byte + + 2 // segment size + + self.contents.len() + } + fn from_bytes(bytes: Bytes) -> Result { // 1 bytes ECN, 2 bytes segment size snafu::ensure!(bytes.len() > 3, InvalidFrameSnafu); @@ -209,6 +215,10 @@ impl RelayToClientMsg { } } + pub(crate) fn to_bytes(&self) -> BytesMut { + self.write_to(BytesMut::with_capacity(self.encoded_len())) + } + /// Encodes this frame for sending over websockets. /// /// Specifically meant for being put into a binary websocket message frame. @@ -246,6 +256,24 @@ impl RelayToClientMsg { dst } + pub(crate) fn encoded_len(&self) -> usize { + let payload_len = match self { + Self::Datagrams { datagrams, .. } => { + 32 // nodeid + + datagrams.encoded_len() + } + Self::NodeGone(_) => 32, + Self::Health { problem } => problem.len(), + Self::Restarting { .. } => { + 4 // u32 + + 4 // u32 + } + Self::Ping(_) | Self::Pong(_) => 8, + }; + 1 // frame type + + payload_len + } + /// Tries to decode a frame received over websockets. /// /// Specifically, bytes received from a binary websocket message frame. @@ -332,6 +360,10 @@ impl ClientToRelayMsg { } } + pub(crate) fn to_bytes(&self) -> BytesMut { + self.write_to(BytesMut::with_capacity(self.encoded_len())) + } + /// Encodes this frame for sending over websockets. /// /// Specifically meant for being put into a binary websocket message frame. @@ -355,6 +387,17 @@ impl ClientToRelayMsg { dst } + pub(crate) fn encoded_len(&self) -> usize { + match self { + Self::Ping(_) => 8, + Self::Pong(_) => 8, + Self::Datagrams { datagrams, .. } => { + 32 // node id + + datagrams.encoded_len() + } + } + } + /// Tries to decode a frame received over websockets. /// /// Specifically, bytes received from a binary websocket message frame. @@ -533,7 +576,6 @@ mod tests { #[cfg(test)] #[cfg(feature = "server")] mod proptests { - use bytes::BytesMut; use iroh_base::SecretKey; use proptest::prelude::*; @@ -608,16 +650,30 @@ mod proptests { proptest! { #[test] fn server_client_frame_roundtrip(frame in server_client_frame()) { - let encoded = frame.clone().write_to(BytesMut::new()).freeze(); + let encoded = frame.to_bytes().freeze(); let decoded = RelayToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); prop_assert_eq!(frame, decoded); } #[test] fn client_server_frame_roundtrip(frame in client_server_frame()) { - let encoded = frame.clone().write_to(BytesMut::new()).freeze(); + let encoded = frame.to_bytes().freeze(); let decoded = ClientToRelayMsg::from_bytes(encoded, &KeyCache::test()).unwrap(); prop_assert_eq!(frame, decoded); } + + #[test] + fn server_client_frame_encoded_len(frame in server_client_frame()) { + let claimed_encoded_len = frame.encoded_len(); + let actual_encoded_len = frame.to_bytes().len(); + prop_assert_eq!(claimed_encoded_len, actual_encoded_len); + } + + #[test] + fn client_server_frame_encoded_len(frame in client_server_frame()) { + let claimed_encoded_len = frame.encoded_len(); + let actual_encoded_len = frame.to_bytes().len(); + prop_assert_eq!(claimed_encoded_len, actual_encoded_len); + } } } diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 2875817dbab..185103ea41f 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -565,7 +565,6 @@ impl ClientCounter { #[cfg(test)] mod tests { - use bytes::BytesMut; use iroh_base::SecretKey; use n0_future::Stream; use n0_snafu::{Result, ResultExt}; @@ -738,7 +737,7 @@ mod tests { dst_node_id: target, datagrams: data.clone(), }; - let frame_len = frame.clone().write_to(BytesMut::new()).freeze().len(); + let frame_len = frame.to_bytes().freeze().len(); assert_eq!(frame_len, LIMIT as usize); // Send a frame, it should arrive. diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index 44515f0832b..c1c00b438a2 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -6,7 +6,6 @@ use std::{ task::{Context, Poll}, }; -use bytes::BytesMut; use n0_future::{ready, time, FutureExt, Sink, Stream}; use snafu::{Backtrace, Snafu}; use tokio::io::{AsyncRead, AsyncWrite}; @@ -83,7 +82,7 @@ impl Sink for RelayedStream { } fn start_send(mut self: Pin<&mut Self>, item: RelayToClientMsg) -> Result<(), Self::Error> { - Pin::new(&mut self.inner).start_send(item.write_to(BytesMut::new()).freeze()) + Pin::new(&mut self.inner).start_send(item.to_bytes().freeze()) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { From 5b0755681d1297650fa5509bcdf4465d0087ed1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 11 Jul 2025 11:35:28 +0200 Subject: [PATCH 62/80] Adjust comment --- iroh/src/magicsock/transports/relay.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/iroh/src/magicsock/transports/relay.rs b/iroh/src/magicsock/transports/relay.rs index e35aab2cedc..3f8cca0ea28 100644 --- a/iroh/src/magicsock/transports/relay.rs +++ b/iroh/src/magicsock/transports/relay.rs @@ -306,13 +306,7 @@ impl RelaySender { } } -/// Split a transmit containing a GSO payload into individual packets. -/// -/// This allocates the data. -/// -/// If the transmit has a segment size it contains multiple GSO packets. It will be split -/// into multiple packets according to that segment size. If it does not have a segment -/// size, the contents will be sent as a single packet. +/// Translate a UDP transmit to the `Datagrams` type for sending over the relay. // TODO: If quinn stayed on bytes this would probably be much cheaper, probably. Need to // figure out where they allocate the Vec. fn datagrams_from_transmit(transmit: &Transmit<'_>) -> Datagrams { From f66e765af38bf24fa7c4932afa4bcca3d0e3a9ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 11 Jul 2025 14:17:13 +0200 Subject: [PATCH 63/80] Cleanup --- iroh-relay/src/client/conn.rs | 2 +- iroh-relay/src/protos/relay.rs | 32 ++++++++++++++++++---------- iroh-relay/src/server/http_server.rs | 15 +++++++------ 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index f0fb0aa98eb..bd9fdfca9a5 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -104,7 +104,7 @@ impl Conn { Ok(Self { conn, key_cache }) } - #[cfg(test)] + #[cfg(all(test, feature = "server"))] pub(crate) fn test(io: tokio::io::DuplexStream) -> Self { use crate::protos::relay::MAX_FRAME_SIZE; Self { diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 7363983c487..0f0e8a286e7 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -257,6 +257,7 @@ impl RelayToClientMsg { dst } + #[cfg(feature = "server")] pub(crate) fn encoded_len(&self) -> usize { let payload_len = match self { Self::Datagrams { datagrams, .. } => { @@ -389,14 +390,16 @@ impl ClientToRelayMsg { } pub(crate) fn encoded_len(&self) -> usize { - match self { + let payload_len = match self { Self::Ping(_) => 8, Self::Pong(_) => 8, Self::Datagrams { datagrams, .. } => { 32 // node id + datagrams.encoded_len() } - } + }; + 1 // frame type (all frame types currently encode as 1 byte varint) + + payload_len } /// Tries to decode a frame received over websockets. @@ -479,22 +482,22 @@ mod tests { problem: "Hello? Yes this is dog.".into(), } .write_to(Vec::new()), - "11 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 + "0a 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 20 69 73 20 64 6f 67 2e", ), ( RelayToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), - "0e 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + "07 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61", ), ( RelayToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), - "0f 2a 2a 2a 2a 2a 2a 2a 2a", + "08 2a 2a 2a 2a 2a 2a 2a 2a", ), ( RelayToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), - "10 2a 2a 2a 2a 2a 2a 2a 2a", + "09 2a 2a 2a 2a 2a 2a 2a 2a", ), ( RelayToClientMsg::Datagrams { @@ -512,7 +515,7 @@ mod tests { // ECN byte // segment size // hello world contents bytes - "0b + "06 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 03 @@ -525,7 +528,7 @@ mod tests { try_for: Duration::from_millis(20), } .write_to(Vec::new()), - "12 00 00 00 0a 00 00 00 14", + "0b 00 00 00 0a 00 00 00 14", ), ]); @@ -539,11 +542,11 @@ mod tests { check_expected_bytes(vec![ ( ClientToRelayMsg::Ping([42u8; 8]).write_to(Vec::new()), - "0f 2a 2a 2a 2a 2a 2a 2a 2a", + "08 2a 2a 2a 2a 2a 2a 2a 2a", ), ( ClientToRelayMsg::Pong([42u8; 8]).write_to(Vec::new()), - "10 2a 2a 2a 2a 2a 2a 2a 2a", + "09 2a 2a 2a 2a 2a 2a 2a 2a", ), ( ClientToRelayMsg::Datagrams { @@ -561,7 +564,7 @@ mod tests { // ECN byte // segment size // hello world contents - "0a + "05 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 03 @@ -676,5 +679,12 @@ mod proptests { let actual_encoded_len = frame.to_bytes().len(); prop_assert_eq!(claimed_encoded_len, actual_encoded_len); } + + #[test] + fn datagrams_encoded_len(datagrams in datagrams()) { + let claimed_encoded_len = datagrams.encoded_len(); + let actual_encoded_len = datagrams.write_to(Vec::new()).len(); + prop_assert_eq!(claimed_encoded_len, actual_encoded_len); + } } } diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 4383a69005e..77d72da43bb 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -26,8 +26,15 @@ use tracing::{debug, debug_span, error, info, info_span, trace, warn, Instrument use super::{clients::Clients, streams::InvalidBucketConfig, AccessConfig, SpawnError}; use crate::{ defaults::{timeouts::SERVER_WRITE_TIMEOUT, DEFAULT_KEY_CACHE_CAPACITY}, - http::{RELAY_PATH, SUPPORTED_WEBSOCKET_VERSION}, - protos::relay::PER_CLIENT_SEND_QUEUE_DEPTH, + http::{ + CLIENT_AUTH_HEADER, RELAY_PATH, RELAY_PROTOCOL_VERSION, SUPPORTED_WEBSOCKET_VERSION, + WEBSOCKET_UPGRADE_PROTOCOL, + }, + protos::{ + handshake, + relay::{MAX_FRAME_SIZE, PER_CLIENT_SEND_QUEUE_DEPTH}, + streams::WsBytesFramed, + }, server::{ client::Config, metrics::Metrics, @@ -36,10 +43,6 @@ use crate::{ }, KeyCache, }; -use crate::{ - http::{CLIENT_AUTH_HEADER, RELAY_PROTOCOL_VERSION, WEBSOCKET_UPGRADE_PROTOCOL}, - protos::{handshake, relay::MAX_FRAME_SIZE, streams::WsBytesFramed}, -}; type BytesBody = http_body_util::Full; type HyperError = Box; From 8c0ae1626c7d2b56f864852cadd734dc94591297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 11 Jul 2025 14:32:01 +0200 Subject: [PATCH 64/80] Diff reduction --- iroh-relay/src/client/conn.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index bd9fdfca9a5..d7588edd7c1 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -19,7 +19,7 @@ use crate::client::streams::{MaybeTlsStream, ProxyStream}; use crate::{ protos::{ handshake, - relay::{ClientToRelayMsg, Error as RecvRelayError, RelayToClientMsg, MAX_PAYLOAD_SIZE}, + relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg, MAX_PAYLOAD_SIZE}, streams::WsBytesFramed, }, MAX_PACKET_SIZE, @@ -57,7 +57,7 @@ pub enum SendError { #[non_exhaustive] pub enum RecvError { #[snafu(transparent)] - Protocol { source: RecvRelayError }, + Protocol { source: ProtoError }, #[snafu(transparent)] StreamError { #[cfg(not(wasm_browser))] @@ -75,11 +75,11 @@ pub enum RecvError { /// - A [`Sink`] for [`ClientToRelayMsg`] to send to the server. #[derive(derive_more::Debug)] pub(crate) struct Conn { - #[debug("tokio_websockets::WebSocketStream")] #[cfg(not(wasm_browser))] + #[debug("tokio_websockets::WebSocketStream")] pub(crate) conn: WsBytesFramed>, - #[debug("ws_stream_wasm::WsStream")] #[cfg(wasm_browser)] + #[debug("ws_stream_wasm::WsStream")] pub(crate) conn: WsBytesFramed, pub(crate) key_cache: KeyCache, } @@ -124,8 +124,7 @@ impl Stream for Conn { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let msg = ready!(Pin::new(&mut self.conn).poll_next(cx)); - match msg { + match ready!(Pin::new(&mut self.conn).poll_next(cx)) { Some(Ok(msg)) => { let message = RelayToClientMsg::from_bytes(msg, &self.key_cache); Poll::Ready(Some(message.map_err(Into::into))) From f6cffe9b1292eca60037f6e3c475f467a6d5e5b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 11 Jul 2025 14:52:36 +0200 Subject: [PATCH 65/80] Cleanup / reduce diff --- iroh-relay/src/protos/handshake.rs | 6 +++--- iroh-relay/src/protos/relay.rs | 8 +++---- iroh-relay/src/protos/streams.rs | 8 +++---- iroh-relay/src/server/client.rs | 32 ++++++++++++++-------------- iroh-relay/src/server/clients.rs | 2 +- iroh-relay/src/server/http_server.rs | 11 +++++----- iroh/tests/integration.rs | 2 +- 7 files changed, 32 insertions(+), 37 deletions(-) diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index e1b888c571d..78588baac0b 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -404,10 +404,10 @@ pub(crate) async fn serverside( #[cfg(feature = "server")] impl SuccessfulAuthentication { - pub async fn authorize( + pub async fn authorize_if( self, - io: &mut (impl BytesStreamSink + ExportKeyingMaterial), is_authorized: bool, + io: &mut (impl BytesStreamSink + ExportKeyingMaterial), ) -> Result { if is_authorized { trace!("authorizing client"); @@ -598,7 +598,7 @@ mod tests { .context("serverside")?; let mechanism = auth_n.mechanism; let is_authorized = restricted_to.is_none_or(|key| key == auth_n.client_key); - let key = auth_n.authorize(&mut server_io, is_authorized).await?; + let key = auth_n.authorize_if(is_authorized, &mut server_io).await?; Ok((key, mechanism)) } .instrument(info_span!("serverside")), diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 0f0e8a286e7..6f0d69bc4c2 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -265,12 +265,12 @@ impl RelayToClientMsg { + datagrams.encoded_len() } Self::NodeGone(_) => 32, + Self::Ping(_) | Self::Pong(_) => 8, Self::Health { problem } => problem.len(), Self::Restarting { .. } => { 4 // u32 + 4 // u32 } - Self::Ping(_) | Self::Pong(_) => 8, }; 1 // frame type + payload_len @@ -391,8 +391,7 @@ impl ClientToRelayMsg { pub(crate) fn encoded_len(&self) -> usize { let payload_len = match self { - Self::Ping(_) => 8, - Self::Pong(_) => 8, + Self::Ping(_) | Self::Pong(_) => 8, Self::Datagrams { datagrams, .. } => { 32 // node id + datagrams.encoded_len() @@ -577,8 +576,7 @@ mod tests { } } -#[cfg(test)] -#[cfg(feature = "server")] +#[cfg(all(test, feature = "server"))] mod proptests { use iroh_base::SecretKey; use proptest::prelude::*; diff --git a/iroh-relay/src/protos/streams.rs b/iroh-relay/src/protos/streams.rs index 3b146599b86..1ed90154347 100644 --- a/iroh-relay/src/protos/streams.rs +++ b/iroh-relay/src/protos/streams.rs @@ -9,6 +9,7 @@ use bytes::Bytes; use n0_future::{ready, Sink, Stream}; #[cfg(not(wasm_browser))] use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::warn; use crate::ExportKeyingMaterial; @@ -90,10 +91,7 @@ impl Stream for WsBytesFramed { continue; // Responding appropriately to these is done inside of tokio_websockets/browser impls } if !msg.is_binary() { - tracing::warn!( - ?msg, - "Got websocket message of unsupported type, skipping." - ); + warn!(?msg, "Got websocket message of unsupported type, skipping."); continue; } return Poll::Ready(Some(Ok(msg.into_payload().into()))); @@ -115,7 +113,7 @@ impl Stream for WsBytesFramed { return Poll::Ready(Some(Ok(msg.into()))) } Some(msg) => { - tracing::warn!(?msg, "Got websocket message of unsupported type, skipping."); + warn!(?msg, "Got websocket message of unsupported type, skipping."); continue; } } diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 5f75d778abd..50d051f1a89 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -24,7 +24,7 @@ use crate::{ server::{ clients::Clients, metrics::Metrics, - streams::{RecvError as StreamRecvError, RelayedStream}, + streams::{RecvError, RelayedStream}, }, PingTracker, }; @@ -168,7 +168,7 @@ impl Client { } } -/// Receive frame error +/// Error for [`Actor::handle_frame`] #[common_fields({ backtrace: Option, })] @@ -181,19 +181,19 @@ pub enum HandleFrameError { #[snafu(display("Stream terminated"))] StreamTerminated {}, #[snafu(transparent)] - Recv { source: StreamRecvError }, + Recv { source: RecvError }, #[snafu(transparent)] - Send { source: SendFrameError }, + Send { source: WriteFrameError }, } -/// Send frame error +/// Error for [`Actor::write_frame`] #[common_fields({ backtrace: Option, })] #[allow(missing_docs)] #[derive(Debug, Snafu)] #[non_exhaustive] -pub enum SendFrameError { +pub enum WriteFrameError { #[snafu(transparent)] Stream { source: StreamError }, #[snafu(transparent)] @@ -224,7 +224,7 @@ pub enum RunError { }, #[snafu(display("Failed to send disco packet"))] DiscoPacketSend { - source: SendFrameError, + source: WriteFrameError, #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, @@ -235,7 +235,7 @@ pub enum RunError { }, #[snafu(display("Failed to send packet"))] PacketSend { - source: SendFrameError, + source: WriteFrameError, #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, @@ -246,13 +246,13 @@ pub enum RunError { }, #[snafu(display("NodeGone write frame failed"))] NodeGoneWriteFrame { - source: SendFrameError, + source: WriteFrameError, #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, #[snafu(display("Keep alive write frame failed"))] KeepAliveWriteFrame { - source: SendFrameError, + source: WriteFrameError, #[snafu(implicit)] span_trace: n0_snafu::SpanTrace, }, @@ -395,7 +395,7 @@ impl Actor { /// Writes the given frame to the connection. /// /// Errors if the send does not happen within the `timeout` duration - async fn write_frame(&mut self, frame: RelayToClientMsg) -> Result<(), SendFrameError> { + async fn write_frame(&mut self, frame: RelayToClientMsg) -> Result<(), WriteFrameError> { tokio::time::timeout(self.timeout, self.stream.send(frame)).await??; Ok(()) } @@ -404,7 +404,7 @@ impl Actor { /// /// Errors if the send does not happen within the `timeout` duration /// Does not flush. - async fn send_raw(&mut self, packet: Packet) -> Result<(), SendFrameError> { + async fn send_raw(&mut self, packet: Packet) -> Result<(), WriteFrameError> { let remote_node_id = packet.src; let datagrams = packet.data; @@ -418,7 +418,7 @@ impl Actor { .await } - async fn send_packet(&mut self, packet: Packet) -> Result<(), SendFrameError> { + async fn send_packet(&mut self, packet: Packet) -> Result<(), WriteFrameError> { trace!("send packet"); match self.send_raw(packet).await { Ok(()) => { @@ -432,7 +432,7 @@ impl Actor { } } - async fn send_disco_packet(&mut self, packet: Packet) -> Result<(), SendFrameError> { + async fn send_disco_packet(&mut self, packet: Packet) -> Result<(), WriteFrameError> { trace!("send disco packet"); match self.send_raw(packet).await { Ok(()) => { @@ -449,7 +449,7 @@ impl Actor { /// Handles frame read results. async fn handle_frame( &mut self, - maybe_frame: Option>, + maybe_frame: Option>, ) -> Result<(), HandleFrameError> { trace!(?maybe_frame, "handle incoming frame"); let frame = match maybe_frame { @@ -585,7 +585,7 @@ mod tests { Some(Ok(frame)) => { if frame_type != frame.typ() { snafu::whatever!( - "Unepxected frame, got {}, but expected {}", + "Unexpected frame, got {}, but expected {}", frame.typ(), frame_type ); diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index b89d73b82ff..bdec78042b0 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -216,7 +216,7 @@ mod tests { Some(Ok(frame)) => { if frame_type != frame.typ() { snafu::whatever!( - "Unepxected frame, got {}, but expected {}", + "Unexpected frame, got {}, but expected {}", frame.typ(), frame_type ); diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 77d72da43bb..4a47625aae3 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -477,7 +477,6 @@ impl RelayService { .context(MissingHeaderSnafu { header }) } - // Send a 400 to any request that doesn't have an `Upgrade` header. let upgrade_header = expect_header(&req, UPGRADE)?; snafu::ensure!( upgrade_header == HeaderValue::from_static(WEBSOCKET_UPGRADE_PROTOCOL), @@ -668,10 +667,10 @@ impl Inner { self.metrics.accepts.inc(); // Create a server builder with default config - let builder = tokio_websockets::ServerBuilder::new() - .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))); - // Serve will create a WebSocketStream on an already upgraded connection - let websocket = builder.serve(io); + let websocket = tokio_websockets::ServerBuilder::new() + .limits(tokio_websockets::Limits::default().max_payload_len(Some(MAX_FRAME_SIZE))) + // Serve will create a WebSocketStream on an already upgraded connection + .serve(io); let mut io = WsBytesFramed { io: websocket }; @@ -681,7 +680,7 @@ impl Inner { trace!(?authentication.mechanism, "accept: verified authentication"); let is_authorized = self.access.is_allowed(authentication.client_key).await; - let client_key = authentication.authorize(&mut io, is_authorized).await?; + let client_key = authentication.authorize_if(is_authorized, &mut io).await?; trace!("accept: verified authorization"); diff --git a/iroh/tests/integration.rs b/iroh/tests/integration.rs index 17ab91f9eb7..e48909603ac 100644 --- a/iroh/tests/integration.rs +++ b/iroh/tests/integration.rs @@ -127,7 +127,7 @@ fn setup_logging() { use std::str::FromStr; tracing_subscriber::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_str("debug").expect("hardcoded")) + .with_env_filter(tracing_subscriber::EnvFilter::from_str("trace").expect("hardcoded")) .with_max_level(tracing::level_filters::LevelFilter::DEBUG) .with_writer( // To avoide trace events in the browser from showing their JS backtrace From e20915d125919d76b9b3c41265c98d40bca16bdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 11 Jul 2025 15:20:34 +0200 Subject: [PATCH 66/80] Fix docs, don't `cfg_attr` on serde impls --- iroh-relay/src/protos/handshake.rs | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/iroh-relay/src/protos/handshake.rs b/iroh-relay/src/protos/handshake.rs index 78588baac0b..5dd7ff56f8f 100644 --- a/iroh-relay/src/protos/handshake.rs +++ b/iroh-relay/src/protos/handshake.rs @@ -10,9 +10,9 @@ //! //! One way is via an explicitly sent challenge: //! -//! 1. Once a websocket connection is opened, a client receives a challenge (the [`ServerChallenge`] frame) +//! 1. Once a websocket connection is opened, a client receives a challenge (the `ServerChallenge` frame) //! 2. The client sends back what is essentially a signature of that challenge with their secret key -//! that matches the NodeId they have, as well as the NodeId (the [`ClientAuth`] frame) +//! that matches the NodeId they have, as well as the NodeId (the `ClientAuth` frame) //! //! The second way is very similar to the [Concealed HTTP Auth RFC], and involves send a header that //! contains a signature of some shared keying material extracted from TLS ([RFC 5705]). @@ -46,8 +46,7 @@ use super::{ use crate::ExportKeyingMaterial; /// Authentication message from the client. -#[derive(derive_more::Debug, serde::Serialize)] -#[cfg_attr(feature = "server", derive(serde::Deserialize))] +#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(wasm_browser, allow(unused))] pub(crate) struct KeyMaterialClientAuth { /// The client's public key @@ -62,8 +61,7 @@ pub(crate) struct KeyMaterialClientAuth { } /// A challenge for the client to sign with their secret key for NodeId authentication. -#[derive(derive_more::Debug, serde::Deserialize)] -#[cfg_attr(feature = "server", derive(serde::Serialize))] +#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)] pub(crate) struct ServerChallenge { /// The challenge to sign. /// Must be randomly generated with an RNG that is safe to use for crypto. @@ -73,8 +71,7 @@ pub(crate) struct ServerChallenge { /// Authentication message from the client. /// /// Used when authentication via [`KeyMaterialClientAuth`] didn't work. -#[derive(derive_more::Debug, serde::Serialize)] -#[cfg_attr(feature = "server", derive(serde::Deserialize))] +#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)] pub(crate) struct ClientAuth { /// The client's public key, a.k.a. the `NodeId` pub(crate) public_key: PublicKey, @@ -86,13 +83,11 @@ pub(crate) struct ClientAuth { } /// Confirmation of successful connection. -#[derive(derive_more::Debug, serde::Deserialize)] -#[cfg_attr(feature = "server", derive(serde::Serialize))] +#[derive(derive_more::Debug, serde::Serialize, serde::Deserialize)] pub(crate) struct ServerConfirmsAuth; /// Denial of connection. The client couldn't be verified as authentic. -#[derive(derive_more::Debug, Clone, serde::Deserialize)] -#[cfg_attr(feature = "server", derive(serde::Serialize))] +#[derive(derive_more::Debug, Clone, serde::Serialize, serde::Deserialize)] pub(crate) struct ServerDeniesAuth { reason: String, } @@ -311,10 +306,10 @@ pub(crate) async fn clientside( /// This represents successful authentication for the client with the `client_key` public key /// via the authentication [`Mechanism`] `mechanism`. /// -/// You must call [`SuccessfulAuthentication::authorize`] to finish the protocol. +/// You must call [`SuccessfulAuthentication::authorize_if`] to finish the protocol. #[cfg(feature = "server")] #[derive(Debug)] -#[must_use = "the protocol is not finished unless `authorize` is called"] +#[must_use = "the protocol is not finished unless `authorize_if` is called"] pub(crate) struct SuccessfulAuthentication { pub(crate) client_key: PublicKey, pub(crate) mechanism: Mechanism, @@ -342,7 +337,7 @@ pub(crate) enum Mechanism { /// If this fails, the protocol falls back to doing a normal extra round trip with a challenge. /// /// The return value [`SuccessfulAuthentication`] still needs to be resolved by calling -/// [`SuccessfulAuthentication::authorize`] to finish the whole authorization protocol +/// [`SuccessfulAuthentication::authorize_if`] to finish the whole authorization protocol /// (otherwise the client won't be notified about auth success or failure). #[cfg(feature = "server")] pub(crate) async fn serverside( From 16c89c294164c3d6853b036d5f9ab1046e84dbc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 14 Jul 2025 14:46:03 +0200 Subject: [PATCH 67/80] Revert logical changes from "Send ECN bits and use stride instead of custom split protocol" Reverting 7a5550ac81d7a01bb70a53687be4c69aede2c748 Also: Adjusting revert so it compiles --- iroh-relay/src/client/conn.rs | 8 +- iroh-relay/src/protos/common.rs | 4 +- iroh-relay/src/protos/relay.rs | 214 ++++------------- iroh-relay/src/server.rs | 55 ++--- iroh-relay/src/server/client.rs | 81 +++---- iroh-relay/src/server/clients.rs | 34 ++- iroh-relay/src/server/http_server.rs | 158 ++++++------- iroh/src/magicsock.rs | 4 + iroh/src/magicsock/transports/relay.rs | 50 ++-- iroh/src/magicsock/transports/relay/actor.rs | 234 ++++++++++++++++--- 10 files changed, 426 insertions(+), 416 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index d7588edd7c1..d30f8f06695 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -19,7 +19,7 @@ use crate::client::streams::{MaybeTlsStream, ProxyStream}; use crate::{ protos::{ handshake, - relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg, MAX_PAYLOAD_SIZE}, + relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg}, streams::WsBytesFramed, }, MAX_PACKET_SIZE, @@ -143,9 +143,9 @@ impl Sink for Conn { } fn start_send(mut self: Pin<&mut Self>, frame: ClientToRelayMsg) -> Result<(), Self::Error> { - if let ClientToRelayMsg::Datagrams { datagrams, .. } = &frame { - let size = datagrams.contents.len(); - snafu::ensure!(size <= MAX_PAYLOAD_SIZE, ExceedsMaxPacketSizeSnafu { size }); + if let ClientToRelayMsg::SendPacket { .. } = &frame { + let size = frame.encoded_len(); + snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); } Pin::new(&mut self.conn) diff --git a/iroh-relay/src/protos/common.rs b/iroh-relay/src/protos/common.rs index b6c29661a5c..04dd4b61542 100644 --- a/iroh-relay/src/protos/common.rs +++ b/iroh-relay/src/protos/common.rs @@ -24,9 +24,9 @@ pub enum FrameType { /// The server frame type for authentication denial ServerDeniesAuth = 4, /// 32B dest pub key + ECN byte + segment size u16 + datagrams contents - ClientToRelayDatagrams = 5, + SendPacket = 5, /// 32B src pub key + ECN byte + segment size u16 + datagrams contents - RelayToClientDatagrams = 6, + RecvPacket = 6, /// Sent from server to client to signal that a previous sender is no longer connected. /// /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 6f0d69bc4c2..dfdc7a4f2ed 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -21,12 +21,6 @@ use crate::KeyCache; /// including its on-wire framing overhead) pub const MAX_PACKET_SIZE: usize = 64 * 1024; -/// Maximum size a datagram payload is allowed to be. -/// -/// This is [`MAX_PACKET_SIZE`] minus the length of an encoded public key minus 3 bytes, -/// one for ECN, and two for the segment size. -pub const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - NodeId::LENGTH - 3; - /// The maximum frame size. /// /// This is also the minimum burst size that a rate-limiter has to accept. @@ -80,11 +74,12 @@ pub enum Error { #[derive(derive_more::Debug, Clone, PartialEq, Eq)] pub enum RelayToClientMsg { /// Represents datagrams sent from relays (originally sent to them by another client). - Datagrams { + ReceivedPacket { /// The [`NodeId`] of the original sender. - remote_node_id: NodeId, - /// The datagrams and related metadata. - datagrams: Datagrams, + src_key: NodeId, + /// The received packet bytes. + #[debug(skip)] + content: Bytes, }, /// Indicates that the client identified by the underlying public key had previously sent you a /// packet but has now disconnected from the relay. @@ -128,85 +123,19 @@ pub enum ClientToRelayMsg { /// with the payload sent previously in the ping. Pong([u8; 8]), /// Request from the client to relay datagrams to given remote node. - Datagrams { + SendPacket { /// The remote node to relay to. - dst_node_id: NodeId, + dst_key: NodeId, /// The datagrams and related metadata to relay. - datagrams: Datagrams, + packet: Bytes, }, } -/// One or multiple datagrams being transferred via the relay. -/// -/// This type is modeled after [`quinn_proto::Transmit`] -/// (or even more similarly `quinn_udp::Transmit`, but we don't depend on that library here). -#[derive(derive_more::Debug, Clone, PartialEq, Eq)] -pub struct Datagrams { - /// Explicit congestion notification bits - pub ecn: Option, - /// The segment size if this transmission contains multiple datagrams. - /// This is `None` if the transmit only contains a single datagram - pub segment_size: Option, - /// The contents of the datagram(s) - #[debug(skip)] - pub contents: Bytes, -} - -impl> From for Datagrams { - fn from(bytes: T) -> Self { - Self { - ecn: None, - segment_size: None, - contents: Bytes::copy_from_slice(bytes.as_ref()), - } - } -} - -impl Datagrams { - fn write_to(&self, mut dst: O) -> O { - let ecn = self.ecn.map_or(0, |ecn| ecn as u8); - let segment_size = self.segment_size.unwrap_or_default(); - dst.put_u8(ecn); - dst.put_u16(segment_size); - dst.put(self.contents.as_ref()); - dst - } - - fn encoded_len(&self) -> usize { - 1 // ECN byte - + 2 // segment size - + self.contents.len() - } - - fn from_bytes(bytes: Bytes) -> Result { - // 1 bytes ECN, 2 bytes segment size - snafu::ensure!(bytes.len() > 3, InvalidFrameSnafu); - - let ecn_byte = bytes[0]; - let ecn = quinn_proto::EcnCodepoint::from_bits(ecn_byte); - - let segment_size = u16::from_be_bytes(bytes[1..3].try_into().expect("length checked")); - let segment_size = if segment_size == 0 { - None - } else { - Some(segment_size) - }; - - let contents = bytes.slice(3..); - - Ok(Self { - ecn, - segment_size, - contents, - }) - } -} - impl RelayToClientMsg { /// Returns this frame's corresponding frame type. pub fn typ(&self) -> FrameType { match self { - Self::Datagrams { .. } => FrameType::RelayToClientDatagrams, + Self::ReceivedPacket { .. } => FrameType::RecvPacket, Self::NodeGone { .. } => FrameType::NodeGone, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, @@ -227,12 +156,12 @@ impl RelayToClientMsg { pub(crate) fn write_to(&self, mut dst: O) -> O { dst = self.typ().write_to(dst); match self { - Self::Datagrams { - remote_node_id, - datagrams, + Self::ReceivedPacket { + src_key: remote_node_id, + content, } => { dst.put(remote_node_id.as_ref()); - dst = datagrams.write_to(dst); + dst.put(content.as_ref()); } Self::NodeGone(node_id) => { dst.put(node_id.as_ref()); @@ -260,9 +189,9 @@ impl RelayToClientMsg { #[cfg(feature = "server")] pub(crate) fn encoded_len(&self) -> usize { let payload_len = match self { - Self::Datagrams { datagrams, .. } => { + Self::ReceivedPacket { content, .. } => { 32 // nodeid - + datagrams.encoded_len() + + content.len() } Self::NodeGone(_) => 32, Self::Ping(_) | Self::Pong(_) => 8, @@ -289,17 +218,14 @@ impl RelayToClientMsg { ); let res = match frame_type { - FrameType::RelayToClientDatagrams => { + FrameType::RecvPacket => { snafu::ensure!(content.len() >= NodeId::LENGTH, InvalidFrameSnafu); - let remote_node_id = cache + let src_key = cache .key_from_slice(&content[..NodeId::LENGTH]) .context(InvalidPublicKeySnafu)?; - let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; - Self::Datagrams { - remote_node_id, - datagrams, - } + let content = content.slice(NodeId::LENGTH..); + Self::ReceivedPacket { src_key, content } } FrameType::NodeGone => { snafu::ensure!(content.len() == NodeId::LENGTH, InvalidFrameSnafu); @@ -356,7 +282,7 @@ impl RelayToClientMsg { impl ClientToRelayMsg { pub(crate) fn typ(&self) -> FrameType { match self { - Self::Datagrams { .. } => FrameType::ClientToRelayDatagrams, + Self::SendPacket { .. } => FrameType::SendPacket, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, } @@ -372,12 +298,9 @@ impl ClientToRelayMsg { pub(crate) fn write_to(&self, mut dst: O) -> O { dst = self.typ().write_to(dst); match self { - Self::Datagrams { - dst_node_id, - datagrams, - } => { - dst.put(dst_node_id.as_ref()); - dst = datagrams.write_to(dst); + Self::SendPacket { dst_key, packet } => { + dst.put(dst_key.as_ref()); + dst.put(packet.as_ref()); } Self::Ping(data) => { dst.put(&data[..]); @@ -392,9 +315,9 @@ impl ClientToRelayMsg { pub(crate) fn encoded_len(&self) -> usize { let payload_len = match self { Self::Ping(_) | Self::Pong(_) => 8, - Self::Datagrams { datagrams, .. } => { + Self::SendPacket { packet, .. } => { 32 // node id - + datagrams.encoded_len() + + packet.len() } }; 1 // frame type (all frame types currently encode as 1 byte varint) @@ -415,15 +338,12 @@ impl ClientToRelayMsg { ); let res = match frame_type { - FrameType::ClientToRelayDatagrams => { - let dst_node_id = cache + FrameType::SendPacket => { + let dst_key = cache .key_from_slice(&content[..NodeId::LENGTH]) .context(InvalidPublicKeySnafu)?; - let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; - Self::Datagrams { - dst_node_id, - datagrams, - } + let packet = content.slice(NodeId::LENGTH..); + Self::SendPacket { dst_key, packet } } FrameType::Ping => { snafu::ensure!(content.len() == 8, InvalidFrameSnafu); @@ -499,13 +419,9 @@ mod tests { "09 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - RelayToClientMsg::Datagrams { - remote_node_id: client_key.public(), - datagrams: Datagrams { - ecn: Some(quinn::EcnCodepoint::Ce), - segment_size: Some(6), - contents: "Hello World!".into(), - }, + RelayToClientMsg::ReceivedPacket { + src_key: client_key.public(), + content: "Hello World!".into(), } .write_to(Vec::new()), // frame type @@ -548,13 +464,9 @@ mod tests { "09 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ClientToRelayMsg::Datagrams { - dst_node_id: client_key.public(), - datagrams: Datagrams { - ecn: Some(quinn::EcnCodepoint::Ce), - segment_size: Some(6), - contents: "Hello World!".into(), - }, + ClientToRelayMsg::SendPacket { + dst_key: client_key.public(), + packet: "Hello World!".into(), } .write_to(Vec::new()), // frame type @@ -591,44 +503,22 @@ mod proptests { secret_key().prop_map(|key| key.public()) } - fn ecn() -> impl Strategy> { - (0..=3).prop_map(|n| match n { - 1 => Some(quinn_proto::EcnCodepoint::Ce), - 2 => Some(quinn_proto::EcnCodepoint::Ect0), - 3 => Some(quinn_proto::EcnCodepoint::Ect1), - _ => None, - }) - } - - fn datagrams() -> impl Strategy { - ( - ecn(), - prop::option::of(MAX_PAYLOAD_SIZE / 20..MAX_PAYLOAD_SIZE), - prop::collection::vec(any::(), 0..MAX_PAYLOAD_SIZE), - ) - .prop_map(|(ecn, segment_size, data)| Datagrams { - ecn, - segment_size: segment_size.map(|ss| std::cmp::min(data.len(), ss) as u16), - contents: Bytes::from(data), - }) + /// Generates random data, up to the maximum packet size minus the given number of bytes + fn data(consumed: usize) -> impl Strategy { + let len = MAX_PACKET_SIZE - consumed; + prop::collection::vec(any::(), 0..len).prop_map(Bytes::from) } /// Generates a random valid frame fn server_client_frame() -> impl Strategy { - let recv_packet = (key(), datagrams()).prop_map(|(remote_node_id, datagrams)| { - RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } - }); - let node_gone = key().prop_map(RelayToClientMsg::NodeGone); + let recv_packet = (key(), data(32)) + .prop_map(|(src_key, content)| RelayToClientMsg::ReceivedPacket { src_key, content }); + let node_gone = key().prop_map(|node_id| RelayToClientMsg::NodeGone(node_id)); let ping = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Pong); - let health = ".{0,65536}" - .prop_filter("exceeds MAX_PAYLOAD_SIZE", |s| { - s.len() < MAX_PAYLOAD_SIZE // a single unicode character can match a regex "." but take up multiple bytes - }) - .prop_map(|problem| RelayToClientMsg::Health { problem }); + let health = data(0).prop_map(|_problem| RelayToClientMsg::Health { + problem: "".to_string(), + }); let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { RelayToClientMsg::Restarting { reconnect_in: Duration::from_millis(reconnect_in.into()), @@ -639,11 +529,8 @@ mod proptests { } fn client_server_frame() -> impl Strategy { - let send_packet = - (key(), datagrams()).prop_map(|(dst_node_id, datagrams)| ClientToRelayMsg::Datagrams { - dst_node_id, - datagrams, - }); + let send_packet = (key(), data(32)) + .prop_map(|(dst_key, packet)| ClientToRelayMsg::SendPacket { dst_key, packet }); let ping = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Pong); prop_oneof![send_packet, ping, pong] @@ -677,12 +564,5 @@ mod proptests { let actual_encoded_len = frame.to_bytes().len(); prop_assert_eq!(claimed_encoded_len, actual_encoded_len); } - - #[test] - fn datagrams_encoded_len(datagrams in datagrams()) { - let claimed_encoded_len = datagrams.encoded_len(); - let actual_encoded_len = datagrams.write_to(Vec::new()).len(); - prop_assert_eq!(claimed_encoded_len, actual_encoded_len); - } } } diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index 94a3690ed55..3660760f4af 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -750,6 +750,7 @@ impl hyper::service::Service> for CaptivePortalService { mod tests { use std::{net::Ipv4Addr, time::Duration}; + use bytes::Bytes; use http::StatusCode; use iroh_base::{NodeId, RelayUrl, SecretKey}; use n0_future::{FutureExt, SinkExt, StreamExt}; @@ -766,7 +767,7 @@ mod tests { dns::DnsResolver, protos::{ handshake, - relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, + relay::{ClientToRelayMsg, RelayToClientMsg}, }, }; @@ -790,14 +791,14 @@ mod tests { client_a: &mut crate::client::Client, client_b: &mut crate::client::Client, b_key: NodeId, - msg: Datagrams, + msg: Bytes, ) -> Result { // try resend 10 times for _ in 0..10 { client_a - .send(ClientToRelayMsg::Datagrams { - dst_node_id: b_key, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: b_key, + packet: msg.clone(), }) .await?; let Ok(res) = tokio::time::timeout(Duration::from_millis(500), client_b.next()).await @@ -911,34 +912,26 @@ mod tests { info!("sending a -> b"); // send message from a to b - let msg = Datagrams::from("hello, b"); + let msg = Bytes::from_static(b"hello, b"); let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?; - let RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } = res - else { + let RelayToClientMsg::ReceivedPacket { src_key, content } = res else { panic!("client_b received unexpected message {res:?}"); }; - assert_eq!(a_key, remote_node_id); - assert_eq!(msg, datagrams); + assert_eq!(a_key, src_key); + assert_eq!(msg, content); info!("sending b -> a"); // send message from b to a - let msg = Datagrams::from("howdy, a"); + let msg = Bytes::from_static(b"howdy, a"); let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?; - let RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } = res - else { + let RelayToClientMsg::ReceivedPacket { src_key, content } = res else { panic!("client_a received unexpected message {res:?}"); }; - assert_eq!(b_key, remote_node_id); - assert_eq!(msg, datagrams); + assert_eq!(b_key, src_key); + assert_eq!(msg, content); Ok(()) } @@ -1010,16 +1003,12 @@ mod tests { .await?; // send message from b to c - let msg = Datagrams::from("hello, c"); + let msg = Bytes::from_static(b"hello, c"); let res = try_send_recv(&mut client_b, &mut client_c, c_key, msg.clone()).await?; - if let RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } = res - { - assert_eq!(b_key, remote_node_id); - assert_eq!(msg, datagrams); + if let RelayToClientMsg::ReceivedPacket { src_key, content } = res { + assert_eq!(b_key, src_key); + assert_eq!(msg, content); } else { panic!("client_c received unexpected message {res:?}"); } @@ -1051,12 +1040,12 @@ mod tests { // send messages from a to b, without b receiving anything. // we should still keep succeeding to send, even if the packet won't be forwarded // by the relay server because the server's send queue for b fills up. - let msg = Datagrams::from("hello, b"); + let msg = Bytes::from_static(b"hello, b"); for _i in 0..1000 { client_a - .send(ClientToRelayMsg::Datagrams { - dst_node_id: b_key, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: b_key, + packet: msg.clone(), }) .await?; } diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 50d051f1a89..3c230d9c59d 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -2,6 +2,7 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; +use bytes::Bytes; use iroh_base::NodeId; use n0_future::{SinkExt, StreamExt}; use nested_enum_utils::common_fields; @@ -18,7 +19,7 @@ use tracing::{debug, trace, warn, Instrument}; use crate::{ protos::{ disco, - relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg, PING_INTERVAL}, + relay::{ClientToRelayMsg, RelayToClientMsg, PING_INTERVAL}, streams::StreamError, }, server::{ @@ -35,7 +36,7 @@ pub(super) struct Packet { /// The sender of the packet src: NodeId, /// The data packet bytes. - data: Datagrams, + data: Bytes, } /// Configuration for a [`Client`]. @@ -150,7 +151,7 @@ impl Client { pub(super) fn try_send_packet( &self, src: NodeId, - data: Datagrams, + data: Bytes, ) -> Result<(), TrySendError> { self.send_queue.try_send(Packet { src, data }) } @@ -158,7 +159,7 @@ impl Client { pub(super) fn try_send_disco_packet( &self, src: NodeId, - data: Datagrams, + data: Bytes, ) -> Result<(), TrySendError> { self.disco_send_queue.try_send(Packet { src, data }) } @@ -405,17 +406,14 @@ impl Actor { /// Errors if the send does not happen within the `timeout` duration /// Does not flush. async fn send_raw(&mut self, packet: Packet) -> Result<(), WriteFrameError> { - let remote_node_id = packet.src; - let datagrams = packet.data; + let src_key = packet.src; + let content = packet.data; - if let Ok(len) = datagrams.contents.len().try_into() { + if let Ok(len) = content.len().try_into() { self.metrics.bytes_sent.inc_by(len); } - self.write_frame(RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - }) - .await + self.write_frame(RelayToClientMsg::ReceivedPacket { src_key, content }) + .await } async fn send_packet(&mut self, packet: Packet) -> Result<(), WriteFrameError> { @@ -458,13 +456,10 @@ impl Actor { }; match frame { - ClientToRelayMsg::Datagrams { - dst_node_id: dst_key, - datagrams, - } => { - let packet_len = datagrams.contents.len(); + ClientToRelayMsg::SendPacket { dst_key, packet } => { + let packet_len = packet.len(); if let Err(err @ ForwardPacketError { .. }) = - self.handle_frame_send_packet(dst_key, datagrams) + self.handle_frame_send_packet(dst_key, packet) { warn!("failed to handle send packet frame: {err:#}"); } @@ -483,12 +478,8 @@ impl Actor { Ok(()) } - fn handle_frame_send_packet( - &self, - dst: NodeId, - data: Datagrams, - ) -> Result<(), ForwardPacketError> { - if disco::looks_like_disco_wrapper(&data.contents) { + fn handle_frame_send_packet(&self, dst: NodeId, data: Bytes) -> Result<(), ForwardPacketError> { + if disco::looks_like_disco_wrapper(&data) { self.metrics.disco_packets_recv.inc(); self.clients .send_disco_packet(dst, data, self.node_id, &self.metrics)?; @@ -637,17 +628,15 @@ mod tests { println!(" send packet"); let packet = Packet { src: node_id, - data: Datagrams::from(&data[..]), + data: Bytes::from(&data[..]), }; send_queue_s.send(packet.clone()).await.context("send")?; - let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut io_rw) - .await - .e()?; + let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; assert_eq!( frame, - RelayToClientMsg::Datagrams { - remote_node_id: node_id, - datagrams: data.to_vec().into() + RelayToClientMsg::ReceivedPacket { + src_key: node_id, + content: data.to_vec().into() } ); @@ -657,14 +646,12 @@ mod tests { .send(packet.clone()) .await .context("send")?; - let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut io_rw) - .await - .e()?; + let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; assert_eq!( frame, - RelayToClientMsg::Datagrams { - remote_node_id: node_id, - datagrams: data.to_vec().into() + RelayToClientMsg::ReceivedPacket { + src_key: node_id, + content: data.to_vec().into() } ); @@ -692,9 +679,9 @@ mod tests { println!(" send packet"); let data = b"hello world!"; io_rw - .send(ClientToRelayMsg::Datagrams { - dst_node_id: target, - datagrams: Datagrams::from(data), + .send(ClientToRelayMsg::SendPacket { + dst_key: target, + packet: Bytes::from_static(data), }) .await .context("send")?; @@ -706,9 +693,9 @@ mod tests { disco_data.extend_from_slice(target.as_bytes()); disco_data.extend_from_slice(data); io_rw - .send(ClientToRelayMsg::Datagrams { - dst_node_id: target, - datagrams: disco_data.clone().into(), + .send(ClientToRelayMsg::SendPacket { + dst_key: target, + packet: disco_data.clone().into(), }) .await .context("send")?; @@ -731,11 +718,11 @@ mod tests { let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT)?; // Prepare a frame to send, assert its size. - let data = Datagrams::from(b"hello world!!1"); + let data = Bytes::from_static(b"hello world!!1"); let target = SecretKey::generate(rand::thread_rng()).public(); - let frame = ClientToRelayMsg::Datagrams { - dst_node_id: target, - datagrams: data.clone(), + let frame = ClientToRelayMsg::SendPacket { + dst_key: target, + packet: data.clone(), }; let frame_len = frame.to_bytes().len(); assert_eq!(frame_len, LIMIT as usize); diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index bdec78042b0..cacb0540bcd 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -9,18 +9,16 @@ use std::{ }, }; +use bytes::Bytes; use dashmap::DashMap; use iroh_base::NodeId; use tokio::sync::mpsc::error::TrySendError; use tracing::{debug, trace}; use super::client::{Client, Config, ForwardPacketError}; -use crate::{ - protos::relay::Datagrams, - server::{ - client::{PacketScope, SendError}, - metrics::Metrics, - }, +use crate::server::{ + client::{PacketScope, SendError}, + metrics::Metrics, }; /// Manages the connections to all currently connected clients. @@ -110,7 +108,7 @@ impl Clients { pub(super) fn send_packet( &self, dst: NodeId, - data: Datagrams, + data: Bytes, src: NodeId, metrics: &Metrics, ) -> Result<(), ForwardPacketError> { @@ -150,7 +148,7 @@ impl Clients { pub(super) fn send_disco_packet( &self, dst: NodeId, - data: Datagrams, + data: Bytes, src: NodeId, metrics: &Metrics, ) -> Result<(), ForwardPacketError> { @@ -254,24 +252,24 @@ mod tests { // send packet let data = b"hello world!"; - clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut a_rw).await?; + clients.send_packet(a_key, Bytes::from_static(&data[..]), b_key, &metrics)?; + let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; assert_eq!( frame, - RelayToClientMsg::Datagrams { - remote_node_id: b_key, - datagrams: data.to_vec().into(), + RelayToClientMsg::ReceivedPacket { + src_key: b_key, + content: data.to_vec().into(), } ); // send disco packet - clients.send_disco_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut a_rw).await?; + clients.send_disco_packet(a_key, Bytes::from_static(&data[..]), b_key, &metrics)?; + let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; assert_eq!( frame, - RelayToClientMsg::Datagrams { - remote_node_id: b_key, - datagrams: data.to_vec().into(), + RelayToClientMsg::ReceivedPacket { + src_key: b_key, + content: data.to_vec().into(), } ); diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 4a47625aae3..5628990e7a0 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -879,7 +879,7 @@ mod tests { use crate::{ client::{conn::Conn, Client, ClientBuilder, ConnectError}, dns::DnsResolver, - protos::relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, + protos::relay::{ClientToRelayMsg, RelayToClientMsg}, }; pub(crate) fn make_tls_config() -> TlsConfig { @@ -947,11 +947,11 @@ mod tests { assert!(matches!(pong, RelayToClientMsg::Pong { .. })); info!("sending message from a to b"); - let msg = Datagrams::from(b"hi there, client b!"); + let msg = Bytes::from_static(b"hi there, client b!"); client_a - .send(ClientToRelayMsg::Datagrams { - dst_node_id: b_key, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: b_key, + packet: msg.clone(), }) .await?; info!("waiting for message from a on b"); @@ -961,11 +961,11 @@ mod tests { assert_eq!(msg, got_msg); info!("sending message from b to a"); - let msg = Datagrams::from(b"right back at ya, client b!"); + let msg = Bytes::from_static(b"right back at ya, client b!"); client_b - .send(ClientToRelayMsg::Datagrams { - dst_node_id: a_key, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: a_key, + packet: msg.clone(), }) .await?; info!("waiting for message b on a"); @@ -996,7 +996,7 @@ mod tests { fn process_msg( msg: Option>, - ) -> Option<(PublicKey, Datagrams)> { + ) -> Option<(PublicKey, Bytes)> { match msg { Some(Err(e)) => { info!("client `recv` error {e}"); @@ -1004,12 +1004,12 @@ mod tests { } Some(Ok(msg)) => { info!("got message on: {msg:?}"); - if let RelayToClientMsg::Datagrams { - remote_node_id: source, - datagrams, + if let RelayToClientMsg::ReceivedPacket { + src_key: source, + content, } = msg { - Some((source, datagrams)) + Some((source, content)) } else { None } @@ -1067,11 +1067,11 @@ mod tests { assert!(matches!(pong, RelayToClientMsg::Pong { .. })); info!("sending message from a to b"); - let msg = Datagrams::from(b"hi there, client b!"); + let msg = Bytes::from_static(b"hi there, client b!"); client_a - .send(ClientToRelayMsg::Datagrams { - dst_node_id: b_key, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: b_key, + packet: msg.clone(), }) .await?; info!("waiting for message from a on b"); @@ -1081,11 +1081,11 @@ mod tests { assert_eq!(msg, got_msg); info!("sending message from b to a"); - let msg = Datagrams::from(b"right back at ya, client b!"); + let msg = Bytes::from_static(b"right back at ya, client b!"); client_b - .send(ClientToRelayMsg::Datagrams { - dst_node_id: a_key, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: a_key, + packet: msg.clone(), }) .await?; info!("waiting for message b on a"); @@ -1144,20 +1144,17 @@ mod tests { handler_task.await.context("join")??; info!("Send message from A to B."); - let msg = Datagrams::from(b"hello client b!!"); + let msg = Bytes::from_static(b"hello client b!!"); client_a - .send(ClientToRelayMsg::Datagrams { - dst_node_id: public_key_b, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: public_key_b, + packet: msg.clone(), }) .await?; match client_b.next().await.unwrap()? { - RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } => { - assert_eq!(public_key_a, remote_node_id); - assert_eq!(msg, datagrams); + RelayToClientMsg::ReceivedPacket { src_key, content } => { + assert_eq!(public_key_a, src_key); + assert_eq!(msg, content); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1165,20 +1162,17 @@ mod tests { } info!("Send message from B to A."); - let msg = Datagrams::from(b"nice to meet you client a!!"); + let msg = Bytes::from_static(b"nice to meet you client a!!"); client_b - .send(ClientToRelayMsg::Datagrams { - dst_node_id: public_key_a, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: public_key_a, + packet: msg.clone(), }) .await?; match client_a.next().await.unwrap()? { - RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } => { - assert_eq!(public_key_b, remote_node_id); - assert_eq!(msg, datagrams); + RelayToClientMsg::ReceivedPacket { src_key, content } => { + assert_eq!(public_key_b, src_key); + assert_eq!(msg, content); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1191,9 +1185,9 @@ mod tests { info!("Fail to send message from A to B."); let res = client_a - .send(ClientToRelayMsg::Datagrams { - dst_node_id: public_key_b, - datagrams: Datagrams::from(b"try to send"), + .send(ClientToRelayMsg::SendPacket { + dst_key: public_key_b, + packet: Bytes::from_static(b"try to send"), }) .await; assert!(res.is_err()); @@ -1234,20 +1228,17 @@ mod tests { handler_task.await.context("join")??; info!("Send message from A to B."); - let msg = Datagrams::from(b"hello client b!!"); + let msg = Bytes::from_static(b"hello client b!!"); client_a - .send(ClientToRelayMsg::Datagrams { - dst_node_id: public_key_b, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: public_key_b, + packet: msg.clone(), }) .await?; match client_b.next().await.expect("eos")? { - RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } => { - assert_eq!(public_key_a, remote_node_id); - assert_eq!(msg, datagrams); + RelayToClientMsg::ReceivedPacket { src_key, content } => { + assert_eq!(public_key_a, src_key); + assert_eq!(msg, content); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1255,20 +1246,17 @@ mod tests { } info!("Send message from B to A."); - let msg = Datagrams::from(b"nice to meet you client a!!"); + let msg = Bytes::from_static(b"nice to meet you client a!!"); client_b - .send(ClientToRelayMsg::Datagrams { - dst_node_id: public_key_a, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: public_key_a, + packet: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } => { - assert_eq!(public_key_b, remote_node_id); - assert_eq!(msg, datagrams); + RelayToClientMsg::ReceivedPacket { src_key, content } => { + assert_eq!(public_key_b, src_key); + assert_eq!(msg, content); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1286,20 +1274,17 @@ mod tests { // assert!(client_b.recv().await.is_err()); info!("Send message from A to B."); - let msg = Datagrams::from(b"are you still there, b?!"); + let msg = Bytes::from_static(b"are you still there, b?!"); client_a - .send(ClientToRelayMsg::Datagrams { - dst_node_id: public_key_b, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: public_key_b, + packet: msg.clone(), }) .await?; match new_client_b.next().await.expect("eos")? { - RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } => { - assert_eq!(public_key_a, remote_node_id); - assert_eq!(msg, datagrams); + RelayToClientMsg::ReceivedPacket { src_key, content } => { + assert_eq!(public_key_a, src_key); + assert_eq!(msg, content); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1307,20 +1292,17 @@ mod tests { } info!("Send message from B to A."); - let msg = Datagrams::from(b"just had a spot of trouble but I'm back now,a!!"); + let msg = Bytes::from_static(b"just had a spot of trouble but I'm back now,a!!"); new_client_b - .send(ClientToRelayMsg::Datagrams { - dst_node_id: public_key_a, - datagrams: msg.clone(), + .send(ClientToRelayMsg::SendPacket { + dst_key: public_key_a, + packet: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, - } => { - assert_eq!(public_key_b, remote_node_id); - assert_eq!(msg, datagrams); + RelayToClientMsg::ReceivedPacket { src_key, content } => { + assert_eq!(public_key_b, src_key); + assert_eq!(msg, content); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1332,9 +1314,9 @@ mod tests { info!("Sending message from A to B fails"); let res = client_a - .send(ClientToRelayMsg::Datagrams { - dst_node_id: public_key_b, - datagrams: Datagrams::from(b"try to send"), + .send(ClientToRelayMsg::SendPacket { + dst_key: public_key_b, + packet: Bytes::from_static(b"try to send"), }) .await; assert!(res.is_err()); diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index 73e45a86cb3..dd57974ba71 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -148,6 +148,10 @@ pub(crate) struct Options { pub(crate) metrics: EndpointMetrics, } +/// Contents of a relay message. Use a SmallVec to avoid allocations for the very +/// common case of a single packet. +type RelayContents = SmallVec<[Bytes; 1]>; + /// Handle for [`MagicSock`]. /// /// Dereferences to [`MagicSock`], and handles closing. diff --git a/iroh/src/magicsock/transports/relay.rs b/iroh/src/magicsock/transports/relay.rs index 3f8cca0ea28..559ceafcb78 100644 --- a/iroh/src/magicsock/transports/relay.rs +++ b/iroh/src/magicsock/transports/relay.rs @@ -5,16 +5,18 @@ use std::{ use bytes::Bytes; use iroh_base::{NodeId, RelayUrl}; -use iroh_relay::protos::relay::Datagrams; use n0_future::{ ready, task::{self, AbortOnDropHandle}, }; use n0_watcher::{Watchable, Watcher as _}; +use smallvec::SmallVec; use tokio::sync::mpsc; use tokio_util::sync::PollSender; use tracing::{error, info_span, trace, warn, Instrument}; +use crate::magicsock::RelayContents; + use super::{Addr, Transmit}; mod actor; @@ -99,12 +101,9 @@ impl RelayTransport { } }; - buf_out[..dm.datagrams.contents.len()].copy_from_slice(&dm.datagrams.contents); - meta_out.len = dm.datagrams.contents.len(); - meta_out.stride = dm - .datagrams - .segment_size - .map_or(dm.datagrams.contents.len(), |s| s as usize); + buf_out[..dm.buf.len()].copy_from_slice(&dm.buf); + meta_out.len = dm.buf.len(); + meta_out.stride = dm.buf.len(); meta_out.ecn = None; meta_out.dst_ip = None; // TODO: insert the relay url for this relay @@ -188,7 +187,7 @@ impl RelaySender { dest_node: NodeId, transmit: &Transmit<'_>, ) -> io::Result<()> { - let contents = datagrams_from_transmit(transmit); + let contents = split_packets(transmit); let item = RelaySendItem { remote_node: dest_node, @@ -230,7 +229,7 @@ impl RelaySender { trace!(node = %dest_node.fmt_short(), relay_url = %dest_url, "send relay: message queued"); - let contents = datagrams_from_transmit(transmit); + let contents = split_packets(transmit); let item = RelaySendItem { remote_node: dest_node, url: dest_url.clone(), @@ -268,7 +267,7 @@ impl RelaySender { dest_node: NodeId, transmit: &Transmit<'_>, ) -> io::Result<()> { - let contents = datagrams_from_transmit(transmit); + let contents = split_packets(transmit); let item = RelaySendItem { remote_node: dest_node, @@ -306,19 +305,26 @@ impl RelaySender { } } -/// Translate a UDP transmit to the `Datagrams` type for sending over the relay. +/// Split a transmit containing a GSO payload into individual packets. +/// +/// This allocates the data. +/// +/// If the transmit has a segment size it contains multiple GSO packets. It will be split +/// into multiple packets according to that segment size. If it does not have a segment +/// size, the contents will be sent as a single packet. // TODO: If quinn stayed on bytes this would probably be much cheaper, probably. Need to // figure out where they allocate the Vec. -fn datagrams_from_transmit(transmit: &Transmit<'_>) -> Datagrams { - Datagrams { - ecn: transmit.ecn.map(|ecn| match ecn { - quinn_udp::EcnCodepoint::Ect0 => quinn_proto::EcnCodepoint::Ect0, - quinn_udp::EcnCodepoint::Ect1 => quinn_proto::EcnCodepoint::Ect1, - quinn_udp::EcnCodepoint::Ce => quinn_proto::EcnCodepoint::Ce, - }), - segment_size: transmit.segment_size.map(|ss| ss as u16), - contents: Bytes::copy_from_slice(transmit.contents), +fn split_packets(transmit: &Transmit<'_>) -> RelayContents { + let mut res = SmallVec::with_capacity(1); + let contents = transmit.contents; + if let Some(segment_size) = transmit.segment_size { + for chunk in contents.chunks(segment_size) { + res.push(Bytes::from(chunk.to_vec())); + } + } else { + res.push(Bytes::from(contents.to_vec())); } + res } #[cfg(test)] @@ -345,7 +351,7 @@ mod tests { let mut expected_msgs: BTreeSet = (0..capacity).collect(); while !expected_msgs.is_empty() { let datagram: RelayRecvDatagram = receiver.recv().await.unwrap(); - let msg_num = usize::from_le_bytes(datagram.datagrams.contents.as_ref().try_into().unwrap()); + let msg_num = usize::from_le_bytes(datagram.buf.as_ref().try_into().unwrap()); debug!("Received {msg_num}"); if !expected_msgs.remove(&msg_num) { @@ -365,7 +371,7 @@ mod tests { .try_send(RelayRecvDatagram { url, src: NodeId::from_bytes(&[0u8; 32]).unwrap(), - datagrams: Datagrams::from(&i.to_le_bytes()), + buf: Bytes::copy_from_slice(&i.to_le_bytes()), }) .unwrap(); } diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index 70e03ed079b..bf2fc0c90b6 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -38,12 +38,13 @@ use std::{ }; use backon::{Backoff, BackoffBuilder, ExponentialBuilder}; -use iroh_base::{NodeId, RelayUrl, SecretKey}; +use bytes::{Bytes, BytesMut}; +use iroh_base::{NodeId, PublicKey, RelayUrl, SecretKey}; use iroh_relay::{ self as relay, client::{Client, ConnectError, RecvError, SendError}, - protos::relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, - PingTracker, + protos::relay::{ClientToRelayMsg, RelayToClientMsg}, + PingTracker, MAX_PACKET_SIZE, }; use n0_future::{ task::JoinSet, @@ -61,11 +62,18 @@ use url::Url; #[cfg(not(wasm_browser))] use crate::dns::DnsResolver; -use crate::{magicsock::Metrics as MagicsockMetrics, net_report::Report, util::MaybeFuture}; +use crate::{ + magicsock::{Metrics as MagicsockMetrics, RelayContents}, + net_report::Report, + util::MaybeFuture, +}; /// How long a non-home relay connection needs to be idle (last written to) before we close it. const RELAY_INACTIVE_CLEANUP_TIME: Duration = Duration::from_secs(60); +/// Maximum size a datagram payload is allowed to be. +const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - PublicKey::LENGTH; + /// Interval in which we ping the relay server to ensure the connection is alive. /// /// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some @@ -618,20 +626,26 @@ impl ActiveRelayActor { self.reset_inactive_timeout(); // TODO: This allocation is *very* unfortunate. But so is the // allocation *inside* of PacketizeIter... - let batch = std::mem::replace( + let dgrams = std::mem::replace( &mut send_datagrams_buf, Vec::with_capacity(SEND_DATAGRAM_BATCH_SIZE), ); // TODO(frando): can we avoid the clone here? let metrics = self.metrics.clone(); - let packet_iter = batch.into_iter().map(|item| { - metrics.send_relay.inc_by(item.datagrams.contents.len() as _); - Ok(ClientToRelayMsg::Datagrams { - dst_node_id: item.remote_node, - datagrams: item.datagrams + let packet_iter = dgrams.into_iter().flat_map(|datagrams| { + PacketizeIter::<_, MAX_PAYLOAD_SIZE>::new( + datagrams.remote_node, + datagrams.datagrams.clone(), + ) + .map(|p| { + Ok(ClientToRelayMsg::SendPacket { dst_key: p.node_id, packet: p.payload }) }) }); - let mut packet_stream = n0_future::stream::iter(packet_iter); + let mut packet_stream = n0_future::stream::iter(packet_iter).inspect(|m| { + if let Ok(ClientToRelayMsg::SendPacket { dst_key: _node_id, packet: payload }) = m { + metrics.send_relay.inc_by(payload.len() as _); + } + }); let fut = client_sink.send_all(&mut packet_stream); self.run_sending(fut, &mut state, &mut client_stream).await?; } @@ -666,11 +680,11 @@ impl ActiveRelayActor { fn handle_relay_msg(&mut self, msg: RelayToClientMsg, state: &mut ConnectedRelayState) { match msg { - RelayToClientMsg::Datagrams { - remote_node_id, - datagrams, + RelayToClientMsg::ReceivedPacket { + src_key: remote_node_id, + content, } => { - trace!(len = %datagrams.contents.len(), "received msg"); + trace!(len = %content.len(), "received msg"); // If this is a new sender, register a route for this peer. if state .last_packet_src @@ -682,12 +696,14 @@ impl ActiveRelayActor { state.last_packet_src = Some(remote_node_id); state.nodes_present.insert(remote_node_id); } - if let Err(err) = self.relay_datagrams_recv.try_send(RelayRecvDatagram { - url: self.url.clone(), - src: remote_node_id, - datagrams, - }) { - warn!("Dropping received relay packet: {err:#}"); + for datagram in PacketSplitIter::new(self.url.clone(), remote_node_id, content) { + let Ok(datagram) = datagram else { + warn!("Invalid packet split"); + break; + }; + if let Err(err) = self.relay_datagrams_recv.try_send(datagram) { + warn!("Dropping received relay packet: {err:#}"); + } } } RelayToClientMsg::NodeGone(node_id) => { @@ -836,7 +852,7 @@ pub(crate) struct RelaySendItem { /// The home relay of the remote node. pub(crate) url: RelayUrl, /// One or more datagrams to send. - pub(crate) datagrams: Datagrams, + pub(crate) datagrams: RelayContents, } pub(super) struct RelayActor { @@ -1216,6 +1232,18 @@ struct ActiveRelayHandle { datagrams_send_queue: mpsc::Sender, } +/// A packet to send over the relay. +/// +/// This is nothing but a newtype, it should be constructed using [`PacketizeIter`]. This +/// is a packet of one or more datagrams, each prefixed with a u16-be length. This is what +/// the `Frame::SendPacket` of the `DerpCodec` transports and is produced by +/// [`PacketizeIter`] and transformed back into datagrams using [`PacketSplitIter`]. +#[derive(Debug, PartialEq, Eq)] +struct RelaySendPacket { + node_id: NodeId, + payload: Bytes, +} + /// A single datagram received from a relay server. /// /// This could be either a QUIC or DISCO packet. @@ -1223,9 +1251,116 @@ struct ActiveRelayHandle { pub(crate) struct RelayRecvDatagram { pub(crate) url: RelayUrl, pub(crate) src: NodeId, - pub(crate) datagrams: Datagrams, + pub(crate) buf: Bytes, } +/// Combines datagrams into a single DISCO frame of at most MAX_PACKET_SIZE. +/// +/// The disco `iroh_relay::protos::Frame::SendPacket` frame can contain more then a single +/// datagram. Each datagram in this frame is prefixed with a little-endian 2-byte length +/// prefix. This occurs when Quinn sends a GSO transmit containing more than one datagram, +/// which are split using `split_packets`. +/// +/// The [`PacketSplitIter`] does the inverse and splits such packets back into individual +/// datagrams. +struct PacketizeIter { + node_id: NodeId, + iter: std::iter::Peekable, + buffer: BytesMut, +} + +impl PacketizeIter { + /// Create a new new PacketizeIter from something that can be turned into an + /// iterator of slices, like a `Vec`. + fn new(node_id: NodeId, iter: impl IntoIterator) -> Self { + Self { + node_id, + iter: iter.into_iter().peekable(), + buffer: BytesMut::with_capacity(N), + } + } +} + +impl Iterator for PacketizeIter +where + I::Item: AsRef<[u8]>, +{ + type Item = RelaySendPacket; + + fn next(&mut self) -> Option { + use bytes::BufMut; + while let Some(next_bytes) = self.iter.peek() { + let next_bytes = next_bytes.as_ref(); + assert!(next_bytes.len() + 2 <= N); + let next_length: u16 = next_bytes.len().try_into().expect("items < 64k size"); + if self.buffer.len() + next_bytes.len() + 2 > N { + break; + } + self.buffer.put_u16_le(next_length); + self.buffer.put_slice(next_bytes); + self.iter.next(); + } + if !self.buffer.is_empty() { + Some(RelaySendPacket { + node_id: self.node_id, + payload: self.buffer.split().freeze(), + }) + } else { + None + } + } +} + +/// Splits a single [`ServerToClientMsg::ReceivedPacket`] frame into datagrams. +/// +/// This splits packets joined by [`PacketizeIter`] back into individual datagrams. See +/// that struct for more details. +#[derive(Debug)] +struct PacketSplitIter { + url: RelayUrl, + src: NodeId, + bytes: Bytes, +} + +impl PacketSplitIter { + /// Create a new PacketSplitIter from a packet. + fn new(url: RelayUrl, src: NodeId, bytes: Bytes) -> Self { + Self { url, src, bytes } + } + + fn fail(&mut self) -> Option> { + self.bytes.clear(); + Some(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "", + ))) + } +} + +impl Iterator for PacketSplitIter { + type Item = std::io::Result; + + fn next(&mut self) -> Option { + use bytes::Buf; + if self.bytes.has_remaining() { + if self.bytes.remaining() < 2 { + return self.fail(); + } + let len = self.bytes.get_u16_le() as usize; + if self.bytes.remaining() < len { + return self.fail(); + } + let buf = self.bytes.split_to(len); + Some(Ok(RelayRecvDatagram { + url: self.url.clone(), + src: self.src, + buf, + })) + } else { + None + } + } +} #[cfg(test)] mod tests { use std::{ @@ -1233,9 +1368,11 @@ mod tests { time::Duration, }; + use bytes::Bytes; use iroh_base::{NodeId, RelayUrl, SecretKey}; - use iroh_relay::{protos::relay::Datagrams, PingTracker}; + use iroh_relay::PingTracker; use n0_snafu::{Error, Result, ResultExt}; + use smallvec::smallvec; use tokio::sync::{mpsc, oneshot}; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{info, info_span, Instrument}; @@ -1243,11 +1380,42 @@ mod tests { use super::{ ActiveRelayActor, ActiveRelayActorOptions, ActiveRelayMessage, ActiveRelayPrioMessage, - RelayConnectionOptions, RelayRecvDatagram, RelaySendItem, RELAY_INACTIVE_CLEANUP_TIME, - UNDELIVERABLE_DATAGRAM_TIMEOUT, + PacketizeIter, RelayConnectionOptions, RelayRecvDatagram, RelaySendItem, MAX_PACKET_SIZE, + RELAY_INACTIVE_CLEANUP_TIME, UNDELIVERABLE_DATAGRAM_TIMEOUT, }; use crate::{dns::DnsResolver, test_utils}; + #[test] + fn test_packetize_iter() { + let node_id = SecretKey::generate(rand::thread_rng()).public(); + let empty_vec: Vec = Vec::new(); + let mut iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(node_id, empty_vec); + assert_eq!(None, iter.next()); + + let single_vec = vec!["Hello"]; + let iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(node_id, single_vec); + let result = iter.collect::>(); + assert_eq!(1, result.len()); + assert_eq!( + &[5, 0, b'H', b'e', b'l', b'l', b'o'], + &result[0].payload[..] + ); + + let spacer = vec![0u8; MAX_PACKET_SIZE - 10]; + let multiple_vec = vec![&b"Hello"[..], &spacer, &b"World"[..]]; + let iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(node_id, multiple_vec); + let result = iter.collect::>(); + assert_eq!(2, result.len()); + assert_eq!( + &[5, 0, b'H', b'e', b'l', b'l', b'o'], + &result[0].payload[..7] + ); + assert_eq!( + &[5, 0, b'W', b'o', b'r', b'l', b'd'], + &result[1].payload[..] + ); + } + /// Starts a new [`ActiveRelayActor`]. #[allow(clippy::too_many_arguments)] fn start_active_relay_actor( @@ -1308,16 +1476,12 @@ mod tests { loop { let datagram = recv_datagram_rx.recv().await; if let Some(recv) = datagram { - let RelayRecvDatagram { - url: _, - src, - datagrams, - } = recv; + let RelayRecvDatagram { url: _, src, buf } = recv; info!(from = src.fmt_short(), "Received datagram"); let send = RelaySendItem { remote_node: src, url: relay_url.clone(), - datagrams, + datagrams: smallvec![buf], }; send_datagram_tx.send(send).await.ok(); } @@ -1358,10 +1522,10 @@ mod tests { let RelayRecvDatagram { url: _, src: _, - datagrams, + buf, } = rx.recv().await.unwrap(); - assert_eq!(datagrams, item.datagrams); + assert_eq!(buf.as_ref(), item.datagrams[0]); Ok::<_, Error>(()) }) @@ -1404,7 +1568,7 @@ mod tests { let hello_send_item = RelaySendItem { remote_node: peer_node, url: relay_url.clone(), - datagrams: Datagrams::from(b"hello"), + datagrams: smallvec![Bytes::from_static(b"hello")], }; send_recv_echo( hello_send_item.clone(), From 6ea0b267309c73e4421fd9123404a978bdd1e55b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 14 Jul 2025 14:58:55 +0200 Subject: [PATCH 68/80] Fix tests --- iroh-relay/src/protos/common.rs | 20 ++++++------ iroh-relay/src/protos/relay.rs | 54 ++++++++++++--------------------- iroh-relay/src/server/client.rs | 2 +- 3 files changed, 30 insertions(+), 46 deletions(-) diff --git a/iroh-relay/src/protos/common.rs b/iroh-relay/src/protos/common.rs index 04dd4b61542..9e005259ac8 100644 --- a/iroh-relay/src/protos/common.rs +++ b/iroh-relay/src/protos/common.rs @@ -16,15 +16,15 @@ use snafu::{Backtrace, OptionExt, Snafu}; // needs to be pub due to being exposed in error types pub enum FrameType { /// The server frame type for the challenge response - ServerChallenge = 1, + ServerChallenge = 0, /// The client frame type for the authentication frame - ClientAuth = 2, + ClientAuth = 1, /// The server frame type for authentication confirmation - ServerConfirmsAuth = 3, + ServerConfirmsAuth = 2, /// The server frame type for authentication denial - ServerDeniesAuth = 4, + ServerDeniesAuth = 3, /// 32B dest pub key + ECN byte + segment size u16 + datagrams contents - SendPacket = 5, + SendPacket = 4, /// 32B src pub key + ECN byte + segment size u16 + datagrams contents RecvPacket = 6, /// Sent from server to client to signal that a previous sender is no longer connected. @@ -33,20 +33,20 @@ pub enum FrameType { /// to B so B can forget that a reverse path exists on that connection to get back to A /// /// 32B pub key of peer that's gone - NodeGone = 7, + NodeGone = 8, /// Messages with these frames will be ignored. /// 8 byte ping payload, to be echoed back in FrameType::Pong - Ping = 8, + Ping = 9, /// 8 byte payload, the contents of ping being replied to - Pong = 9, + Pong = 10, /// Sent from server to client to tell the client if their connection is unhealthy somehow. /// Contains only UTF-8 bytes. - Health = 10, + Health = 11, /// Sent from server to client for the server to declare that it's restarting. /// Payload is two big endian u32 durations in milliseconds: when to reconnect, /// and how long to try total. - Restarting = 11, + Restarting = 12, } #[common_fields({ diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index dfdc7a4f2ed..7a64f7b02a0 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -401,22 +401,22 @@ mod tests { problem: "Hello? Yes this is dog.".into(), } .write_to(Vec::new()), - "0a 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 + "0b 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73 20 69 73 20 64 6f 67 2e", ), ( RelayToClientMsg::NodeGone(client_key.public()).write_to(Vec::new()), - "07 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + "08 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61", ), ( RelayToClientMsg::Ping([42u8; 8]).write_to(Vec::new()), - "08 2a 2a 2a 2a 2a 2a 2a 2a", + "09 2a 2a 2a 2a 2a 2a 2a 2a", ), ( RelayToClientMsg::Pong([42u8; 8]).write_to(Vec::new()), - "09 2a 2a 2a 2a 2a 2a 2a 2a", + "0a 2a 2a 2a 2a 2a 2a 2a 2a", ), ( RelayToClientMsg::ReceivedPacket { @@ -424,18 +424,9 @@ mod tests { content: "Hello World!".into(), } .write_to(Vec::new()), - // frame type - // public key first 16 bytes - // public key second 16 bytes - // ECN byte - // segment size - // hello world contents bytes - "06 - 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 - 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 - 03 - 00 06 - 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + "06 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d + 61 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", ), ( RelayToClientMsg::Restarting { @@ -443,7 +434,7 @@ mod tests { try_for: Duration::from_millis(20), } .write_to(Vec::new()), - "0b 00 00 00 0a 00 00 00 14", + "0c 00 00 00 0a 00 00 00 14", ), ]); @@ -457,30 +448,21 @@ mod tests { check_expected_bytes(vec![ ( ClientToRelayMsg::Ping([42u8; 8]).write_to(Vec::new()), - "08 2a 2a 2a 2a 2a 2a 2a 2a", + "09 2a 2a 2a 2a 2a 2a 2a 2a", ), ( ClientToRelayMsg::Pong([42u8; 8]).write_to(Vec::new()), - "09 2a 2a 2a 2a 2a 2a 2a 2a", + "0a 2a 2a 2a 2a 2a 2a 2a 2a", ), ( ClientToRelayMsg::SendPacket { dst_key: client_key.public(), - packet: "Hello World!".into(), + packet: "Goodbye!".into(), } .write_to(Vec::new()), - // frame type - // public key first 16 bytes - // public key second 16 bytes - // ECN byte - // segment size - // hello world contents - "05 - 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 - 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 - 03 - 00 06 - 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + "04 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e + a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d + 61 47 6f 6f 64 62 79 65 21", ), ]); @@ -516,9 +498,11 @@ mod proptests { let node_gone = key().prop_map(|node_id| RelayToClientMsg::NodeGone(node_id)); let ping = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Pong); - let health = data(0).prop_map(|_problem| RelayToClientMsg::Health { - problem: "".to_string(), - }); + let health = ".{0,65536}" + .prop_filter("exceeds max payload size", |s| { + s.len() < 65536 // a single unicode character can match a regex "." but take up multiple bytes + }) + .prop_map(|problem| RelayToClientMsg::Health { problem }); let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { RelayToClientMsg::Restarting { reconnect_in: Duration::from_millis(reconnect_in.into()), diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 3c230d9c59d..5a2026c8e44 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -718,7 +718,7 @@ mod tests { let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT)?; // Prepare a frame to send, assert its size. - let data = Bytes::from_static(b"hello world!!1"); + let data = Bytes::from_static(b"hello world!1eins"); let target = SecretKey::generate(rand::thread_rng()).public(); let frame = ClientToRelayMsg::SendPacket { dst_key: target, From e2e5926ad2503319f3795d7b9bc8c32ec3a642fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 14 Jul 2025 15:47:09 +0200 Subject: [PATCH 69/80] `cargo make format` --- iroh/src/magicsock/transports/relay.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/iroh/src/magicsock/transports/relay.rs b/iroh/src/magicsock/transports/relay.rs index 559ceafcb78..db5dea34374 100644 --- a/iroh/src/magicsock/transports/relay.rs +++ b/iroh/src/magicsock/transports/relay.rs @@ -15,9 +15,8 @@ use tokio::sync::mpsc; use tokio_util::sync::PollSender; use tracing::{error, info_span, trace, warn, Instrument}; -use crate::magicsock::RelayContents; - use super::{Addr, Transmit}; +use crate::magicsock::RelayContents; mod actor; From 805d103b8d6554adcc1d9ad5436eda7d138603a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 14 Jul 2025 15:51:10 +0200 Subject: [PATCH 70/80] Fix docs --- iroh-relay/src/protos/relay.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 7a64f7b02a0..370bdd95088 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -3,8 +3,8 @@ //! Protocol flow: //! * server occasionally sends [`FrameType::Ping`] //! * client responds to any [`FrameType::Ping`] with a [`FrameType::Pong`] -//! * clients sends [`FrameType::ClientToRelayDatagrams`] -//! * server then sends [`FrameType::RelayToClientDatagrams`] to recipient +//! * clients sends [`FrameType::SendPacket`] +//! * server then sends [`FrameType::RecvPacket`] to recipient //! * server sends [`FrameType::NodeGone`] when the other client disconnects use bytes::{BufMut, Bytes, BytesMut}; From a448330cc239f6761586399a597543dc7d474c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 14 Jul 2025 15:51:38 +0200 Subject: [PATCH 71/80] clippy fix --- iroh-relay/src/protos/relay.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 370bdd95088..ca56d4f57ba 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -495,7 +495,7 @@ mod proptests { fn server_client_frame() -> impl Strategy { let recv_packet = (key(), data(32)) .prop_map(|(src_key, content)| RelayToClientMsg::ReceivedPacket { src_key, content }); - let node_gone = key().prop_map(|node_id| RelayToClientMsg::NodeGone(node_id)); + let node_gone = key().prop_map(RelayToClientMsg::NodeGone); let ping = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Pong); let health = ".{0,65536}" From 11dfe8ec259ac30a3aa5565b8133000f447bf481 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 14 Jul 2025 16:05:51 +0200 Subject: [PATCH 72/80] Fix more documentation --- iroh/src/magicsock/transports/relay/actor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index bf2fc0c90b6..dbfaa75ad93 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -1311,7 +1311,7 @@ where } } -/// Splits a single [`ServerToClientMsg::ReceivedPacket`] frame into datagrams. +/// Splits a single [`RelayToClientMsg::ReceivedPacket`] frame into datagrams. /// /// This splits packets joined by [`PacketizeIter`] back into individual datagrams. See /// that struct for more details. From a38e63677f4e72a86579a1a64fa65b877c3ad452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 14 Jul 2025 15:16:47 +0200 Subject: [PATCH 73/80] Send ECN bits and use stride instead of custom split protocol --- iroh-relay/src/client/conn.rs | 8 +- iroh-relay/src/protos/common.rs | 4 +- iroh-relay/src/protos/relay.rs | 208 +++++++++++++---- iroh-relay/src/server.rs | 55 +++-- iroh-relay/src/server/client.rs | 81 ++++--- iroh-relay/src/server/clients.rs | 34 +-- iroh-relay/src/server/http_server.rs | 158 +++++++------ iroh/src/magicsock.rs | 4 - iroh/src/magicsock/transports/relay.rs | 48 ++-- iroh/src/magicsock/transports/relay/actor.rs | 234 +++---------------- 10 files changed, 412 insertions(+), 422 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index d30f8f06695..d7588edd7c1 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -19,7 +19,7 @@ use crate::client::streams::{MaybeTlsStream, ProxyStream}; use crate::{ protos::{ handshake, - relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg}, + relay::{ClientToRelayMsg, Error as ProtoError, RelayToClientMsg, MAX_PAYLOAD_SIZE}, streams::WsBytesFramed, }, MAX_PACKET_SIZE, @@ -143,9 +143,9 @@ impl Sink for Conn { } fn start_send(mut self: Pin<&mut Self>, frame: ClientToRelayMsg) -> Result<(), Self::Error> { - if let ClientToRelayMsg::SendPacket { .. } = &frame { - let size = frame.encoded_len(); - snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); + if let ClientToRelayMsg::Datagrams { datagrams, .. } = &frame { + let size = datagrams.contents.len(); + snafu::ensure!(size <= MAX_PAYLOAD_SIZE, ExceedsMaxPacketSizeSnafu { size }); } Pin::new(&mut self.conn) diff --git a/iroh-relay/src/protos/common.rs b/iroh-relay/src/protos/common.rs index 9e005259ac8..6253c03c4cb 100644 --- a/iroh-relay/src/protos/common.rs +++ b/iroh-relay/src/protos/common.rs @@ -24,9 +24,9 @@ pub enum FrameType { /// The server frame type for authentication denial ServerDeniesAuth = 3, /// 32B dest pub key + ECN byte + segment size u16 + datagrams contents - SendPacket = 4, + ClientToRelayDatagrams = 4, /// 32B src pub key + ECN byte + segment size u16 + datagrams contents - RecvPacket = 6, + RelayToClientDatagrams = 6, /// Sent from server to client to signal that a previous sender is no longer connected. /// /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index ca56d4f57ba..6de26cfd000 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -21,6 +21,12 @@ use crate::KeyCache; /// including its on-wire framing overhead) pub const MAX_PACKET_SIZE: usize = 64 * 1024; +/// Maximum size a datagram payload is allowed to be. +/// +/// This is [`MAX_PACKET_SIZE`] minus the length of an encoded public key minus 3 bytes, +/// one for ECN, and two for the segment size. +pub const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - NodeId::LENGTH - 3; + /// The maximum frame size. /// /// This is also the minimum burst size that a rate-limiter has to accept. @@ -74,12 +80,11 @@ pub enum Error { #[derive(derive_more::Debug, Clone, PartialEq, Eq)] pub enum RelayToClientMsg { /// Represents datagrams sent from relays (originally sent to them by another client). - ReceivedPacket { + Datagrams { /// The [`NodeId`] of the original sender. - src_key: NodeId, - /// The received packet bytes. - #[debug(skip)] - content: Bytes, + remote_node_id: NodeId, + /// The datagrams and related metadata. + datagrams: Datagrams, }, /// Indicates that the client identified by the underlying public key had previously sent you a /// packet but has now disconnected from the relay. @@ -123,19 +128,85 @@ pub enum ClientToRelayMsg { /// with the payload sent previously in the ping. Pong([u8; 8]), /// Request from the client to relay datagrams to given remote node. - SendPacket { + Datagrams { /// The remote node to relay to. - dst_key: NodeId, + dst_node_id: NodeId, /// The datagrams and related metadata to relay. - packet: Bytes, + datagrams: Datagrams, }, } +/// One or multiple datagrams being transferred via the relay. +/// +/// This type is modeled after [`quinn_proto::Transmit`] +/// (or even more similarly `quinn_udp::Transmit`, but we don't depend on that library here). +#[derive(derive_more::Debug, Clone, PartialEq, Eq)] +pub struct Datagrams { + /// Explicit congestion notification bits + pub ecn: Option, + /// The segment size if this transmission contains multiple datagrams. + /// This is `None` if the transmit only contains a single datagram + pub segment_size: Option, + /// The contents of the datagram(s) + #[debug(skip)] + pub contents: Bytes, +} + +impl> From for Datagrams { + fn from(bytes: T) -> Self { + Self { + ecn: None, + segment_size: None, + contents: Bytes::copy_from_slice(bytes.as_ref()), + } + } +} + +impl Datagrams { + fn write_to(&self, mut dst: O) -> O { + let ecn = self.ecn.map_or(0, |ecn| ecn as u8); + let segment_size = self.segment_size.unwrap_or_default(); + dst.put_u8(ecn); + dst.put_u16(segment_size); + dst.put(self.contents.as_ref()); + dst + } + + fn encoded_len(&self) -> usize { + 1 // ECN byte + + 2 // segment size + + self.contents.len() + } + + fn from_bytes(bytes: Bytes) -> Result { + // 1 bytes ECN, 2 bytes segment size + snafu::ensure!(bytes.len() > 3, InvalidFrameSnafu); + + let ecn_byte = bytes[0]; + let ecn = quinn_proto::EcnCodepoint::from_bits(ecn_byte); + + let segment_size = u16::from_be_bytes(bytes[1..3].try_into().expect("length checked")); + let segment_size = if segment_size == 0 { + None + } else { + Some(segment_size) + }; + + let contents = bytes.slice(3..); + + Ok(Self { + ecn, + segment_size, + contents, + }) + } +} + impl RelayToClientMsg { /// Returns this frame's corresponding frame type. pub fn typ(&self) -> FrameType { match self { - Self::ReceivedPacket { .. } => FrameType::RecvPacket, + Self::Datagrams { .. } => FrameType::RelayToClientDatagrams, Self::NodeGone { .. } => FrameType::NodeGone, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, @@ -156,12 +227,12 @@ impl RelayToClientMsg { pub(crate) fn write_to(&self, mut dst: O) -> O { dst = self.typ().write_to(dst); match self { - Self::ReceivedPacket { - src_key: remote_node_id, - content, + Self::Datagrams { + remote_node_id, + datagrams, } => { dst.put(remote_node_id.as_ref()); - dst.put(content.as_ref()); + dst = datagrams.write_to(dst); } Self::NodeGone(node_id) => { dst.put(node_id.as_ref()); @@ -189,9 +260,9 @@ impl RelayToClientMsg { #[cfg(feature = "server")] pub(crate) fn encoded_len(&self) -> usize { let payload_len = match self { - Self::ReceivedPacket { content, .. } => { + Self::Datagrams { datagrams, .. } => { 32 // nodeid - + content.len() + + datagrams.encoded_len() } Self::NodeGone(_) => 32, Self::Ping(_) | Self::Pong(_) => 8, @@ -218,14 +289,17 @@ impl RelayToClientMsg { ); let res = match frame_type { - FrameType::RecvPacket => { + FrameType::RelayToClientDatagrams => { snafu::ensure!(content.len() >= NodeId::LENGTH, InvalidFrameSnafu); - let src_key = cache + let remote_node_id = cache .key_from_slice(&content[..NodeId::LENGTH]) .context(InvalidPublicKeySnafu)?; - let content = content.slice(NodeId::LENGTH..); - Self::ReceivedPacket { src_key, content } + let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; + Self::Datagrams { + remote_node_id, + datagrams, + } } FrameType::NodeGone => { snafu::ensure!(content.len() == NodeId::LENGTH, InvalidFrameSnafu); @@ -282,7 +356,7 @@ impl RelayToClientMsg { impl ClientToRelayMsg { pub(crate) fn typ(&self) -> FrameType { match self { - Self::SendPacket { .. } => FrameType::SendPacket, + Self::Datagrams { .. } => FrameType::ClientToRelayDatagrams, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, } @@ -298,9 +372,12 @@ impl ClientToRelayMsg { pub(crate) fn write_to(&self, mut dst: O) -> O { dst = self.typ().write_to(dst); match self { - Self::SendPacket { dst_key, packet } => { - dst.put(dst_key.as_ref()); - dst.put(packet.as_ref()); + Self::Datagrams { + dst_node_id, + datagrams, + } => { + dst.put(dst_node_id.as_ref()); + dst = datagrams.write_to(dst); } Self::Ping(data) => { dst.put(&data[..]); @@ -315,9 +392,9 @@ impl ClientToRelayMsg { pub(crate) fn encoded_len(&self) -> usize { let payload_len = match self { Self::Ping(_) | Self::Pong(_) => 8, - Self::SendPacket { packet, .. } => { + Self::Datagrams { datagrams, .. } => { 32 // node id - + packet.len() + + datagrams.encoded_len() } }; 1 // frame type (all frame types currently encode as 1 byte varint) @@ -338,12 +415,15 @@ impl ClientToRelayMsg { ); let res = match frame_type { - FrameType::SendPacket => { - let dst_key = cache + FrameType::ClientToRelayDatagrams => { + let dst_node_id = cache .key_from_slice(&content[..NodeId::LENGTH]) .context(InvalidPublicKeySnafu)?; - let packet = content.slice(NodeId::LENGTH..); - Self::SendPacket { dst_key, packet } + let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; + Self::Datagrams { + dst_node_id, + datagrams, + } } FrameType::Ping => { snafu::ensure!(content.len() == 8, InvalidFrameSnafu); @@ -419,9 +499,13 @@ mod tests { "0a 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - RelayToClientMsg::ReceivedPacket { - src_key: client_key.public(), - content: "Hello World!".into(), + RelayToClientMsg::Datagrams { + remote_node_id: client_key.public(), + datagrams: Datagrams { + ecn: Some(quinn::EcnCodepoint::Ce), + segment_size: Some(6), + contents: "Hello World!".into(), + }, } .write_to(Vec::new()), "06 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e @@ -455,9 +539,13 @@ mod tests { "0a 2a 2a 2a 2a 2a 2a 2a 2a", ), ( - ClientToRelayMsg::SendPacket { - dst_key: client_key.public(), - packet: "Goodbye!".into(), + ClientToRelayMsg::Datagrams { + dst_node_id: client_key.public(), + datagrams: Datagrams { + ecn: Some(quinn::EcnCodepoint::Ce), + segment_size: Some(6), + contents: "Hello World!".into(), + }, } .write_to(Vec::new()), "04 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e @@ -485,22 +573,42 @@ mod proptests { secret_key().prop_map(|key| key.public()) } - /// Generates random data, up to the maximum packet size minus the given number of bytes - fn data(consumed: usize) -> impl Strategy { - let len = MAX_PACKET_SIZE - consumed; - prop::collection::vec(any::(), 0..len).prop_map(Bytes::from) + fn ecn() -> impl Strategy> { + (0..=3).prop_map(|n| match n { + 1 => Some(quinn_proto::EcnCodepoint::Ce), + 2 => Some(quinn_proto::EcnCodepoint::Ect0), + 3 => Some(quinn_proto::EcnCodepoint::Ect1), + _ => None, + }) + } + + fn datagrams() -> impl Strategy { + ( + ecn(), + prop::option::of(MAX_PAYLOAD_SIZE / 20..MAX_PAYLOAD_SIZE), + prop::collection::vec(any::(), 0..MAX_PAYLOAD_SIZE), + ) + .prop_map(|(ecn, segment_size, data)| Datagrams { + ecn, + segment_size: segment_size.map(|ss| std::cmp::min(data.len(), ss) as u16), + contents: Bytes::from(data), + }) } /// Generates a random valid frame fn server_client_frame() -> impl Strategy { - let recv_packet = (key(), data(32)) - .prop_map(|(src_key, content)| RelayToClientMsg::ReceivedPacket { src_key, content }); + let recv_packet = (key(), datagrams()).prop_map(|(remote_node_id, datagrams)| { + RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } + }); let node_gone = key().prop_map(RelayToClientMsg::NodeGone); let ping = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Pong); let health = ".{0,65536}" - .prop_filter("exceeds max payload size", |s| { - s.len() < 65536 // a single unicode character can match a regex "." but take up multiple bytes + .prop_filter("exceeds MAX_PAYLOAD_SIZE", |s| { + s.len() < MAX_PAYLOAD_SIZE // a single unicode character can match a regex "." but take up multiple bytes }) .prop_map(|problem| RelayToClientMsg::Health { problem }); let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { @@ -513,8 +621,11 @@ mod proptests { } fn client_server_frame() -> impl Strategy { - let send_packet = (key(), data(32)) - .prop_map(|(dst_key, packet)| ClientToRelayMsg::SendPacket { dst_key, packet }); + let send_packet = + (key(), datagrams()).prop_map(|(dst_node_id, datagrams)| ClientToRelayMsg::Datagrams { + dst_node_id, + datagrams, + }); let ping = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(ClientToRelayMsg::Pong); prop_oneof![send_packet, ping, pong] @@ -548,5 +659,12 @@ mod proptests { let actual_encoded_len = frame.to_bytes().len(); prop_assert_eq!(claimed_encoded_len, actual_encoded_len); } + + #[test] + fn datagrams_encoded_len(datagrams in datagrams()) { + let claimed_encoded_len = datagrams.encoded_len(); + let actual_encoded_len = datagrams.write_to(Vec::new()).len(); + prop_assert_eq!(claimed_encoded_len, actual_encoded_len); + } } } diff --git a/iroh-relay/src/server.rs b/iroh-relay/src/server.rs index 3660760f4af..94a3690ed55 100644 --- a/iroh-relay/src/server.rs +++ b/iroh-relay/src/server.rs @@ -750,7 +750,6 @@ impl hyper::service::Service> for CaptivePortalService { mod tests { use std::{net::Ipv4Addr, time::Duration}; - use bytes::Bytes; use http::StatusCode; use iroh_base::{NodeId, RelayUrl, SecretKey}; use n0_future::{FutureExt, SinkExt, StreamExt}; @@ -767,7 +766,7 @@ mod tests { dns::DnsResolver, protos::{ handshake, - relay::{ClientToRelayMsg, RelayToClientMsg}, + relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, }, }; @@ -791,14 +790,14 @@ mod tests { client_a: &mut crate::client::Client, client_b: &mut crate::client::Client, b_key: NodeId, - msg: Bytes, + msg: Datagrams, ) -> Result { // try resend 10 times for _ in 0..10 { client_a - .send(ClientToRelayMsg::SendPacket { - dst_key: b_key, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: b_key, + datagrams: msg.clone(), }) .await?; let Ok(res) = tokio::time::timeout(Duration::from_millis(500), client_b.next()).await @@ -912,26 +911,34 @@ mod tests { info!("sending a -> b"); // send message from a to b - let msg = Bytes::from_static(b"hello, b"); + let msg = Datagrams::from("hello, b"); let res = try_send_recv(&mut client_a, &mut client_b, b_key, msg.clone()).await?; - let RelayToClientMsg::ReceivedPacket { src_key, content } = res else { + let RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } = res + else { panic!("client_b received unexpected message {res:?}"); }; - assert_eq!(a_key, src_key); - assert_eq!(msg, content); + assert_eq!(a_key, remote_node_id); + assert_eq!(msg, datagrams); info!("sending b -> a"); // send message from b to a - let msg = Bytes::from_static(b"howdy, a"); + let msg = Datagrams::from("howdy, a"); let res = try_send_recv(&mut client_b, &mut client_a, a_key, msg.clone()).await?; - let RelayToClientMsg::ReceivedPacket { src_key, content } = res else { + let RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } = res + else { panic!("client_a received unexpected message {res:?}"); }; - assert_eq!(b_key, src_key); - assert_eq!(msg, content); + assert_eq!(b_key, remote_node_id); + assert_eq!(msg, datagrams); Ok(()) } @@ -1003,12 +1010,16 @@ mod tests { .await?; // send message from b to c - let msg = Bytes::from_static(b"hello, c"); + let msg = Datagrams::from("hello, c"); let res = try_send_recv(&mut client_b, &mut client_c, c_key, msg.clone()).await?; - if let RelayToClientMsg::ReceivedPacket { src_key, content } = res { - assert_eq!(b_key, src_key); - assert_eq!(msg, content); + if let RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } = res + { + assert_eq!(b_key, remote_node_id); + assert_eq!(msg, datagrams); } else { panic!("client_c received unexpected message {res:?}"); } @@ -1040,12 +1051,12 @@ mod tests { // send messages from a to b, without b receiving anything. // we should still keep succeeding to send, even if the packet won't be forwarded // by the relay server because the server's send queue for b fills up. - let msg = Bytes::from_static(b"hello, b"); + let msg = Datagrams::from("hello, b"); for _i in 0..1000 { client_a - .send(ClientToRelayMsg::SendPacket { - dst_key: b_key, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: b_key, + datagrams: msg.clone(), }) .await?; } diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 5a2026c8e44..50d051f1a89 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -2,7 +2,6 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; -use bytes::Bytes; use iroh_base::NodeId; use n0_future::{SinkExt, StreamExt}; use nested_enum_utils::common_fields; @@ -19,7 +18,7 @@ use tracing::{debug, trace, warn, Instrument}; use crate::{ protos::{ disco, - relay::{ClientToRelayMsg, RelayToClientMsg, PING_INTERVAL}, + relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg, PING_INTERVAL}, streams::StreamError, }, server::{ @@ -36,7 +35,7 @@ pub(super) struct Packet { /// The sender of the packet src: NodeId, /// The data packet bytes. - data: Bytes, + data: Datagrams, } /// Configuration for a [`Client`]. @@ -151,7 +150,7 @@ impl Client { pub(super) fn try_send_packet( &self, src: NodeId, - data: Bytes, + data: Datagrams, ) -> Result<(), TrySendError> { self.send_queue.try_send(Packet { src, data }) } @@ -159,7 +158,7 @@ impl Client { pub(super) fn try_send_disco_packet( &self, src: NodeId, - data: Bytes, + data: Datagrams, ) -> Result<(), TrySendError> { self.disco_send_queue.try_send(Packet { src, data }) } @@ -406,14 +405,17 @@ impl Actor { /// Errors if the send does not happen within the `timeout` duration /// Does not flush. async fn send_raw(&mut self, packet: Packet) -> Result<(), WriteFrameError> { - let src_key = packet.src; - let content = packet.data; + let remote_node_id = packet.src; + let datagrams = packet.data; - if let Ok(len) = content.len().try_into() { + if let Ok(len) = datagrams.contents.len().try_into() { self.metrics.bytes_sent.inc_by(len); } - self.write_frame(RelayToClientMsg::ReceivedPacket { src_key, content }) - .await + self.write_frame(RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + }) + .await } async fn send_packet(&mut self, packet: Packet) -> Result<(), WriteFrameError> { @@ -456,10 +458,13 @@ impl Actor { }; match frame { - ClientToRelayMsg::SendPacket { dst_key, packet } => { - let packet_len = packet.len(); + ClientToRelayMsg::Datagrams { + dst_node_id: dst_key, + datagrams, + } => { + let packet_len = datagrams.contents.len(); if let Err(err @ ForwardPacketError { .. }) = - self.handle_frame_send_packet(dst_key, packet) + self.handle_frame_send_packet(dst_key, datagrams) { warn!("failed to handle send packet frame: {err:#}"); } @@ -478,8 +483,12 @@ impl Actor { Ok(()) } - fn handle_frame_send_packet(&self, dst: NodeId, data: Bytes) -> Result<(), ForwardPacketError> { - if disco::looks_like_disco_wrapper(&data) { + fn handle_frame_send_packet( + &self, + dst: NodeId, + data: Datagrams, + ) -> Result<(), ForwardPacketError> { + if disco::looks_like_disco_wrapper(&data.contents) { self.metrics.disco_packets_recv.inc(); self.clients .send_disco_packet(dst, data, self.node_id, &self.metrics)?; @@ -628,15 +637,17 @@ mod tests { println!(" send packet"); let packet = Packet { src: node_id, - data: Bytes::from(&data[..]), + data: Datagrams::from(&data[..]), }; send_queue_s.send(packet.clone()).await.context("send")?; - let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; + let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut io_rw) + .await + .e()?; assert_eq!( frame, - RelayToClientMsg::ReceivedPacket { - src_key: node_id, - content: data.to_vec().into() + RelayToClientMsg::Datagrams { + remote_node_id: node_id, + datagrams: data.to_vec().into() } ); @@ -646,12 +657,14 @@ mod tests { .send(packet.clone()) .await .context("send")?; - let frame = recv_frame(FrameType::RecvPacket, &mut io_rw).await.e()?; + let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut io_rw) + .await + .e()?; assert_eq!( frame, - RelayToClientMsg::ReceivedPacket { - src_key: node_id, - content: data.to_vec().into() + RelayToClientMsg::Datagrams { + remote_node_id: node_id, + datagrams: data.to_vec().into() } ); @@ -679,9 +692,9 @@ mod tests { println!(" send packet"); let data = b"hello world!"; io_rw - .send(ClientToRelayMsg::SendPacket { - dst_key: target, - packet: Bytes::from_static(data), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: target, + datagrams: Datagrams::from(data), }) .await .context("send")?; @@ -693,9 +706,9 @@ mod tests { disco_data.extend_from_slice(target.as_bytes()); disco_data.extend_from_slice(data); io_rw - .send(ClientToRelayMsg::SendPacket { - dst_key: target, - packet: disco_data.clone().into(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: target, + datagrams: disco_data.clone().into(), }) .await .context("send")?; @@ -718,11 +731,11 @@ mod tests { let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT)?; // Prepare a frame to send, assert its size. - let data = Bytes::from_static(b"hello world!1eins"); + let data = Datagrams::from(b"hello world!!1"); let target = SecretKey::generate(rand::thread_rng()).public(); - let frame = ClientToRelayMsg::SendPacket { - dst_key: target, - packet: data.clone(), + let frame = ClientToRelayMsg::Datagrams { + dst_node_id: target, + datagrams: data.clone(), }; let frame_len = frame.to_bytes().len(); assert_eq!(frame_len, LIMIT as usize); diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index cacb0540bcd..bdec78042b0 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -9,16 +9,18 @@ use std::{ }, }; -use bytes::Bytes; use dashmap::DashMap; use iroh_base::NodeId; use tokio::sync::mpsc::error::TrySendError; use tracing::{debug, trace}; use super::client::{Client, Config, ForwardPacketError}; -use crate::server::{ - client::{PacketScope, SendError}, - metrics::Metrics, +use crate::{ + protos::relay::Datagrams, + server::{ + client::{PacketScope, SendError}, + metrics::Metrics, + }, }; /// Manages the connections to all currently connected clients. @@ -108,7 +110,7 @@ impl Clients { pub(super) fn send_packet( &self, dst: NodeId, - data: Bytes, + data: Datagrams, src: NodeId, metrics: &Metrics, ) -> Result<(), ForwardPacketError> { @@ -148,7 +150,7 @@ impl Clients { pub(super) fn send_disco_packet( &self, dst: NodeId, - data: Bytes, + data: Datagrams, src: NodeId, metrics: &Metrics, ) -> Result<(), ForwardPacketError> { @@ -252,24 +254,24 @@ mod tests { // send packet let data = b"hello world!"; - clients.send_packet(a_key, Bytes::from_static(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; + clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; + let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut a_rw).await?; assert_eq!( frame, - RelayToClientMsg::ReceivedPacket { - src_key: b_key, - content: data.to_vec().into(), + RelayToClientMsg::Datagrams { + remote_node_id: b_key, + datagrams: data.to_vec().into(), } ); // send disco packet - clients.send_disco_packet(a_key, Bytes::from_static(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; + clients.send_disco_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; + let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut a_rw).await?; assert_eq!( frame, - RelayToClientMsg::ReceivedPacket { - src_key: b_key, - content: data.to_vec().into(), + RelayToClientMsg::Datagrams { + remote_node_id: b_key, + datagrams: data.to_vec().into(), } ); diff --git a/iroh-relay/src/server/http_server.rs b/iroh-relay/src/server/http_server.rs index 5628990e7a0..4a47625aae3 100644 --- a/iroh-relay/src/server/http_server.rs +++ b/iroh-relay/src/server/http_server.rs @@ -879,7 +879,7 @@ mod tests { use crate::{ client::{conn::Conn, Client, ClientBuilder, ConnectError}, dns::DnsResolver, - protos::relay::{ClientToRelayMsg, RelayToClientMsg}, + protos::relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, }; pub(crate) fn make_tls_config() -> TlsConfig { @@ -947,11 +947,11 @@ mod tests { assert!(matches!(pong, RelayToClientMsg::Pong { .. })); info!("sending message from a to b"); - let msg = Bytes::from_static(b"hi there, client b!"); + let msg = Datagrams::from(b"hi there, client b!"); client_a - .send(ClientToRelayMsg::SendPacket { - dst_key: b_key, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: b_key, + datagrams: msg.clone(), }) .await?; info!("waiting for message from a on b"); @@ -961,11 +961,11 @@ mod tests { assert_eq!(msg, got_msg); info!("sending message from b to a"); - let msg = Bytes::from_static(b"right back at ya, client b!"); + let msg = Datagrams::from(b"right back at ya, client b!"); client_b - .send(ClientToRelayMsg::SendPacket { - dst_key: a_key, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: a_key, + datagrams: msg.clone(), }) .await?; info!("waiting for message b on a"); @@ -996,7 +996,7 @@ mod tests { fn process_msg( msg: Option>, - ) -> Option<(PublicKey, Bytes)> { + ) -> Option<(PublicKey, Datagrams)> { match msg { Some(Err(e)) => { info!("client `recv` error {e}"); @@ -1004,12 +1004,12 @@ mod tests { } Some(Ok(msg)) => { info!("got message on: {msg:?}"); - if let RelayToClientMsg::ReceivedPacket { - src_key: source, - content, + if let RelayToClientMsg::Datagrams { + remote_node_id: source, + datagrams, } = msg { - Some((source, content)) + Some((source, datagrams)) } else { None } @@ -1067,11 +1067,11 @@ mod tests { assert!(matches!(pong, RelayToClientMsg::Pong { .. })); info!("sending message from a to b"); - let msg = Bytes::from_static(b"hi there, client b!"); + let msg = Datagrams::from(b"hi there, client b!"); client_a - .send(ClientToRelayMsg::SendPacket { - dst_key: b_key, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: b_key, + datagrams: msg.clone(), }) .await?; info!("waiting for message from a on b"); @@ -1081,11 +1081,11 @@ mod tests { assert_eq!(msg, got_msg); info!("sending message from b to a"); - let msg = Bytes::from_static(b"right back at ya, client b!"); + let msg = Datagrams::from(b"right back at ya, client b!"); client_b - .send(ClientToRelayMsg::SendPacket { - dst_key: a_key, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: a_key, + datagrams: msg.clone(), }) .await?; info!("waiting for message b on a"); @@ -1144,17 +1144,20 @@ mod tests { handler_task.await.context("join")??; info!("Send message from A to B."); - let msg = Bytes::from_static(b"hello client b!!"); + let msg = Datagrams::from(b"hello client b!!"); client_a - .send(ClientToRelayMsg::SendPacket { - dst_key: public_key_b, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: public_key_b, + datagrams: msg.clone(), }) .await?; match client_b.next().await.unwrap()? { - RelayToClientMsg::ReceivedPacket { src_key, content } => { - assert_eq!(public_key_a, src_key); - assert_eq!(msg, content); + RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } => { + assert_eq!(public_key_a, remote_node_id); + assert_eq!(msg, datagrams); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1162,17 +1165,20 @@ mod tests { } info!("Send message from B to A."); - let msg = Bytes::from_static(b"nice to meet you client a!!"); + let msg = Datagrams::from(b"nice to meet you client a!!"); client_b - .send(ClientToRelayMsg::SendPacket { - dst_key: public_key_a, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: public_key_a, + datagrams: msg.clone(), }) .await?; match client_a.next().await.unwrap()? { - RelayToClientMsg::ReceivedPacket { src_key, content } => { - assert_eq!(public_key_b, src_key); - assert_eq!(msg, content); + RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } => { + assert_eq!(public_key_b, remote_node_id); + assert_eq!(msg, datagrams); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1185,9 +1191,9 @@ mod tests { info!("Fail to send message from A to B."); let res = client_a - .send(ClientToRelayMsg::SendPacket { - dst_key: public_key_b, - packet: Bytes::from_static(b"try to send"), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: public_key_b, + datagrams: Datagrams::from(b"try to send"), }) .await; assert!(res.is_err()); @@ -1228,17 +1234,20 @@ mod tests { handler_task.await.context("join")??; info!("Send message from A to B."); - let msg = Bytes::from_static(b"hello client b!!"); + let msg = Datagrams::from(b"hello client b!!"); client_a - .send(ClientToRelayMsg::SendPacket { - dst_key: public_key_b, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: public_key_b, + datagrams: msg.clone(), }) .await?; match client_b.next().await.expect("eos")? { - RelayToClientMsg::ReceivedPacket { src_key, content } => { - assert_eq!(public_key_a, src_key); - assert_eq!(msg, content); + RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } => { + assert_eq!(public_key_a, remote_node_id); + assert_eq!(msg, datagrams); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1246,17 +1255,20 @@ mod tests { } info!("Send message from B to A."); - let msg = Bytes::from_static(b"nice to meet you client a!!"); + let msg = Datagrams::from(b"nice to meet you client a!!"); client_b - .send(ClientToRelayMsg::SendPacket { - dst_key: public_key_a, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: public_key_a, + datagrams: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - RelayToClientMsg::ReceivedPacket { src_key, content } => { - assert_eq!(public_key_b, src_key); - assert_eq!(msg, content); + RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } => { + assert_eq!(public_key_b, remote_node_id); + assert_eq!(msg, datagrams); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1274,17 +1286,20 @@ mod tests { // assert!(client_b.recv().await.is_err()); info!("Send message from A to B."); - let msg = Bytes::from_static(b"are you still there, b?!"); + let msg = Datagrams::from(b"are you still there, b?!"); client_a - .send(ClientToRelayMsg::SendPacket { - dst_key: public_key_b, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: public_key_b, + datagrams: msg.clone(), }) .await?; match new_client_b.next().await.expect("eos")? { - RelayToClientMsg::ReceivedPacket { src_key, content } => { - assert_eq!(public_key_a, src_key); - assert_eq!(msg, content); + RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } => { + assert_eq!(public_key_a, remote_node_id); + assert_eq!(msg, datagrams); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1292,17 +1307,20 @@ mod tests { } info!("Send message from B to A."); - let msg = Bytes::from_static(b"just had a spot of trouble but I'm back now,a!!"); + let msg = Datagrams::from(b"just had a spot of trouble but I'm back now,a!!"); new_client_b - .send(ClientToRelayMsg::SendPacket { - dst_key: public_key_a, - packet: msg.clone(), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: public_key_a, + datagrams: msg.clone(), }) .await?; match client_a.next().await.expect("eos")? { - RelayToClientMsg::ReceivedPacket { src_key, content } => { - assert_eq!(public_key_b, src_key); - assert_eq!(msg, content); + RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, + } => { + assert_eq!(public_key_b, remote_node_id); + assert_eq!(msg, datagrams); } msg => { whatever!("expected ReceivedDatagrams msg, got {msg:?}"); @@ -1314,9 +1332,9 @@ mod tests { info!("Sending message from A to B fails"); let res = client_a - .send(ClientToRelayMsg::SendPacket { - dst_key: public_key_b, - packet: Bytes::from_static(b"try to send"), + .send(ClientToRelayMsg::Datagrams { + dst_node_id: public_key_b, + datagrams: Datagrams::from(b"try to send"), }) .await; assert!(res.is_err()); diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index dd57974ba71..73e45a86cb3 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -148,10 +148,6 @@ pub(crate) struct Options { pub(crate) metrics: EndpointMetrics, } -/// Contents of a relay message. Use a SmallVec to avoid allocations for the very -/// common case of a single packet. -type RelayContents = SmallVec<[Bytes; 1]>; - /// Handle for [`MagicSock`]. /// /// Dereferences to [`MagicSock`], and handles closing. diff --git a/iroh/src/magicsock/transports/relay.rs b/iroh/src/magicsock/transports/relay.rs index db5dea34374..d3b7f09be78 100644 --- a/iroh/src/magicsock/transports/relay.rs +++ b/iroh/src/magicsock/transports/relay.rs @@ -5,12 +5,12 @@ use std::{ use bytes::Bytes; use iroh_base::{NodeId, RelayUrl}; +use iroh_relay::protos::relay::Datagrams; use n0_future::{ ready, task::{self, AbortOnDropHandle}, }; use n0_watcher::{Watchable, Watcher as _}; -use smallvec::SmallVec; use tokio::sync::mpsc; use tokio_util::sync::PollSender; use tracing::{error, info_span, trace, warn, Instrument}; @@ -100,9 +100,12 @@ impl RelayTransport { } }; - buf_out[..dm.buf.len()].copy_from_slice(&dm.buf); - meta_out.len = dm.buf.len(); - meta_out.stride = dm.buf.len(); + buf_out[..dm.datagrams.contents.len()].copy_from_slice(&dm.datagrams.contents); + meta_out.len = dm.datagrams.contents.len(); + meta_out.stride = dm + .datagrams + .segment_size + .map_or(dm.datagrams.contents.len(), |s| s as usize); meta_out.ecn = None; meta_out.dst_ip = None; // TODO: insert the relay url for this relay @@ -186,7 +189,7 @@ impl RelaySender { dest_node: NodeId, transmit: &Transmit<'_>, ) -> io::Result<()> { - let contents = split_packets(transmit); + let contents = datagrams_from_transmit(transmit); let item = RelaySendItem { remote_node: dest_node, @@ -228,7 +231,7 @@ impl RelaySender { trace!(node = %dest_node.fmt_short(), relay_url = %dest_url, "send relay: message queued"); - let contents = split_packets(transmit); + let contents = datagrams_from_transmit(transmit); let item = RelaySendItem { remote_node: dest_node, url: dest_url.clone(), @@ -266,7 +269,7 @@ impl RelaySender { dest_node: NodeId, transmit: &Transmit<'_>, ) -> io::Result<()> { - let contents = split_packets(transmit); + let contents = datagrams_from_transmit(transmit); let item = RelaySendItem { remote_node: dest_node, @@ -304,26 +307,19 @@ impl RelaySender { } } -/// Split a transmit containing a GSO payload into individual packets. -/// -/// This allocates the data. -/// -/// If the transmit has a segment size it contains multiple GSO packets. It will be split -/// into multiple packets according to that segment size. If it does not have a segment -/// size, the contents will be sent as a single packet. +/// Translate a UDP transmit to the `Datagrams` type for sending over the relay. // TODO: If quinn stayed on bytes this would probably be much cheaper, probably. Need to // figure out where they allocate the Vec. -fn split_packets(transmit: &Transmit<'_>) -> RelayContents { - let mut res = SmallVec::with_capacity(1); - let contents = transmit.contents; - if let Some(segment_size) = transmit.segment_size { - for chunk in contents.chunks(segment_size) { - res.push(Bytes::from(chunk.to_vec())); - } - } else { - res.push(Bytes::from(contents.to_vec())); +fn datagrams_from_transmit(transmit: &Transmit<'_>) -> Datagrams { + Datagrams { + ecn: transmit.ecn.map(|ecn| match ecn { + quinn_udp::EcnCodepoint::Ect0 => quinn_proto::EcnCodepoint::Ect0, + quinn_udp::EcnCodepoint::Ect1 => quinn_proto::EcnCodepoint::Ect1, + quinn_udp::EcnCodepoint::Ce => quinn_proto::EcnCodepoint::Ce, + }), + segment_size: transmit.segment_size.map(|ss| ss as u16), + contents: Bytes::copy_from_slice(transmit.contents), } - res } #[cfg(test)] @@ -350,7 +346,7 @@ mod tests { let mut expected_msgs: BTreeSet = (0..capacity).collect(); while !expected_msgs.is_empty() { let datagram: RelayRecvDatagram = receiver.recv().await.unwrap(); - let msg_num = usize::from_le_bytes(datagram.buf.as_ref().try_into().unwrap()); + let msg_num = usize::from_le_bytes(datagram.datagrams.contents.as_ref().try_into().unwrap()); debug!("Received {msg_num}"); if !expected_msgs.remove(&msg_num) { @@ -370,7 +366,7 @@ mod tests { .try_send(RelayRecvDatagram { url, src: NodeId::from_bytes(&[0u8; 32]).unwrap(), - buf: Bytes::copy_from_slice(&i.to_le_bytes()), + datagrams: Datagrams::from(&i.to_le_bytes()), }) .unwrap(); } diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index dbfaa75ad93..70e03ed079b 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -38,13 +38,12 @@ use std::{ }; use backon::{Backoff, BackoffBuilder, ExponentialBuilder}; -use bytes::{Bytes, BytesMut}; -use iroh_base::{NodeId, PublicKey, RelayUrl, SecretKey}; +use iroh_base::{NodeId, RelayUrl, SecretKey}; use iroh_relay::{ self as relay, client::{Client, ConnectError, RecvError, SendError}, - protos::relay::{ClientToRelayMsg, RelayToClientMsg}, - PingTracker, MAX_PACKET_SIZE, + protos::relay::{ClientToRelayMsg, Datagrams, RelayToClientMsg}, + PingTracker, }; use n0_future::{ task::JoinSet, @@ -62,18 +61,11 @@ use url::Url; #[cfg(not(wasm_browser))] use crate::dns::DnsResolver; -use crate::{ - magicsock::{Metrics as MagicsockMetrics, RelayContents}, - net_report::Report, - util::MaybeFuture, -}; +use crate::{magicsock::Metrics as MagicsockMetrics, net_report::Report, util::MaybeFuture}; /// How long a non-home relay connection needs to be idle (last written to) before we close it. const RELAY_INACTIVE_CLEANUP_TIME: Duration = Duration::from_secs(60); -/// Maximum size a datagram payload is allowed to be. -const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - PublicKey::LENGTH; - /// Interval in which we ping the relay server to ensure the connection is alive. /// /// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some @@ -626,26 +618,20 @@ impl ActiveRelayActor { self.reset_inactive_timeout(); // TODO: This allocation is *very* unfortunate. But so is the // allocation *inside* of PacketizeIter... - let dgrams = std::mem::replace( + let batch = std::mem::replace( &mut send_datagrams_buf, Vec::with_capacity(SEND_DATAGRAM_BATCH_SIZE), ); // TODO(frando): can we avoid the clone here? let metrics = self.metrics.clone(); - let packet_iter = dgrams.into_iter().flat_map(|datagrams| { - PacketizeIter::<_, MAX_PAYLOAD_SIZE>::new( - datagrams.remote_node, - datagrams.datagrams.clone(), - ) - .map(|p| { - Ok(ClientToRelayMsg::SendPacket { dst_key: p.node_id, packet: p.payload }) + let packet_iter = batch.into_iter().map(|item| { + metrics.send_relay.inc_by(item.datagrams.contents.len() as _); + Ok(ClientToRelayMsg::Datagrams { + dst_node_id: item.remote_node, + datagrams: item.datagrams }) }); - let mut packet_stream = n0_future::stream::iter(packet_iter).inspect(|m| { - if let Ok(ClientToRelayMsg::SendPacket { dst_key: _node_id, packet: payload }) = m { - metrics.send_relay.inc_by(payload.len() as _); - } - }); + let mut packet_stream = n0_future::stream::iter(packet_iter); let fut = client_sink.send_all(&mut packet_stream); self.run_sending(fut, &mut state, &mut client_stream).await?; } @@ -680,11 +666,11 @@ impl ActiveRelayActor { fn handle_relay_msg(&mut self, msg: RelayToClientMsg, state: &mut ConnectedRelayState) { match msg { - RelayToClientMsg::ReceivedPacket { - src_key: remote_node_id, - content, + RelayToClientMsg::Datagrams { + remote_node_id, + datagrams, } => { - trace!(len = %content.len(), "received msg"); + trace!(len = %datagrams.contents.len(), "received msg"); // If this is a new sender, register a route for this peer. if state .last_packet_src @@ -696,14 +682,12 @@ impl ActiveRelayActor { state.last_packet_src = Some(remote_node_id); state.nodes_present.insert(remote_node_id); } - for datagram in PacketSplitIter::new(self.url.clone(), remote_node_id, content) { - let Ok(datagram) = datagram else { - warn!("Invalid packet split"); - break; - }; - if let Err(err) = self.relay_datagrams_recv.try_send(datagram) { - warn!("Dropping received relay packet: {err:#}"); - } + if let Err(err) = self.relay_datagrams_recv.try_send(RelayRecvDatagram { + url: self.url.clone(), + src: remote_node_id, + datagrams, + }) { + warn!("Dropping received relay packet: {err:#}"); } } RelayToClientMsg::NodeGone(node_id) => { @@ -852,7 +836,7 @@ pub(crate) struct RelaySendItem { /// The home relay of the remote node. pub(crate) url: RelayUrl, /// One or more datagrams to send. - pub(crate) datagrams: RelayContents, + pub(crate) datagrams: Datagrams, } pub(super) struct RelayActor { @@ -1232,18 +1216,6 @@ struct ActiveRelayHandle { datagrams_send_queue: mpsc::Sender, } -/// A packet to send over the relay. -/// -/// This is nothing but a newtype, it should be constructed using [`PacketizeIter`]. This -/// is a packet of one or more datagrams, each prefixed with a u16-be length. This is what -/// the `Frame::SendPacket` of the `DerpCodec` transports and is produced by -/// [`PacketizeIter`] and transformed back into datagrams using [`PacketSplitIter`]. -#[derive(Debug, PartialEq, Eq)] -struct RelaySendPacket { - node_id: NodeId, - payload: Bytes, -} - /// A single datagram received from a relay server. /// /// This could be either a QUIC or DISCO packet. @@ -1251,116 +1223,9 @@ struct RelaySendPacket { pub(crate) struct RelayRecvDatagram { pub(crate) url: RelayUrl, pub(crate) src: NodeId, - pub(crate) buf: Bytes, + pub(crate) datagrams: Datagrams, } -/// Combines datagrams into a single DISCO frame of at most MAX_PACKET_SIZE. -/// -/// The disco `iroh_relay::protos::Frame::SendPacket` frame can contain more then a single -/// datagram. Each datagram in this frame is prefixed with a little-endian 2-byte length -/// prefix. This occurs when Quinn sends a GSO transmit containing more than one datagram, -/// which are split using `split_packets`. -/// -/// The [`PacketSplitIter`] does the inverse and splits such packets back into individual -/// datagrams. -struct PacketizeIter { - node_id: NodeId, - iter: std::iter::Peekable, - buffer: BytesMut, -} - -impl PacketizeIter { - /// Create a new new PacketizeIter from something that can be turned into an - /// iterator of slices, like a `Vec`. - fn new(node_id: NodeId, iter: impl IntoIterator) -> Self { - Self { - node_id, - iter: iter.into_iter().peekable(), - buffer: BytesMut::with_capacity(N), - } - } -} - -impl Iterator for PacketizeIter -where - I::Item: AsRef<[u8]>, -{ - type Item = RelaySendPacket; - - fn next(&mut self) -> Option { - use bytes::BufMut; - while let Some(next_bytes) = self.iter.peek() { - let next_bytes = next_bytes.as_ref(); - assert!(next_bytes.len() + 2 <= N); - let next_length: u16 = next_bytes.len().try_into().expect("items < 64k size"); - if self.buffer.len() + next_bytes.len() + 2 > N { - break; - } - self.buffer.put_u16_le(next_length); - self.buffer.put_slice(next_bytes); - self.iter.next(); - } - if !self.buffer.is_empty() { - Some(RelaySendPacket { - node_id: self.node_id, - payload: self.buffer.split().freeze(), - }) - } else { - None - } - } -} - -/// Splits a single [`RelayToClientMsg::ReceivedPacket`] frame into datagrams. -/// -/// This splits packets joined by [`PacketizeIter`] back into individual datagrams. See -/// that struct for more details. -#[derive(Debug)] -struct PacketSplitIter { - url: RelayUrl, - src: NodeId, - bytes: Bytes, -} - -impl PacketSplitIter { - /// Create a new PacketSplitIter from a packet. - fn new(url: RelayUrl, src: NodeId, bytes: Bytes) -> Self { - Self { url, src, bytes } - } - - fn fail(&mut self) -> Option> { - self.bytes.clear(); - Some(Err(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "", - ))) - } -} - -impl Iterator for PacketSplitIter { - type Item = std::io::Result; - - fn next(&mut self) -> Option { - use bytes::Buf; - if self.bytes.has_remaining() { - if self.bytes.remaining() < 2 { - return self.fail(); - } - let len = self.bytes.get_u16_le() as usize; - if self.bytes.remaining() < len { - return self.fail(); - } - let buf = self.bytes.split_to(len); - Some(Ok(RelayRecvDatagram { - url: self.url.clone(), - src: self.src, - buf, - })) - } else { - None - } - } -} #[cfg(test)] mod tests { use std::{ @@ -1368,11 +1233,9 @@ mod tests { time::Duration, }; - use bytes::Bytes; use iroh_base::{NodeId, RelayUrl, SecretKey}; - use iroh_relay::PingTracker; + use iroh_relay::{protos::relay::Datagrams, PingTracker}; use n0_snafu::{Error, Result, ResultExt}; - use smallvec::smallvec; use tokio::sync::{mpsc, oneshot}; use tokio_util::{sync::CancellationToken, task::AbortOnDropHandle}; use tracing::{info, info_span, Instrument}; @@ -1380,42 +1243,11 @@ mod tests { use super::{ ActiveRelayActor, ActiveRelayActorOptions, ActiveRelayMessage, ActiveRelayPrioMessage, - PacketizeIter, RelayConnectionOptions, RelayRecvDatagram, RelaySendItem, MAX_PACKET_SIZE, - RELAY_INACTIVE_CLEANUP_TIME, UNDELIVERABLE_DATAGRAM_TIMEOUT, + RelayConnectionOptions, RelayRecvDatagram, RelaySendItem, RELAY_INACTIVE_CLEANUP_TIME, + UNDELIVERABLE_DATAGRAM_TIMEOUT, }; use crate::{dns::DnsResolver, test_utils}; - #[test] - fn test_packetize_iter() { - let node_id = SecretKey::generate(rand::thread_rng()).public(); - let empty_vec: Vec = Vec::new(); - let mut iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(node_id, empty_vec); - assert_eq!(None, iter.next()); - - let single_vec = vec!["Hello"]; - let iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(node_id, single_vec); - let result = iter.collect::>(); - assert_eq!(1, result.len()); - assert_eq!( - &[5, 0, b'H', b'e', b'l', b'l', b'o'], - &result[0].payload[..] - ); - - let spacer = vec![0u8; MAX_PACKET_SIZE - 10]; - let multiple_vec = vec![&b"Hello"[..], &spacer, &b"World"[..]]; - let iter = PacketizeIter::<_, MAX_PACKET_SIZE>::new(node_id, multiple_vec); - let result = iter.collect::>(); - assert_eq!(2, result.len()); - assert_eq!( - &[5, 0, b'H', b'e', b'l', b'l', b'o'], - &result[0].payload[..7] - ); - assert_eq!( - &[5, 0, b'W', b'o', b'r', b'l', b'd'], - &result[1].payload[..] - ); - } - /// Starts a new [`ActiveRelayActor`]. #[allow(clippy::too_many_arguments)] fn start_active_relay_actor( @@ -1476,12 +1308,16 @@ mod tests { loop { let datagram = recv_datagram_rx.recv().await; if let Some(recv) = datagram { - let RelayRecvDatagram { url: _, src, buf } = recv; + let RelayRecvDatagram { + url: _, + src, + datagrams, + } = recv; info!(from = src.fmt_short(), "Received datagram"); let send = RelaySendItem { remote_node: src, url: relay_url.clone(), - datagrams: smallvec![buf], + datagrams, }; send_datagram_tx.send(send).await.ok(); } @@ -1522,10 +1358,10 @@ mod tests { let RelayRecvDatagram { url: _, src: _, - buf, + datagrams, } = rx.recv().await.unwrap(); - assert_eq!(buf.as_ref(), item.datagrams[0]); + assert_eq!(datagrams, item.datagrams); Ok::<_, Error>(()) }) @@ -1568,7 +1404,7 @@ mod tests { let hello_send_item = RelaySendItem { remote_node: peer_node, url: relay_url.clone(), - datagrams: smallvec![Bytes::from_static(b"hello")], + datagrams: Datagrams::from(b"hello"), }; send_recv_echo( hello_send_item.clone(), From a1be02ce6b6ff81fdca45324fce05a431d091eba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 11 Jul 2025 16:03:08 +0200 Subject: [PATCH 74/80] feat(iroh-relay)!: Save 2 bytes for non-GSO datagrams --- iroh-relay/src/protos/common.rs | 8 +- iroh-relay/src/protos/relay.rs | 124 +++++++++++++++++++++++++------ iroh-relay/src/server/client.rs | 6 +- iroh-relay/src/server/clients.rs | 4 +- 4 files changed, 111 insertions(+), 31 deletions(-) diff --git a/iroh-relay/src/protos/common.rs b/iroh-relay/src/protos/common.rs index 6253c03c4cb..0b3045837e1 100644 --- a/iroh-relay/src/protos/common.rs +++ b/iroh-relay/src/protos/common.rs @@ -23,10 +23,14 @@ pub enum FrameType { ServerConfirmsAuth = 2, /// The server frame type for authentication denial ServerDeniesAuth = 3, + /// 32B dest pub key + ECN bytes + one datagram's content + ClientToRelayDatagram = 4, /// 32B dest pub key + ECN byte + segment size u16 + datagrams contents - ClientToRelayDatagrams = 4, + ClientToRelayDatagrams = 5, + /// 32B src pub key + ECN bytes + one datagram's content + RelayToClientDatagram = 6, /// 32B src pub key + ECN byte + segment size u16 + datagrams contents - RelayToClientDatagrams = 6, + RelayToClientDatagrams = 7, /// Sent from server to client to signal that a previous sender is no longer connected. /// /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 6de26cfd000..44f90bd5724 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -7,7 +7,7 @@ //! * server then sends [`FrameType::RecvPacket`] to recipient //! * server sends [`FrameType::NodeGone`] when the other client disconnects -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use iroh_base::{NodeId, SignatureError}; use n0_future::time::{self, Duration}; use nested_enum_utils::common_fields; @@ -165,39 +165,46 @@ impl> From for Datagrams { impl Datagrams { fn write_to(&self, mut dst: O) -> O { let ecn = self.ecn.map_or(0, |ecn| ecn as u8); - let segment_size = self.segment_size.unwrap_or_default(); dst.put_u8(ecn); - dst.put_u16(segment_size); + if let Some(segment_size) = self.segment_size { + dst.put_u16(segment_size); + } dst.put(self.contents.as_ref()); dst } fn encoded_len(&self) -> usize { 1 // ECN byte - + 2 // segment size + + self.segment_size.map_or(0, |_| 2) // segment size, when None, then a packed representation is assumed + self.contents.len() } - fn from_bytes(bytes: Bytes) -> Result { - // 1 bytes ECN, 2 bytes segment size - snafu::ensure!(bytes.len() > 3, InvalidFrameSnafu); + fn from_bytes(mut bytes: Bytes, is_batch: bool) -> Result { + if is_batch { + // 1 bytes ECN, 2 bytes segment size + snafu::ensure!(bytes.len() >= 3, InvalidFrameSnafu); + } else { + snafu::ensure!(bytes.len() >= 1, InvalidFrameSnafu); + } - let ecn_byte = bytes[0]; + let ecn_byte = bytes.get_u8(); let ecn = quinn_proto::EcnCodepoint::from_bits(ecn_byte); - let segment_size = u16::from_be_bytes(bytes[1..3].try_into().expect("length checked")); - let segment_size = if segment_size == 0 { - None + let segment_size = if is_batch { + let segment_size = bytes.get_u16(); // length checked above + if segment_size == 0 { + None + } else { + Some(segment_size) + } } else { - Some(segment_size) + None }; - let contents = bytes.slice(3..); - Ok(Self { ecn, segment_size, - contents, + contents: bytes, }) } } @@ -206,7 +213,13 @@ impl RelayToClientMsg { /// Returns this frame's corresponding frame type. pub fn typ(&self) -> FrameType { match self { - Self::Datagrams { .. } => FrameType::RelayToClientDatagrams, + Self::Datagrams { datagrams, .. } => { + if datagrams.segment_size.is_some() { + FrameType::RelayToClientDatagrams + } else { + FrameType::RelayToClientDatagram + } + } Self::NodeGone { .. } => FrameType::NodeGone, Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, @@ -289,13 +302,16 @@ impl RelayToClientMsg { ); let res = match frame_type { - FrameType::RelayToClientDatagrams => { + FrameType::RelayToClientDatagram | FrameType::RelayToClientDatagrams => { snafu::ensure!(content.len() >= NodeId::LENGTH, InvalidFrameSnafu); let remote_node_id = cache .key_from_slice(&content[..NodeId::LENGTH]) .context(InvalidPublicKeySnafu)?; - let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; + let datagrams = Datagrams::from_bytes( + content.slice(NodeId::LENGTH..), + frame_type == FrameType::RelayToClientDatagrams, + )?; Self::Datagrams { remote_node_id, datagrams, @@ -356,7 +372,13 @@ impl RelayToClientMsg { impl ClientToRelayMsg { pub(crate) fn typ(&self) -> FrameType { match self { - Self::Datagrams { .. } => FrameType::ClientToRelayDatagrams, + Self::Datagrams { datagrams, .. } => { + if datagrams.segment_size.is_some() { + FrameType::ClientToRelayDatagrams + } else { + FrameType::ClientToRelayDatagram + } + } Self::Ping { .. } => FrameType::Ping, Self::Pong { .. } => FrameType::Pong, } @@ -415,11 +437,14 @@ impl ClientToRelayMsg { ); let res = match frame_type { - FrameType::ClientToRelayDatagrams => { + FrameType::ClientToRelayDatagram | FrameType::ClientToRelayDatagrams => { let dst_node_id = cache .key_from_slice(&content[..NodeId::LENGTH]) .context(InvalidPublicKeySnafu)?; - let datagrams = Datagrams::from_bytes(content.slice(NodeId::LENGTH..))?; + let datagrams = Datagrams::from_bytes( + content.slice(NodeId::LENGTH..), + frame_type == FrameType::ClientToRelayDatagrams, + )?; Self::Datagrams { dst_node_id, datagrams, @@ -508,9 +533,39 @@ mod tests { }, } .write_to(Vec::new()), - "06 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e - a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d - 61 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + // frame type + // public key first 16 bytes + // public key second 16 bytes + // ECN byte + // segment size + // hello world contents bytes + "07 + 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 + 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 + 03 + 00 06 + 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + ), + ( + RelayToClientMsg::Datagrams { + remote_node_id: client_key.public(), + datagrams: Datagrams { + ecn: Some(quinn::EcnCodepoint::Ce), + segment_size: None, + contents: "Hello World!".into(), + }, + } + .write_to(Vec::new()), + // frame type + // public key first 16 bytes + // public key second 16 bytes + // ECN byte + // hello world contents bytes + "06 + 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 + 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 + 03 + 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", ), ( RelayToClientMsg::Restarting { @@ -552,6 +607,27 @@ mod tests { a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 47 6f 6f 64 62 79 65 21", ), + ( + ClientToRelayMsg::Datagrams { + dst_node_id: client_key.public(), + datagrams: Datagrams { + ecn: Some(quinn::EcnCodepoint::Ce), + segment_size: None, + contents: "Hello World!".into(), + }, + } + .write_to(Vec::new()), + // frame type + // public key first 16 bytes + // public key second 16 bytes + // ECN byte + // hello world contents + "04 + 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 + 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 + 03 + 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", + ), ]); Ok(()) diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index 50d051f1a89..67db0868644 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -640,7 +640,7 @@ mod tests { data: Datagrams::from(&data[..]), }; send_queue_s.send(packet.clone()).await.context("send")?; - let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut io_rw) + let frame = recv_frame(FrameType::RelayToClientDatagram, &mut io_rw) .await .e()?; assert_eq!( @@ -657,7 +657,7 @@ mod tests { .send(packet.clone()) .await .context("send")?; - let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut io_rw) + let frame = recv_frame(FrameType::RelayToClientDatagram, &mut io_rw) .await .e()?; assert_eq!( @@ -731,7 +731,7 @@ mod tests { let mut stream = RelayedStream::test_limited(io_read, LIMIT / 10, LIMIT)?; // Prepare a frame to send, assert its size. - let data = Datagrams::from(b"hello world!!1"); + let data = Datagrams::from(b"hello world!!!!!"); let target = SecretKey::generate(rand::thread_rng()).public(); let frame = ClientToRelayMsg::Datagrams { dst_node_id: target, diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index bdec78042b0..a889f5d9e09 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -255,7 +255,7 @@ mod tests { // send packet let data = b"hello world!"; clients.send_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut a_rw).await?; + let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a_rw).await?; assert_eq!( frame, RelayToClientMsg::Datagrams { @@ -266,7 +266,7 @@ mod tests { // send disco packet clients.send_disco_packet(a_key, Datagrams::from(&data[..]), b_key, &metrics)?; - let frame = recv_frame(FrameType::RelayToClientDatagrams, &mut a_rw).await?; + let frame = recv_frame(FrameType::RelayToClientDatagram, &mut a_rw).await?; assert_eq!( frame, RelayToClientMsg::Datagrams { From 3a6d751c69b9300eddfdd358c589a28c0ba823ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Mon, 14 Jul 2025 16:29:11 +0200 Subject: [PATCH 75/80] Fix snapshot tests --- iroh-relay/src/protos/relay.rs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 44f90bd5724..7cbd8df2772 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -603,9 +603,18 @@ mod tests { }, } .write_to(Vec::new()), - "04 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e - a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d - 61 47 6f 6f 64 62 79 65 21", + // frame type + // public key first 16 bytes + // public key second 16 bytes + // ECN byte + // Segment size + // hello world contents + "05 + 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7 + 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61 + 03 + 00 06 + 48 65 6c 6c 6f 20 57 6f 72 6c 64 21", ), ( ClientToRelayMsg::Datagrams { From edcafb4a35c020b83a315cecf67274c5e5145416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 18 Jul 2025 09:27:50 +0200 Subject: [PATCH 76/80] Address code review --- iroh-relay/src/protos/common.rs | 4 ++-- iroh-relay/src/protos/relay.rs | 12 ++++++------ iroh/src/magicsock/transports/relay.rs | 2 -- iroh/src/magicsock/transports/relay/actor.rs | 2 +- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/iroh-relay/src/protos/common.rs b/iroh-relay/src/protos/common.rs index 8e041c88ce8..09b5adaed76 100644 --- a/iroh-relay/src/protos/common.rs +++ b/iroh-relay/src/protos/common.rs @@ -26,11 +26,11 @@ pub enum FrameType { /// 32B dest pub key + ECN bytes + one datagram's content ClientToRelayDatagram = 4, /// 32B dest pub key + ECN byte + segment size u16 + datagrams contents - ClientToRelayDatagrams = 5, + ClientToRelayDatagramBatch = 5, /// 32B src pub key + ECN bytes + one datagram's content RelayToClientDatagram = 6, /// 32B src pub key + ECN byte + segment size u16 + datagrams contents - RelayToClientDatagrams = 7, + RelayToClientDatagramBatch = 7, /// Sent from server to client to signal that a previous sender is no longer connected. /// /// That is, if A sent to B, and then if A disconnects, the server sends `FrameType::PeerGone` diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index d4d967406a9..a1cbcca3773 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -213,7 +213,7 @@ impl RelayToClientMsg { match self { Self::Datagrams { datagrams, .. } => { if datagrams.segment_size.is_some() { - FrameType::RelayToClientDatagrams + FrameType::RelayToClientDatagramBatch } else { FrameType::RelayToClientDatagram } @@ -299,7 +299,7 @@ impl RelayToClientMsg { ); let res = match frame_type { - FrameType::RelayToClientDatagram | FrameType::RelayToClientDatagrams => { + FrameType::RelayToClientDatagram | FrameType::RelayToClientDatagramBatch => { snafu::ensure!(content.len() >= NodeId::LENGTH, InvalidFrameSnafu); let remote_node_id = cache @@ -307,7 +307,7 @@ impl RelayToClientMsg { .context(InvalidPublicKeySnafu)?; let datagrams = Datagrams::from_bytes( content.slice(NodeId::LENGTH..), - frame_type == FrameType::RelayToClientDatagrams, + frame_type == FrameType::RelayToClientDatagramBatch, )?; Self::Datagrams { remote_node_id, @@ -371,7 +371,7 @@ impl ClientToRelayMsg { match self { Self::Datagrams { datagrams, .. } => { if datagrams.segment_size.is_some() { - FrameType::ClientToRelayDatagrams + FrameType::ClientToRelayDatagramBatch } else { FrameType::ClientToRelayDatagram } @@ -434,13 +434,13 @@ impl ClientToRelayMsg { ); let res = match frame_type { - FrameType::ClientToRelayDatagram | FrameType::ClientToRelayDatagrams => { + FrameType::ClientToRelayDatagram | FrameType::ClientToRelayDatagramBatch => { let dst_node_id = cache .key_from_slice(&content[..NodeId::LENGTH]) .context(InvalidPublicKeySnafu)?; let datagrams = Datagrams::from_bytes( content.slice(NodeId::LENGTH..), - frame_type == FrameType::ClientToRelayDatagrams, + frame_type == FrameType::ClientToRelayDatagramBatch, )?; Self::Datagrams { dst_node_id, diff --git a/iroh/src/magicsock/transports/relay.rs b/iroh/src/magicsock/transports/relay.rs index 3f8cca0ea28..56a9da4cd23 100644 --- a/iroh/src/magicsock/transports/relay.rs +++ b/iroh/src/magicsock/transports/relay.rs @@ -307,8 +307,6 @@ impl RelaySender { } /// Translate a UDP transmit to the `Datagrams` type for sending over the relay. -// TODO: If quinn stayed on bytes this would probably be much cheaper, probably. Need to -// figure out where they allocate the Vec. fn datagrams_from_transmit(transmit: &Transmit<'_>) -> Datagrams { Datagrams { ecn: transmit.ecn.map(|ecn| match ecn { diff --git a/iroh/src/magicsock/transports/relay/actor.rs b/iroh/src/magicsock/transports/relay/actor.rs index 8bef8a46e58..da7f15261a2 100644 --- a/iroh/src/magicsock/transports/relay/actor.rs +++ b/iroh/src/magicsock/transports/relay/actor.rs @@ -670,7 +670,7 @@ impl ActiveRelayActor { remote_node_id, datagrams, } => { - trace!(len = %datagrams.contents.len(), "received msg"); + trace!(len = datagrams.contents.len(), "received msg"); // If this is a new sender, register a route for this peer. if state .last_packet_src From 5990021d500024ea3c227eba09a1356dc761749b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 18 Jul 2025 09:33:58 +0200 Subject: [PATCH 77/80] Remove `MAX_PAYLOAD_SIZE` (its calculations are wrong) --- iroh-relay/src/client/conn.rs | 6 +++--- iroh-relay/src/protos/relay.rs | 12 ++++-------- iroh-relay/src/server/streams.rs | 6 +++--- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/iroh-relay/src/client/conn.rs b/iroh-relay/src/client/conn.rs index 370bc07fcf3..93e70ddd084 100644 --- a/iroh-relay/src/client/conn.rs +++ b/iroh-relay/src/client/conn.rs @@ -145,10 +145,10 @@ impl Sink for Conn { } fn start_send(mut self: Pin<&mut Self>, frame: ClientToRelayMsg) -> Result<(), Self::Error> { + let size = frame.encoded_len(); + snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); if let ClientToRelayMsg::Datagrams { datagrams, .. } = &frame { - let size = datagrams.contents.len(); - snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); - snafu::ensure!(size != 0, EmptyPacketSnafu); + snafu::ensure!(!datagrams.contents.is_empty(), EmptyPacketSnafu); } Pin::new(&mut self.conn) diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index a1cbcca3773..e28bbe9b2bc 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -21,12 +21,6 @@ use crate::KeyCache; /// including its on-wire framing overhead) pub const MAX_PACKET_SIZE: usize = 64 * 1024; -/// Maximum size a datagram payload is allowed to be. -/// -/// This is [`MAX_PACKET_SIZE`] minus the length of an encoded public key minus 3 bytes, -/// one for ECN, and two for the segment size. -pub const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - NodeId::LENGTH - 3; - /// The maximum frame size. /// /// This is also the minimum burst size that a rate-limiter has to accept. @@ -665,6 +659,8 @@ mod proptests { } fn datagrams() -> impl Strategy { + // The max payload size (conservatively, since with segment_size = 0 we'd have slightly more space) + const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - NodeId::LENGTH - 1 /* ECN bytes */ - 2 /* segment size */; ( ecn(), prop::option::of(MAX_PAYLOAD_SIZE / 20..MAX_PAYLOAD_SIZE), @@ -689,8 +685,8 @@ mod proptests { let ping = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Ping); let pong = prop::array::uniform8(any::()).prop_map(RelayToClientMsg::Pong); let health = ".{0,65536}" - .prop_filter("exceeds MAX_PAYLOAD_SIZE", |s| { - s.len() < MAX_PAYLOAD_SIZE // a single unicode character can match a regex "." but take up multiple bytes + .prop_filter("exceeds MAX_PACKET_SIZE", |s| { + s.len() < MAX_PACKET_SIZE // a single unicode character can match a regex "." but take up multiple bytes }) .prop_map(|problem| RelayToClientMsg::Health { problem }); let restarting = (any::(), any::()).prop_map(|(reconnect_in, try_for)| { diff --git a/iroh-relay/src/server/streams.rs b/iroh-relay/src/server/streams.rs index a50ddda5343..45b235ae4d6 100644 --- a/iroh-relay/src/server/streams.rs +++ b/iroh-relay/src/server/streams.rs @@ -94,10 +94,10 @@ impl Sink for RelayedStream { } fn start_send(mut self: Pin<&mut Self>, item: RelayToClientMsg) -> Result<(), Self::Error> { + let size = item.encoded_len(); + snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); if let RelayToClientMsg::Datagrams { datagrams, .. } = &item { - let size = datagrams.contents.len(); - snafu::ensure!(size <= MAX_PACKET_SIZE, ExceedsMaxPacketSizeSnafu { size }); - snafu::ensure!(size != 0, EmptyPacketSnafu); + snafu::ensure!(!datagrams.contents.is_empty(), EmptyPacketSnafu); } Pin::new(&mut self.inner) From 1a34025f8b3b84a53687ffe14affe60b22867ce3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 18 Jul 2025 09:49:25 +0200 Subject: [PATCH 78/80] Fix documentation --- iroh-relay/src/protos/relay.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index e28bbe9b2bc..48d28aeab56 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -3,8 +3,8 @@ //! Protocol flow: //! * server occasionally sends [`FrameType::Ping`] //! * client responds to any [`FrameType::Ping`] with a [`FrameType::Pong`] -//! * clients sends [`FrameType::SendPacket`] -//! * server then sends [`FrameType::RecvPacket`] to recipient +//! * clients sends [`FrameType::ClientToRelayDatagram`] or [`FrameType::ClientToRelayDatagramBatch`] +//! * server then sends [`FrameType::RelayToClientDatagram`] or [`FrameType::RelayToClientDatagramBatch`] to recipient //! * server sends [`FrameType::NodeGone`] when the other client disconnects use bytes::{Buf, BufMut, Bytes, BytesMut}; From 6f42a31b61ae8964d972c0203acbe49b83a1edc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 18 Jul 2025 09:50:36 +0200 Subject: [PATCH 79/80] Ignore clippy in this case --- iroh-relay/src/protos/relay.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 48d28aeab56..5fd22f9fa3b 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -171,6 +171,7 @@ impl Datagrams { + self.contents.len() } + #[allow(clippy::len_zero)] fn from_bytes(mut bytes: Bytes, is_batch: bool) -> Result { if is_batch { // 1 bytes ECN, 2 bytes segment size From 4606f7322c7db5a6090efb0b9c7c1169f432a173 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Kr=C3=BCger?= Date: Fri, 18 Jul 2025 10:01:53 +0200 Subject: [PATCH 80/80] Use `FrameType::encoded_len` in `ClientToRelayMsg::encoded_len` --- iroh-relay/src/protos/relay.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/iroh-relay/src/protos/relay.rs b/iroh-relay/src/protos/relay.rs index 5fd22f9fa3b..d0e5aed3fe5 100644 --- a/iroh-relay/src/protos/relay.rs +++ b/iroh-relay/src/protos/relay.rs @@ -411,8 +411,7 @@ impl ClientToRelayMsg { + datagrams.encoded_len() } }; - 1 // frame type (all frame types currently encode as 1 byte varint) - + payload_len + self.typ().encoded_len() + payload_len } /// Tries to decode a frame received over websockets.