From 69e5d42e094bbf6b2d1ef51a8a95e991ebc0497c Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Sat, 27 Jul 2024 10:57:34 +0800 Subject: [PATCH 01/12] add the load_malance_made attribute to the Client Options. --- README.md | 13 +++++++++++++ src/client/options.rs | 9 +++++++++ src/consumer.rs | 27 ++++++++++++++++++++------- src/environment.rs | 5 +++++ src/producer.rs | 29 ++++++++++++++++++++++------- 5 files changed, 69 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 4ba1e33d..aa8b5ddc 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,19 @@ let environment = Environment::builder() .build() ``` +##### Building the environment with a load balancer + +```rust,no_run +use rabbitmq_stream_client::Environment; + + +let environment = Environment::builder() + .load_balancer_mode(true) + .build() +``` + + + ##### Publishing messages ```rust,no_run diff --git a/src/client/options.rs b/src/client/options.rs index 16ae1b72..829c0f1d 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -12,6 +12,7 @@ pub struct ClientOptions { pub(crate) v_host: String, pub(crate) heartbeat: u32, pub(crate) max_frame_size: u32, + pub(crate) load_balancer_mode: bool, pub(crate) tls: TlsConfiguration, pub(crate) collector: Arc, } @@ -39,6 +40,7 @@ impl Default for ClientOptions { v_host: "/".to_owned(), heartbeat: 60, max_frame_size: 1048576, + load_balancer_mode: false, collector: Arc::new(NopMetricsCollector {}), tls: TlsConfiguration { enabled: false, @@ -117,6 +119,11 @@ impl ClientOptionsBuilder { self } + pub fn load_balancer_mode(mut self, load_balancer_mode: bool) -> Self { + self.0.load_balancer_mode = load_balancer_mode; + self + } + pub fn build(self) -> ClientOptions { self.0 } @@ -145,6 +152,7 @@ mod tests { client_keys_path: String::from(""), }) .collector(Arc::new(NopMetricsCollector {})) + .load_balancer_mode(true) .build(); assert_eq!(options.host, "test"); assert_eq!(options.port, 8888); @@ -154,5 +162,6 @@ mod tests { assert_eq!(options.heartbeat, 10000); assert_eq!(options.max_frame_size, 1); assert_eq!(options.tls.enabled, true); + assert_eq!(options.load_balancer_mode, true); } } diff --git a/src/consumer.rs b/src/consumer.rs index 59971d3f..58959028 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -76,12 +76,26 @@ impl ConsumerBuilder { metadata.replicas, stream ); - client = Client::connect(ClientOptions { - host: replica.host.clone(), - port: replica.port as u16, - ..self.environment.options.client_options - }) - .await?; + let load_balancer_mode = self.environment.options.client_options.load_balancer_mode; + if load_balancer_mode { + let options = self.environment.options.client_options.clone(); + loop { + let temp_client = Client::connect(options.clone()).await?; + let mapping = temp_client.connection_properties().await; + let advertised_host = mapping.get("advertised_host").unwrap(); + if *advertised_host == replica.host.clone() { + client = temp_client; + break; + } + } + } else { + client = Client::connect(ClientOptions { + host: replica.host.clone(), + port: replica.port as u16, + ..self.environment.options.client_options + }) + .await?; + } } } else { return Err(ConsumerCreateError::StreamDoesNotExist { @@ -100,7 +114,6 @@ impl ConsumerBuilder { waker: AtomicWaker::new(), metrics_collector: collector, }); - let msg_handler = ConsumerMessageHandler(consumer.clone()); client.set_handler(msg_handler).await; diff --git a/src/environment.rs b/src/environment.rs index 84bd11fd..b63077be 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -126,6 +126,11 @@ impl EnvironmentBuilder { self.0.client_options.collector = Arc::new(collector); self } + + pub fn load_balancer_mode(mut self, load_balancer_mode: bool) -> EnvironmentBuilder { + self.0.client_options.load_balancer_mode = load_balancer_mode; + self + } } #[derive(Clone, Default)] pub struct EnvironmentOptions { diff --git a/src/producer.rs b/src/producer.rs index 39ad8bce..aa289187 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -119,13 +119,28 @@ impl ProducerBuilder { metadata.leader, stream ); - client.close().await?; - client = Client::connect(ClientOptions { - host: metadata.leader.host.clone(), - port: metadata.leader.port as u16, - ..self.environment.options.client_options - }) - .await?; + let load_balancer_mode: bool = self.environment.options.client_options.load_balancer_mode; + if load_balancer_mode { + // Producer must connect to leader node + let options: ClientOptions = self.environment.options.client_options.clone(); + loop { + let temp_client = Client::connect(options.clone()).await?; + let mapping = temp_client.connection_properties().await; + let advertised_host = mapping.get("advertised_host").unwrap(); + if *advertised_host == metadata.leader.host.clone() { + client = temp_client; + break; + } + } + } else { + client.close().await?; + client = Client::connect(ClientOptions { + host: metadata.leader.host.clone(), + port: metadata.leader.port as u16, + ..self.environment.options.client_options + }) + .await? + }; } else { return Err(ProducerCreateError::StreamDoesNotExist { stream: stream.into(), From 76550c11773efee163131cf562a4c98089074837 Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Tue, 30 Jul 2024 09:08:44 +0800 Subject: [PATCH 02/12] fixup cargo fmt --- src/producer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/producer.rs b/src/producer.rs index aa289187..7b3ec5e0 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -119,7 +119,7 @@ impl ProducerBuilder { metadata.leader, stream ); - let load_balancer_mode: bool = self.environment.options.client_options.load_balancer_mode; + let load_balancer_mode = self.environment.options.client_options.load_balancer_mode; if load_balancer_mode { // Producer must connect to leader node let options: ClientOptions = self.environment.options.client_options.clone(); From 9faf502a0bc717d22ffbd1af3baae35ff21494ed Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Tue, 30 Jul 2024 16:37:45 +0800 Subject: [PATCH 03/12] fixup: remove unwrap and skip the temp_client if no advertised_host property --- src/consumer.rs | 9 +++++---- src/producer.rs | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/consumer.rs b/src/consumer.rs index 58959028..409cd374 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -82,10 +82,11 @@ impl ConsumerBuilder { loop { let temp_client = Client::connect(options.clone()).await?; let mapping = temp_client.connection_properties().await; - let advertised_host = mapping.get("advertised_host").unwrap(); - if *advertised_host == replica.host.clone() { - client = temp_client; - break; + if let Some(advertised_host) = mapping.get("advertised_host") { + if *advertised_host == replica.host.clone() { + client = temp_client; + break; + } } } } else { diff --git a/src/producer.rs b/src/producer.rs index 7b3ec5e0..db8169ca 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -126,10 +126,11 @@ impl ProducerBuilder { loop { let temp_client = Client::connect(options.clone()).await?; let mapping = temp_client.connection_properties().await; - let advertised_host = mapping.get("advertised_host").unwrap(); - if *advertised_host == metadata.leader.host.clone() { - client = temp_client; - break; + if let Some(advertised_host) = mapping.get("advertised_host") { + if *advertised_host == metadata.leader.host.clone() { + client = temp_client; + break; + } } } } else { From 17eb38d585ec69ff3f676e3d154b1b9c98a8cb5a Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Tue, 30 Jul 2024 18:18:11 +0800 Subject: [PATCH 04/12] fixup: fixup clippy error --- src/environment.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/environment.rs b/src/environment.rs index b63077be..cf472729 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -121,7 +121,7 @@ impl EnvironmentBuilder { } pub fn metrics_collector( mut self, - collector: impl MetricsCollector + Send + Sync + 'static, + collector: impl MetricsCollector + 'static, ) -> EnvironmentBuilder { self.0.client_options.collector = Arc::new(collector); self From 3d8d59c3814bf8a5600155327dcc6b3aef7459fb Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Wed, 7 Aug 2024 19:10:18 +0800 Subject: [PATCH 05/12] add: add exchange command versions request and response --- .../src/commands/exchange_command_versions.rs | 244 ++++++++++++++++++ protocol/src/commands/mod.rs | 1 + protocol/src/protocol.rs | 5 + protocol/src/request/mod.rs | 27 +- protocol/src/request/shims.rs | 8 +- protocol/src/response/mod.rs | 31 ++- src/client/mod.rs | 14 + tests/integration/client_test.rs | 9 +- tests/integration/consumer_test.rs | 2 +- 9 files changed, 330 insertions(+), 11 deletions(-) create mode 100644 protocol/src/commands/exchange_command_versions.rs diff --git a/protocol/src/commands/exchange_command_versions.rs b/protocol/src/commands/exchange_command_versions.rs new file mode 100644 index 00000000..a954e61f --- /dev/null +++ b/protocol/src/commands/exchange_command_versions.rs @@ -0,0 +1,244 @@ +use std::io::Write; + +use crate::{ + codec::{decoder::read_vec, Decoder, Encoder}, + error::{DecodeError, EncodeError}, + protocol::commands::{COMMAND_EXCHANGE_COMMAND_VERSIONS, COMMAND_PUBLISH}, + response::{FromResponse, ResponseCode}, +}; + +use super::Command; +use byteorder::{BigEndian, WriteBytesExt}; + +#[cfg(test)] +use fake::Fake; + +#[cfg_attr(test, derive(fake::Dummy))] +#[derive(PartialEq, Eq, Debug)] +struct ExchangeCommandVersion(u16, u16, u16); + +impl Encoder for ExchangeCommandVersion { + fn encoded_size(&self) -> u32 { + self.0.encoded_size() + self.1.encoded_size() + self.2.encoded_size() + } + + fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.0.encode(writer)?; + self.1.encode(writer)?; + self.2.encode(writer)?; + + Ok(()) + } +} + +impl Decoder for ExchangeCommandVersion { + fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, key) = u16::decode(input)?; + let (input, min_version) = u16::decode(input)?; + let (input, max_version) = u16::decode(input)?; + Ok((input, ExchangeCommandVersion(key, min_version, max_version))) + } +} + +impl Encoder for Vec { + fn encoded_size(&self) -> u32 { + 4 + self.iter().fold(0, |acc, v| acc + v.encoded_size()) + } + + fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + writer.write_u32::(self.len() as u32)?; + for x in self { + x.encode(writer)?; + } + Ok(()) + } +} + +impl Decoder for Vec { + fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, result) = read_vec(input)?; + Ok((input, result)) + } +} + +#[cfg_attr(test, derive(fake::Dummy))] +#[derive(PartialEq, Eq, Debug)] +pub struct ExchangeCommandVersionsRequest { + correlation_id: u32, + key: u16, + min_version: u16, + max_version: u16, +} + +impl ExchangeCommandVersionsRequest { + pub fn new(correlation_id: u32, min_version: u16, max_version: u16) -> Self { + Self { + correlation_id, + min_version, + max_version, + key: COMMAND_PUBLISH, + } + } +} + +impl Encoder for ExchangeCommandVersionsRequest { + fn encoded_size(&self) -> u32 { + self.correlation_id.encoded_size() + + vec![ExchangeCommandVersion( + self.key, + self.min_version, + self.max_version, + )] + .encoded_size() + } + + fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.correlation_id.encode(writer)?; + vec![ExchangeCommandVersion( + self.key, + self.min_version, + self.max_version, + )] + .encode(writer)?; + Ok(()) + } +} + +impl Command for ExchangeCommandVersionsRequest { + fn key(&self) -> u16 { + COMMAND_EXCHANGE_COMMAND_VERSIONS + } +} + +impl Decoder for ExchangeCommandVersionsRequest { + fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, correlation_id) = u32::decode(input)?; + let (input, commands) = >::decode(input)?; + let command = commands.get(0); + match command { + Some(&ExchangeCommandVersion(key, min_version, max_version)) => Ok(( + input, + ExchangeCommandVersionsRequest { + correlation_id, + key, + min_version, + max_version, + }, + )), + None => Ok(( + input, + ExchangeCommandVersionsRequest { + correlation_id, + key: 0, + min_version: 0, + max_version: 0, + }, + )), + } + } +} + +#[cfg_attr(test, derive(fake::Dummy))] +#[derive(PartialEq, Eq, Debug)] +pub struct ExchangeCommandVersionsResponse { + pub(crate) correlation_id: u32, + response_code: ResponseCode, + commands: Vec, +} + +impl ExchangeCommandVersionsResponse { + // pub fn new( + // correlation_id: u32, + // response_code: ResponseCode, + // commands: Vec, + // ) -> Self { + // Self { + // correlation_id, + // response_code, + // commands, + // } + // } + + pub fn code(&self) -> &ResponseCode { + &self.response_code + } + + pub fn is_ok(&self) -> bool { + self.response_code == ResponseCode::Ok + } + + pub fn key_version(&self, key_command: u16) -> (u16, u16) { + for i in &self.commands { + match i { + ExchangeCommandVersion(match_key_command, min_version, max_version) => { + if *match_key_command == key_command { + return (min_version.clone(), max_version.clone()); + } + } + } + } + + (1, 1) + } +} + +impl Encoder for ExchangeCommandVersionsResponse { + fn encoded_size(&self) -> u32 { + self.correlation_id.encoded_size() + + self.response_code.encoded_size() + + self.commands.encoded_size() + } + + fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.correlation_id.encode(writer)?; + self.response_code.encode(writer)?; + self.commands.encode(writer)?; + Ok(()) + } +} + +impl Decoder for ExchangeCommandVersionsResponse { + fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, correlation_id) = u32::decode(input)?; + let (input, response_code) = ResponseCode::decode(input)?; + let (input, commands) = >::decode(input)?; + + Ok(( + input, + ExchangeCommandVersionsResponse { + correlation_id, + response_code, + commands, + }, + )) + } +} + +impl FromResponse for ExchangeCommandVersionsResponse { + fn from_response(response: crate::Response) -> Option { + match response.kind { + crate::ResponseKind::ExchangeCommandVersions(exchange_command_versions) => { + Some(exchange_command_versions) + } + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + + use crate::commands::tests::command_encode_decode_test; + + use super::{ExchangeCommandVersionsRequest, ExchangeCommandVersionsResponse}; + + #[test] + fn exchange_command_versions_request_test() { + command_encode_decode_test::(); + } + + #[test] + fn exchange_command_versions_response_test() { + command_encode_decode_test::(); + } +} diff --git a/protocol/src/commands/mod.rs b/protocol/src/commands/mod.rs index 5f07ded0..d9c0dd60 100644 --- a/protocol/src/commands/mod.rs +++ b/protocol/src/commands/mod.rs @@ -5,6 +5,7 @@ pub mod declare_publisher; pub mod delete; pub mod delete_publisher; pub mod deliver; +pub mod exchange_command_versions; pub mod generic; pub mod heart_beat; pub mod metadata; diff --git a/protocol/src/protocol.rs b/protocol/src/protocol.rs index b1484a40..28d3006f 100644 --- a/protocol/src/protocol.rs +++ b/protocol/src/protocol.rs @@ -25,6 +25,11 @@ pub mod commands { pub const COMMAND_OPEN: u16 = 21; pub const COMMAND_CLOSE: u16 = 22; pub const COMMAND_HEARTBEAT: u16 = 23; + pub const COMMAND_ROUTE: u16 = 24; + pub const COMMAND_PARTITIONS: u16 = 25; + pub const COMMAND_CONSUMER_UPDATE: u16 = 26; + pub const COMMAND_EXCHANGE_COMMAND_VERSIONS: u16 = 27; + pub const COMMAND_STREAMS_STATS: u16 = 28; } // server responses diff --git a/protocol/src/request/mod.rs b/protocol/src/request/mod.rs index 05d52608..fc616436 100644 --- a/protocol/src/request/mod.rs +++ b/protocol/src/request/mod.rs @@ -5,7 +5,8 @@ use crate::{ commands::{ close::CloseRequest, create_stream::CreateStreamCommand, credit::CreditCommand, declare_publisher::DeclarePublisherCommand, delete::Delete, - delete_publisher::DeletePublisherCommand, heart_beat::HeartBeatCommand, + delete_publisher::DeletePublisherCommand, + exchange_command_versions::ExchangeCommandVersionsRequest, heart_beat::HeartBeatCommand, metadata::MetadataCommand, open::OpenCommand, peer_properties::PeerPropertiesCommand, publish::PublishCommand, query_offset::QueryOffsetRequest, query_publisher_sequence::QueryPublisherRequest, @@ -58,6 +59,7 @@ pub enum RequestKind { QueryPublisherSequence(QueryPublisherRequest), StoreOffset(StoreOffset), Unsubscribe(UnSubscribeCommand), + ExchangeCommandVersions(ExchangeCommandVersionsRequest), } impl Encoder for RequestKind { @@ -82,6 +84,9 @@ impl Encoder for RequestKind { RequestKind::QueryPublisherSequence(query_publisher) => query_publisher.encoded_size(), RequestKind::StoreOffset(store_offset) => store_offset.encoded_size(), RequestKind::Unsubscribe(unsubscribe) => unsubscribe.encoded_size(), + RequestKind::ExchangeCommandVersions(exchange_command_versions) => { + exchange_command_versions.encoded_size() + } } } @@ -106,6 +111,9 @@ impl Encoder for RequestKind { RequestKind::QueryPublisherSequence(query_publisher) => query_publisher.encode(writer), RequestKind::StoreOffset(store_offset) => store_offset.encode(writer), RequestKind::Unsubscribe(unsubcribe) => unsubcribe.encode(writer), + RequestKind::ExchangeCommandVersions(exchange_command_versions) => { + exchange_command_versions.encode(writer) + } } } } @@ -171,6 +179,9 @@ impl Decoder for Request { COMMAND_UNSUBSCRIBE => { UnSubscribeCommand::decode(input).map(|(i, kind)| (i, kind.into()))? } + COMMAND_EXCHANGE_COMMAND_VERSIONS => { + ExchangeCommandVersionsRequest::decode(input).map(|(i, kind)| (i, kind.into()))? + } n => return Err(DecodeError::UnsupportedResponseType(n)), }; Ok((input, Request { header, kind: cmd })) @@ -185,10 +196,11 @@ mod tests { commands::{ close::CloseRequest, create_stream::CreateStreamCommand, credit::CreditCommand, declare_publisher::DeclarePublisherCommand, delete::Delete, - delete_publisher::DeletePublisherCommand, heart_beat::HeartBeatCommand, - metadata::MetadataCommand, open::OpenCommand, peer_properties::PeerPropertiesCommand, - publish::PublishCommand, query_offset::QueryOffsetRequest, - query_publisher_sequence::QueryPublisherRequest, + delete_publisher::DeletePublisherCommand, + exchange_command_versions::ExchangeCommandVersionsRequest, + heart_beat::HeartBeatCommand, metadata::MetadataCommand, open::OpenCommand, + peer_properties::PeerPropertiesCommand, publish::PublishCommand, + query_offset::QueryOffsetRequest, query_publisher_sequence::QueryPublisherRequest, sasl_authenticate::SaslAuthenticateCommand, sasl_handshake::SaslHandshakeCommand, store_offset::StoreOffset, subscribe::SubscribeCommand, tune::TunesCommand, unsubscribe::UnSubscribeCommand, Command, @@ -307,4 +319,9 @@ mod tests { assert!(remaining.is_empty()); } + + #[test] + fn request_exchange_command_versions_test() { + request_encode_decode_test::() + } } diff --git a/protocol/src/request/shims.rs b/protocol/src/request/shims.rs index 4acc7624..b68e99ce 100644 --- a/protocol/src/request/shims.rs +++ b/protocol/src/request/shims.rs @@ -2,7 +2,8 @@ use crate::{ commands::{ close::CloseRequest, create_stream::CreateStreamCommand, credit::CreditCommand, declare_publisher::DeclarePublisherCommand, delete::Delete, - delete_publisher::DeletePublisherCommand, heart_beat::HeartBeatCommand, + delete_publisher::DeletePublisherCommand, + exchange_command_versions::ExchangeCommandVersionsRequest, heart_beat::HeartBeatCommand, metadata::MetadataCommand, open::OpenCommand, peer_properties::PeerPropertiesCommand, publish::PublishCommand, query_offset::QueryOffsetRequest, query_publisher_sequence::QueryPublisherRequest, @@ -130,3 +131,8 @@ impl From for RequestKind { RequestKind::Unsubscribe(cmd) } } +impl From for RequestKind { + fn from(cmd: ExchangeCommandVersionsRequest) -> Self { + RequestKind::ExchangeCommandVersions(cmd) + } +} diff --git a/protocol/src/response/mod.rs b/protocol/src/response/mod.rs index a779c7f6..23b15a06 100644 --- a/protocol/src/response/mod.rs +++ b/protocol/src/response/mod.rs @@ -7,7 +7,8 @@ use crate::{ }, commands::{ close::CloseResponse, credit::CreditResponse, deliver::DeliverCommand, - generic::GenericResponse, heart_beat::HeartbeatResponse, metadata::MetadataResponse, + exchange_command_versions::ExchangeCommandVersionsResponse, generic::GenericResponse, + heart_beat::HeartbeatResponse, metadata::MetadataResponse, metadata_update::MetadataUpdateCommand, open::OpenResponse, peer_properties::PeerPropertiesResponse, publish_confirm::PublishConfirm, publish_error::PublishErrorResponse, query_offset::QueryOffsetResponse, @@ -66,6 +67,7 @@ pub enum ResponseKind { QueryOffset(QueryOffsetResponse), QueryPublisherSequence(QueryPublisherResponse), Credit(CreditResponse), + ExchangeCommandVersions(ExchangeCommandVersionsResponse), } impl Response { @@ -92,6 +94,9 @@ impl Response { ResponseKind::Heartbeat(_) => None, ResponseKind::Deliver(_) => None, ResponseKind::Credit(_) => None, + ResponseKind::ExchangeCommandVersions(exchange_command_versions) => { + Some(exchange_command_versions.correlation_id) + } } } @@ -164,7 +169,10 @@ impl Decoder for Response { COMMAND_QUERY_PUBLISHER_SEQUENCE => QueryPublisherResponse::decode(input) .map(|(remaining, kind)| (remaining, ResponseKind::QueryPublisherSequence(kind)))?, - + COMMAND_EXCHANGE_COMMAND_VERSIONS => ExchangeCommandVersionsResponse::decode(input) + .map(|(remaining, kind)| { + (remaining, ResponseKind::ExchangeCommandVersions(kind)) + })?, n => return Err(DecodeError::UnsupportedResponseType(n)), }; Ok((input, Response { header, kind })) @@ -195,7 +203,8 @@ mod tests { use crate::{ codec::{Decoder, Encoder}, commands::{ - close::CloseResponse, deliver::DeliverCommand, generic::GenericResponse, + close::CloseResponse, deliver::DeliverCommand, + exchange_command_versions::ExchangeCommandVersionsResponse, generic::GenericResponse, heart_beat::HeartbeatResponse, metadata::MetadataResponse, metadata_update::MetadataUpdateCommand, open::OpenResponse, peer_properties::PeerPropertiesResponse, publish_confirm::PublishConfirm, @@ -213,6 +222,7 @@ mod tests { }, version::PROTOCOL_VERSION, }, + response::COMMAND_EXCHANGE_COMMAND_VERSIONS, types::Header, ResponseCode, }; @@ -238,6 +248,9 @@ mod tests { query_publisher.encoded_size() } ResponseKind::Credit(credit) => credit.encoded_size(), + ResponseKind::ExchangeCommandVersions(exchange_command_versions) => { + exchange_command_versions.encoded_size() + } } } @@ -263,6 +276,9 @@ mod tests { query_publisher.encode(writer) } ResponseKind::Credit(credit) => credit.encode(writer), + ResponseKind::ExchangeCommandVersions(exchange_command_versions) => { + exchange_command_versions.encode(writer) + } } } } @@ -423,4 +439,13 @@ mod tests { COMMAND_HEARTBEAT ); } + + #[test] + fn exchange_command_versions_response_test() { + response_test!( + ExchangeCommandVersionsResponse, + ResponseKind::ExchangeCommandVersions, + COMMAND_EXCHANGE_COMMAND_VERSIONS + ); + } } diff --git a/src/client/mod.rs b/src/client/mod.rs index f0c671d6..c909dda4 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -16,6 +16,9 @@ use futures::{ Stream, StreamExt, TryFutureExt, }; use pin_project::pin_project; +use rabbitmq_stream_protocol::commands::exchange_command_versions::{ + ExchangeCommandVersionsRequest, ExchangeCommandVersionsResponse, +}; use rustls::PrivateKey; use rustls::ServerName; use std::{fs::File, io::BufReader, path::Path}; @@ -411,6 +414,17 @@ impl Client { .map(|sequence| sequence.from_response()) } + pub async fn exchange_command_versions( + &self, + min_version: u16, + max_version: u16, + ) -> RabbitMQStreamResult { + self.send_and_receive::(|correlation_id| { + ExchangeCommandVersionsRequest::new(correlation_id, min_version, max_version) + }) + .await + } + async fn create_connection( broker: &ClientOptions, ) -> Result< diff --git a/tests/integration/client_test.rs b/tests/integration/client_test.rs index b75eef8f..e8847d1f 100644 --- a/tests/integration/client_test.rs +++ b/tests/integration/client_test.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use fake::{Fake, Faker}; -use rabbitmq_stream_protocol::commands::close::CloseRequest; use tokio::sync::mpsc::channel; use rabbitmq_stream_client::error::ClientError; @@ -378,3 +377,11 @@ async fn client_handle_unexpected_connection_interruption() { let res = Client::connect(options).await; assert!(matches!(res, Err(ClientError::ConnectionClosed))); } + +#[tokio::test(flavor = "multi_thread")] +async fn client_exchange_command_versions() { + let test = TestClient::create().await; + + let response = test.client.exchange_command_versions(1, 1).await.unwrap(); + assert_eq!(&ResponseCode::Ok, response.code()); +} diff --git a/tests/integration/consumer_test.rs b/tests/integration/consumer_test.rs index 8f156b70..47984c10 100644 --- a/tests/integration/consumer_test.rs +++ b/tests/integration/consumer_test.rs @@ -280,7 +280,7 @@ async fn consumer_test_with_store_offset() { consumer_store.handle().close().await.unwrap(); - let mut consumer_query = env + let consumer_query = env .env .consumer() .offset(OffsetSpecification::First) From 0a5dd5f4326f6a80b9197ddaf6049c38a2dcd1f0 Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Fri, 9 Aug 2024 17:08:58 +0800 Subject: [PATCH 06/12] update: add command version and producer filter --- protocol/src/codec/decoder.rs | 13 +++++++++++- protocol/src/codec/encoder.rs | 29 ++++++++++++++++++++++++++ protocol/src/codec/mod.rs | 9 ++++++++ protocol/src/commands/mod.rs | 5 +++++ protocol/src/commands/publish.rs | 22 +++++++++++++++++--- protocol/src/protocol.rs | 1 + protocol/src/request/shims.rs | 3 +-- protocol/src/types.rs | 4 +++- src/client/message.rs | 18 +++++++++++++++- src/client/mod.rs | 4 ++-- src/environment.rs | 1 + src/error.rs | 6 ++++++ src/producer.rs | 35 ++++++++++++++++++++++++++++++-- 13 files changed, 138 insertions(+), 12 deletions(-) diff --git a/protocol/src/codec/decoder.rs b/protocol/src/codec/decoder.rs index 64bb9052..5aead389 100644 --- a/protocol/src/codec/decoder.rs +++ b/protocol/src/codec/decoder.rs @@ -106,7 +106,18 @@ impl Decoder for PublishedMessage { let (input, publishing_id) = u64::decode(input)?; let (input, body) = read_vec::(input)?; let (_, message) = Message::decode(&body)?; - Ok((input, PublishedMessage::new(publishing_id, message))) + Ok((input, PublishedMessage::new(publishing_id, message, None))) + } + + fn decode_version_2(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, publishing_id) = u64::decode(input)?; + let (input, body) = read_vec::(input)?; + let (input, filter_value) = >::decode(input)?; + let (_, message) = Message::decode(&body)?; + Ok(( + input, + PublishedMessage::new(publishing_id, message, filter_value), + )) } } diff --git a/protocol/src/codec/encoder.rs b/protocol/src/codec/encoder.rs index 6f992d81..13943204 100644 --- a/protocol/src/codec/encoder.rs +++ b/protocol/src/codec/encoder.rs @@ -109,6 +109,21 @@ impl Encoder for PublishedMessage { self.message.encode(writer)?; Ok(()) } + + fn encoded_size_version_2(&self) -> u32 { + self.publishing_id.encoded_size() + + self.filter_value.encoded_size() + + 4 + + self.message.encoded_size() + } + + fn encode_version_2(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.publishing_id.encode(writer)?; + self.filter_value.encode(writer)?; + self.message.encoded_size().encode(writer)?; + self.message.encode(writer)?; + Ok(()) + } } impl Encoder for Vec { @@ -123,6 +138,20 @@ impl Encoder for Vec { } Ok(()) } + + fn encoded_size_version_2(&self) -> u32 { + 4 + self + .iter() + .fold(0, |acc, v| acc + v.encoded_size_version_2()) + } + + fn encode_version_2(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + writer.write_u32::(self.len() as u32)?; + for x in self { + x.encode_version_2(writer)?; + } + Ok(()) + } } impl Encoder for &str { diff --git a/protocol/src/codec/mod.rs b/protocol/src/codec/mod.rs index fc2701f8..ec6fddce 100644 --- a/protocol/src/codec/mod.rs +++ b/protocol/src/codec/mod.rs @@ -8,6 +8,12 @@ pub mod encoder; pub trait Encoder { fn encoded_size(&self) -> u32; fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError>; + fn encoded_size_version_2(&self) -> u32 { + self.encoded_size() + } + fn encode_version_2(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.encode(writer) + } } pub trait Decoder @@ -15,4 +21,7 @@ where Self: Sized, { fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError>; + fn decode_version_2(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + Decoder::decode(input) + } } diff --git a/protocol/src/commands/mod.rs b/protocol/src/commands/mod.rs index d9c0dd60..0a4ae942 100644 --- a/protocol/src/commands/mod.rs +++ b/protocol/src/commands/mod.rs @@ -1,3 +1,5 @@ +use crate::protocol::version::PROTOCOL_VERSION; + pub mod close; pub mod create_stream; pub mod credit; @@ -26,6 +28,9 @@ pub mod unsubscribe; pub trait Command { fn key(&self) -> u16; + fn version(&self) -> u16 { + PROTOCOL_VERSION + } } #[cfg(test)] diff --git a/protocol/src/commands/publish.rs b/protocol/src/commands/publish.rs index 22d2337a..361806dd 100644 --- a/protocol/src/commands/publish.rs +++ b/protocol/src/commands/publish.rs @@ -4,6 +4,7 @@ use crate::{ codec::{Decoder, Encoder}, error::{DecodeError, EncodeError}, protocol::commands::COMMAND_PUBLISH, + protocol::version::{PROTOCOL_VERSION, PROTOCOL_VERSION_2}, }; use super::Command; @@ -17,25 +18,35 @@ use fake::Fake; pub struct PublishCommand { publisher_id: u8, published_messages: Vec, + version: u16, } impl PublishCommand { - pub fn new(publisher_id: u8, published_messages: Vec) -> Self { + pub fn new(publisher_id: u8, published_messages: Vec, version: u16) -> Self { Self { publisher_id, published_messages, + version, } } } impl Encoder for PublishCommand { fn encoded_size(&self) -> u32 { - self.publisher_id.encoded_size() + self.published_messages.encoded_size() + if self.version == PROTOCOL_VERSION_2 { + self.publisher_id.encoded_size() + self.published_messages.encoded_size_version_2() + } else { + self.publisher_id.encoded_size() + self.published_messages.encoded_size() + } } fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { self.publisher_id.encode(writer)?; - self.published_messages.encode(writer)?; + if self.version == PROTOCOL_VERSION_2 { + self.published_messages.encode_version_2(writer)?; + } else { + self.published_messages.encode(writer)?; + } Ok(()) } } @@ -44,6 +55,10 @@ impl Command for PublishCommand { fn key(&self) -> u16 { COMMAND_PUBLISH } + + fn version(&self) -> u16 { + self.version + } } impl Decoder for PublishCommand { @@ -56,6 +71,7 @@ impl Decoder for PublishCommand { PublishCommand { publisher_id, published_messages, + version: 1, }, )) } diff --git a/protocol/src/protocol.rs b/protocol/src/protocol.rs index 28d3006f..d87be920 100644 --- a/protocol/src/protocol.rs +++ b/protocol/src/protocol.rs @@ -62,4 +62,5 @@ pub mod responses { #[allow(unused)] pub mod version { pub const PROTOCOL_VERSION: u16 = 1; + pub const PROTOCOL_VERSION_2: u16 = 2; } diff --git a/protocol/src/request/shims.rs b/protocol/src/request/shims.rs index b68e99ce..323d4d51 100644 --- a/protocol/src/request/shims.rs +++ b/protocol/src/request/shims.rs @@ -11,7 +11,6 @@ use crate::{ store_offset::StoreOffset, subscribe::SubscribeCommand, tune::TunesCommand, unsubscribe::UnSubscribeCommand, Command, }, - protocol::version::PROTOCOL_VERSION, types::Header, Request, RequestKind, }; @@ -21,7 +20,7 @@ where { fn from(cmd: T) -> Self { Request { - header: Header::new(cmd.key(), PROTOCOL_VERSION), + header: Header::new(cmd.key(), cmd.version()), kind: cmd.into(), } } diff --git a/protocol/src/types.rs b/protocol/src/types.rs index e8fe6277..d21a0bdf 100644 --- a/protocol/src/types.rs +++ b/protocol/src/types.rs @@ -30,13 +30,15 @@ use crate::{message::Message, ResponseCode}; pub struct PublishedMessage { pub(crate) publishing_id: u64, pub(crate) message: Message, + pub(crate) filter_value: Option, } impl PublishedMessage { - pub fn new(publishing_id: u64, message: Message) -> Self { + pub fn new(publishing_id: u64, message: Message, filter_value: Option) -> Self { Self { publishing_id, message, + filter_value, } } diff --git a/src/client/message.rs b/src/client/message.rs index 40bd4a9b..2b45e651 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -1,8 +1,10 @@ use rabbitmq_stream_protocol::message::Message; +use std::sync::Arc; pub trait BaseMessage { fn publishing_id(&self) -> Option; fn to_message(self) -> Message; + fn filter_value(&self) -> Option; } impl BaseMessage for Message { @@ -13,21 +15,31 @@ impl BaseMessage for Message { fn to_message(self) -> Message { self } + + fn filter_value(&self) -> Option { + None + } } #[derive(Debug)] pub struct ClientMessage { publishing_id: u64, message: Message, + filter_value: Option, } impl ClientMessage { - pub fn new(publishing_id: u64, message: Message) -> Self { + pub fn new(publishing_id: u64, message: Message, filter_value: Option) -> Self { Self { publishing_id, message, + filter_value, } } + + pub fn filter_value(&mut self, filter_value_extractor: &fn(Message) -> String) { + self.filter_value = Some(filter_value_extractor(self.message.clone())); + } } impl BaseMessage for ClientMessage { @@ -38,4 +50,8 @@ impl BaseMessage for ClientMessage { fn to_message(self) -> Message { self.message } + + fn filter_value(&self) -> Option { + self.filter_value.clone() + } } diff --git a/src/client/mod.rs b/src/client/mod.rs index c909dda4..a2541b21 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -380,11 +380,11 @@ impl Client { .into() .into_iter() .map(|message| { - let publishing_id = message + let publishing_id: u64 = message .publishing_id() .unwrap_or_else(|| self.publish_sequence.fetch_add(1, Ordering::Relaxed)); - PublishedMessage::new(publishing_id, message.to_message()) + PublishedMessage::new(publishing_id, message.to_message(), message.filter_value()) }) .collect(); let sequences = messages diff --git a/src/environment.rs b/src/environment.rs index cf472729..647f7224 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -45,6 +45,7 @@ impl Environment { batch_size: 100, batch_publishing_delay: Duration::from_millis(100), data: PhantomData, + filter_value_extractor: None, } } diff --git a/src/error.rs b/src/error.rs index 8b3b9fbd..b3b3a838 100644 --- a/src/error.rs +++ b/src/error.rs @@ -88,6 +88,12 @@ pub enum ProducerCreateError { #[error(transparent)] Client(#[from] ClientError), + + #[error("Server publish version {server_version} not support client pulish version {client_version}")] + VersionNotSupport { + client_version: u16, + server_version: u16, + } } #[derive(Error, Debug)] diff --git a/src/producer.rs b/src/producer.rs index db8169ca..7f903b87 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -73,6 +73,8 @@ pub struct ProducerInternal { waiting_confirmations: WaiterMap, closed: Arc, accumulator: MessageAccumulator, + publish_version: u16, + filter_value_extractor: Option String>, } impl ProducerInternal { @@ -99,6 +101,7 @@ pub struct ProducerBuilder { pub batch_size: usize, pub batch_publishing_delay: Duration, pub(crate) data: PhantomData, + pub filter_value_extractor: Option String>, } #[derive(Clone)] @@ -170,6 +173,18 @@ impl ProducerBuilder { }; if response.is_ok() { + let mut publish_version = 1; + if let Some(filter_value_extractor) = self.filter_value_extractor { + publish_version = 2; + let exchange_command_version = client.exchange_command_versions(1, 2).await?; + let (_, max_version) = exchange_command_version.key_version(2); + if max_version < publish_version { + return Err(ProducerCreateError::VersionNotSupport { + client_version: publish_version, + server_version: max_version, + }); + } + } let producer = ProducerInternal { producer_id, batch_size: self.batch_size, @@ -177,8 +192,10 @@ impl ProducerBuilder { client, publish_sequence, waiting_confirmations, + publish_version, closed: Arc::new(AtomicBool::new(false)), accumulator: MessageAccumulator::new(self.batch_size), + filter_value_extractor: self.filter_value_extractor, }; let internal_producer = Arc::new(producer); @@ -211,8 +228,14 @@ impl ProducerBuilder { batch_size: self.batch_size, batch_publishing_delay: self.batch_publishing_delay, data: PhantomData, + filter_value_extractor: None, } } + + pub fn filter_value_extractor(mut self, filter_value_extractor: fn(Message) -> String) -> Self { + self.filter_value_extractor = Some(filter_value_extractor); + self + } } pub struct MessageAccumulator { @@ -437,7 +460,11 @@ impl Producer { Some(publishing_id) => *publishing_id, None => self.0.publish_sequence.fetch_add(1, Ordering::Relaxed), }; - let msg = ClientMessage::new(publishing_id, message.clone()); + let mut msg = ClientMessage::new(publishing_id, message.clone(), None); + + if let Some(f) = self.0.filter_value_extractor.as_ref() { + msg.filter_value(f) + } let waiter = ProducerMessageWaiter::waiter_with_cb(cb, message); self.0.waiting_confirmations.insert(publishing_id, waiter); @@ -470,7 +497,11 @@ impl Producer { None => self.0.publish_sequence.fetch_add(1, Ordering::Relaxed), }; - wrapped_msgs.push(ClientMessage::new(publishing_id, message)); + let mut client_message = ClientMessage::new(publishing_id, message, None); + if let Some(f) = self.0.filter_value_extractor.as_ref() { + client_message.filter_value(f) + } + wrapped_msgs.push(client_message); self.0 .waiting_confirmations From 982cdc214ab457e116bf994be7a19f77b7efc7f0 Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Mon, 12 Aug 2024 18:31:38 +0800 Subject: [PATCH 07/12] update: implement publish version 2 message --- examples/raw_client.rs | 1 + .../src/commands/exchange_command_versions.rs | 87 +++++++------------ src/client/message.rs | 3 +- src/client/mod.rs | 23 +++-- src/error.rs | 7 +- src/producer.rs | 33 +++---- tests/integration/client_test.rs | 4 +- tests/integration/producer_test.rs | 46 +++++++++- 8 files changed, 117 insertions(+), 87 deletions(-) diff --git a/examples/raw_client.rs b/examples/raw_client.rs index 5dc6790a..4759f697 100644 --- a/examples/raw_client.rs +++ b/examples/raw_client.rs @@ -39,6 +39,7 @@ async fn main() -> Result<(), Box> { Message::builder() .body(format!("message {}", i).as_bytes().to_vec()) .build(), + 1, ) .await .unwrap(); diff --git a/protocol/src/commands/exchange_command_versions.rs b/protocol/src/commands/exchange_command_versions.rs index a954e61f..821e790a 100644 --- a/protocol/src/commands/exchange_command_versions.rs +++ b/protocol/src/commands/exchange_command_versions.rs @@ -3,7 +3,7 @@ use std::io::Write; use crate::{ codec::{decoder::read_vec, Decoder, Encoder}, error::{DecodeError, EncodeError}, - protocol::commands::{COMMAND_EXCHANGE_COMMAND_VERSIONS, COMMAND_PUBLISH}, + protocol::commands::COMMAND_EXCHANGE_COMMAND_VERSIONS, response::{FromResponse, ResponseCode}, }; @@ -15,7 +15,13 @@ use fake::Fake; #[cfg_attr(test, derive(fake::Dummy))] #[derive(PartialEq, Eq, Debug)] -struct ExchangeCommandVersion(u16, u16, u16); +pub struct ExchangeCommandVersion(u16, u16, u16); + +impl ExchangeCommandVersion { + pub fn new(key: u16, min_version: u16, max_version: u16) -> Self { + return ExchangeCommandVersion(key, min_version, max_version); + } +} impl Encoder for ExchangeCommandVersion { fn encoded_size(&self) -> u32 { @@ -64,42 +70,27 @@ impl Decoder for Vec { #[cfg_attr(test, derive(fake::Dummy))] #[derive(PartialEq, Eq, Debug)] pub struct ExchangeCommandVersionsRequest { - correlation_id: u32, - key: u16, - min_version: u16, - max_version: u16, + pub(crate) correlation_id: u32, + commands: Vec, } impl ExchangeCommandVersionsRequest { - pub fn new(correlation_id: u32, min_version: u16, max_version: u16) -> Self { + pub fn new(correlation_id: u32, commands: Vec) -> Self { Self { correlation_id, - min_version, - max_version, - key: COMMAND_PUBLISH, + commands, } } } impl Encoder for ExchangeCommandVersionsRequest { fn encoded_size(&self) -> u32 { - self.correlation_id.encoded_size() - + vec![ExchangeCommandVersion( - self.key, - self.min_version, - self.max_version, - )] - .encoded_size() + self.correlation_id.encoded_size() + self.commands.encoded_size() } fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { self.correlation_id.encode(writer)?; - vec![ExchangeCommandVersion( - self.key, - self.min_version, - self.max_version, - )] - .encode(writer)?; + self.commands.encode(writer)?; Ok(()) } } @@ -114,27 +105,13 @@ impl Decoder for ExchangeCommandVersionsRequest { fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { let (input, correlation_id) = u32::decode(input)?; let (input, commands) = >::decode(input)?; - let command = commands.get(0); - match command { - Some(&ExchangeCommandVersion(key, min_version, max_version)) => Ok(( - input, - ExchangeCommandVersionsRequest { - correlation_id, - key, - min_version, - max_version, - }, - )), - None => Ok(( - input, - ExchangeCommandVersionsRequest { - correlation_id, - key: 0, - min_version: 0, - max_version: 0, - }, - )), - } + Ok(( + input, + ExchangeCommandVersionsRequest { + correlation_id, + commands, + }, + )) } } @@ -147,17 +124,17 @@ pub struct ExchangeCommandVersionsResponse { } impl ExchangeCommandVersionsResponse { - // pub fn new( - // correlation_id: u32, - // response_code: ResponseCode, - // commands: Vec, - // ) -> Self { - // Self { - // correlation_id, - // response_code, - // commands, - // } - // } + pub fn new( + correlation_id: u32, + response_code: ResponseCode, + commands: Vec, + ) -> Self { + Self { + correlation_id, + response_code, + commands, + } + } pub fn code(&self) -> &ResponseCode { &self.response_code diff --git a/src/client/message.rs b/src/client/message.rs index 2b45e651..7e2a6dee 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -1,5 +1,4 @@ use rabbitmq_stream_protocol::message::Message; -use std::sync::Arc; pub trait BaseMessage { fn publishing_id(&self) -> Option; @@ -37,7 +36,7 @@ impl ClientMessage { } } - pub fn filter_value(&mut self, filter_value_extractor: &fn(Message) -> String) { + pub fn filter_value_extract(&mut self, filter_value_extractor: &fn(Message) -> String) { self.filter_value = Some(filter_value_extractor(self.message.clone())); } } diff --git a/src/client/mod.rs b/src/client/mod.rs index a2541b21..83e2a096 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -193,6 +193,7 @@ pub struct Client { opts: ClientOptions, tune_notifier: Arc, publish_sequence: Arc, + filtering_supported: bool, } impl Client { @@ -219,10 +220,17 @@ impl Client { state: Arc::new(RwLock::new(state)), tune_notifier: Arc::new(Notify::new()), publish_sequence: Arc::new(AtomicU64::new(1)), + filtering_supported: false, }; client.initialize(receiver).await?; + let command_versions = client.exchange_command_versions().await?; + let (_, max_version) = command_versions.key_version(2); + if max_version >= 2 { + client.filtering_supported = true + } + Ok(client) } @@ -375,6 +383,7 @@ impl Client { &self, publisher_id: u8, messages: impl Into>, + version: u16, ) -> RabbitMQStreamResult> { let messages: Vec = messages .into() @@ -383,8 +392,8 @@ impl Client { let publishing_id: u64 = message .publishing_id() .unwrap_or_else(|| self.publish_sequence.fetch_add(1, Ordering::Relaxed)); - - PublishedMessage::new(publishing_id, message.to_message(), message.filter_value()) + let filter_value = message.filter_value(); + PublishedMessage::new(publishing_id, message.to_message(), filter_value) }) .collect(); let sequences = messages @@ -394,7 +403,7 @@ impl Client { let len = messages.len(); // TODO batch publish with max frame size check - self.send(PublishCommand::new(publisher_id, messages)) + self.send(PublishCommand::new(publisher_id, messages, version)) .await?; self.opts.collector.publish(len as u64).await; @@ -416,15 +425,17 @@ impl Client { pub async fn exchange_command_versions( &self, - min_version: u16, - max_version: u16, ) -> RabbitMQStreamResult { self.send_and_receive::(|correlation_id| { - ExchangeCommandVersionsRequest::new(correlation_id, min_version, max_version) + ExchangeCommandVersionsRequest::new(correlation_id, vec![]) }) .await } + pub fn filtering_supported(&self) -> bool { + self.filtering_supported + } + async fn create_connection( broker: &ClientOptions, ) -> Result< diff --git a/src/error.rs b/src/error.rs index b3b3a838..a00da3d7 100644 --- a/src/error.rs +++ b/src/error.rs @@ -89,11 +89,8 @@ pub enum ProducerCreateError { #[error(transparent)] Client(#[from] ClientError), - #[error("Server publish version {server_version} not support client pulish version {client_version}")] - VersionNotSupport { - client_version: u16, - server_version: u16, - } + #[error("Filtering is not supported by the broker (requires RabbitMQ 3.13+ and stream_filtering feature flag activated)")] + FilteringNotSupport, } #[derive(Error, Debug)] diff --git a/src/producer.rs b/src/producer.rs index 7f903b87..2a33ef96 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -83,7 +83,9 @@ impl ProducerInternal { if !messages.is_empty() { debug!("Sending batch of {} messages", messages.len()); - self.client.publish(self.producer_id, messages).await?; + self.client + .publish(self.producer_id, messages, self.publish_version) + .await?; } Ok(()) @@ -115,6 +117,17 @@ impl ProducerBuilder { // The leader is the recommended node for writing, because writing to a replica will redundantly pass these messages // to the leader anyway - it is the only one capable of writing. let mut client = self.environment.create_client().await?; + + let mut publish_version = 1; + + if self.filter_value_extractor.is_some() { + if client.filtering_supported() { + publish_version = 2 + } else { + return Err(ProducerCreateError::FilteringNotSupport); + } + } + let metrics_collector = self.environment.options.client_options.collector.clone(); if let Some(metadata) = client.metadata(vec![stream.to_string()]).await?.get(stream) { tracing::debug!( @@ -173,18 +186,6 @@ impl ProducerBuilder { }; if response.is_ok() { - let mut publish_version = 1; - if let Some(filter_value_extractor) = self.filter_value_extractor { - publish_version = 2; - let exchange_command_version = client.exchange_command_versions(1, 2).await?; - let (_, max_version) = exchange_command_version.key_version(2); - if max_version < publish_version { - return Err(ProducerCreateError::VersionNotSupport { - client_version: publish_version, - server_version: max_version, - }); - } - } let producer = ProducerInternal { producer_id, batch_size: self.batch_size, @@ -463,7 +464,7 @@ impl Producer { let mut msg = ClientMessage::new(publishing_id, message.clone(), None); if let Some(f) = self.0.filter_value_extractor.as_ref() { - msg.filter_value(f) + msg.filter_value_extract(f) } let waiter = ProducerMessageWaiter::waiter_with_cb(cb, message); @@ -499,7 +500,7 @@ impl Producer { let mut client_message = ClientMessage::new(publishing_id, message, None); if let Some(f) = self.0.filter_value_extractor.as_ref() { - client_message.filter_value(f) + client_message.filter_value_extract(f) } wrapped_msgs.push(client_message); @@ -510,7 +511,7 @@ impl Producer { self.0 .client - .publish(self.0.producer_id, wrapped_msgs) + .publish(self.0.producer_id, wrapped_msgs, self.0.publish_version) .await?; Ok(()) diff --git a/tests/integration/client_test.rs b/tests/integration/client_test.rs index e8847d1f..c9dae67f 100644 --- a/tests/integration/client_test.rs +++ b/tests/integration/client_test.rs @@ -351,7 +351,7 @@ async fn client_publish() { let sequences = test .client - .publish(1, Message::builder().body(b"message".to_vec()).build()) + .publish(1, Message::builder().body(b"message".to_vec()).build(), 1) .await .unwrap(); @@ -382,6 +382,6 @@ async fn client_handle_unexpected_connection_interruption() { async fn client_exchange_command_versions() { let test = TestClient::create().await; - let response = test.client.exchange_command_versions(1, 1).await.unwrap(); + let response = test.client.exchange_command_versions().await.unwrap(); assert_eq!(&ResponseCode::Ok, response.code()); } diff --git a/tests/integration/producer_test.rs b/tests/integration/producer_test.rs index 3836d564..2caf5751 100644 --- a/tests/integration/producer_test.rs +++ b/tests/integration/producer_test.rs @@ -3,7 +3,7 @@ use fake::{Fake, Faker}; use futures::StreamExt; use tokio::sync::mpsc::channel; -use rabbitmq_stream_client::types::{Message, OffsetSpecification}; +use rabbitmq_stream_client::types::{Message, OffsetSpecification, SimpleValue}; use crate::common::TestEnvironment; @@ -384,3 +384,47 @@ async fn producer_send_after_close_error() { true ); } + +#[tokio::test(flavor = "multi_thread")] +async fn producer_send_filtering_message() { + let env = TestEnvironment::create().await; + let producer = env + .env + .producer() + .filter_value_extractor(|message| { + let app_properties = message.application_properties(); + match app_properties { + Some(properties) => { + let value = properties.get("region").and_then(|item| match item { + SimpleValue::String(s) => Some(s.clone()), + _ => None, + }); + value.unwrap_or(String::from("")) + } + None => String::from(""), + } + }) + .build(&env.stream) + .await + .unwrap(); + producer.clone().close().await.unwrap(); + + let message_builder = Message::builder(); + let mut application_properties = message_builder.application_properties(); + application_properties = application_properties.insert("region", "emea"); + + let message = application_properties + .message_builder() + .body(b"message".to_vec()) + .build(); + + let closed = producer.send_with_confirm(message).await.unwrap_err(); + + assert_eq!( + matches!( + closed, + rabbitmq_stream_client::error::ProducerPublishError::Closed + ), + true + ); +} From 457eb007358efe7f29d397523ed48f2fc4800937 Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Tue, 13 Aug 2024 16:49:23 +0800 Subject: [PATCH 08/12] update: implement consumer filtering options --- protocol/src/response/mod.rs | 1 - src/consumer.rs | 66 ++++++++++++++++++++++++++++-- src/environment.rs | 1 + src/error.rs | 3 ++ src/lib.rs | 2 +- tests/integration/consumer_test.rs | 66 +++++++++++++++++++++++++++++- 6 files changed, 133 insertions(+), 6 deletions(-) diff --git a/protocol/src/response/mod.rs b/protocol/src/response/mod.rs index 23b15a06..905a0954 100644 --- a/protocol/src/response/mod.rs +++ b/protocol/src/response/mod.rs @@ -120,7 +120,6 @@ impl Decoder for Response { let (input, _) = read_u32(input)?; let (input, header) = Header::decode(input)?; - let (input, kind) = match header.key() { COMMAND_OPEN => { OpenResponse::decode(input).map(|(i, kind)| (i, ResponseKind::Open(kind)))? diff --git a/src/consumer.rs b/src/consumer.rs index 409cd374..4399f2b7 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -12,7 +12,9 @@ use std::{ }; use rabbitmq_stream_protocol::{ - commands::subscribe::OffsetSpecification, message::Message, ResponseKind, + commands::subscribe::OffsetSpecification, + message::{self, Message}, + ResponseKind, }; use tokio::sync::mpsc::{channel, Receiver, Sender}; @@ -45,6 +47,7 @@ struct ConsumerInternal { closed: Arc, waker: AtomicWaker, metrics_collector: Arc, + filter_configuration: Option, } impl ConsumerInternal { @@ -53,11 +56,34 @@ impl ConsumerInternal { } } +#[derive(Clone)] +pub struct FilterConfiguration { + filter_values: Vec, + pub predicate: Arc bool + Send + Sync>, + match_unfiltered: bool, +} + +impl FilterConfiguration { + pub fn new( + filter_values: Vec, + predicate: impl Fn(&Message) -> bool + 'static + Send + Sync, + match_unfiltered: bool, + ) -> Self { + let f = Arc::new(predicate); + Self { + filter_values, + match_unfiltered, + predicate: f, + } + } +} + /// Builder for [`Consumer`] pub struct ConsumerBuilder { pub(crate) consumer_name: Option, pub(crate) environment: Environment, pub(crate) offset_specification: OffsetSpecification, + pub(crate) filter_configuration: Option, } impl ConsumerBuilder { @@ -114,17 +140,35 @@ impl ConsumerBuilder { closed: Arc::new(AtomicBool::new(false)), waker: AtomicWaker::new(), metrics_collector: collector, + filter_configuration: self.filter_configuration.clone(), }); let msg_handler = ConsumerMessageHandler(consumer.clone()); client.set_handler(msg_handler).await; + let mut properties = HashMap::new(); + if let Some(filter_input) = self.filter_configuration { + if !client.filtering_supported() { + return Err(ConsumerCreateError::FilteringNotSupport); + } + for (index, item) in filter_input.filter_values.iter().enumerate() { + let key = format!("filter.{}", index); + properties.insert(key, item.to_owned()); + } + + let match_unfiltered_key = "match-unfiltered".to_string(); + properties.insert( + match_unfiltered_key, + filter_input.match_unfiltered.to_string(), + ); + } + let response = client .subscribe( subscription_id, stream, self.offset_specification, 1, - HashMap::new(), + properties, ) .await?; @@ -151,6 +195,11 @@ impl ConsumerBuilder { self.consumer_name = Some(String::from(consumer_name)); self } + + pub fn filter_input(mut self, filter_configuration: Option) -> Self { + self.filter_configuration = filter_configuration; + self + } } impl Consumer { @@ -242,7 +291,18 @@ impl MessageHandler for ConsumerMessageHandler { let len = delivery.messages.len(); trace!("Got delivery with messages {}", len); - for message in delivery.messages { + + // // client filter + let messages = match &self.0.filter_configuration { + Some(filter_input) => delivery + .messages + .into_iter() + .filter(|message| filter_input.predicate.as_ref()(message)) + .collect::>(), + None => delivery.messages, + }; + + for message in messages { let _ = self .0 .sender diff --git a/src/environment.rs b/src/environment.rs index 647f7224..8b1a590d 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -55,6 +55,7 @@ impl Environment { consumer_name: None, environment: self.clone(), offset_specification: OffsetSpecification::Next, + filter_configuration: None, } } pub(crate) async fn create_client(&self) -> RabbitMQStreamResult { diff --git a/src/error.rs b/src/error.rs index a00da3d7..22d5c7f1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -136,6 +136,9 @@ pub enum ConsumerCreateError { #[error(transparent)] Client(#[from] ClientError), + + #[error("Filtering is not supported by the broker (requires RabbitMQ 3.13+ and stream_filtering feature flag activated)")] + FilteringNotSupport, } #[derive(Error, Debug)] diff --git a/src/lib.rs b/src/lib.rs index bcf2a13f..b6d2ab24 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -84,7 +84,7 @@ pub type RabbitMQStreamResult = Result; pub use crate::client::{Client, ClientOptions, MetricsCollector}; -pub use crate::consumer::{Consumer, ConsumerBuilder, ConsumerHandle}; +pub use crate::consumer::{Consumer, ConsumerBuilder, ConsumerHandle, FilterConfiguration}; pub use crate::environment::{Environment, EnvironmentBuilder, TlsConfiguration}; pub use crate::producer::{Dedup, NoDedup, Producer, ProducerBuilder}; pub mod types { diff --git a/tests/integration/consumer_test.rs b/tests/integration/consumer_test.rs index 47984c10..11d519d4 100644 --- a/tests/integration/consumer_test.rs +++ b/tests/integration/consumer_test.rs @@ -9,7 +9,7 @@ use rabbitmq_stream_client::{ ProducerCloseError, }, types::{Delivery, Message, OffsetSpecification}, - Consumer, NoDedup, Producer, + Consumer, FilterConfiguration, NoDedup, Producer, }; use rabbitmq_stream_protocol::ResponseCode; @@ -296,3 +296,67 @@ async fn consumer_test_with_store_offset() { consumer_query.handle().close().await.unwrap(); producer.close().await.unwrap(); } + +#[tokio::test(flavor = "multi_thread")] +async fn consumer_test_with_filtering() { + let env = TestEnvironment::create().await; + let reference: String = Faker.fake(); + + let message_count = 10; + let mut producer = env + .env + .producer() + .name(&reference) + .filter_value_extractor(|_| "filtering".to_string()) + .build(&env.stream) + .await + .unwrap(); + + let filter_configuration = FilterConfiguration::new( + vec!["filtering".to_string()], + |message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) + == "filtering".to_string() + }, + false, + ); + + let mut consumer = env + .env + .consumer() + .offset(OffsetSpecification::First) + .filter_input(Some(filter_configuration)) + .build(&env.stream) + .await + .unwrap(); + + for _ in 0..message_count { + let _ = producer + .send_with_confirm(Message::builder().body("filtering").build()) + .await + .unwrap(); + } + + for _ in 0..message_count { + let _ = producer + .send_with_confirm(Message::builder().body("not filtering").build()) + .await + .unwrap(); + } + + tokio::task::spawn(async move { + loop { + let delivery = consumer.next().await.unwrap(); + + let d = delivery.unwrap(); + let data = d + .message() + .data() + .map(|data| String::from_utf8(data.to_vec()).unwrap()) + .unwrap(); + + assert!(data == "filtering".to_string()); + } + }); + producer.close().await.unwrap(); +} From 88e5b41f6c22e439fdf7cf9b6e99139fcbdcd2ef Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Wed, 14 Aug 2024 09:06:48 +0800 Subject: [PATCH 09/12] update: optimized puiblish filter_value_extractor --- src/client/message.rs | 4 ++-- src/producer.rs | 16 ++++++++++------ tests/integration/consumer_test.rs | 23 +++++++++++++++++++---- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/client/message.rs b/src/client/message.rs index 7e2a6dee..7e38f2f0 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -36,8 +36,8 @@ impl ClientMessage { } } - pub fn filter_value_extract(&mut self, filter_value_extractor: &fn(Message) -> String) { - self.filter_value = Some(filter_value_extractor(self.message.clone())); + pub fn filter_value_extract(&mut self, filter_value_extractor: impl Fn(&Message) -> String) { + self.filter_value = Some(filter_value_extractor(&self.message)); } } diff --git a/src/producer.rs b/src/producer.rs index 2a33ef96..b6d1b66c 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -74,7 +74,7 @@ pub struct ProducerInternal { closed: Arc, accumulator: MessageAccumulator, publish_version: u16, - filter_value_extractor: Option String>, + filter_value_extractor: Option String + 'static + Send + Sync>>, } impl ProducerInternal { @@ -103,7 +103,7 @@ pub struct ProducerBuilder { pub batch_size: usize, pub batch_publishing_delay: Duration, pub(crate) data: PhantomData, - pub filter_value_extractor: Option String>, + pub filter_value_extractor: Option String + 'static + Send + Sync>>, } #[derive(Clone)] @@ -233,8 +233,12 @@ impl ProducerBuilder { } } - pub fn filter_value_extractor(mut self, filter_value_extractor: fn(Message) -> String) -> Self { - self.filter_value_extractor = Some(filter_value_extractor); + pub fn filter_value_extractor( + mut self, + filter_value_extractor: impl Fn(&Message) -> String + Send + Sync + 'static, + ) -> Self { + let f = Arc::new(filter_value_extractor); + self.filter_value_extractor = Some(f); self } } @@ -464,7 +468,7 @@ impl Producer { let mut msg = ClientMessage::new(publishing_id, message.clone(), None); if let Some(f) = self.0.filter_value_extractor.as_ref() { - msg.filter_value_extract(f) + msg.filter_value_extract(f.as_ref()) } let waiter = ProducerMessageWaiter::waiter_with_cb(cb, message); @@ -500,7 +504,7 @@ impl Producer { let mut client_message = ClientMessage::new(publishing_id, message, None); if let Some(f) = self.0.filter_value_extractor.as_ref() { - client_message.filter_value_extract(f) + client_message.filter_value_extract(f.as_ref()) } wrapped_msgs.push(client_message); diff --git a/tests/integration/consumer_test.rs b/tests/integration/consumer_test.rs index 11d519d4..3cbf7e16 100644 --- a/tests/integration/consumer_test.rs +++ b/tests/integration/consumer_test.rs @@ -13,6 +13,7 @@ use rabbitmq_stream_client::{ }; use rabbitmq_stream_protocol::ResponseCode; +use std::sync::Arc; #[tokio::test(flavor = "multi_thread")] async fn consumer_test() { @@ -335,16 +336,17 @@ async fn consumer_test_with_filtering() { .send_with_confirm(Message::builder().body("filtering").build()) .await .unwrap(); - } - for _ in 0..message_count { let _ = producer .send_with_confirm(Message::builder().body("not filtering").build()) .await .unwrap(); } - tokio::task::spawn(async move { + let response = Arc::new(tokio::sync::Mutex::new(vec![])); + let response_clone = Arc::clone(&response); + + let task = tokio::task::spawn(async move { loop { let delivery = consumer.next().await.unwrap(); @@ -355,8 +357,21 @@ async fn consumer_test_with_filtering() { .map(|data| String::from_utf8(data.to_vec()).unwrap()) .unwrap(); - assert!(data == "filtering".to_string()); + let mut r = response_clone.lock().await; + r.push(data); } }); + + let _ = tokio::time::timeout(tokio::time::Duration::from_secs(3), task).await; + let repsonse_length = response.lock().await.len(); + let filtering_response_length = response + .lock() + .await + .iter() + .filter(|item| item == &&"filtering") + .collect::>() + .len(); + + assert!(repsonse_length == filtering_response_length); producer.close().await.unwrap(); } From ea34d1c30107960b59abdff04081e25fa600b610 Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Wed, 14 Aug 2024 09:30:34 +0800 Subject: [PATCH 10/12] fixup: cargo clippy error --- protocol/src/commands/exchange_command_versions.rs | 4 ++-- protocol/src/commands/publish.rs | 2 +- src/consumer.rs | 4 +--- src/producer.rs | 5 +++-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/protocol/src/commands/exchange_command_versions.rs b/protocol/src/commands/exchange_command_versions.rs index 821e790a..786e6495 100644 --- a/protocol/src/commands/exchange_command_versions.rs +++ b/protocol/src/commands/exchange_command_versions.rs @@ -19,7 +19,7 @@ pub struct ExchangeCommandVersion(u16, u16, u16); impl ExchangeCommandVersion { pub fn new(key: u16, min_version: u16, max_version: u16) -> Self { - return ExchangeCommandVersion(key, min_version, max_version); + ExchangeCommandVersion(key, min_version, max_version) } } @@ -149,7 +149,7 @@ impl ExchangeCommandVersionsResponse { match i { ExchangeCommandVersion(match_key_command, min_version, max_version) => { if *match_key_command == key_command { - return (min_version.clone(), max_version.clone()); + return (*min_version, *max_version); } } } diff --git a/protocol/src/commands/publish.rs b/protocol/src/commands/publish.rs index 361806dd..6d256b27 100644 --- a/protocol/src/commands/publish.rs +++ b/protocol/src/commands/publish.rs @@ -4,7 +4,7 @@ use crate::{ codec::{Decoder, Encoder}, error::{DecodeError, EncodeError}, protocol::commands::COMMAND_PUBLISH, - protocol::version::{PROTOCOL_VERSION, PROTOCOL_VERSION_2}, + protocol::version::PROTOCOL_VERSION_2, }; use super::Command; diff --git a/src/consumer.rs b/src/consumer.rs index cc8aa857..8428147d 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -12,9 +12,7 @@ use std::{ }; use rabbitmq_stream_protocol::{ - commands::subscribe::OffsetSpecification, - message::{self, Message}, - ResponseKind, + commands::subscribe::OffsetSpecification, message::Message, ResponseKind, }; use tokio::sync::mpsc::{channel, Receiver, Sender}; diff --git a/src/producer.rs b/src/producer.rs index b6d1b66c..89126795 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -27,6 +27,7 @@ use crate::{ }; type WaiterMap = Arc>; +type FilterValueExtractor = Arc String + 'static + Send + Sync>; type ConfirmCallback = Arc< dyn Fn(Result) -> BoxFuture<'static, ()> @@ -74,7 +75,7 @@ pub struct ProducerInternal { closed: Arc, accumulator: MessageAccumulator, publish_version: u16, - filter_value_extractor: Option String + 'static + Send + Sync>>, + filter_value_extractor: Option, } impl ProducerInternal { @@ -103,7 +104,7 @@ pub struct ProducerBuilder { pub batch_size: usize, pub batch_publishing_delay: Duration, pub(crate) data: PhantomData, - pub filter_value_extractor: Option String + 'static + Send + Sync>>, + pub filter_value_extractor: Option, } #[derive(Clone)] From 9200342796b3a5509419ec089178f46701641d8e Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Wed, 14 Aug 2024 10:58:15 +0800 Subject: [PATCH 11/12] config publish command test version --- protocol/src/commands/publish.rs | 2 ++ protocol/src/types.rs | 1 + 2 files changed, 3 insertions(+) diff --git a/protocol/src/commands/publish.rs b/protocol/src/commands/publish.rs index 6d256b27..7c8f3f3b 100644 --- a/protocol/src/commands/publish.rs +++ b/protocol/src/commands/publish.rs @@ -10,6 +10,7 @@ use crate::{ use super::Command; use crate::types::PublishedMessage; + #[cfg(test)] use fake::Fake; @@ -18,6 +19,7 @@ use fake::Fake; pub struct PublishCommand { publisher_id: u8, published_messages: Vec, + #[cfg_attr(test, dummy(faker = "1"))] version: u16, } diff --git a/protocol/src/types.rs b/protocol/src/types.rs index d21a0bdf..8bf0ee3f 100644 --- a/protocol/src/types.rs +++ b/protocol/src/types.rs @@ -30,6 +30,7 @@ use crate::{message::Message, ResponseCode}; pub struct PublishedMessage { pub(crate) publishing_id: u64, pub(crate) message: Message, + #[cfg_attr(test, dummy(expr = "None"))] pub(crate) filter_value: Option, } From 87f09c92231c9aaccff65d7c6506a39770942ff6 Mon Sep 17 00:00:00 2001 From: JiaYing Zhang Date: Tue, 20 Aug 2024 17:55:08 +0800 Subject: [PATCH 12/12] update for pr suggested --- examples/filtering.rs | 84 ++++++++++++++++++++++++ src/consumer.rs | 38 +++++++---- tests/integration/consumer_test.rs | 101 +++++++++++++++++++++++++++-- 3 files changed, 204 insertions(+), 19 deletions(-) create mode 100644 examples/filtering.rs diff --git a/examples/filtering.rs b/examples/filtering.rs new file mode 100644 index 00000000..448a2b4b --- /dev/null +++ b/examples/filtering.rs @@ -0,0 +1,84 @@ +use futures::StreamExt; +use rabbitmq_stream_client::types::{Message, OffsetSpecification}; +use rabbitmq_stream_client::{Environment, FilterConfiguration}; +use tracing::info; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let environment = Environment::builder() + .host("localhost") + .port(5552) + .build() + .await?; + + let message_count = 10; + environment.stream_creator().create("test").await?; + + let mut producer = environment + .producer() + .name("test_producer") + .filter_value_extractor(|message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap() + }) + .build("test") + .await?; + + // publish filtering message + for i in 0..message_count { + producer + .send_with_confirm(Message::builder().body(i.to_string()).build()) + .await?; + } + + producer.close().await?; + + // publish filtering message + let mut producer = environment + .producer() + .name("test_producer") + .build("test") + .await?; + + // publish unset filter value + for i in 0..message_count { + producer + .send_with_confirm(Message::builder().body(i.to_string()).build()) + .await?; + } + + producer.close().await?; + + // filter configuration: https://www.rabbitmq.com/blog/2023/10/16/stream-filtering + let filter_configuration = + FilterConfiguration::new(vec!["1".to_string()], false).post_filter(|message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) + == "1".to_string() + }); + // let filter_configuration = FilterConfiguration::new(vec!["1".to_string()], true); + + let mut consumer = environment + .consumer() + .offset(OffsetSpecification::First) + .filter_input(Some(filter_configuration)) + .build("test") + .await + .unwrap(); + + let task = tokio::task::spawn(async move { + loop { + let delivery = consumer.next().await.unwrap().unwrap(); + info!( + "Got message : {:?} with offset {}", + delivery + .message() + .data() + .map(|data| String::from_utf8(data.to_vec())), + delivery.offset() + ); + } + }); + let _ = tokio::time::timeout(tokio::time::Duration::from_secs(3), task).await; + + environment.delete_stream("test").await?; + Ok(()) +} diff --git a/src/consumer.rs b/src/consumer.rs index 8428147d..1b2a37b5 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -29,6 +29,8 @@ use futures::{task::AtomicWaker, Stream}; use rand::rngs::StdRng; use rand::{seq::SliceRandom, SeedableRng}; +type FilterPredicate = Option bool + Send + Sync>>; + /// API for consuming RabbitMQ stream messages pub struct Consumer { // Mandatory in case of manual offset tracking @@ -58,23 +60,26 @@ impl ConsumerInternal { #[derive(Clone)] pub struct FilterConfiguration { filter_values: Vec, - pub predicate: Arc bool + Send + Sync>, + pub predicate: FilterPredicate, match_unfiltered: bool, } impl FilterConfiguration { - pub fn new( - filter_values: Vec, - predicate: impl Fn(&Message) -> bool + 'static + Send + Sync, - match_unfiltered: bool, - ) -> Self { - let f = Arc::new(predicate); + pub fn new(filter_values: Vec, match_unfiltered: bool) -> Self { Self { filter_values, match_unfiltered, - predicate: f, + predicate: None, } } + + pub fn post_filter( + mut self, + predicate: impl Fn(&Message) -> bool + 'static + Send + Sync, + ) -> FilterConfiguration { + self.predicate = Some(Arc::new(predicate)); + self + } } /// Builder for [`Consumer`] @@ -294,11 +299,18 @@ impl MessageHandler for ConsumerMessageHandler { // // client filter let messages = match &self.0.filter_configuration { - Some(filter_input) => delivery - .messages - .into_iter() - .filter(|message| filter_input.predicate.as_ref()(message)) - .collect::>(), + Some(filter_input) => { + if let Some(f) = &filter_input.predicate { + delivery + .messages + .into_iter() + .filter(|message| f(message)) + .collect::>() + } else { + delivery.messages + } + } + None => delivery.messages, }; diff --git a/tests/integration/consumer_test.rs b/tests/integration/consumer_test.rs index 64bb602f..035c8c68 100644 --- a/tests/integration/consumer_test.rs +++ b/tests/integration/consumer_test.rs @@ -372,14 +372,11 @@ async fn consumer_test_with_filtering() { .await .unwrap(); - let filter_configuration = FilterConfiguration::new( - vec!["filtering".to_string()], - |message| { + let filter_configuration = FilterConfiguration::new(vec!["filtering".to_string()], false) + .post_filter(|message| { String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) == "filtering".to_string() - }, - false, - ); + }); let mut consumer = env .env @@ -434,3 +431,95 @@ async fn consumer_test_with_filtering() { assert!(repsonse_length == filtering_response_length); producer.close().await.unwrap(); } + +#[tokio::test(flavor = "multi_thread")] +async fn consumer_test_with_filtering_match_unfiltered() { + let env = TestEnvironment::create().await; + let reference: String = Faker.fake(); + + let message_count = 10; + let mut producer = env + .env + .producer() + .name(&reference) + .filter_value_extractor(|message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap() + }) + .build(&env.stream) + .await + .unwrap(); + + // publish filtering message + for i in 0..message_count { + producer + .send_with_confirm(Message::builder().body(i.to_string()).build()) + .await + .unwrap(); + } + + producer.close().await.unwrap(); + + let mut producer = env + .env + .producer() + .name(&reference) + .build(&env.stream) + .await + .unwrap(); + + // publish unset filter value + for i in 0..message_count { + producer + .send_with_confirm(Message::builder().body(i.to_string()).build()) + .await + .unwrap(); + } + + producer.close().await.unwrap(); + + let filter_configuration = + FilterConfiguration::new(vec!["1".to_string()], true).post_filter(|message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) + == "1".to_string() + }); + + let mut consumer = env + .env + .consumer() + .offset(OffsetSpecification::First) + .filter_input(Some(filter_configuration)) + .build(&env.stream) + .await + .unwrap(); + + let response = Arc::new(tokio::sync::Mutex::new(vec![])); + let response_clone = Arc::clone(&response); + + let task = tokio::task::spawn(async move { + loop { + let delivery = consumer.next().await.unwrap(); + + let d = delivery.unwrap(); + let data = d + .message() + .data() + .map(|data| String::from_utf8(data.to_vec()).unwrap()) + .unwrap(); + + let mut r = response_clone.lock().await; + r.push(data); + } + }); + + let _ = tokio::time::timeout(tokio::time::Duration::from_secs(3), task).await; + let repsonse_length = response.lock().await.len(); + let filtering_response_length = response + .lock() + .await + .iter() + .filter(|item| item == &&"1") + .collect::>() + .len(); + + assert!(repsonse_length == filtering_response_length); +}