From 5245b141d39fe5d4c1cc2c42304c8aa47e10db6c Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Mon, 17 Feb 2025 13:02:32 +0100 Subject: [PATCH 01/12] Implement dynamic send. no timeout --- src/environment.rs | 2 -- src/producer.rs | 86 ++++++++++++++++------------------------------ 2 files changed, 29 insertions(+), 59 deletions(-) diff --git a/src/environment.rs b/src/environment.rs index 911a90d7..d0089093 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -1,6 +1,5 @@ use std::marker::PhantomData; use std::sync::Arc; -use std::time::Duration; use crate::types::OffsetSpecification; use crate::{client::TlsConfiguration, producer::NoDedup}; @@ -197,7 +196,6 @@ impl Environment { environment: self.clone(), name: None, batch_size: 100, - batch_publishing_delay: Duration::from_millis(100), data: PhantomData, filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), diff --git a/src/producer.rs b/src/producer.rs index 29913bdd..9cd8a2a0 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -5,7 +5,6 @@ use std::{ atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, Arc, }, - time::Duration, }; use dashmap::DashMap; @@ -71,21 +70,6 @@ pub struct ProducerInternal { filter_value_extractor: Option, } -impl ProducerInternal { - async fn batch_send(&self) -> Result<(), ProducerPublishError> { - let messages = self.accumulator.get(self.batch_size).await?; - - if !messages.is_empty() { - debug!("Sending batch of {} messages", messages.len()); - self.client - .publish(self.producer_id, messages, self.publish_version) - .await?; - } - - Ok(()) - } -} - /// API for publising messages to RabbitMQ stream #[derive(Clone)] pub struct Producer(Arc, PhantomData); @@ -95,7 +79,6 @@ pub struct ProducerBuilder { pub(crate) environment: Environment, pub(crate) name: Option, pub batch_size: usize, - pub batch_publishing_delay: Duration, pub(crate) data: PhantomData, pub filter_value_extractor: Option, pub(crate) client_provided_name: String, @@ -169,7 +152,7 @@ impl ProducerBuilder { let internal_producer = Arc::new(producer); let producer = Producer(internal_producer.clone(), PhantomData); - schedule_batch_send(internal_producer, self.batch_publishing_delay); + schedule_batch_send(internal_producer); Ok(producer) } else { @@ -185,11 +168,6 @@ impl ProducerBuilder { self } - pub fn batch_delay(mut self, delay: Duration) -> Self { - self.batch_publishing_delay = delay; - self - } - pub fn client_provided_name(mut self, name: &str) -> Self { self.client_provided_name = String::from(name); self @@ -201,7 +179,6 @@ impl ProducerBuilder { environment: self.environment, name: self.name, batch_size: self.batch_size, - batch_publishing_delay: self.batch_publishing_delay, data: PhantomData, filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), @@ -229,7 +206,6 @@ impl ProducerBuilder { pub struct MessageAccumulator { sender: mpsc::Sender, receiver: Mutex>, - capacity: usize, message_count: AtomicUsize, } @@ -239,52 +215,50 @@ impl MessageAccumulator { Self { sender, receiver: Mutex::new(receiver), - capacity: batch_size, message_count: AtomicUsize::new(0), } } - pub async fn add(&self, message: ClientMessage) -> RabbitMQStreamResult { + pub async fn add(&self, message: ClientMessage) -> RabbitMQStreamResult<()> { self.sender .send(message) .await - .map_err(|err| ClientError::GenericError(Box::new(err)))?; - - let val = self.message_count.fetch_add(1, Ordering::Relaxed); - - Ok(val + 1 == self.capacity) + .map_err(|err| ClientError::GenericError(Box::new(err))) } - pub async fn get(&self, batch_size: usize) -> RabbitMQStreamResult> { - let mut messages = Vec::with_capacity(batch_size); - let mut count = 0; + pub async fn get(&self, buffer: &mut Vec, batch_size: usize) -> (bool, usize) { let mut receiver = self.receiver.lock().await; - while count < batch_size { - match receiver.try_recv().ok() { - Some(message) => { - messages.push(message); - count += 1; - } - _ => break, - } - } + + let count = receiver.recv_many(buffer, batch_size).await; self.message_count - .fetch_sub(messages.len(), Ordering::Relaxed); - Ok(messages) + .fetch_sub(count, Ordering::Relaxed); + + // `recv_many` returns 0 only if the channel is closed + // Read https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.Receiver.html#method.recv_many + (count == 0, count) } } -fn schedule_batch_send(producer: Arc, delay: Duration) { +fn schedule_batch_send(producer: Arc) { tokio::task::spawn(async move { - let mut interval = tokio::time::interval(delay); - - debug!("Starting batch send interval every {:?}", delay); + let mut buffer = Vec::with_capacity(producer.batch_size); loop { - interval.tick().await; + let (is_closed, count) = producer.accumulator.get(&mut buffer, producer.batch_size).await; - match producer.batch_send().await { - Ok(_) => {} - Err(e) => error!("Error publishing batch {:?}", e), + if is_closed { + error!("Channel is closed and this is bad"); + break; + } + + if count > 0 { + debug!("Sending batch of {} messages", count); + let messages: Vec<_> = buffer.drain(..count).collect(); + match producer.client + .publish(producer.producer_id, messages, producer.publish_version) + .await { + Ok(_) => {} + Err(e) => error!("Error publishing batch {:?}", e), + }; } } }); @@ -459,9 +433,7 @@ impl Producer { .waiting_confirmations .insert(publishing_id, ProducerMessageWaiter::Once(waiter)); - if self.0.accumulator.add(msg).await? { - self.0.batch_send().await?; - } + self.0.accumulator.add(msg).await?; Ok(()) } From bca3425f52c4bbe5466871e6ff1a9d2655a8fcb5 Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Wed, 19 Feb 2025 14:11:39 +0100 Subject: [PATCH 02/12] fmt. stop loop if producer is closed. Fix integration test --- src/client/mod.rs | 2 +- src/producer.rs | 25 ++++++++++---- tests/{integration => }/client_test.rs | 5 ++- tests/{integration => }/common.rs | 37 ++++++++++++++------- tests/{integration => }/consumer_test.rs | 14 ++++++-- tests/{integration => }/environment_test.rs | 13 +++++--- tests/integration/main.rs | 5 --- tests/{integration => }/producer_test.rs | 5 ++- 8 files changed, 72 insertions(+), 34 deletions(-) rename tests/{integration => }/client_test.rs (99%) rename tests/{integration => }/common.rs (75%) rename tests/{integration => }/consumer_test.rs (99%) rename tests/{integration => }/environment_test.rs (96%) delete mode 100644 tests/integration/main.rs rename tests/{integration => }/producer_test.rs (99%) diff --git a/src/client/mod.rs b/src/client/mod.rs index 646777ea..7bc4f81f 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -622,7 +622,7 @@ impl Client { M: FnOnce(u32) -> R, { let Some((correlation_id, mut receiver)) = self.dispatcher.response_channel() else { - trace!("Connection si closed here"); + trace!("Connection is closed here"); return Err(ClientError::ConnectionClosed); }; diff --git a/src/producer.rs b/src/producer.rs index 9cd8a2a0..1677005c 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -230,8 +230,7 @@ impl MessageAccumulator { let mut receiver = self.receiver.lock().await; let count = receiver.recv_many(buffer, batch_size).await; - self.message_count - .fetch_sub(count, Ordering::Relaxed); + self.message_count.fetch_sub(count, Ordering::Relaxed); // `recv_many` returns 0 only if the channel is closed // Read https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.Receiver.html#method.recv_many @@ -243,7 +242,10 @@ fn schedule_batch_send(producer: Arc) { tokio::task::spawn(async move { let mut buffer = Vec::with_capacity(producer.batch_size); loop { - let (is_closed, count) = producer.accumulator.get(&mut buffer, producer.batch_size).await; + let (is_closed, count) = producer + .accumulator + .get(&mut buffer, producer.batch_size) + .await; if is_closed { error!("Channel is closed and this is bad"); @@ -253,11 +255,20 @@ fn schedule_batch_send(producer: Arc) { if count > 0 { debug!("Sending batch of {} messages", count); let messages: Vec<_> = buffer.drain(..count).collect(); - match producer.client + match producer + .client .publish(producer.producer_id, messages, producer.publish_version) - .await { - Ok(_) => {} - Err(e) => error!("Error publishing batch {:?}", e), + .await + { + Ok(_) => {} + Err(e) => { + error!("Error publishing batch {:?}", e); + + // Stop loop if producer is closed + if producer.closed.load(Ordering::Relaxed) { + break; + } + } }; } } diff --git a/tests/integration/client_test.rs b/tests/client_test.rs similarity index 99% rename from tests/integration/client_test.rs rename to tests/client_test.rs index 4051c484..e350cd00 100644 --- a/tests/integration/client_test.rs +++ b/tests/client_test.rs @@ -12,7 +12,10 @@ use rabbitmq_stream_client::{ Client, ClientOptions, }; -use crate::common::TestClient; +#[path = "./common.rs"] +mod common; + +use common::*; #[tokio::test] async fn client_connection_test() { diff --git a/tests/integration/common.rs b/tests/common.rs similarity index 75% rename from tests/integration/common.rs rename to tests/common.rs index ef36ab96..eb3563d9 100644 --- a/tests/integration/common.rs +++ b/tests/common.rs @@ -78,17 +78,21 @@ impl Drop for TestClient { fn drop(&mut self) { if self.stream != "" { tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(async { self.client.delete_stream(&self.stream).await.unwrap() }) + tokio::runtime::Handle::current().block_on(async { + let r = self.client.delete_stream(&self.stream).await; + if let Err(e) = r { + eprintln!("Error deleting stream: {:?}", e); + } + }) }); } if self.super_stream != "" { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { - self.client - .delete_super_stream(&self.super_stream) - .await - .unwrap() + let r = self.client.delete_super_stream(&self.super_stream).await; + if let Err(e) = r { + eprintln!("Error deleting super stream: {:?}", e); + } }) }); } @@ -130,17 +134,26 @@ impl Drop for TestEnvironment { fn drop(&mut self) { if self.stream != "" { tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(async { self.env.delete_stream(&self.stream).await.unwrap() }) + tokio::runtime::Handle::current().block_on(async { + let r = self.env.delete_stream(&self.stream).await; + // Since we generate stream name randomly, + // it doesn't matter if the deletion goes wrong + if let Err(e) = r { + eprintln!("Error deleting stream: {:?}", e); + } + }) }); } if self.super_stream != "" { tokio::task::block_in_place(|| { + println!("Deleting super stream: {}", self.super_stream); tokio::runtime::Handle::current().block_on(async { - self.env - .delete_super_stream(&self.super_stream) - .await - .unwrap() + let r = self.env.delete_super_stream(&self.super_stream).await; + // Since we generate super stream name randomly, + // it doesn't matter if the deletion goes wrong + if let Err(e) = r { + eprintln!("Error deleting super stream: {:?}", e); + } }) }); } diff --git a/tests/integration/consumer_test.rs b/tests/consumer_test.rs similarity index 99% rename from tests/integration/consumer_test.rs rename to tests/consumer_test.rs index dfcaa56b..7a0df273 100644 --- a/tests/integration/consumer_test.rs +++ b/tests/consumer_test.rs @@ -1,6 +1,10 @@ use std::time::Duration; -use crate::common::TestEnvironment; +#[path = "./common.rs"] +mod common; + +use common::*; + use fake::{Fake, Faker}; use futures::StreamExt; use rabbitmq_stream_client::{ @@ -12,7 +16,6 @@ use rabbitmq_stream_client::{ Consumer, FilterConfiguration, NoDedup, Producer, }; -use crate::producer_test::routing_key_strategy_value_extractor; use rabbitmq_stream_client::types::{ HashRoutingMurmurStrategy, RoutingKeyRoutingStrategy, RoutingStrategy, }; @@ -22,6 +25,10 @@ use tokio::sync::Notify; use tokio::task; use {std::sync::Arc, std::sync::Mutex}; +pub fn routing_key_strategy_value_extractor(_: &Message) -> String { + return "0".to_string(); +} + #[tokio::test(flavor = "multi_thread")] async fn consumer_test() { let env = TestEnvironment::create().await; @@ -66,7 +73,7 @@ fn hash_strategy_value_extractor(message: &Message) -> String { let s = String::from_utf8(Vec::from(message.data().unwrap())).expect("Found invalid UTF-8"); return s; } - +/* #[tokio::test(flavor = "multi_thread")] async fn super_stream_consumer_test() { let env = TestEnvironment::create_super_stream().await; @@ -115,6 +122,7 @@ async fn super_stream_consumer_test() { super_stream_producer.close().await.unwrap(); _ = handle.close().await; } +*/ #[tokio::test(flavor = "multi_thread")] async fn consumer_test_offset_specification_offset() { diff --git a/tests/integration/environment_test.rs b/tests/environment_test.rs similarity index 96% rename from tests/integration/environment_test.rs rename to tests/environment_test.rs index d62c050d..1dceabe1 100644 --- a/tests/integration/environment_test.rs +++ b/tests/environment_test.rs @@ -6,7 +6,10 @@ use rabbitmq_stream_client::types::ByteCapacity; use rabbitmq_stream_client::{error, Environment, TlsConfiguration}; use rabbitmq_stream_protocol::ResponseCode; -use crate::common::TestEnvironment; +#[path = "./common.rs"] +mod common; + +use common::*; #[tokio::test(flavor = "multi_thread")] async fn environment_create_test() { @@ -40,19 +43,21 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn environment_create_and_delete_super_stream_test() { - let super_stream = "super_stream_test"; + let super_stream: String = Faker.fake(); let env = Environment::builder().build().await.unwrap(); let response = env .stream_creator() .max_length(ByteCapacity::GB(5)) - .create_super_stream(super_stream, 3, None) + .create_super_stream(&super_stream, 3, None) .await; + println!("{:?}", response); assert_eq!(response.is_ok(), true); - let response = env.delete_super_stream(super_stream).await; + let response = env.delete_super_stream(&super_stream).await; + println!("{:?}", response); assert_eq!(response.is_ok(), true); } diff --git a/tests/integration/main.rs b/tests/integration/main.rs deleted file mode 100644 index 25dbbab8..00000000 --- a/tests/integration/main.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod client_test; -mod common; -mod consumer_test; -mod environment_test; -mod producer_test; diff --git a/tests/integration/producer_test.rs b/tests/producer_test.rs similarity index 99% rename from tests/integration/producer_test.rs rename to tests/producer_test.rs index c62c6c35..d72b18b0 100644 --- a/tests/integration/producer_test.rs +++ b/tests/producer_test.rs @@ -10,7 +10,10 @@ use rabbitmq_stream_client::{ Environment, }; -use crate::common::{Countdown, TestEnvironment}; +#[path = "./common.rs"] +mod common; + +use common::*; use rabbitmq_stream_client::types::{ HashRoutingMurmurStrategy, RoutingKeyRoutingStrategy, RoutingStrategy, From a16519009c5e3131a6f6fcb228899c8523ac2e66 Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Wed, 19 Feb 2025 14:21:52 +0100 Subject: [PATCH 03/12] - --- src/producer.rs | 11 +++++++---- tests/common.rs | 30 ++++++++++-------------------- tests/consumer_test.rs | 3 +-- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/producer.rs b/src/producer.rs index 1677005c..7474b108 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -220,10 +220,13 @@ impl MessageAccumulator { } pub async fn add(&self, message: ClientMessage) -> RabbitMQStreamResult<()> { - self.sender - .send(message) - .await - .map_err(|err| ClientError::GenericError(Box::new(err))) + match self.sender.send(message).await { + Ok(_) => { + self.message_count.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + Err(e) => Err(ClientError::GenericError(Box::new(e))), + } } pub async fn get(&self, buffer: &mut Vec, batch_size: usize) -> (bool, usize) { diff --git a/tests/common.rs b/tests/common.rs index eb3563d9..8344dc9e 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -79,20 +79,17 @@ impl Drop for TestClient { if self.stream != "" { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { - let r = self.client.delete_stream(&self.stream).await; - if let Err(e) = r { - eprintln!("Error deleting stream: {:?}", e); - } + self.client.delete_stream(&self.stream).await.unwrap(); }) }); } if self.super_stream != "" { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { - let r = self.client.delete_super_stream(&self.super_stream).await; - if let Err(e) = r { - eprintln!("Error deleting super stream: {:?}", e); - } + self.client + .delete_super_stream(&self.super_stream) + .await + .unwrap(); }) }); } @@ -135,12 +132,7 @@ impl Drop for TestEnvironment { if self.stream != "" { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { - let r = self.env.delete_stream(&self.stream).await; - // Since we generate stream name randomly, - // it doesn't matter if the deletion goes wrong - if let Err(e) = r { - eprintln!("Error deleting stream: {:?}", e); - } + self.env.delete_stream(&self.stream).await.unwrap(); }) }); } @@ -148,12 +140,10 @@ impl Drop for TestEnvironment { tokio::task::block_in_place(|| { println!("Deleting super stream: {}", self.super_stream); tokio::runtime::Handle::current().block_on(async { - let r = self.env.delete_super_stream(&self.super_stream).await; - // Since we generate super stream name randomly, - // it doesn't matter if the deletion goes wrong - if let Err(e) = r { - eprintln!("Error deleting super stream: {:?}", e); - } + self.env + .delete_super_stream(&self.super_stream) + .await + .unwrap(); }) }); } diff --git a/tests/consumer_test.rs b/tests/consumer_test.rs index 7a0df273..41e067bb 100644 --- a/tests/consumer_test.rs +++ b/tests/consumer_test.rs @@ -73,7 +73,7 @@ fn hash_strategy_value_extractor(message: &Message) -> String { let s = String::from_utf8(Vec::from(message.data().unwrap())).expect("Found invalid UTF-8"); return s; } -/* + #[tokio::test(flavor = "multi_thread")] async fn super_stream_consumer_test() { let env = TestEnvironment::create_super_stream().await; @@ -122,7 +122,6 @@ async fn super_stream_consumer_test() { super_stream_producer.close().await.unwrap(); _ = handle.close().await; } -*/ #[tokio::test(flavor = "multi_thread")] async fn consumer_test_offset_specification_offset() { From e6ceb901c1c63b3d61fd41e259bf20c7f1377e01 Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Wed, 19 Feb 2025 14:29:29 +0100 Subject: [PATCH 04/12] Use accumulator for batch_send --- src/producer.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/producer.rs b/src/producer.rs index 7474b108..9bedb732 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -465,7 +465,6 @@ impl Producer { let arc_cb = Arc::new(move |status| cb(status).boxed()); - let mut wrapped_msgs = Vec::with_capacity(messages.len()); for message in messages { let waiter = SharedProducerMessageWaiter::waiter_with_arc_cb(arc_cb.clone(), message.clone()); @@ -479,18 +478,14 @@ impl Producer { if let Some(f) = self.0.filter_value_extractor.as_ref() { client_message.filter_value_extract(f.as_ref()) } - wrapped_msgs.push(client_message); + // Queue the message for sending + self.0.accumulator.add(client_message).await?; self.0 .waiting_confirmations .insert(publishing_id, ProducerMessageWaiter::Shared(waiter.clone())); } - self.0 - .client - .publish(self.0.producer_id, wrapped_msgs, self.0.publish_version) - .await?; - Ok(()) } From 974ce402f7a13afa53c304c1185b4dc0c762b797 Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Wed, 19 Feb 2025 14:34:04 +0100 Subject: [PATCH 05/12] Remove println --- tests/common.rs | 1 - tests/environment_test.rs | 2 -- 2 files changed, 3 deletions(-) diff --git a/tests/common.rs b/tests/common.rs index 8344dc9e..19ef6b7b 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -138,7 +138,6 @@ impl Drop for TestEnvironment { } if self.super_stream != "" { tokio::task::block_in_place(|| { - println!("Deleting super stream: {}", self.super_stream); tokio::runtime::Handle::current().block_on(async { self.env .delete_super_stream(&self.super_stream) diff --git a/tests/environment_test.rs b/tests/environment_test.rs index 1dceabe1..8db42c03 100644 --- a/tests/environment_test.rs +++ b/tests/environment_test.rs @@ -52,12 +52,10 @@ async fn environment_create_and_delete_super_stream_test() { .create_super_stream(&super_stream, 3, None) .await; - println!("{:?}", response); assert_eq!(response.is_ok(), true); let response = env.delete_super_stream(&super_stream).await; - println!("{:?}", response); assert_eq!(response.is_ok(), true); } From e171dcc532aac2a81ebb6d713534f54e16361fce Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Thu, 20 Feb 2025 18:20:16 +0100 Subject: [PATCH 06/12] Add bin to test latency performance --- Cargo.toml | 3 + src/bin/perf-producer.rs | 175 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+) create mode 100644 src/bin/perf-producer.rs diff --git a/Cargo.toml b/Cargo.toml index 06cb49e4..65856363 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,3 +66,6 @@ path="examples/superstreams/send_super_stream.rs" name="environment_deserialization" path="examples/environment_deserialization.rs" +[[bin]] +name = "perf-producer" +path = "src/bin/perf-producer.rs" diff --git a/src/bin/perf-producer.rs b/src/bin/perf-producer.rs new file mode 100644 index 00000000..f491c5a3 --- /dev/null +++ b/src/bin/perf-producer.rs @@ -0,0 +1,175 @@ +#![allow(dead_code)] + +use std::{ + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::{Duration, Instant, SystemTime, UNIX_EPOCH}, +}; + +use rabbitmq_stream_client::{ + types::{ByteCapacity, Message, OffsetSpecification}, + Environment, +}; +use tokio::{sync::mpsc::UnboundedSender, time::sleep}; +use tokio_stream::StreamExt; + +static ONE_SECOND: Duration = Duration::from_secs(1); +static ONE_MINUTE: Duration = Duration::from_secs(60); + +struct Metric { + created_at: u128, + received_at: SystemTime, +} + +#[derive(Debug)] +struct Stats { + average_latency: f32, + messages_received: usize, +} + +#[tokio::main] +async fn main() { + let stream_name = "perf-stream"; + + let environment = Environment::builder().build().await.unwrap(); + let _ = environment.delete_stream(stream_name).await; + environment + .stream_creator() + .max_length(ByteCapacity::GB(5)) + .create(stream_name) + .await + .unwrap(); + + let environment = Arc::new(environment); + + let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel(); + let consumer_env = environment.clone(); + let consumer_handler = tokio::spawn(async move { + start_consumer(consumer_env, stream_name, sender).await; + }); + + let produced_messages = AtomicU32::new(0); + let producer_env = environment.clone(); + let producer_handler = tokio::spawn(async move { + start_producer(producer_env, stream_name, &produced_messages).await; + }); + + let run_for = Duration::from_secs(5 * 60); + + tokio::spawn(async move { + sleep(run_for).await; + producer_handler.abort(); + sleep(Duration::from_secs(1)).await; + consumer_handler.abort(); + }); + + let minutes = run_for.as_secs() / 60; + + let mut now = Instant::now(); + // 5 minutes of metrics + let mut metrics = Vec::with_capacity(50 * 60 * minutes as usize); + while let Some(metric) = receiver.recv().await { + if now.elapsed() > ONE_MINUTE { + now = Instant::now(); + + let last_metrics = metrics; + metrics = Vec::with_capacity(50 * 60 * minutes as usize); + tokio::spawn(async move { + let stats = calculate_stats(last_metrics).await; + println!("stats: {:?}", stats); + }); + } + metrics.push(metric); + } + + let stats = calculate_stats(metrics).await; + println!("stats: {:?}", stats); +} + +async fn calculate_stats(metrics: Vec) -> Stats { + let mut total_latency = 0; + let metric_count = metrics.len(); + for metric in metrics { + let created_at = SystemTime::UNIX_EPOCH + Duration::from_millis(metric.created_at as u64); + let received_at = metric.received_at; + let delta = received_at.duration_since(created_at).unwrap(); + total_latency += delta.as_millis(); + } + + Stats { + average_latency: total_latency as f32 / metric_count as f32, + messages_received: metric_count, + } +} + +async fn start_consumer( + environment: Arc, + stream_name: &str, + sender: UnboundedSender, +) { + let mut consumer = environment + .consumer() + .offset(OffsetSpecification::First) + .build(stream_name) + .await + .unwrap(); + while let Some(Ok(delivery)) = consumer.next().await { + let produced_at = delivery + .message() + .data() + .map(|data| { + u128::from_be_bytes([ + data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], + data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15], + ]) + }) + .unwrap(); + let metric = Metric { + created_at: produced_at, + received_at: SystemTime::now(), + }; + sender.send(metric).unwrap(); + } +} + +async fn start_producer( + environment: Arc, + stream_name: &str, + produced_messages: &AtomicU32, +) { + let message_per_second = 50_usize; + let producer = environment.producer().build(stream_name).await.unwrap(); + + loop { + let start = Instant::now(); + let messages = create_messages(message_per_second); + let messages_sent = messages.len() as u32; + for message in messages { + producer.send(message, |_| async {}).await.unwrap(); + } + produced_messages.fetch_add(messages_sent, Ordering::Relaxed); + + let elapsed = start.elapsed(); + + if ONE_SECOND > elapsed { + sleep(ONE_SECOND - elapsed).await; + } + } +} + +fn create_messages(message_count_per_batch: usize) -> Vec { + (0..message_count_per_batch) + .map(|_| { + let start = SystemTime::now(); + let since_the_epoch = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + let since_the_epoch = since_the_epoch.as_millis(); + Message::builder() + .body(since_the_epoch.to_be_bytes()) + .build() + }) + .collect() +} From 2818257e37a2e77abdad858b60f0fd161e204403 Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Tue, 25 Feb 2025 16:13:31 +0100 Subject: [PATCH 07/12] Add test. fmt & clippy --- examples/send_async.rs | 2 +- examples/tls_producer.rs | 4 +-- src/client/options.rs | 2 +- tests/client_test.rs | 43 +++++++++++++++++++----------- tests/common.rs | 16 ++++++----- tests/consumer_test.rs | 30 +++++++++------------ tests/environment_test.rs | 12 ++++----- tests/producer_test.rs | 56 +++++++++++++++------------------------ 8 files changed, 81 insertions(+), 84 deletions(-) diff --git a/examples/send_async.rs b/examples/send_async.rs index 401fbeb6..2e99ad4c 100644 --- a/examples/send_async.rs +++ b/examples/send_async.rs @@ -32,7 +32,7 @@ async fn main() -> Result<(), Box> { let create_response = environment .stream_creator() .max_length(ByteCapacity::GB(5)) - .create(&stream) + .create(stream) .await; if let Err(e) = create_response { diff --git a/examples/tls_producer.rs b/examples/tls_producer.rs index eb37b6e3..dbe8b4b9 100644 --- a/examples/tls_producer.rs +++ b/examples/tls_producer.rs @@ -54,9 +54,9 @@ async fn start_publisher( env: Environment, stream: &String, ) -> Result<(), Box> { - let _ = env.stream_creator().create(&stream).await; + let _ = env.stream_creator().create(stream).await; - let producer = env.producer().batch_size(BATCH_SIZE).build(&stream).await?; + let producer = env.producer().batch_size(BATCH_SIZE).build(stream).await?; let is_batch_send = true; tokio::task::spawn(async move { diff --git a/src/client/options.rs b/src/client/options.rs index 74fecaa8..8421e5ad 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -587,7 +587,7 @@ mod tests { assert_eq!(options.heartbeat, 10000); assert_eq!(options.max_frame_size, 1); assert!(matches!(options.tls, TlsConfiguration::Untrusted)); - assert_eq!(options.load_balancer_mode, true); + assert!(options.load_balancer_mode); } #[cfg(feature = "serde")] diff --git a/tests/client_test.rs b/tests/client_test.rs index e350cd00..23d7d1a9 100644 --- a/tests/client_test.rs +++ b/tests/client_test.rs @@ -224,10 +224,10 @@ async fn client_store_query_offset_error_test() { match response_off_not_found { Ok(_) => panic!("Should not be ok"), Err(e) => { - assert_eq!( - matches!(e, ClientError::RequestError(ResponseCode::OffsetNotFound)), - true - ) + assert!(matches!( + e, + ClientError::RequestError(ResponseCode::OffsetNotFound) + )) } } @@ -241,13 +241,10 @@ async fn client_store_query_offset_error_test() { match response_stream_does_not_exist { Ok(_) => panic!("Should not be ok"), Err(e) => { - assert_eq!( - matches!( - e, - ClientError::RequestError(ResponseCode::StreamDoesNotExist) - ), - true - ) + assert!(matches!( + e, + ClientError::RequestError(ResponseCode::StreamDoesNotExist) + )) } } } @@ -400,7 +397,7 @@ async fn client_publish() { assert_eq!(1, delivery.messages.len()); assert_eq!( Some(b"message".as_ref()), - delivery.messages.get(0).unwrap().data() + delivery.messages.first().unwrap().data() ); } @@ -432,8 +429,8 @@ async fn client_test_partitions_test() { .unwrap(); assert_eq!( - response.streams.get(0).unwrap(), - test.partitions.get(0).unwrap() + response.streams.first().unwrap(), + test.partitions.first().unwrap() ); assert_eq!( response.streams.get(1).unwrap(), @@ -456,7 +453,21 @@ async fn client_test_route_test() { assert_eq!(response.streams.len(), 1); assert_eq!( - response.streams.get(0).unwrap(), - test.partitions.get(0).unwrap() + response.streams.first().unwrap(), + test.partitions.first().unwrap() ); } + +#[tokio::test(flavor = "multi_thread")] +async fn client_close() { + let test = TestClient::create().await; + + test.client + .close() + .await + .expect("Failed to close the client"); + + let err = test.client.unsubscribe(1).await.unwrap_err(); + + assert!(matches!(err, ClientError::ConnectionClosed)); +} diff --git a/tests/common.rs b/tests/common.rs index 19ef6b7b..1230fd76 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -76,14 +76,16 @@ impl TestClient { impl Drop for TestClient { fn drop(&mut self) { - if self.stream != "" { + if !self.stream.is_empty() { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { - self.client.delete_stream(&self.stream).await.unwrap(); + // Some tests may close the connection intentionally + // so we ignore the error here + let _ = self.client.delete_stream(&self.stream).await; }) }); } - if self.super_stream != "" { + if !self.super_stream.is_empty() { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { self.client @@ -129,14 +131,14 @@ impl TestEnvironment { impl Drop for TestEnvironment { fn drop(&mut self) { - if self.stream != "" { + if !self.stream.is_empty() { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { self.env.delete_stream(&self.stream).await.unwrap(); }) }); } - if self.super_stream != "" { + if !self.super_stream.is_empty() { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { self.env @@ -166,7 +168,7 @@ pub async fn create_generic_super_stream( let response = client .create_super_stream( - &super_stream, + super_stream, partitions.clone(), binding_keys, HashMap::new(), @@ -174,5 +176,5 @@ pub async fn create_generic_super_stream( .await .unwrap(); - return (response, partitions); + (response, partitions) } diff --git a/tests/consumer_test.rs b/tests/consumer_test.rs index 41e067bb..ab50394f 100644 --- a/tests/consumer_test.rs +++ b/tests/consumer_test.rs @@ -26,7 +26,7 @@ use tokio::task; use {std::sync::Arc, std::sync::Mutex}; pub fn routing_key_strategy_value_extractor(_: &Message) -> String { - return "0".to_string(); + "0".to_string() } #[tokio::test(flavor = "multi_thread")] @@ -71,7 +71,7 @@ async fn consumer_test() { fn hash_strategy_value_extractor(message: &Message) -> String { let s = String::from_utf8(Vec::from(message.data().unwrap())).expect("Found invalid UTF-8"); - return s; + s } #[tokio::test(flavor = "multi_thread")] @@ -101,7 +101,7 @@ async fn super_stream_consumer_test() { for n in 0..message_count { let msg = Message::builder().body(format!("message{}", n)).build(); - let _ = super_stream_producer + super_stream_producer .send(msg, |_confirmation_status| async move {}) .await .unwrap(); @@ -111,7 +111,7 @@ async fn super_stream_consumer_test() { let handle = super_stream_consumer.handle(); while let Some(_) = super_stream_consumer.next().await { - received_messages = received_messages + 1; + received_messages += 1; if received_messages == 10 { break; } @@ -288,13 +288,10 @@ async fn consumer_create_stream_not_existing_error() { let consumer = env.env.consumer().build("stream_not_existing").await; match consumer { - Err(e) => assert_eq!( - matches!( - e, - rabbitmq_stream_client::error::ConsumerCreateError::StreamDoesNotExist { .. } - ), - true - ), + Err(e) => assert!(matches!( + e, + rabbitmq_stream_client::error::ConsumerCreateError::StreamDoesNotExist { .. } + )), _ => panic!("Should be StreamNotFound error"), } } @@ -444,7 +441,7 @@ async fn consumer_test_with_filtering() { let filter_configuration = FilterConfiguration::new(vec!["filtering".to_string()], false) .post_filter(|message| { String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) - == "filtering".to_string() + == *"filtering" }); let mut consumer = env @@ -521,7 +518,7 @@ async fn super_stream_consumer_test_with_filtering() { let filter_configuration = FilterConfiguration::new(vec!["filtering".to_string()], false) .post_filter(|message| { String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) - == "filtering".to_string() + == *"filtering" }); let mut super_stream_consumer = env @@ -629,8 +626,7 @@ async fn consumer_test_with_filtering_match_unfiltered() { let filter_configuration = FilterConfiguration::new(vec!["1".to_string()], true).post_filter(|message| { - String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) - == "1".to_string() + String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) == *"1" }); let mut consumer = env @@ -725,7 +721,7 @@ async fn super_stream_single_active_consumer_test() { for n in 0..message_count { let msg = Message::builder().body(format!("message{}", n)).build(); - let _ = super_stream_producer + super_stream_producer .send(msg, |_confirmation_status| async move {}) .await .unwrap(); @@ -884,7 +880,7 @@ async fn super_stream_single_active_consumer_test_with_callback() { for n in 0..message_count { let msg = Message::builder().body(format!("message{}", n)).build(); - let _ = super_stream_producer + super_stream_producer .send(msg, |_confirmation_status| async move {}) .await .unwrap(); diff --git a/tests/environment_test.rs b/tests/environment_test.rs index 8db42c03..7362552c 100644 --- a/tests/environment_test.rs +++ b/tests/environment_test.rs @@ -52,11 +52,11 @@ async fn environment_create_and_delete_super_stream_test() { .create_super_stream(&super_stream, 3, None) .await; - assert_eq!(response.is_ok(), true); + assert!(response.is_ok()); let response = env.delete_super_stream(&super_stream).await; - assert_eq!(response.is_ok(), true); + assert!(response.is_ok()); } #[tokio::test(flavor = "multi_thread")] @@ -130,7 +130,7 @@ async fn environment_create_delete_stream_twice() { let env = Environment::builder().build().await.unwrap(); let stream_to_test: String = Faker.fake(); let response = env.stream_creator().create(&stream_to_test).await; - assert_eq!(response.is_ok(), true); + assert!(response.is_ok()); let response = env.stream_creator().create(&stream_to_test).await; @@ -144,7 +144,7 @@ async fn environment_create_delete_stream_twice() { // The first delete should succeed since the stream was created let delete_response = env.delete_stream(&stream_to_test).await; - assert_eq!(delete_response.is_ok(), true); + assert!(delete_response.is_ok()); // the second delete should fail since the stream was already deleted let delete_response = env.delete_stream(&stream_to_test).await; @@ -174,10 +174,10 @@ async fn environment_create_streams_with_parameters() { .max_segment_size(ByteCapacity::GB(1)) .create(&stream_to_test) .await; - assert_eq!(response.is_ok(), true); + assert!(response.is_ok()); let delete_response = env.delete_stream(&stream_to_test).await; - assert_eq!(delete_response.is_ok(), true); + assert!(delete_response.is_ok()); } #[tokio::test(flavor = "multi_thread")] diff --git a/tests/producer_test.rs b/tests/producer_test.rs index d72b18b0..082c2b1c 100644 --- a/tests/producer_test.rs +++ b/tests/producer_test.rs @@ -71,7 +71,7 @@ async fn producer_send_name_deduplication_unique_ids() { for _ in 0..times { let cloned_ids = ids.clone(); let countdown = countdown.clone(); - let _ = producer + producer .send( Message::builder().body(b"message".to_vec()).build(), move |result| { @@ -309,7 +309,7 @@ async fn producer_batch_send() { assert_eq!(1, result.len()); - let confirmation = result.get(0).unwrap(); + let confirmation = result.first().unwrap(); assert_eq!(0, confirmation.publishing_id()); assert!(confirmation.confirmed()); assert_eq!(Some(b"message".as_ref()), confirmation.message().data()); @@ -414,8 +414,8 @@ async fn producer_send_with_complex_message_ok() { ); assert_eq!( - Some(1u32.into()), - properties.and_then(|properties| properties.group_sequence.clone()) + Some(1u32), + properties.and_then(|properties| properties.group_sequence) ); assert_eq!( @@ -454,13 +454,10 @@ async fn producer_create_stream_not_existing_error() { let producer = env.env.producer().build("stream_not_existing").await; match producer { - Err(e) => assert_eq!( - matches!( - e, - rabbitmq_stream_client::error::ProducerCreateError::StreamDoesNotExist { .. } - ), - true - ), + Err(e) => assert!(matches!( + e, + rabbitmq_stream_client::error::ProducerCreateError::StreamDoesNotExist { .. } + )), _ => panic!("Should be StreamNotFound error"), } } @@ -475,22 +472,19 @@ async fn producer_send_after_close_error() { .await .unwrap_err(); - assert_eq!( - matches!( - closed, - rabbitmq_stream_client::error::ProducerPublishError::Closed - ), - true - ); + assert!(matches!( + closed, + rabbitmq_stream_client::error::ProducerPublishError::Closed + )); } pub fn routing_key_strategy_value_extractor(_: &Message) -> String { - return "0".to_string(); + "0".to_string() } fn hash_strategy_value_extractor(message: &Message) -> String { let s = String::from_utf8(Vec::from(message.data().unwrap())).expect("Found invalid UTF-8"); - return s; + s } #[tokio::test(flavor = "multi_thread")] @@ -554,13 +548,10 @@ async fn key_super_steam_non_existing_producer_test() { .await .unwrap_err(); - assert_eq!( - matches!( - result, - rabbitmq_stream_client::error::SuperStreamProducerPublishError::ProducerCreateError() - ), - true - ); + assert!(matches!( + result, + rabbitmq_stream_client::error::SuperStreamProducerPublishError::ProducerCreateError() + )); _ = super_stream_producer.close(); } @@ -641,13 +632,10 @@ async fn producer_send_filtering_message() { let closed = producer.send_with_confirm(message).await.unwrap_err(); - assert_eq!( - matches!( - closed, - rabbitmq_stream_client::error::ProducerPublishError::Closed - ), - true - ); + assert!(matches!( + closed, + rabbitmq_stream_client::error::ProducerPublishError::Closed + )); } #[tokio::test(flavor = "multi_thread")] From 88519cea3371bd6434fa53885a6c29832e125a1d Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Tue, 25 Feb 2025 16:16:56 +0100 Subject: [PATCH 08/12] Fix typo in method name --- protocol/src/message/builder.rs | 2 +- tests/producer_test.rs | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/protocol/src/message/builder.rs b/protocol/src/message/builder.rs index 7afe103f..1c819639 100644 --- a/protocol/src/message/builder.rs +++ b/protocol/src/message/builder.rs @@ -30,7 +30,7 @@ impl MessageBuilder { pub fn application_properties(self) -> ApplicationPropertiesBuider { ApplicationPropertiesBuider(self) } - pub fn publising_id(mut self, publishing_id: u64) -> Self { + pub fn publishing_id(mut self, publishing_id: u64) -> Self { self.0.publishing_id = Some(publishing_id); self } diff --git a/tests/producer_test.rs b/tests/producer_test.rs index 082c2b1c..98438174 100644 --- a/tests/producer_test.rs +++ b/tests/producer_test.rs @@ -132,7 +132,7 @@ async fn producer_send_name_with_deduplication_ok() { .send_with_confirm( Message::builder() .body(b"message0".to_vec()) - .publising_id(0) + .publishing_id(0) .build(), ) .await @@ -173,24 +173,24 @@ async fn producer_send_batch_name_with_deduplication_ok() { // confirmed Message::builder() .body(b"message".to_vec()) - .publising_id(0) + .publishing_id(0) .build(), // this won't be confirmed // since it will skipped by deduplication Message::builder() .body(b"message".to_vec()) - .publising_id(0) + .publishing_id(0) .build(), // confirmed since the publishing id is different Message::builder() .body(b"message".to_vec()) - .publising_id(1) + .publishing_id(1) .build(), // not confirmed since the publishing id is the same // message will be skipped by deduplication Message::builder() .body(b"message".to_vec()) - .publising_id(1) + .publishing_id(1) .build(), ]) .await From 4dd577cdb6a54baea48c5d224c796d0776533032 Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Tue, 25 Feb 2025 17:19:23 +0100 Subject: [PATCH 09/12] Test drop connection --- Cargo.toml | 2 ++ src/client/mod.rs | 26 +++++++++++++++++---- src/client/options.rs | 5 ++++ src/environment.rs | 1 + tests/client_test.rs | 36 +++++++++++++++++++++++++++-- tests/common.rs | 54 ++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 116 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 65856363..684d8cd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,8 @@ tracing-subscriber = "0.3.1" fake = { version = "3.0.0", features = ['derive'] } chrono = "0.4.26" serde_json = "1.0" +reqwest = { version = "0.12", features = ["json"] } +serde = { version = "1.0", features = ["derive"] } [features] default = [] diff --git a/src/client/mod.rs b/src/client/mod.rs index 7bc4f81f..4807ae7d 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -25,7 +25,7 @@ use tokio::{net::TcpStream, sync::Notify}; use tokio_rustls::client::TlsStream; use tokio_util::codec::Framed; -use tracing::trace; +use tracing::{trace, warn}; use crate::{error::ClientError, RabbitMQStreamResult}; pub use message::ClientMessage; @@ -275,8 +275,12 @@ impl Client { state.handler = Some(Arc::new(handler)); } + pub fn is_closed(&self) -> bool { + self.channel.is_closed() + } + pub async fn close(&self) -> RabbitMQStreamResult<()> { - if self.channel.is_closed() { + if self.is_closed() { return Err(ClientError::AlreadyClosed); } let _: CloseResponse = self @@ -286,12 +290,17 @@ impl Client { .await?; let mut state = self.state.write().await; - + // This stop the tokio task that performs heartbeats state.heartbeat_task.take(); - drop(state); + + self.force_drop_connection().await + } + + async fn force_drop_connection(&self) -> RabbitMQStreamResult<()> { self.channel.close().await } + pub async fn subscribe( &self, subscription_id: u8, @@ -711,9 +720,16 @@ impl Client { let heartbeat_task = tokio::spawn(async move { loop { trace!("Sending heartbeat"); - let _ = channel.send(HeartBeatCommand::default().into()).await; + if channel + .send(HeartBeatCommand::default().into()) + .await + .is_err() + { + break; + } tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await; } + warn!("Heartbeat task stopped. Force closing connection"); }) .into(); state.heartbeat_task = Some(heartbeat_task); diff --git a/src/client/options.rs b/src/client/options.rs index 8421e5ad..814bfcc7 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -223,6 +223,11 @@ impl ClientOptionsBuilder { self } + pub fn client_provided_name(mut self, client_provided_name: String) -> Self { + self.0.client_provided_name = client_provided_name; + self + } + pub fn collector(mut self, collector: Arc) -> Self { self.0.collector = collector; self diff --git a/src/environment.rs b/src/environment.rs index d0089093..9592dbcd 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -326,6 +326,7 @@ impl EnvironmentBuilder { self.0.client_options.heartbeat = heartbeat; self } + pub fn metrics_collector( mut self, collector: impl MetricsCollector + 'static, diff --git a/tests/client_test.rs b/tests/client_test.rs index 23d7d1a9..85145b16 100644 --- a/tests/client_test.rs +++ b/tests/client_test.rs @@ -467,7 +467,39 @@ async fn client_close() { .await .expect("Failed to close the client"); - let err = test.client.unsubscribe(1).await.unwrap_err(); + let err = test.client.unsubscribe(1).await; + assert!( + matches!(err, Err(ClientError::ConnectionClosed)) || matches!(err, Err(ClientError::Io(_))) + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn client_drop_connection() { + let _ = tracing_subscriber::fmt::try_init(); + let client_provider_name: String = Faker.fake(); + + let options = ClientOptions::builder() + .client_provided_name(client_provider_name.clone()) + .heartbeat(2) + .build(); + let test = TestClient::create_with_option(options).await; + + let reference: String = Faker.fake(); + let _ = test + .client + .declare_publisher(1, Some(reference.clone()), "not_existing_stream") + .await; + let _ = test.client.unsubscribe(1).await; + + let connection = wait_for_named_connection(client_provider_name.clone()).await; + drop_connection(connection).await; - assert!(matches!(err, ClientError::ConnectionClosed)); + let res = test + .client + .declare_publisher(1, Some(reference.clone()), "not_existing_stream") + .await; + + assert!(matches!(res, Err(ClientError::ConnectionClosed))); + let res = test.client.close().await; + assert!(matches!(res, Err(ClientError::ConnectionClosed))); } diff --git a/tests/common.rs b/tests/common.rs index 1230fd76..d23d7ab3 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -1,9 +1,11 @@ +use core::panic; use std::{collections::HashMap, future::Future, sync::Arc}; use fake::{Fake, Faker}; use rabbitmq_stream_client::{Client, ClientOptions, Environment}; use rabbitmq_stream_protocol::commands::generic::GenericResponse; use rabbitmq_stream_protocol::ResponseCode; +use serde::Deserialize; use tokio::sync::Semaphore; pub struct TestClient { @@ -44,8 +46,12 @@ pub struct TestEnvironment { impl TestClient { pub async fn create() -> TestClient { + Self::create_with_option(ClientOptions::default()).await + } + + pub async fn create_with_option(options: ClientOptions) -> TestClient { let stream: String = Faker.fake(); - let client = Client::connect(ClientOptions::default()).await.unwrap(); + let client = Client::connect(options).await.unwrap(); let response = client.create_stream(&stream, HashMap::new()).await.unwrap(); @@ -178,3 +184,49 @@ pub async fn create_generic_super_stream( (response, partitions) } + +#[derive(Deserialize, Debug)] +pub struct RabbitConnection { + pub name: String, + pub client_properties: HashMap, +} + +pub async fn list_http_connection() -> Vec { + reqwest::Client::new() + .get("http://localhost:15672/api/connections/") + .basic_auth("guest", Some("guest")) + .send() + .await + .unwrap() + .json() + .await + .unwrap() +} + +pub async fn wait_for_named_connection(connection_name: String) -> RabbitConnection { + let mut max = 10; + while max > 0 { + let connections = list_http_connection().await; + let connection = connections + .into_iter() + .find(|x| x.client_properties.get("connection_name") == Some(&connection_name)); + match connection { + Some(connection) => return connection, + None => tokio::time::sleep(tokio::time::Duration::from_secs(1)).await, + } + max -= 1; + } + panic!("Connection not found. timeout"); +} + +pub async fn drop_connection(connection: RabbitConnection) { + reqwest::Client::new() + .delete(format!( + "http://localhost:15672/api/connections/{}", + connection.name + )) + .basic_auth("guest", Some("guest")) + .send() + .await + .unwrap(); +} From ce6d1dfa8f69dc9b0b8f608acea2a30551b13a4b Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Tue, 25 Feb 2025 18:20:49 +0100 Subject: [PATCH 10/12] Add timeout to send_with_confirm --- src/error.rs | 2 + src/producer.rs | 114 ++++++++++++++++++++++++++++------------- tests/producer_test.rs | 38 ++++++++++++++ 3 files changed, 117 insertions(+), 37 deletions(-) diff --git a/src/error.rs b/src/error.rs index f138a699..0e1075f9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -109,6 +109,8 @@ pub enum ProducerPublishError { Confirmation { stream: String }, #[error(transparent)] Client(#[from] ClientError), + #[error("Failed to publish message, timeout")] + Timeout, } #[derive(Error, Debug)] diff --git a/src/producer.rs b/src/producer.rs index 9bedb732..066ebc69 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -1,4 +1,5 @@ use std::future::Future; +use std::time::Duration; use std::{ marker::PhantomData, sync::{ @@ -11,6 +12,7 @@ use dashmap::DashMap; use futures::{future::BoxFuture, FutureExt}; use tokio::sync::mpsc::channel; use tokio::sync::{mpsc, Mutex}; +use tokio::time::sleep; use tracing::{debug, error, trace}; use rabbitmq_stream_protocol::{message::Message, ResponseCode, ResponseKind}; @@ -364,13 +366,19 @@ impl Producer { }) .await?; - rx.recv() - .await - .ok_or_else(|| ProducerPublishError::Confirmation { - stream: self.0.stream.clone(), - })? - .map_err(|err| ClientError::GenericError(Box::new(err))) - .map(Ok)? + let r = tokio::select! { + val = rx.recv() => { + Ok(val) + } + _ = sleep(Duration::from_secs(1)) => { + Err(ProducerPublishError::Timeout) + } + }?; + r.ok_or_else(|| ProducerPublishError::Confirmation { + stream: self.0.stream.clone(), + })? + .map_err(|err| ClientError::GenericError(Box::new(err))) + .map(Ok)? } async fn do_batch_send_with_confirm( @@ -539,23 +547,23 @@ impl MessageHandler for ProducerConfirmHandler { }; match waiter { ProducerMessageWaiter::Once(waiter) => { - let cb = waiter.cb; - cb(Ok(ConfirmationStatus { - publishing_id: id, - confirmed: true, - status: ResponseCode::Ok, - message: waiter.msg, - })) + invoke_handler_once( + waiter.cb, + id, + true, + ResponseCode::Ok, + waiter.msg, + ) .await; } ProducerMessageWaiter::Shared(waiter) => { - let cb = waiter.cb; - cb(Ok(ConfirmationStatus { - publishing_id: id, - confirmed: true, - status: ResponseCode::Ok, - message: waiter.msg, - })) + invoke_handler( + waiter.cb, + id, + true, + ResponseCode::Ok, + waiter.msg, + ) .await; } } @@ -576,24 +584,11 @@ impl MessageHandler for ProducerConfirmHandler { }; match waiter { ProducerMessageWaiter::Once(waiter) => { - let cb = waiter.cb; - cb(Ok(ConfirmationStatus { - publishing_id: id, - confirmed: false, - status: code, - message: waiter.msg, - })) - .await; + invoke_handler_once(waiter.cb, id, false, code, waiter.msg) + .await; } ProducerMessageWaiter::Shared(waiter) => { - let cb = waiter.cb; - cb(Ok(ConfirmationStatus { - publishing_id: id, - confirmed: false, - status: code, - message: waiter.msg, - })) - .await; + invoke_handler(waiter.cb, id, false, code, waiter.msg).await; } } } @@ -614,6 +609,51 @@ impl MessageHandler for ProducerConfirmHandler { } } +async fn invoke_handler( + f: T, + publishing_id: u64, + confirmed: bool, + status: ResponseCode, + message: Message, +) where + T: std::ops::Deref< + Target = dyn Fn( + Result, + ) -> std::pin::Pin + Send>> + + Send + + Sync, + >, +{ + f(Ok(ConfirmationStatus { + publishing_id, + confirmed, + status, + message, + })) + .await; +} +async fn invoke_handler_once( + f: Box< + dyn FnOnce( + Result, + ) -> std::pin::Pin + Send>> + + Send + + Sync, + >, + publishing_id: u64, + confirmed: bool, + status: ResponseCode, + message: Message, +) { + f(Ok(ConfirmationStatus { + publishing_id, + confirmed, + status, + message, + })) + .await; +} + type ConfirmCallback = Box< dyn FnOnce(Result) -> BoxFuture<'static, ()> + Send diff --git a/tests/producer_test.rs b/tests/producer_test.rs index 98438174..77fddd9a 100644 --- a/tests/producer_test.rs +++ b/tests/producer_test.rs @@ -6,6 +6,7 @@ use futures::{lock::Mutex, StreamExt}; use tokio::sync::mpsc::channel; use rabbitmq_stream_client::{ + error::ClientError, types::{Message, OffsetSpecification, SimpleValue}, Environment, }; @@ -681,3 +682,40 @@ async fn super_stream_producer_send_filtering_message() { Err(_) => assert!(false), } } + +#[tokio::test(flavor = "multi_thread")] +async fn producer_drop_connection() { + let _ = tracing_subscriber::fmt::try_init(); + let client_provided_name: String = Faker.fake(); + let env = TestEnvironment::create().await; + let producer = env + .env + .producer() + .client_provided_name(&client_provided_name) + .build(&env.stream) + .await + .unwrap(); + + producer + .send_with_confirm(Message::builder().body(b"message".to_vec()).build()) + .await + .unwrap(); + + let connection = wait_for_named_connection(client_provided_name.clone()).await; + drop_connection(connection).await; + + let closed = producer + .send_with_confirm(Message::builder().body(b"message".to_vec()).build()) + .await; + + assert!(matches!( + closed, + Err(rabbitmq_stream_client::error::ProducerPublishError::Timeout) + )); + + let err = producer.close().await.unwrap_err(); + assert!(matches!( + err, + rabbitmq_stream_client::error::ProducerCloseError::Client(ClientError::ConnectionClosed) + )); +} From a8752b9799fabe0de33d2ff6a83df939da5e2f6f Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Tue, 25 Feb 2025 18:25:12 +0100 Subject: [PATCH 11/12] Fix fmt & clippy --- src/producer.rs | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/producer.rs b/src/producer.rs index 066ebc69..3175c0dc 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -609,21 +609,13 @@ impl MessageHandler for ProducerConfirmHandler { } } -async fn invoke_handler( - f: T, +async fn invoke_handler( + f: ArcConfirmCallback, publishing_id: u64, confirmed: bool, status: ResponseCode, message: Message, -) where - T: std::ops::Deref< - Target = dyn Fn( - Result, - ) -> std::pin::Pin + Send>> - + Send - + Sync, - >, -{ +) { f(Ok(ConfirmationStatus { publishing_id, confirmed, @@ -633,13 +625,7 @@ async fn invoke_handler( .await; } async fn invoke_handler_once( - f: Box< - dyn FnOnce( - Result, - ) -> std::pin::Pin + Send>> - + Send - + Sync, - >, + f: ConfirmCallback, publishing_id: u64, confirmed: bool, status: ResponseCode, From 38b956a76eb2a607fae8dc4c0acd9b691ddb98bd Mon Sep 17 00:00:00 2001 From: Tommaso Allevi Date: Tue, 25 Feb 2025 18:32:47 +0100 Subject: [PATCH 12/12] Ignore bin on coverage --- .tarpaulin.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/.tarpaulin.toml b/.tarpaulin.toml index 443683d0..4a2714da 100644 --- a/.tarpaulin.toml +++ b/.tarpaulin.toml @@ -4,6 +4,7 @@ exclude-files=[ "benchmark/*", "build.rs", "src/lib.rs", + "src/bin/**/*.rs", "tests/**/*", "mod.rs" ]