diff --git a/examples/filtering.rs b/examples/filtering.rs new file mode 100644 index 00000000..448a2b4b --- /dev/null +++ b/examples/filtering.rs @@ -0,0 +1,84 @@ +use futures::StreamExt; +use rabbitmq_stream_client::types::{Message, OffsetSpecification}; +use rabbitmq_stream_client::{Environment, FilterConfiguration}; +use tracing::info; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let environment = Environment::builder() + .host("localhost") + .port(5552) + .build() + .await?; + + let message_count = 10; + environment.stream_creator().create("test").await?; + + let mut producer = environment + .producer() + .name("test_producer") + .filter_value_extractor(|message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap() + }) + .build("test") + .await?; + + // publish filtering message + for i in 0..message_count { + producer + .send_with_confirm(Message::builder().body(i.to_string()).build()) + .await?; + } + + producer.close().await?; + + // publish filtering message + let mut producer = environment + .producer() + .name("test_producer") + .build("test") + .await?; + + // publish unset filter value + for i in 0..message_count { + producer + .send_with_confirm(Message::builder().body(i.to_string()).build()) + .await?; + } + + producer.close().await?; + + // filter configuration: https://www.rabbitmq.com/blog/2023/10/16/stream-filtering + let filter_configuration = + FilterConfiguration::new(vec!["1".to_string()], false).post_filter(|message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) + == "1".to_string() + }); + // let filter_configuration = FilterConfiguration::new(vec!["1".to_string()], true); + + let mut consumer = environment + .consumer() + .offset(OffsetSpecification::First) + .filter_input(Some(filter_configuration)) + .build("test") + .await + .unwrap(); + + let task = tokio::task::spawn(async move { + loop { + let delivery = consumer.next().await.unwrap().unwrap(); + info!( + "Got message : {:?} with offset {}", + delivery + .message() + .data() + .map(|data| String::from_utf8(data.to_vec())), + delivery.offset() + ); + } + }); + let _ = tokio::time::timeout(tokio::time::Duration::from_secs(3), task).await; + + environment.delete_stream("test").await?; + Ok(()) +} diff --git a/examples/raw_client.rs b/examples/raw_client.rs index 5dc6790a..4759f697 100644 --- a/examples/raw_client.rs +++ b/examples/raw_client.rs @@ -39,6 +39,7 @@ async fn main() -> Result<(), Box> { Message::builder() .body(format!("message {}", i).as_bytes().to_vec()) .build(), + 1, ) .await .unwrap(); diff --git a/protocol/src/codec/decoder.rs b/protocol/src/codec/decoder.rs index 64bb9052..5aead389 100644 --- a/protocol/src/codec/decoder.rs +++ b/protocol/src/codec/decoder.rs @@ -106,7 +106,18 @@ impl Decoder for PublishedMessage { let (input, publishing_id) = u64::decode(input)?; let (input, body) = read_vec::(input)?; let (_, message) = Message::decode(&body)?; - Ok((input, PublishedMessage::new(publishing_id, message))) + Ok((input, PublishedMessage::new(publishing_id, message, None))) + } + + fn decode_version_2(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, publishing_id) = u64::decode(input)?; + let (input, body) = read_vec::(input)?; + let (input, filter_value) = >::decode(input)?; + let (_, message) = Message::decode(&body)?; + Ok(( + input, + PublishedMessage::new(publishing_id, message, filter_value), + )) } } diff --git a/protocol/src/codec/encoder.rs b/protocol/src/codec/encoder.rs index 6f992d81..13943204 100644 --- a/protocol/src/codec/encoder.rs +++ b/protocol/src/codec/encoder.rs @@ -109,6 +109,21 @@ impl Encoder for PublishedMessage { self.message.encode(writer)?; Ok(()) } + + fn encoded_size_version_2(&self) -> u32 { + self.publishing_id.encoded_size() + + self.filter_value.encoded_size() + + 4 + + self.message.encoded_size() + } + + fn encode_version_2(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.publishing_id.encode(writer)?; + self.filter_value.encode(writer)?; + self.message.encoded_size().encode(writer)?; + self.message.encode(writer)?; + Ok(()) + } } impl Encoder for Vec { @@ -123,6 +138,20 @@ impl Encoder for Vec { } Ok(()) } + + fn encoded_size_version_2(&self) -> u32 { + 4 + self + .iter() + .fold(0, |acc, v| acc + v.encoded_size_version_2()) + } + + fn encode_version_2(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + writer.write_u32::(self.len() as u32)?; + for x in self { + x.encode_version_2(writer)?; + } + Ok(()) + } } impl Encoder for &str { diff --git a/protocol/src/codec/mod.rs b/protocol/src/codec/mod.rs index fc2701f8..ec6fddce 100644 --- a/protocol/src/codec/mod.rs +++ b/protocol/src/codec/mod.rs @@ -8,6 +8,12 @@ pub mod encoder; pub trait Encoder { fn encoded_size(&self) -> u32; fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError>; + fn encoded_size_version_2(&self) -> u32 { + self.encoded_size() + } + fn encode_version_2(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.encode(writer) + } } pub trait Decoder @@ -15,4 +21,7 @@ where Self: Sized, { fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError>; + fn decode_version_2(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + Decoder::decode(input) + } } diff --git a/protocol/src/commands/exchange_command_versions.rs b/protocol/src/commands/exchange_command_versions.rs new file mode 100644 index 00000000..786e6495 --- /dev/null +++ b/protocol/src/commands/exchange_command_versions.rs @@ -0,0 +1,221 @@ +use std::io::Write; + +use crate::{ + codec::{decoder::read_vec, Decoder, Encoder}, + error::{DecodeError, EncodeError}, + protocol::commands::COMMAND_EXCHANGE_COMMAND_VERSIONS, + response::{FromResponse, ResponseCode}, +}; + +use super::Command; +use byteorder::{BigEndian, WriteBytesExt}; + +#[cfg(test)] +use fake::Fake; + +#[cfg_attr(test, derive(fake::Dummy))] +#[derive(PartialEq, Eq, Debug)] +pub struct ExchangeCommandVersion(u16, u16, u16); + +impl ExchangeCommandVersion { + pub fn new(key: u16, min_version: u16, max_version: u16) -> Self { + ExchangeCommandVersion(key, min_version, max_version) + } +} + +impl Encoder for ExchangeCommandVersion { + fn encoded_size(&self) -> u32 { + self.0.encoded_size() + self.1.encoded_size() + self.2.encoded_size() + } + + fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.0.encode(writer)?; + self.1.encode(writer)?; + self.2.encode(writer)?; + + Ok(()) + } +} + +impl Decoder for ExchangeCommandVersion { + fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, key) = u16::decode(input)?; + let (input, min_version) = u16::decode(input)?; + let (input, max_version) = u16::decode(input)?; + Ok((input, ExchangeCommandVersion(key, min_version, max_version))) + } +} + +impl Encoder for Vec { + fn encoded_size(&self) -> u32 { + 4 + self.iter().fold(0, |acc, v| acc + v.encoded_size()) + } + + fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + writer.write_u32::(self.len() as u32)?; + for x in self { + x.encode(writer)?; + } + Ok(()) + } +} + +impl Decoder for Vec { + fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, result) = read_vec(input)?; + Ok((input, result)) + } +} + +#[cfg_attr(test, derive(fake::Dummy))] +#[derive(PartialEq, Eq, Debug)] +pub struct ExchangeCommandVersionsRequest { + pub(crate) correlation_id: u32, + commands: Vec, +} + +impl ExchangeCommandVersionsRequest { + pub fn new(correlation_id: u32, commands: Vec) -> Self { + Self { + correlation_id, + commands, + } + } +} + +impl Encoder for ExchangeCommandVersionsRequest { + fn encoded_size(&self) -> u32 { + self.correlation_id.encoded_size() + self.commands.encoded_size() + } + + fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.correlation_id.encode(writer)?; + self.commands.encode(writer)?; + Ok(()) + } +} + +impl Command for ExchangeCommandVersionsRequest { + fn key(&self) -> u16 { + COMMAND_EXCHANGE_COMMAND_VERSIONS + } +} + +impl Decoder for ExchangeCommandVersionsRequest { + fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, correlation_id) = u32::decode(input)?; + let (input, commands) = >::decode(input)?; + Ok(( + input, + ExchangeCommandVersionsRequest { + correlation_id, + commands, + }, + )) + } +} + +#[cfg_attr(test, derive(fake::Dummy))] +#[derive(PartialEq, Eq, Debug)] +pub struct ExchangeCommandVersionsResponse { + pub(crate) correlation_id: u32, + response_code: ResponseCode, + commands: Vec, +} + +impl ExchangeCommandVersionsResponse { + pub fn new( + correlation_id: u32, + response_code: ResponseCode, + commands: Vec, + ) -> Self { + Self { + correlation_id, + response_code, + commands, + } + } + + pub fn code(&self) -> &ResponseCode { + &self.response_code + } + + pub fn is_ok(&self) -> bool { + self.response_code == ResponseCode::Ok + } + + pub fn key_version(&self, key_command: u16) -> (u16, u16) { + for i in &self.commands { + match i { + ExchangeCommandVersion(match_key_command, min_version, max_version) => { + if *match_key_command == key_command { + return (*min_version, *max_version); + } + } + } + } + + (1, 1) + } +} + +impl Encoder for ExchangeCommandVersionsResponse { + fn encoded_size(&self) -> u32 { + self.correlation_id.encoded_size() + + self.response_code.encoded_size() + + self.commands.encoded_size() + } + + fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { + self.correlation_id.encode(writer)?; + self.response_code.encode(writer)?; + self.commands.encode(writer)?; + Ok(()) + } +} + +impl Decoder for ExchangeCommandVersionsResponse { + fn decode(input: &[u8]) -> Result<(&[u8], Self), DecodeError> { + let (input, correlation_id) = u32::decode(input)?; + let (input, response_code) = ResponseCode::decode(input)?; + let (input, commands) = >::decode(input)?; + + Ok(( + input, + ExchangeCommandVersionsResponse { + correlation_id, + response_code, + commands, + }, + )) + } +} + +impl FromResponse for ExchangeCommandVersionsResponse { + fn from_response(response: crate::Response) -> Option { + match response.kind { + crate::ResponseKind::ExchangeCommandVersions(exchange_command_versions) => { + Some(exchange_command_versions) + } + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + + use crate::commands::tests::command_encode_decode_test; + + use super::{ExchangeCommandVersionsRequest, ExchangeCommandVersionsResponse}; + + #[test] + fn exchange_command_versions_request_test() { + command_encode_decode_test::(); + } + + #[test] + fn exchange_command_versions_response_test() { + command_encode_decode_test::(); + } +} diff --git a/protocol/src/commands/mod.rs b/protocol/src/commands/mod.rs index 5f07ded0..0a4ae942 100644 --- a/protocol/src/commands/mod.rs +++ b/protocol/src/commands/mod.rs @@ -1,3 +1,5 @@ +use crate::protocol::version::PROTOCOL_VERSION; + pub mod close; pub mod create_stream; pub mod credit; @@ -5,6 +7,7 @@ pub mod declare_publisher; pub mod delete; pub mod delete_publisher; pub mod deliver; +pub mod exchange_command_versions; pub mod generic; pub mod heart_beat; pub mod metadata; @@ -25,6 +28,9 @@ pub mod unsubscribe; pub trait Command { fn key(&self) -> u16; + fn version(&self) -> u16 { + PROTOCOL_VERSION + } } #[cfg(test)] diff --git a/protocol/src/commands/publish.rs b/protocol/src/commands/publish.rs index 22d2337a..7c8f3f3b 100644 --- a/protocol/src/commands/publish.rs +++ b/protocol/src/commands/publish.rs @@ -4,11 +4,13 @@ use crate::{ codec::{Decoder, Encoder}, error::{DecodeError, EncodeError}, protocol::commands::COMMAND_PUBLISH, + protocol::version::PROTOCOL_VERSION_2, }; use super::Command; use crate::types::PublishedMessage; + #[cfg(test)] use fake::Fake; @@ -17,25 +19,36 @@ use fake::Fake; pub struct PublishCommand { publisher_id: u8, published_messages: Vec, + #[cfg_attr(test, dummy(faker = "1"))] + version: u16, } impl PublishCommand { - pub fn new(publisher_id: u8, published_messages: Vec) -> Self { + pub fn new(publisher_id: u8, published_messages: Vec, version: u16) -> Self { Self { publisher_id, published_messages, + version, } } } impl Encoder for PublishCommand { fn encoded_size(&self) -> u32 { - self.publisher_id.encoded_size() + self.published_messages.encoded_size() + if self.version == PROTOCOL_VERSION_2 { + self.publisher_id.encoded_size() + self.published_messages.encoded_size_version_2() + } else { + self.publisher_id.encoded_size() + self.published_messages.encoded_size() + } } fn encode(&self, writer: &mut impl Write) -> Result<(), EncodeError> { self.publisher_id.encode(writer)?; - self.published_messages.encode(writer)?; + if self.version == PROTOCOL_VERSION_2 { + self.published_messages.encode_version_2(writer)?; + } else { + self.published_messages.encode(writer)?; + } Ok(()) } } @@ -44,6 +57,10 @@ impl Command for PublishCommand { fn key(&self) -> u16 { COMMAND_PUBLISH } + + fn version(&self) -> u16 { + self.version + } } impl Decoder for PublishCommand { @@ -56,6 +73,7 @@ impl Decoder for PublishCommand { PublishCommand { publisher_id, published_messages, + version: 1, }, )) } diff --git a/protocol/src/protocol.rs b/protocol/src/protocol.rs index b1484a40..d87be920 100644 --- a/protocol/src/protocol.rs +++ b/protocol/src/protocol.rs @@ -25,6 +25,11 @@ pub mod commands { pub const COMMAND_OPEN: u16 = 21; pub const COMMAND_CLOSE: u16 = 22; pub const COMMAND_HEARTBEAT: u16 = 23; + pub const COMMAND_ROUTE: u16 = 24; + pub const COMMAND_PARTITIONS: u16 = 25; + pub const COMMAND_CONSUMER_UPDATE: u16 = 26; + pub const COMMAND_EXCHANGE_COMMAND_VERSIONS: u16 = 27; + pub const COMMAND_STREAMS_STATS: u16 = 28; } // server responses @@ -57,4 +62,5 @@ pub mod responses { #[allow(unused)] pub mod version { pub const PROTOCOL_VERSION: u16 = 1; + pub const PROTOCOL_VERSION_2: u16 = 2; } diff --git a/protocol/src/request/mod.rs b/protocol/src/request/mod.rs index 05d52608..fc616436 100644 --- a/protocol/src/request/mod.rs +++ b/protocol/src/request/mod.rs @@ -5,7 +5,8 @@ use crate::{ commands::{ close::CloseRequest, create_stream::CreateStreamCommand, credit::CreditCommand, declare_publisher::DeclarePublisherCommand, delete::Delete, - delete_publisher::DeletePublisherCommand, heart_beat::HeartBeatCommand, + delete_publisher::DeletePublisherCommand, + exchange_command_versions::ExchangeCommandVersionsRequest, heart_beat::HeartBeatCommand, metadata::MetadataCommand, open::OpenCommand, peer_properties::PeerPropertiesCommand, publish::PublishCommand, query_offset::QueryOffsetRequest, query_publisher_sequence::QueryPublisherRequest, @@ -58,6 +59,7 @@ pub enum RequestKind { QueryPublisherSequence(QueryPublisherRequest), StoreOffset(StoreOffset), Unsubscribe(UnSubscribeCommand), + ExchangeCommandVersions(ExchangeCommandVersionsRequest), } impl Encoder for RequestKind { @@ -82,6 +84,9 @@ impl Encoder for RequestKind { RequestKind::QueryPublisherSequence(query_publisher) => query_publisher.encoded_size(), RequestKind::StoreOffset(store_offset) => store_offset.encoded_size(), RequestKind::Unsubscribe(unsubscribe) => unsubscribe.encoded_size(), + RequestKind::ExchangeCommandVersions(exchange_command_versions) => { + exchange_command_versions.encoded_size() + } } } @@ -106,6 +111,9 @@ impl Encoder for RequestKind { RequestKind::QueryPublisherSequence(query_publisher) => query_publisher.encode(writer), RequestKind::StoreOffset(store_offset) => store_offset.encode(writer), RequestKind::Unsubscribe(unsubcribe) => unsubcribe.encode(writer), + RequestKind::ExchangeCommandVersions(exchange_command_versions) => { + exchange_command_versions.encode(writer) + } } } } @@ -171,6 +179,9 @@ impl Decoder for Request { COMMAND_UNSUBSCRIBE => { UnSubscribeCommand::decode(input).map(|(i, kind)| (i, kind.into()))? } + COMMAND_EXCHANGE_COMMAND_VERSIONS => { + ExchangeCommandVersionsRequest::decode(input).map(|(i, kind)| (i, kind.into()))? + } n => return Err(DecodeError::UnsupportedResponseType(n)), }; Ok((input, Request { header, kind: cmd })) @@ -185,10 +196,11 @@ mod tests { commands::{ close::CloseRequest, create_stream::CreateStreamCommand, credit::CreditCommand, declare_publisher::DeclarePublisherCommand, delete::Delete, - delete_publisher::DeletePublisherCommand, heart_beat::HeartBeatCommand, - metadata::MetadataCommand, open::OpenCommand, peer_properties::PeerPropertiesCommand, - publish::PublishCommand, query_offset::QueryOffsetRequest, - query_publisher_sequence::QueryPublisherRequest, + delete_publisher::DeletePublisherCommand, + exchange_command_versions::ExchangeCommandVersionsRequest, + heart_beat::HeartBeatCommand, metadata::MetadataCommand, open::OpenCommand, + peer_properties::PeerPropertiesCommand, publish::PublishCommand, + query_offset::QueryOffsetRequest, query_publisher_sequence::QueryPublisherRequest, sasl_authenticate::SaslAuthenticateCommand, sasl_handshake::SaslHandshakeCommand, store_offset::StoreOffset, subscribe::SubscribeCommand, tune::TunesCommand, unsubscribe::UnSubscribeCommand, Command, @@ -307,4 +319,9 @@ mod tests { assert!(remaining.is_empty()); } + + #[test] + fn request_exchange_command_versions_test() { + request_encode_decode_test::() + } } diff --git a/protocol/src/request/shims.rs b/protocol/src/request/shims.rs index 4acc7624..323d4d51 100644 --- a/protocol/src/request/shims.rs +++ b/protocol/src/request/shims.rs @@ -2,7 +2,8 @@ use crate::{ commands::{ close::CloseRequest, create_stream::CreateStreamCommand, credit::CreditCommand, declare_publisher::DeclarePublisherCommand, delete::Delete, - delete_publisher::DeletePublisherCommand, heart_beat::HeartBeatCommand, + delete_publisher::DeletePublisherCommand, + exchange_command_versions::ExchangeCommandVersionsRequest, heart_beat::HeartBeatCommand, metadata::MetadataCommand, open::OpenCommand, peer_properties::PeerPropertiesCommand, publish::PublishCommand, query_offset::QueryOffsetRequest, query_publisher_sequence::QueryPublisherRequest, @@ -10,7 +11,6 @@ use crate::{ store_offset::StoreOffset, subscribe::SubscribeCommand, tune::TunesCommand, unsubscribe::UnSubscribeCommand, Command, }, - protocol::version::PROTOCOL_VERSION, types::Header, Request, RequestKind, }; @@ -20,7 +20,7 @@ where { fn from(cmd: T) -> Self { Request { - header: Header::new(cmd.key(), PROTOCOL_VERSION), + header: Header::new(cmd.key(), cmd.version()), kind: cmd.into(), } } @@ -130,3 +130,8 @@ impl From for RequestKind { RequestKind::Unsubscribe(cmd) } } +impl From for RequestKind { + fn from(cmd: ExchangeCommandVersionsRequest) -> Self { + RequestKind::ExchangeCommandVersions(cmd) + } +} diff --git a/protocol/src/response/mod.rs b/protocol/src/response/mod.rs index a779c7f6..905a0954 100644 --- a/protocol/src/response/mod.rs +++ b/protocol/src/response/mod.rs @@ -7,7 +7,8 @@ use crate::{ }, commands::{ close::CloseResponse, credit::CreditResponse, deliver::DeliverCommand, - generic::GenericResponse, heart_beat::HeartbeatResponse, metadata::MetadataResponse, + exchange_command_versions::ExchangeCommandVersionsResponse, generic::GenericResponse, + heart_beat::HeartbeatResponse, metadata::MetadataResponse, metadata_update::MetadataUpdateCommand, open::OpenResponse, peer_properties::PeerPropertiesResponse, publish_confirm::PublishConfirm, publish_error::PublishErrorResponse, query_offset::QueryOffsetResponse, @@ -66,6 +67,7 @@ pub enum ResponseKind { QueryOffset(QueryOffsetResponse), QueryPublisherSequence(QueryPublisherResponse), Credit(CreditResponse), + ExchangeCommandVersions(ExchangeCommandVersionsResponse), } impl Response { @@ -92,6 +94,9 @@ impl Response { ResponseKind::Heartbeat(_) => None, ResponseKind::Deliver(_) => None, ResponseKind::Credit(_) => None, + ResponseKind::ExchangeCommandVersions(exchange_command_versions) => { + Some(exchange_command_versions.correlation_id) + } } } @@ -115,7 +120,6 @@ impl Decoder for Response { let (input, _) = read_u32(input)?; let (input, header) = Header::decode(input)?; - let (input, kind) = match header.key() { COMMAND_OPEN => { OpenResponse::decode(input).map(|(i, kind)| (i, ResponseKind::Open(kind)))? @@ -164,7 +168,10 @@ impl Decoder for Response { COMMAND_QUERY_PUBLISHER_SEQUENCE => QueryPublisherResponse::decode(input) .map(|(remaining, kind)| (remaining, ResponseKind::QueryPublisherSequence(kind)))?, - + COMMAND_EXCHANGE_COMMAND_VERSIONS => ExchangeCommandVersionsResponse::decode(input) + .map(|(remaining, kind)| { + (remaining, ResponseKind::ExchangeCommandVersions(kind)) + })?, n => return Err(DecodeError::UnsupportedResponseType(n)), }; Ok((input, Response { header, kind })) @@ -195,7 +202,8 @@ mod tests { use crate::{ codec::{Decoder, Encoder}, commands::{ - close::CloseResponse, deliver::DeliverCommand, generic::GenericResponse, + close::CloseResponse, deliver::DeliverCommand, + exchange_command_versions::ExchangeCommandVersionsResponse, generic::GenericResponse, heart_beat::HeartbeatResponse, metadata::MetadataResponse, metadata_update::MetadataUpdateCommand, open::OpenResponse, peer_properties::PeerPropertiesResponse, publish_confirm::PublishConfirm, @@ -213,6 +221,7 @@ mod tests { }, version::PROTOCOL_VERSION, }, + response::COMMAND_EXCHANGE_COMMAND_VERSIONS, types::Header, ResponseCode, }; @@ -238,6 +247,9 @@ mod tests { query_publisher.encoded_size() } ResponseKind::Credit(credit) => credit.encoded_size(), + ResponseKind::ExchangeCommandVersions(exchange_command_versions) => { + exchange_command_versions.encoded_size() + } } } @@ -263,6 +275,9 @@ mod tests { query_publisher.encode(writer) } ResponseKind::Credit(credit) => credit.encode(writer), + ResponseKind::ExchangeCommandVersions(exchange_command_versions) => { + exchange_command_versions.encode(writer) + } } } } @@ -423,4 +438,13 @@ mod tests { COMMAND_HEARTBEAT ); } + + #[test] + fn exchange_command_versions_response_test() { + response_test!( + ExchangeCommandVersionsResponse, + ResponseKind::ExchangeCommandVersions, + COMMAND_EXCHANGE_COMMAND_VERSIONS + ); + } } diff --git a/protocol/src/types.rs b/protocol/src/types.rs index e8fe6277..8bf0ee3f 100644 --- a/protocol/src/types.rs +++ b/protocol/src/types.rs @@ -30,13 +30,16 @@ use crate::{message::Message, ResponseCode}; pub struct PublishedMessage { pub(crate) publishing_id: u64, pub(crate) message: Message, + #[cfg_attr(test, dummy(expr = "None"))] + pub(crate) filter_value: Option, } impl PublishedMessage { - pub fn new(publishing_id: u64, message: Message) -> Self { + pub fn new(publishing_id: u64, message: Message, filter_value: Option) -> Self { Self { publishing_id, message, + filter_value, } } diff --git a/src/client/message.rs b/src/client/message.rs index 40bd4a9b..7e38f2f0 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -3,6 +3,7 @@ use rabbitmq_stream_protocol::message::Message; pub trait BaseMessage { fn publishing_id(&self) -> Option; fn to_message(self) -> Message; + fn filter_value(&self) -> Option; } impl BaseMessage for Message { @@ -13,21 +14,31 @@ impl BaseMessage for Message { fn to_message(self) -> Message { self } + + fn filter_value(&self) -> Option { + None + } } #[derive(Debug)] pub struct ClientMessage { publishing_id: u64, message: Message, + filter_value: Option, } impl ClientMessage { - pub fn new(publishing_id: u64, message: Message) -> Self { + pub fn new(publishing_id: u64, message: Message, filter_value: Option) -> Self { Self { publishing_id, message, + filter_value, } } + + pub fn filter_value_extract(&mut self, filter_value_extractor: impl Fn(&Message) -> String) { + self.filter_value = Some(filter_value_extractor(&self.message)); + } } impl BaseMessage for ClientMessage { @@ -38,4 +49,8 @@ impl BaseMessage for ClientMessage { fn to_message(self) -> Message { self.message } + + fn filter_value(&self) -> Option { + self.filter_value.clone() + } } diff --git a/src/client/mod.rs b/src/client/mod.rs index f0c671d6..83e2a096 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -16,6 +16,9 @@ use futures::{ Stream, StreamExt, TryFutureExt, }; use pin_project::pin_project; +use rabbitmq_stream_protocol::commands::exchange_command_versions::{ + ExchangeCommandVersionsRequest, ExchangeCommandVersionsResponse, +}; use rustls::PrivateKey; use rustls::ServerName; use std::{fs::File, io::BufReader, path::Path}; @@ -190,6 +193,7 @@ pub struct Client { opts: ClientOptions, tune_notifier: Arc, publish_sequence: Arc, + filtering_supported: bool, } impl Client { @@ -216,10 +220,17 @@ impl Client { state: Arc::new(RwLock::new(state)), tune_notifier: Arc::new(Notify::new()), publish_sequence: Arc::new(AtomicU64::new(1)), + filtering_supported: false, }; client.initialize(receiver).await?; + let command_versions = client.exchange_command_versions().await?; + let (_, max_version) = command_versions.key_version(2); + if max_version >= 2 { + client.filtering_supported = true + } + Ok(client) } @@ -372,16 +383,17 @@ impl Client { &self, publisher_id: u8, messages: impl Into>, + version: u16, ) -> RabbitMQStreamResult> { let messages: Vec = messages .into() .into_iter() .map(|message| { - let publishing_id = message + let publishing_id: u64 = message .publishing_id() .unwrap_or_else(|| self.publish_sequence.fetch_add(1, Ordering::Relaxed)); - - PublishedMessage::new(publishing_id, message.to_message()) + let filter_value = message.filter_value(); + PublishedMessage::new(publishing_id, message.to_message(), filter_value) }) .collect(); let sequences = messages @@ -391,7 +403,7 @@ impl Client { let len = messages.len(); // TODO batch publish with max frame size check - self.send(PublishCommand::new(publisher_id, messages)) + self.send(PublishCommand::new(publisher_id, messages, version)) .await?; self.opts.collector.publish(len as u64).await; @@ -411,6 +423,19 @@ impl Client { .map(|sequence| sequence.from_response()) } + pub async fn exchange_command_versions( + &self, + ) -> RabbitMQStreamResult { + self.send_and_receive::(|correlation_id| { + ExchangeCommandVersionsRequest::new(correlation_id, vec![]) + }) + .await + } + + pub fn filtering_supported(&self) -> bool { + self.filtering_supported + } + async fn create_connection( broker: &ClientOptions, ) -> Result< diff --git a/src/consumer.rs b/src/consumer.rs index 13123d04..1b2a37b5 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -29,6 +29,8 @@ use futures::{task::AtomicWaker, Stream}; use rand::rngs::StdRng; use rand::{seq::SliceRandom, SeedableRng}; +type FilterPredicate = Option bool + Send + Sync>>; + /// API for consuming RabbitMQ stream messages pub struct Consumer { // Mandatory in case of manual offset tracking @@ -46,6 +48,7 @@ struct ConsumerInternal { closed: Arc, waker: AtomicWaker, metrics_collector: Arc, + filter_configuration: Option, } impl ConsumerInternal { @@ -54,11 +57,37 @@ impl ConsumerInternal { } } +#[derive(Clone)] +pub struct FilterConfiguration { + filter_values: Vec, + pub predicate: FilterPredicate, + match_unfiltered: bool, +} + +impl FilterConfiguration { + pub fn new(filter_values: Vec, match_unfiltered: bool) -> Self { + Self { + filter_values, + match_unfiltered, + predicate: None, + } + } + + pub fn post_filter( + mut self, + predicate: impl Fn(&Message) -> bool + 'static + Send + Sync, + ) -> FilterConfiguration { + self.predicate = Some(Arc::new(predicate)); + self + } +} + /// Builder for [`Consumer`] pub struct ConsumerBuilder { pub(crate) consumer_name: Option, pub(crate) environment: Environment, pub(crate) offset_specification: OffsetSpecification, + pub(crate) filter_configuration: Option, } impl ConsumerBuilder { @@ -116,17 +145,35 @@ impl ConsumerBuilder { closed: Arc::new(AtomicBool::new(false)), waker: AtomicWaker::new(), metrics_collector: collector, + filter_configuration: self.filter_configuration.clone(), }); let msg_handler = ConsumerMessageHandler(consumer.clone()); client.set_handler(msg_handler).await; + let mut properties = HashMap::new(); + if let Some(filter_input) = self.filter_configuration { + if !client.filtering_supported() { + return Err(ConsumerCreateError::FilteringNotSupport); + } + for (index, item) in filter_input.filter_values.iter().enumerate() { + let key = format!("filter.{}", index); + properties.insert(key, item.to_owned()); + } + + let match_unfiltered_key = "match-unfiltered".to_string(); + properties.insert( + match_unfiltered_key, + filter_input.match_unfiltered.to_string(), + ); + } + let response = client .subscribe( subscription_id, stream, self.offset_specification, 1, - HashMap::new(), + properties, ) .await?; @@ -153,6 +200,11 @@ impl ConsumerBuilder { self.consumer_name = Some(String::from(consumer_name)); self } + + pub fn filter_input(mut self, filter_configuration: Option) -> Self { + self.filter_configuration = filter_configuration; + self + } } impl Consumer { @@ -244,7 +296,25 @@ impl MessageHandler for ConsumerMessageHandler { let len = delivery.messages.len(); trace!("Got delivery with messages {}", len); - for message in delivery.messages { + + // // client filter + let messages = match &self.0.filter_configuration { + Some(filter_input) => { + if let Some(f) = &filter_input.predicate { + delivery + .messages + .into_iter() + .filter(|message| f(message)) + .collect::>() + } else { + delivery.messages + } + } + + None => delivery.messages, + }; + + for message in messages { if let OffsetSpecification::Offset(offset_) = self.0.offset_specification { if offset_ > offset { offset += 1; diff --git a/src/environment.rs b/src/environment.rs index cf472729..8b1a590d 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -45,6 +45,7 @@ impl Environment { batch_size: 100, batch_publishing_delay: Duration::from_millis(100), data: PhantomData, + filter_value_extractor: None, } } @@ -54,6 +55,7 @@ impl Environment { consumer_name: None, environment: self.clone(), offset_specification: OffsetSpecification::Next, + filter_configuration: None, } } pub(crate) async fn create_client(&self) -> RabbitMQStreamResult { diff --git a/src/error.rs b/src/error.rs index 8b3b9fbd..22d5c7f1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -88,6 +88,9 @@ pub enum ProducerCreateError { #[error(transparent)] Client(#[from] ClientError), + + #[error("Filtering is not supported by the broker (requires RabbitMQ 3.13+ and stream_filtering feature flag activated)")] + FilteringNotSupport, } #[derive(Error, Debug)] @@ -133,6 +136,9 @@ pub enum ConsumerCreateError { #[error(transparent)] Client(#[from] ClientError), + + #[error("Filtering is not supported by the broker (requires RabbitMQ 3.13+ and stream_filtering feature flag activated)")] + FilteringNotSupport, } #[derive(Error, Debug)] diff --git a/src/lib.rs b/src/lib.rs index bcf2a13f..b6d2ab24 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -84,7 +84,7 @@ pub type RabbitMQStreamResult = Result; pub use crate::client::{Client, ClientOptions, MetricsCollector}; -pub use crate::consumer::{Consumer, ConsumerBuilder, ConsumerHandle}; +pub use crate::consumer::{Consumer, ConsumerBuilder, ConsumerHandle, FilterConfiguration}; pub use crate::environment::{Environment, EnvironmentBuilder, TlsConfiguration}; pub use crate::producer::{Dedup, NoDedup, Producer, ProducerBuilder}; pub mod types { diff --git a/src/producer.rs b/src/producer.rs index db8169ca..89126795 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -27,6 +27,7 @@ use crate::{ }; type WaiterMap = Arc>; +type FilterValueExtractor = Arc String + 'static + Send + Sync>; type ConfirmCallback = Arc< dyn Fn(Result) -> BoxFuture<'static, ()> @@ -73,6 +74,8 @@ pub struct ProducerInternal { waiting_confirmations: WaiterMap, closed: Arc, accumulator: MessageAccumulator, + publish_version: u16, + filter_value_extractor: Option, } impl ProducerInternal { @@ -81,7 +84,9 @@ impl ProducerInternal { if !messages.is_empty() { debug!("Sending batch of {} messages", messages.len()); - self.client.publish(self.producer_id, messages).await?; + self.client + .publish(self.producer_id, messages, self.publish_version) + .await?; } Ok(()) @@ -99,6 +104,7 @@ pub struct ProducerBuilder { pub batch_size: usize, pub batch_publishing_delay: Duration, pub(crate) data: PhantomData, + pub filter_value_extractor: Option, } #[derive(Clone)] @@ -112,6 +118,17 @@ impl ProducerBuilder { // The leader is the recommended node for writing, because writing to a replica will redundantly pass these messages // to the leader anyway - it is the only one capable of writing. let mut client = self.environment.create_client().await?; + + let mut publish_version = 1; + + if self.filter_value_extractor.is_some() { + if client.filtering_supported() { + publish_version = 2 + } else { + return Err(ProducerCreateError::FilteringNotSupport); + } + } + let metrics_collector = self.environment.options.client_options.collector.clone(); if let Some(metadata) = client.metadata(vec![stream.to_string()]).await?.get(stream) { tracing::debug!( @@ -177,8 +194,10 @@ impl ProducerBuilder { client, publish_sequence, waiting_confirmations, + publish_version, closed: Arc::new(AtomicBool::new(false)), accumulator: MessageAccumulator::new(self.batch_size), + filter_value_extractor: self.filter_value_extractor, }; let internal_producer = Arc::new(producer); @@ -211,8 +230,18 @@ impl ProducerBuilder { batch_size: self.batch_size, batch_publishing_delay: self.batch_publishing_delay, data: PhantomData, + filter_value_extractor: None, } } + + pub fn filter_value_extractor( + mut self, + filter_value_extractor: impl Fn(&Message) -> String + Send + Sync + 'static, + ) -> Self { + let f = Arc::new(filter_value_extractor); + self.filter_value_extractor = Some(f); + self + } } pub struct MessageAccumulator { @@ -437,7 +466,11 @@ impl Producer { Some(publishing_id) => *publishing_id, None => self.0.publish_sequence.fetch_add(1, Ordering::Relaxed), }; - let msg = ClientMessage::new(publishing_id, message.clone()); + let mut msg = ClientMessage::new(publishing_id, message.clone(), None); + + if let Some(f) = self.0.filter_value_extractor.as_ref() { + msg.filter_value_extract(f.as_ref()) + } let waiter = ProducerMessageWaiter::waiter_with_cb(cb, message); self.0.waiting_confirmations.insert(publishing_id, waiter); @@ -470,7 +503,11 @@ impl Producer { None => self.0.publish_sequence.fetch_add(1, Ordering::Relaxed), }; - wrapped_msgs.push(ClientMessage::new(publishing_id, message)); + let mut client_message = ClientMessage::new(publishing_id, message, None); + if let Some(f) = self.0.filter_value_extractor.as_ref() { + client_message.filter_value_extract(f.as_ref()) + } + wrapped_msgs.push(client_message); self.0 .waiting_confirmations @@ -479,7 +516,7 @@ impl Producer { self.0 .client - .publish(self.0.producer_id, wrapped_msgs) + .publish(self.0.producer_id, wrapped_msgs, self.0.publish_version) .await?; Ok(()) diff --git a/tests/integration/client_test.rs b/tests/integration/client_test.rs index b75eef8f..c9dae67f 100644 --- a/tests/integration/client_test.rs +++ b/tests/integration/client_test.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use fake::{Fake, Faker}; -use rabbitmq_stream_protocol::commands::close::CloseRequest; use tokio::sync::mpsc::channel; use rabbitmq_stream_client::error::ClientError; @@ -352,7 +351,7 @@ async fn client_publish() { let sequences = test .client - .publish(1, Message::builder().body(b"message".to_vec()).build()) + .publish(1, Message::builder().body(b"message".to_vec()).build(), 1) .await .unwrap(); @@ -378,3 +377,11 @@ async fn client_handle_unexpected_connection_interruption() { let res = Client::connect(options).await; assert!(matches!(res, Err(ClientError::ConnectionClosed))); } + +#[tokio::test(flavor = "multi_thread")] +async fn client_exchange_command_versions() { + let test = TestClient::create().await; + + let response = test.client.exchange_command_versions().await.unwrap(); + assert_eq!(&ResponseCode::Ok, response.code()); +} diff --git a/tests/integration/consumer_test.rs b/tests/integration/consumer_test.rs index fc0e0cc2..035c8c68 100644 --- a/tests/integration/consumer_test.rs +++ b/tests/integration/consumer_test.rs @@ -9,10 +9,11 @@ use rabbitmq_stream_client::{ ProducerCloseError, }, types::{Delivery, Message, OffsetSpecification}, - Consumer, NoDedup, Producer, + Consumer, FilterConfiguration, NoDedup, Producer, }; use rabbitmq_stream_protocol::ResponseCode; +use std::sync::Arc; #[tokio::test(flavor = "multi_thread")] async fn consumer_test() { @@ -339,7 +340,7 @@ async fn consumer_test_with_store_offset() { consumer_store.handle().close().await.unwrap(); - let mut consumer_query = env + let consumer_query = env .env .consumer() .offset(OffsetSpecification::First) @@ -355,3 +356,170 @@ async fn consumer_test_with_store_offset() { consumer_query.handle().close().await.unwrap(); producer.close().await.unwrap(); } + +#[tokio::test(flavor = "multi_thread")] +async fn consumer_test_with_filtering() { + let env = TestEnvironment::create().await; + let reference: String = Faker.fake(); + + let message_count = 10; + let mut producer = env + .env + .producer() + .name(&reference) + .filter_value_extractor(|_| "filtering".to_string()) + .build(&env.stream) + .await + .unwrap(); + + let filter_configuration = FilterConfiguration::new(vec!["filtering".to_string()], false) + .post_filter(|message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) + == "filtering".to_string() + }); + + let mut consumer = env + .env + .consumer() + .offset(OffsetSpecification::First) + .filter_input(Some(filter_configuration)) + .build(&env.stream) + .await + .unwrap(); + + for _ in 0..message_count { + let _ = producer + .send_with_confirm(Message::builder().body("filtering").build()) + .await + .unwrap(); + + let _ = producer + .send_with_confirm(Message::builder().body("not filtering").build()) + .await + .unwrap(); + } + + let response = Arc::new(tokio::sync::Mutex::new(vec![])); + let response_clone = Arc::clone(&response); + + let task = tokio::task::spawn(async move { + loop { + let delivery = consumer.next().await.unwrap(); + + let d = delivery.unwrap(); + let data = d + .message() + .data() + .map(|data| String::from_utf8(data.to_vec()).unwrap()) + .unwrap(); + + let mut r = response_clone.lock().await; + r.push(data); + } + }); + + let _ = tokio::time::timeout(tokio::time::Duration::from_secs(3), task).await; + let repsonse_length = response.lock().await.len(); + let filtering_response_length = response + .lock() + .await + .iter() + .filter(|item| item == &&"filtering") + .collect::>() + .len(); + + assert!(repsonse_length == filtering_response_length); + producer.close().await.unwrap(); +} + +#[tokio::test(flavor = "multi_thread")] +async fn consumer_test_with_filtering_match_unfiltered() { + let env = TestEnvironment::create().await; + let reference: String = Faker.fake(); + + let message_count = 10; + let mut producer = env + .env + .producer() + .name(&reference) + .filter_value_extractor(|message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap() + }) + .build(&env.stream) + .await + .unwrap(); + + // publish filtering message + for i in 0..message_count { + producer + .send_with_confirm(Message::builder().body(i.to_string()).build()) + .await + .unwrap(); + } + + producer.close().await.unwrap(); + + let mut producer = env + .env + .producer() + .name(&reference) + .build(&env.stream) + .await + .unwrap(); + + // publish unset filter value + for i in 0..message_count { + producer + .send_with_confirm(Message::builder().body(i.to_string()).build()) + .await + .unwrap(); + } + + producer.close().await.unwrap(); + + let filter_configuration = + FilterConfiguration::new(vec!["1".to_string()], true).post_filter(|message| { + String::from_utf8(message.data().unwrap().to_vec()).unwrap_or("".to_string()) + == "1".to_string() + }); + + let mut consumer = env + .env + .consumer() + .offset(OffsetSpecification::First) + .filter_input(Some(filter_configuration)) + .build(&env.stream) + .await + .unwrap(); + + let response = Arc::new(tokio::sync::Mutex::new(vec![])); + let response_clone = Arc::clone(&response); + + let task = tokio::task::spawn(async move { + loop { + let delivery = consumer.next().await.unwrap(); + + let d = delivery.unwrap(); + let data = d + .message() + .data() + .map(|data| String::from_utf8(data.to_vec()).unwrap()) + .unwrap(); + + let mut r = response_clone.lock().await; + r.push(data); + } + }); + + let _ = tokio::time::timeout(tokio::time::Duration::from_secs(3), task).await; + let repsonse_length = response.lock().await.len(); + let filtering_response_length = response + .lock() + .await + .iter() + .filter(|item| item == &&"1") + .collect::>() + .len(); + + assert!(repsonse_length == filtering_response_length); +} diff --git a/tests/integration/producer_test.rs b/tests/integration/producer_test.rs index 3836d564..2caf5751 100644 --- a/tests/integration/producer_test.rs +++ b/tests/integration/producer_test.rs @@ -3,7 +3,7 @@ use fake::{Fake, Faker}; use futures::StreamExt; use tokio::sync::mpsc::channel; -use rabbitmq_stream_client::types::{Message, OffsetSpecification}; +use rabbitmq_stream_client::types::{Message, OffsetSpecification, SimpleValue}; use crate::common::TestEnvironment; @@ -384,3 +384,47 @@ async fn producer_send_after_close_error() { true ); } + +#[tokio::test(flavor = "multi_thread")] +async fn producer_send_filtering_message() { + let env = TestEnvironment::create().await; + let producer = env + .env + .producer() + .filter_value_extractor(|message| { + let app_properties = message.application_properties(); + match app_properties { + Some(properties) => { + let value = properties.get("region").and_then(|item| match item { + SimpleValue::String(s) => Some(s.clone()), + _ => None, + }); + value.unwrap_or(String::from("")) + } + None => String::from(""), + } + }) + .build(&env.stream) + .await + .unwrap(); + producer.clone().close().await.unwrap(); + + let message_builder = Message::builder(); + let mut application_properties = message_builder.application_properties(); + application_properties = application_properties.insert("region", "emea"); + + let message = application_properties + .message_builder() + .body(b"message".to_vec()) + .build(); + + let closed = producer.send_with_confirm(message).await.unwrap_err(); + + assert_eq!( + matches!( + closed, + rabbitmq_stream_client::error::ProducerPublishError::Closed + ), + true + ); +}