diff --git a/.tarpaulin.toml b/.tarpaulin.toml index 443683d0..4a2714da 100644 --- a/.tarpaulin.toml +++ b/.tarpaulin.toml @@ -4,6 +4,7 @@ exclude-files=[ "benchmark/*", "build.rs", "src/lib.rs", + "src/bin/**/*.rs", "tests/**/*", "mod.rs" ] diff --git a/Cargo.toml b/Cargo.toml index 06cb49e4..684d8cd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,8 @@ tracing-subscriber = "0.3.1" fake = { version = "3.0.0", features = ['derive'] } chrono = "0.4.26" serde_json = "1.0" +reqwest = { version = "0.12", features = ["json"] } +serde = { version = "1.0", features = ["derive"] } [features] default = [] @@ -66,3 +68,6 @@ path="examples/superstreams/send_super_stream.rs" name="environment_deserialization" path="examples/environment_deserialization.rs" +[[bin]] +name = "perf-producer" +path = "src/bin/perf-producer.rs" diff --git a/examples/send_async.rs b/examples/send_async.rs index 401fbeb6..2e99ad4c 100644 --- a/examples/send_async.rs +++ b/examples/send_async.rs @@ -32,7 +32,7 @@ async fn main() -> Result<(), Box> { let create_response = environment .stream_creator() .max_length(ByteCapacity::GB(5)) - .create(&stream) + .create(stream) .await; if let Err(e) = create_response { diff --git a/examples/tls_producer.rs b/examples/tls_producer.rs index eb37b6e3..dbe8b4b9 100644 --- a/examples/tls_producer.rs +++ b/examples/tls_producer.rs @@ -54,9 +54,9 @@ async fn start_publisher( env: Environment, stream: &String, ) -> Result<(), Box> { - let _ = env.stream_creator().create(&stream).await; + let _ = env.stream_creator().create(stream).await; - let producer = env.producer().batch_size(BATCH_SIZE).build(&stream).await?; + let producer = env.producer().batch_size(BATCH_SIZE).build(stream).await?; let is_batch_send = true; tokio::task::spawn(async move { diff --git a/protocol/src/message/builder.rs b/protocol/src/message/builder.rs index 7afe103f..1c819639 100644 --- a/protocol/src/message/builder.rs +++ b/protocol/src/message/builder.rs @@ -30,7 +30,7 @@ impl MessageBuilder { pub fn application_properties(self) -> ApplicationPropertiesBuider { ApplicationPropertiesBuider(self) } - pub fn publising_id(mut self, publishing_id: u64) -> Self { + pub fn publishing_id(mut self, publishing_id: u64) -> Self { self.0.publishing_id = Some(publishing_id); self } diff --git a/src/bin/perf-producer.rs b/src/bin/perf-producer.rs new file mode 100644 index 00000000..f491c5a3 --- /dev/null +++ b/src/bin/perf-producer.rs @@ -0,0 +1,175 @@ +#![allow(dead_code)] + +use std::{ + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::{Duration, Instant, SystemTime, UNIX_EPOCH}, +}; + +use rabbitmq_stream_client::{ + types::{ByteCapacity, Message, OffsetSpecification}, + Environment, +}; +use tokio::{sync::mpsc::UnboundedSender, time::sleep}; +use tokio_stream::StreamExt; + +static ONE_SECOND: Duration = Duration::from_secs(1); +static ONE_MINUTE: Duration = Duration::from_secs(60); + +struct Metric { + created_at: u128, + received_at: SystemTime, +} + +#[derive(Debug)] +struct Stats { + average_latency: f32, + messages_received: usize, +} + +#[tokio::main] +async fn main() { + let stream_name = "perf-stream"; + + let environment = Environment::builder().build().await.unwrap(); + let _ = environment.delete_stream(stream_name).await; + environment + .stream_creator() + .max_length(ByteCapacity::GB(5)) + .create(stream_name) + .await + .unwrap(); + + let environment = Arc::new(environment); + + let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel(); + let consumer_env = environment.clone(); + let consumer_handler = tokio::spawn(async move { + start_consumer(consumer_env, stream_name, sender).await; + }); + + let produced_messages = AtomicU32::new(0); + let producer_env = environment.clone(); + let producer_handler = tokio::spawn(async move { + start_producer(producer_env, stream_name, &produced_messages).await; + }); + + let run_for = Duration::from_secs(5 * 60); + + tokio::spawn(async move { + sleep(run_for).await; + producer_handler.abort(); + sleep(Duration::from_secs(1)).await; + consumer_handler.abort(); + }); + + let minutes = run_for.as_secs() / 60; + + let mut now = Instant::now(); + // 5 minutes of metrics + let mut metrics = Vec::with_capacity(50 * 60 * minutes as usize); + while let Some(metric) = receiver.recv().await { + if now.elapsed() > ONE_MINUTE { + now = Instant::now(); + + let last_metrics = metrics; + metrics = Vec::with_capacity(50 * 60 * minutes as usize); + tokio::spawn(async move { + let stats = calculate_stats(last_metrics).await; + println!("stats: {:?}", stats); + }); + } + metrics.push(metric); + } + + let stats = calculate_stats(metrics).await; + println!("stats: {:?}", stats); +} + +async fn calculate_stats(metrics: Vec) -> Stats { + let mut total_latency = 0; + let metric_count = metrics.len(); + for metric in metrics { + let created_at = SystemTime::UNIX_EPOCH + Duration::from_millis(metric.created_at as u64); + let received_at = metric.received_at; + let delta = received_at.duration_since(created_at).unwrap(); + total_latency += delta.as_millis(); + } + + Stats { + average_latency: total_latency as f32 / metric_count as f32, + messages_received: metric_count, + } +} + +async fn start_consumer( + environment: Arc, + stream_name: &str, + sender: UnboundedSender, +) { + let mut consumer = environment + .consumer() + .offset(OffsetSpecification::First) + .build(stream_name) + .await + .unwrap(); + while let Some(Ok(delivery)) = consumer.next().await { + let produced_at = delivery + .message() + .data() + .map(|data| { + u128::from_be_bytes([ + data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], + data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15], + ]) + }) + .unwrap(); + let metric = Metric { + created_at: produced_at, + received_at: SystemTime::now(), + }; + sender.send(metric).unwrap(); + } +} + +async fn start_producer( + environment: Arc, + stream_name: &str, + produced_messages: &AtomicU32, +) { + let message_per_second = 50_usize; + let producer = environment.producer().build(stream_name).await.unwrap(); + + loop { + let start = Instant::now(); + let messages = create_messages(message_per_second); + let messages_sent = messages.len() as u32; + for message in messages { + producer.send(message, |_| async {}).await.unwrap(); + } + produced_messages.fetch_add(messages_sent, Ordering::Relaxed); + + let elapsed = start.elapsed(); + + if ONE_SECOND > elapsed { + sleep(ONE_SECOND - elapsed).await; + } + } +} + +fn create_messages(message_count_per_batch: usize) -> Vec { + (0..message_count_per_batch) + .map(|_| { + let start = SystemTime::now(); + let since_the_epoch = start + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + let since_the_epoch = since_the_epoch.as_millis(); + Message::builder() + .body(since_the_epoch.to_be_bytes()) + .build() + }) + .collect() +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 646777ea..4807ae7d 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -25,7 +25,7 @@ use tokio::{net::TcpStream, sync::Notify}; use tokio_rustls::client::TlsStream; use tokio_util::codec::Framed; -use tracing::trace; +use tracing::{trace, warn}; use crate::{error::ClientError, RabbitMQStreamResult}; pub use message::ClientMessage; @@ -275,8 +275,12 @@ impl Client { state.handler = Some(Arc::new(handler)); } + pub fn is_closed(&self) -> bool { + self.channel.is_closed() + } + pub async fn close(&self) -> RabbitMQStreamResult<()> { - if self.channel.is_closed() { + if self.is_closed() { return Err(ClientError::AlreadyClosed); } let _: CloseResponse = self @@ -286,12 +290,17 @@ impl Client { .await?; let mut state = self.state.write().await; - + // This stop the tokio task that performs heartbeats state.heartbeat_task.take(); - drop(state); + + self.force_drop_connection().await + } + + async fn force_drop_connection(&self) -> RabbitMQStreamResult<()> { self.channel.close().await } + pub async fn subscribe( &self, subscription_id: u8, @@ -622,7 +631,7 @@ impl Client { M: FnOnce(u32) -> R, { let Some((correlation_id, mut receiver)) = self.dispatcher.response_channel() else { - trace!("Connection si closed here"); + trace!("Connection is closed here"); return Err(ClientError::ConnectionClosed); }; @@ -711,9 +720,16 @@ impl Client { let heartbeat_task = tokio::spawn(async move { loop { trace!("Sending heartbeat"); - let _ = channel.send(HeartBeatCommand::default().into()).await; + if channel + .send(HeartBeatCommand::default().into()) + .await + .is_err() + { + break; + } tokio::time::sleep(Duration::from_secs(heartbeat_interval.into())).await; } + warn!("Heartbeat task stopped. Force closing connection"); }) .into(); state.heartbeat_task = Some(heartbeat_task); diff --git a/src/client/options.rs b/src/client/options.rs index 74fecaa8..814bfcc7 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -223,6 +223,11 @@ impl ClientOptionsBuilder { self } + pub fn client_provided_name(mut self, client_provided_name: String) -> Self { + self.0.client_provided_name = client_provided_name; + self + } + pub fn collector(mut self, collector: Arc) -> Self { self.0.collector = collector; self @@ -587,7 +592,7 @@ mod tests { assert_eq!(options.heartbeat, 10000); assert_eq!(options.max_frame_size, 1); assert!(matches!(options.tls, TlsConfiguration::Untrusted)); - assert_eq!(options.load_balancer_mode, true); + assert!(options.load_balancer_mode); } #[cfg(feature = "serde")] diff --git a/src/environment.rs b/src/environment.rs index 911a90d7..9592dbcd 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -1,6 +1,5 @@ use std::marker::PhantomData; use std::sync::Arc; -use std::time::Duration; use crate::types::OffsetSpecification; use crate::{client::TlsConfiguration, producer::NoDedup}; @@ -197,7 +196,6 @@ impl Environment { environment: self.clone(), name: None, batch_size: 100, - batch_publishing_delay: Duration::from_millis(100), data: PhantomData, filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), @@ -328,6 +326,7 @@ impl EnvironmentBuilder { self.0.client_options.heartbeat = heartbeat; self } + pub fn metrics_collector( mut self, collector: impl MetricsCollector + 'static, diff --git a/src/error.rs b/src/error.rs index f138a699..0e1075f9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -109,6 +109,8 @@ pub enum ProducerPublishError { Confirmation { stream: String }, #[error(transparent)] Client(#[from] ClientError), + #[error("Failed to publish message, timeout")] + Timeout, } #[derive(Error, Debug)] diff --git a/src/producer.rs b/src/producer.rs index 29913bdd..3175c0dc 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -1,17 +1,18 @@ use std::future::Future; +use std::time::Duration; use std::{ marker::PhantomData, sync::{ atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, Arc, }, - time::Duration, }; use dashmap::DashMap; use futures::{future::BoxFuture, FutureExt}; use tokio::sync::mpsc::channel; use tokio::sync::{mpsc, Mutex}; +use tokio::time::sleep; use tracing::{debug, error, trace}; use rabbitmq_stream_protocol::{message::Message, ResponseCode, ResponseKind}; @@ -71,21 +72,6 @@ pub struct ProducerInternal { filter_value_extractor: Option, } -impl ProducerInternal { - async fn batch_send(&self) -> Result<(), ProducerPublishError> { - let messages = self.accumulator.get(self.batch_size).await?; - - if !messages.is_empty() { - debug!("Sending batch of {} messages", messages.len()); - self.client - .publish(self.producer_id, messages, self.publish_version) - .await?; - } - - Ok(()) - } -} - /// API for publising messages to RabbitMQ stream #[derive(Clone)] pub struct Producer(Arc, PhantomData); @@ -95,7 +81,6 @@ pub struct ProducerBuilder { pub(crate) environment: Environment, pub(crate) name: Option, pub batch_size: usize, - pub batch_publishing_delay: Duration, pub(crate) data: PhantomData, pub filter_value_extractor: Option, pub(crate) client_provided_name: String, @@ -169,7 +154,7 @@ impl ProducerBuilder { let internal_producer = Arc::new(producer); let producer = Producer(internal_producer.clone(), PhantomData); - schedule_batch_send(internal_producer, self.batch_publishing_delay); + schedule_batch_send(internal_producer); Ok(producer) } else { @@ -185,11 +170,6 @@ impl ProducerBuilder { self } - pub fn batch_delay(mut self, delay: Duration) -> Self { - self.batch_publishing_delay = delay; - self - } - pub fn client_provided_name(mut self, name: &str) -> Self { self.client_provided_name = String::from(name); self @@ -201,7 +181,6 @@ impl ProducerBuilder { environment: self.environment, name: self.name, batch_size: self.batch_size, - batch_publishing_delay: self.batch_publishing_delay, data: PhantomData, filter_value_extractor: None, client_provided_name: String::from("rust-stream-producer"), @@ -229,7 +208,6 @@ impl ProducerBuilder { pub struct MessageAccumulator { sender: mpsc::Sender, receiver: Mutex>, - capacity: usize, message_count: AtomicUsize, } @@ -239,52 +217,64 @@ impl MessageAccumulator { Self { sender, receiver: Mutex::new(receiver), - capacity: batch_size, message_count: AtomicUsize::new(0), } } - pub async fn add(&self, message: ClientMessage) -> RabbitMQStreamResult { - self.sender - .send(message) - .await - .map_err(|err| ClientError::GenericError(Box::new(err)))?; - - let val = self.message_count.fetch_add(1, Ordering::Relaxed); - - Ok(val + 1 == self.capacity) + pub async fn add(&self, message: ClientMessage) -> RabbitMQStreamResult<()> { + match self.sender.send(message).await { + Ok(_) => { + self.message_count.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + Err(e) => Err(ClientError::GenericError(Box::new(e))), + } } - pub async fn get(&self, batch_size: usize) -> RabbitMQStreamResult> { - let mut messages = Vec::with_capacity(batch_size); - let mut count = 0; + pub async fn get(&self, buffer: &mut Vec, batch_size: usize) -> (bool, usize) { let mut receiver = self.receiver.lock().await; - while count < batch_size { - match receiver.try_recv().ok() { - Some(message) => { - messages.push(message); - count += 1; - } - _ => break, - } - } - self.message_count - .fetch_sub(messages.len(), Ordering::Relaxed); - Ok(messages) + + let count = receiver.recv_many(buffer, batch_size).await; + self.message_count.fetch_sub(count, Ordering::Relaxed); + + // `recv_many` returns 0 only if the channel is closed + // Read https://docs.rs/tokio/latest/tokio/sync/mpsc/struct.Receiver.html#method.recv_many + (count == 0, count) } } -fn schedule_batch_send(producer: Arc, delay: Duration) { +fn schedule_batch_send(producer: Arc) { tokio::task::spawn(async move { - let mut interval = tokio::time::interval(delay); - - debug!("Starting batch send interval every {:?}", delay); + let mut buffer = Vec::with_capacity(producer.batch_size); loop { - interval.tick().await; + let (is_closed, count) = producer + .accumulator + .get(&mut buffer, producer.batch_size) + .await; + + if is_closed { + error!("Channel is closed and this is bad"); + break; + } - match producer.batch_send().await { - Ok(_) => {} - Err(e) => error!("Error publishing batch {:?}", e), + if count > 0 { + debug!("Sending batch of {} messages", count); + let messages: Vec<_> = buffer.drain(..count).collect(); + match producer + .client + .publish(producer.producer_id, messages, producer.publish_version) + .await + { + Ok(_) => {} + Err(e) => { + error!("Error publishing batch {:?}", e); + + // Stop loop if producer is closed + if producer.closed.load(Ordering::Relaxed) { + break; + } + } + }; } } }); @@ -376,13 +366,19 @@ impl Producer { }) .await?; - rx.recv() - .await - .ok_or_else(|| ProducerPublishError::Confirmation { - stream: self.0.stream.clone(), - })? - .map_err(|err| ClientError::GenericError(Box::new(err))) - .map(Ok)? + let r = tokio::select! { + val = rx.recv() => { + Ok(val) + } + _ = sleep(Duration::from_secs(1)) => { + Err(ProducerPublishError::Timeout) + } + }?; + r.ok_or_else(|| ProducerPublishError::Confirmation { + stream: self.0.stream.clone(), + })? + .map_err(|err| ClientError::GenericError(Box::new(err))) + .map(Ok)? } async fn do_batch_send_with_confirm( @@ -459,9 +455,7 @@ impl Producer { .waiting_confirmations .insert(publishing_id, ProducerMessageWaiter::Once(waiter)); - if self.0.accumulator.add(msg).await? { - self.0.batch_send().await?; - } + self.0.accumulator.add(msg).await?; Ok(()) } @@ -479,7 +473,6 @@ impl Producer { let arc_cb = Arc::new(move |status| cb(status).boxed()); - let mut wrapped_msgs = Vec::with_capacity(messages.len()); for message in messages { let waiter = SharedProducerMessageWaiter::waiter_with_arc_cb(arc_cb.clone(), message.clone()); @@ -493,18 +486,14 @@ impl Producer { if let Some(f) = self.0.filter_value_extractor.as_ref() { client_message.filter_value_extract(f.as_ref()) } - wrapped_msgs.push(client_message); + // Queue the message for sending + self.0.accumulator.add(client_message).await?; self.0 .waiting_confirmations .insert(publishing_id, ProducerMessageWaiter::Shared(waiter.clone())); } - self.0 - .client - .publish(self.0.producer_id, wrapped_msgs, self.0.publish_version) - .await?; - Ok(()) } @@ -558,23 +547,23 @@ impl MessageHandler for ProducerConfirmHandler { }; match waiter { ProducerMessageWaiter::Once(waiter) => { - let cb = waiter.cb; - cb(Ok(ConfirmationStatus { - publishing_id: id, - confirmed: true, - status: ResponseCode::Ok, - message: waiter.msg, - })) + invoke_handler_once( + waiter.cb, + id, + true, + ResponseCode::Ok, + waiter.msg, + ) .await; } ProducerMessageWaiter::Shared(waiter) => { - let cb = waiter.cb; - cb(Ok(ConfirmationStatus { - publishing_id: id, - confirmed: true, - status: ResponseCode::Ok, - message: waiter.msg, - })) + invoke_handler( + waiter.cb, + id, + true, + ResponseCode::Ok, + waiter.msg, + ) .await; } } @@ -595,24 +584,11 @@ impl MessageHandler for ProducerConfirmHandler { }; match waiter { ProducerMessageWaiter::Once(waiter) => { - let cb = waiter.cb; - cb(Ok(ConfirmationStatus { - publishing_id: id, - confirmed: false, - status: code, - message: waiter.msg, - })) - .await; + invoke_handler_once(waiter.cb, id, false, code, waiter.msg) + .await; } ProducerMessageWaiter::Shared(waiter) => { - let cb = waiter.cb; - cb(Ok(ConfirmationStatus { - publishing_id: id, - confirmed: false, - status: code, - message: waiter.msg, - })) - .await; + invoke_handler(waiter.cb, id, false, code, waiter.msg).await; } } } @@ -633,6 +609,37 @@ impl MessageHandler for ProducerConfirmHandler { } } +async fn invoke_handler( + f: ArcConfirmCallback, + publishing_id: u64, + confirmed: bool, + status: ResponseCode, + message: Message, +) { + f(Ok(ConfirmationStatus { + publishing_id, + confirmed, + status, + message, + })) + .await; +} +async fn invoke_handler_once( + f: ConfirmCallback, + publishing_id: u64, + confirmed: bool, + status: ResponseCode, + message: Message, +) { + f(Ok(ConfirmationStatus { + publishing_id, + confirmed, + status, + message, + })) + .await; +} + type ConfirmCallback = Box< dyn FnOnce(Result) -> BoxFuture<'static, ()> + Send diff --git a/tests/integration/client_test.rs b/tests/client_test.rs similarity index 85% rename from tests/integration/client_test.rs rename to tests/client_test.rs index 4051c484..85145b16 100644 --- a/tests/integration/client_test.rs +++ b/tests/client_test.rs @@ -12,7 +12,10 @@ use rabbitmq_stream_client::{ Client, ClientOptions, }; -use crate::common::TestClient; +#[path = "./common.rs"] +mod common; + +use common::*; #[tokio::test] async fn client_connection_test() { @@ -221,10 +224,10 @@ async fn client_store_query_offset_error_test() { match response_off_not_found { Ok(_) => panic!("Should not be ok"), Err(e) => { - assert_eq!( - matches!(e, ClientError::RequestError(ResponseCode::OffsetNotFound)), - true - ) + assert!(matches!( + e, + ClientError::RequestError(ResponseCode::OffsetNotFound) + )) } } @@ -238,13 +241,10 @@ async fn client_store_query_offset_error_test() { match response_stream_does_not_exist { Ok(_) => panic!("Should not be ok"), Err(e) => { - assert_eq!( - matches!( - e, - ClientError::RequestError(ResponseCode::StreamDoesNotExist) - ), - true - ) + assert!(matches!( + e, + ClientError::RequestError(ResponseCode::StreamDoesNotExist) + )) } } } @@ -397,7 +397,7 @@ async fn client_publish() { assert_eq!(1, delivery.messages.len()); assert_eq!( Some(b"message".as_ref()), - delivery.messages.get(0).unwrap().data() + delivery.messages.first().unwrap().data() ); } @@ -429,8 +429,8 @@ async fn client_test_partitions_test() { .unwrap(); assert_eq!( - response.streams.get(0).unwrap(), - test.partitions.get(0).unwrap() + response.streams.first().unwrap(), + test.partitions.first().unwrap() ); assert_eq!( response.streams.get(1).unwrap(), @@ -453,7 +453,53 @@ async fn client_test_route_test() { assert_eq!(response.streams.len(), 1); assert_eq!( - response.streams.get(0).unwrap(), - test.partitions.get(0).unwrap() + response.streams.first().unwrap(), + test.partitions.first().unwrap() + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn client_close() { + let test = TestClient::create().await; + + test.client + .close() + .await + .expect("Failed to close the client"); + + let err = test.client.unsubscribe(1).await; + assert!( + matches!(err, Err(ClientError::ConnectionClosed)) || matches!(err, Err(ClientError::Io(_))) ); } + +#[tokio::test(flavor = "multi_thread")] +async fn client_drop_connection() { + let _ = tracing_subscriber::fmt::try_init(); + let client_provider_name: String = Faker.fake(); + + let options = ClientOptions::builder() + .client_provided_name(client_provider_name.clone()) + .heartbeat(2) + .build(); + let test = TestClient::create_with_option(options).await; + + let reference: String = Faker.fake(); + let _ = test + .client + .declare_publisher(1, Some(reference.clone()), "not_existing_stream") + .await; + let _ = test.client.unsubscribe(1).await; + + let connection = wait_for_named_connection(client_provider_name.clone()).await; + drop_connection(connection).await; + + let res = test + .client + .declare_publisher(1, Some(reference.clone()), "not_existing_stream") + .await; + + assert!(matches!(res, Err(ClientError::ConnectionClosed))); + let res = test.client.close().await; + assert!(matches!(res, Err(ClientError::ConnectionClosed))); +} diff --git a/tests/integration/common.rs b/tests/common.rs similarity index 65% rename from tests/integration/common.rs rename to tests/common.rs index ef36ab96..d23d7ab3 100644 --- a/tests/integration/common.rs +++ b/tests/common.rs @@ -1,9 +1,11 @@ +use core::panic; use std::{collections::HashMap, future::Future, sync::Arc}; use fake::{Fake, Faker}; use rabbitmq_stream_client::{Client, ClientOptions, Environment}; use rabbitmq_stream_protocol::commands::generic::GenericResponse; use rabbitmq_stream_protocol::ResponseCode; +use serde::Deserialize; use tokio::sync::Semaphore; pub struct TestClient { @@ -44,8 +46,12 @@ pub struct TestEnvironment { impl TestClient { pub async fn create() -> TestClient { + Self::create_with_option(ClientOptions::default()).await + } + + pub async fn create_with_option(options: ClientOptions) -> TestClient { let stream: String = Faker.fake(); - let client = Client::connect(ClientOptions::default()).await.unwrap(); + let client = Client::connect(options).await.unwrap(); let response = client.create_stream(&stream, HashMap::new()).await.unwrap(); @@ -76,19 +82,22 @@ impl TestClient { impl Drop for TestClient { fn drop(&mut self) { - if self.stream != "" { + if !self.stream.is_empty() { tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(async { self.client.delete_stream(&self.stream).await.unwrap() }) + tokio::runtime::Handle::current().block_on(async { + // Some tests may close the connection intentionally + // so we ignore the error here + let _ = self.client.delete_stream(&self.stream).await; + }) }); } - if self.super_stream != "" { + if !self.super_stream.is_empty() { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { self.client .delete_super_stream(&self.super_stream) .await - .unwrap() + .unwrap(); }) }); } @@ -128,19 +137,20 @@ impl TestEnvironment { impl Drop for TestEnvironment { fn drop(&mut self) { - if self.stream != "" { + if !self.stream.is_empty() { tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(async { self.env.delete_stream(&self.stream).await.unwrap() }) + tokio::runtime::Handle::current().block_on(async { + self.env.delete_stream(&self.stream).await.unwrap(); + }) }); } - if self.super_stream != "" { + if !self.super_stream.is_empty() { tokio::task::block_in_place(|| { tokio::runtime::Handle::current().block_on(async { self.env .delete_super_stream(&self.super_stream) .await - .unwrap() + .unwrap(); }) }); } @@ -164,7 +174,7 @@ pub async fn create_generic_super_stream( let response = client .create_super_stream( - &super_stream, + super_stream, partitions.clone(), binding_keys, HashMap::new(), @@ -172,5 +182,51 @@ pub async fn create_generic_super_stream( .await .unwrap(); - return (response, partitions); + (response, partitions) +} + +#[derive(Deserialize, Debug)] +pub struct RabbitConnection { + pub name: String, + pub client_properties: HashMap, +} + +pub async fn list_http_connection() -> Vec { + reqwest::Client::new() + .get("http://localhost:15672/api/connections/") + .basic_auth("guest", Some("guest")) + .send() + .await + .unwrap() + .json() + .await + .unwrap() +} + +pub async fn wait_for_named_connection(connection_name: String) -> RabbitConnection { + let mut max = 10; + while max > 0 { + let connections = list_http_connection().await; + let connection = connections + .into_iter() + .find(|x| x.client_properties.get("connection_name") == Some(&connection_name)); + match connection { + Some(connection) => return connection, + None => tokio::time::sleep(tokio::time::Duration::from_secs(1)).await, + } + max -= 1; + } + panic!("Connection not found. timeout"); +} + +pub async fn drop_connection(connection: RabbitConnection) { + reqwest::Client::new() + .delete(format!( + "http://localhost:15672/api/connections/{}", + connection.name + )) + .basic_auth("guest", Some("guest")) + .send() + .await + .unwrap(); } diff --git a/tests/integration/consumer_test.rs b/tests/consumer_test.rs similarity index 97% rename from tests/integration/consumer_test.rs rename to tests/consumer_test.rs index dfcaa56b..ab50394f 100644 --- a/tests/integration/consumer_test.rs +++ b/tests/consumer_test.rs @@ -1,6 +1,10 @@ use std::time::Duration; -use crate::common::TestEnvironment; +#[path = "./common.rs"] +mod common; + +use common::*; + use fake::{Fake, Faker}; use futures::StreamExt; use rabbitmq_stream_client::{ @@ -12,7 +16,6 @@ use rabbitmq_stream_client::{ Consumer, FilterConfiguration, NoDedup, Producer, }; -use crate::producer_test::routing_key_strategy_value_extractor; use rabbitmq_stream_client::types::{ HashRoutingMurmurStrategy, RoutingKeyRoutingStrategy, RoutingStrategy, }; @@ -22,6 +25,10 @@ use tokio::sync::Notify; use tokio::task; use {std::sync::Arc, std::sync::Mutex}; +pub fn routing_key_strategy_value_extractor(_: &Message) -> String { + "0".to_string() +} + #[tokio::test(flavor = "multi_thread")] async fn consumer_test() { let env = TestEnvironment::create().await; @@ -64,7 +71,7 @@ async fn consumer_test() { fn hash_strategy_value_extractor(message: &Message) -> String { let s = String::from_utf8(Vec::from(message.data().unwrap())).expect("Found invalid UTF-8"); - return s; + s } #[tokio::test(flavor = "multi_thread")] @@ -94,7 +101,7 @@ async fn super_stream_consumer_test() { for n in 0..message_count { let msg = Message::builder().body(format!("message{}", n)).build(); - let _ = super_stream_producer + super_stream_producer .send(msg, |_confirmation_status| async move {}) .await .unwrap(); @@ -104,7 +111,7 @@ async fn super_stream_consumer_test() { let handle = super_stream_consumer.handle(); while let Some(_) = super_stream_consumer.next().await { - received_messages = received_messages + 1; + received_messages += 1; if received_messages == 10 { break; } @@ -281,13 +288,10 @@ async fn consumer_create_stream_not_existing_error() { let consumer = env.env.consumer().build("stream_not_existing").await; match consumer { - Err(e) => assert_eq!( - matches!( - e, - rabbitmq_stream_client::error::ConsumerCreateError::StreamDoesNotExist { .. } - ), - true - ), + Err(e) => assert!(matches!( + e, + rabbitmq_stream_client::error::ConsumerCreateError::StreamDoesNotExist { .. } + )), _ => panic!("Should be StreamNotFound error"), } } @@ -437,7 +441,7 @@ async fn consumer_test_with_filtering() { let filter_configuration = FilterConfiguration::new(vec!["filtering".to_string()], false) .post_filter(|message| { String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) - == "filtering".to_string() + == *"filtering" }); let mut consumer = env @@ -514,7 +518,7 @@ async fn super_stream_consumer_test_with_filtering() { let filter_configuration = FilterConfiguration::new(vec!["filtering".to_string()], false) .post_filter(|message| { String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) - == "filtering".to_string() + == *"filtering" }); let mut super_stream_consumer = env @@ -622,8 +626,7 @@ async fn consumer_test_with_filtering_match_unfiltered() { let filter_configuration = FilterConfiguration::new(vec!["1".to_string()], true).post_filter(|message| { - String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) - == "1".to_string() + String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) == *"1" }); let mut consumer = env @@ -718,7 +721,7 @@ async fn super_stream_single_active_consumer_test() { for n in 0..message_count { let msg = Message::builder().body(format!("message{}", n)).build(); - let _ = super_stream_producer + super_stream_producer .send(msg, |_confirmation_status| async move {}) .await .unwrap(); @@ -877,7 +880,7 @@ async fn super_stream_single_active_consumer_test_with_callback() { for n in 0..message_count { let msg = Message::builder().body(format!("message{}", n)).build(); - let _ = super_stream_producer + super_stream_producer .send(msg, |_confirmation_status| async move {}) .await .unwrap(); diff --git a/tests/integration/environment_test.rs b/tests/environment_test.rs similarity index 94% rename from tests/integration/environment_test.rs rename to tests/environment_test.rs index d62c050d..7362552c 100644 --- a/tests/integration/environment_test.rs +++ b/tests/environment_test.rs @@ -6,7 +6,10 @@ use rabbitmq_stream_client::types::ByteCapacity; use rabbitmq_stream_client::{error, Environment, TlsConfiguration}; use rabbitmq_stream_protocol::ResponseCode; -use crate::common::TestEnvironment; +#[path = "./common.rs"] +mod common; + +use common::*; #[tokio::test(flavor = "multi_thread")] async fn environment_create_test() { @@ -40,20 +43,20 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn environment_create_and_delete_super_stream_test() { - let super_stream = "super_stream_test"; + let super_stream: String = Faker.fake(); let env = Environment::builder().build().await.unwrap(); let response = env .stream_creator() .max_length(ByteCapacity::GB(5)) - .create_super_stream(super_stream, 3, None) + .create_super_stream(&super_stream, 3, None) .await; - assert_eq!(response.is_ok(), true); + assert!(response.is_ok()); - let response = env.delete_super_stream(super_stream).await; + let response = env.delete_super_stream(&super_stream).await; - assert_eq!(response.is_ok(), true); + assert!(response.is_ok()); } #[tokio::test(flavor = "multi_thread")] @@ -127,7 +130,7 @@ async fn environment_create_delete_stream_twice() { let env = Environment::builder().build().await.unwrap(); let stream_to_test: String = Faker.fake(); let response = env.stream_creator().create(&stream_to_test).await; - assert_eq!(response.is_ok(), true); + assert!(response.is_ok()); let response = env.stream_creator().create(&stream_to_test).await; @@ -141,7 +144,7 @@ async fn environment_create_delete_stream_twice() { // The first delete should succeed since the stream was created let delete_response = env.delete_stream(&stream_to_test).await; - assert_eq!(delete_response.is_ok(), true); + assert!(delete_response.is_ok()); // the second delete should fail since the stream was already deleted let delete_response = env.delete_stream(&stream_to_test).await; @@ -171,10 +174,10 @@ async fn environment_create_streams_with_parameters() { .max_segment_size(ByteCapacity::GB(1)) .create(&stream_to_test) .await; - assert_eq!(response.is_ok(), true); + assert!(response.is_ok()); let delete_response = env.delete_stream(&stream_to_test).await; - assert_eq!(delete_response.is_ok(), true); + assert!(delete_response.is_ok()); } #[tokio::test(flavor = "multi_thread")] diff --git a/tests/integration/main.rs b/tests/integration/main.rs deleted file mode 100644 index 25dbbab8..00000000 --- a/tests/integration/main.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod client_test; -mod common; -mod consumer_test; -mod environment_test; -mod producer_test; diff --git a/tests/integration/producer_test.rs b/tests/producer_test.rs similarity index 91% rename from tests/integration/producer_test.rs rename to tests/producer_test.rs index c62c6c35..77fddd9a 100644 --- a/tests/integration/producer_test.rs +++ b/tests/producer_test.rs @@ -6,11 +6,15 @@ use futures::{lock::Mutex, StreamExt}; use tokio::sync::mpsc::channel; use rabbitmq_stream_client::{ + error::ClientError, types::{Message, OffsetSpecification, SimpleValue}, Environment, }; -use crate::common::{Countdown, TestEnvironment}; +#[path = "./common.rs"] +mod common; + +use common::*; use rabbitmq_stream_client::types::{ HashRoutingMurmurStrategy, RoutingKeyRoutingStrategy, RoutingStrategy, @@ -68,7 +72,7 @@ async fn producer_send_name_deduplication_unique_ids() { for _ in 0..times { let cloned_ids = ids.clone(); let countdown = countdown.clone(); - let _ = producer + producer .send( Message::builder().body(b"message".to_vec()).build(), move |result| { @@ -129,7 +133,7 @@ async fn producer_send_name_with_deduplication_ok() { .send_with_confirm( Message::builder() .body(b"message0".to_vec()) - .publising_id(0) + .publishing_id(0) .build(), ) .await @@ -170,24 +174,24 @@ async fn producer_send_batch_name_with_deduplication_ok() { // confirmed Message::builder() .body(b"message".to_vec()) - .publising_id(0) + .publishing_id(0) .build(), // this won't be confirmed // since it will skipped by deduplication Message::builder() .body(b"message".to_vec()) - .publising_id(0) + .publishing_id(0) .build(), // confirmed since the publishing id is different Message::builder() .body(b"message".to_vec()) - .publising_id(1) + .publishing_id(1) .build(), // not confirmed since the publishing id is the same // message will be skipped by deduplication Message::builder() .body(b"message".to_vec()) - .publising_id(1) + .publishing_id(1) .build(), ]) .await @@ -306,7 +310,7 @@ async fn producer_batch_send() { assert_eq!(1, result.len()); - let confirmation = result.get(0).unwrap(); + let confirmation = result.first().unwrap(); assert_eq!(0, confirmation.publishing_id()); assert!(confirmation.confirmed()); assert_eq!(Some(b"message".as_ref()), confirmation.message().data()); @@ -411,8 +415,8 @@ async fn producer_send_with_complex_message_ok() { ); assert_eq!( - Some(1u32.into()), - properties.and_then(|properties| properties.group_sequence.clone()) + Some(1u32), + properties.and_then(|properties| properties.group_sequence) ); assert_eq!( @@ -451,13 +455,10 @@ async fn producer_create_stream_not_existing_error() { let producer = env.env.producer().build("stream_not_existing").await; match producer { - Err(e) => assert_eq!( - matches!( - e, - rabbitmq_stream_client::error::ProducerCreateError::StreamDoesNotExist { .. } - ), - true - ), + Err(e) => assert!(matches!( + e, + rabbitmq_stream_client::error::ProducerCreateError::StreamDoesNotExist { .. } + )), _ => panic!("Should be StreamNotFound error"), } } @@ -472,22 +473,19 @@ async fn producer_send_after_close_error() { .await .unwrap_err(); - assert_eq!( - matches!( - closed, - rabbitmq_stream_client::error::ProducerPublishError::Closed - ), - true - ); + assert!(matches!( + closed, + rabbitmq_stream_client::error::ProducerPublishError::Closed + )); } pub fn routing_key_strategy_value_extractor(_: &Message) -> String { - return "0".to_string(); + "0".to_string() } fn hash_strategy_value_extractor(message: &Message) -> String { let s = String::from_utf8(Vec::from(message.data().unwrap())).expect("Found invalid UTF-8"); - return s; + s } #[tokio::test(flavor = "multi_thread")] @@ -551,13 +549,10 @@ async fn key_super_steam_non_existing_producer_test() { .await .unwrap_err(); - assert_eq!( - matches!( - result, - rabbitmq_stream_client::error::SuperStreamProducerPublishError::ProducerCreateError() - ), - true - ); + assert!(matches!( + result, + rabbitmq_stream_client::error::SuperStreamProducerPublishError::ProducerCreateError() + )); _ = super_stream_producer.close(); } @@ -638,13 +633,10 @@ async fn producer_send_filtering_message() { let closed = producer.send_with_confirm(message).await.unwrap_err(); - assert_eq!( - matches!( - closed, - rabbitmq_stream_client::error::ProducerPublishError::Closed - ), - true - ); + assert!(matches!( + closed, + rabbitmq_stream_client::error::ProducerPublishError::Closed + )); } #[tokio::test(flavor = "multi_thread")] @@ -690,3 +682,40 @@ async fn super_stream_producer_send_filtering_message() { Err(_) => assert!(false), } } + +#[tokio::test(flavor = "multi_thread")] +async fn producer_drop_connection() { + let _ = tracing_subscriber::fmt::try_init(); + let client_provided_name: String = Faker.fake(); + let env = TestEnvironment::create().await; + let producer = env + .env + .producer() + .client_provided_name(&client_provided_name) + .build(&env.stream) + .await + .unwrap(); + + producer + .send_with_confirm(Message::builder().body(b"message".to_vec()).build()) + .await + .unwrap(); + + let connection = wait_for_named_connection(client_provided_name.clone()).await; + drop_connection(connection).await; + + let closed = producer + .send_with_confirm(Message::builder().body(b"message".to_vec()).build()) + .await; + + assert!(matches!( + closed, + Err(rabbitmq_stream_client::error::ProducerPublishError::Timeout) + )); + + let err = producer.close().await.unwrap_err(); + assert!(matches!( + err, + rabbitmq_stream_client::error::ProducerCloseError::Client(ClientError::ConnectionClosed) + )); +}