diff --git a/README.md b/README.md index 4ba1e33d..aa8b5ddc 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,19 @@ let environment = Environment::builder() .build() ``` +##### Building the environment with a load balancer + +```rust,no_run +use rabbitmq_stream_client::Environment; + + +let environment = Environment::builder() + .load_balancer_mode(true) + .build() +``` + + + ##### Publishing messages ```rust,no_run diff --git a/src/client/options.rs b/src/client/options.rs index 16ae1b72..829c0f1d 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -12,6 +12,7 @@ pub struct ClientOptions { pub(crate) v_host: String, pub(crate) heartbeat: u32, pub(crate) max_frame_size: u32, + pub(crate) load_balancer_mode: bool, pub(crate) tls: TlsConfiguration, pub(crate) collector: Arc, } @@ -39,6 +40,7 @@ impl Default for ClientOptions { v_host: "/".to_owned(), heartbeat: 60, max_frame_size: 1048576, + load_balancer_mode: false, collector: Arc::new(NopMetricsCollector {}), tls: TlsConfiguration { enabled: false, @@ -117,6 +119,11 @@ impl ClientOptionsBuilder { self } + pub fn load_balancer_mode(mut self, load_balancer_mode: bool) -> Self { + self.0.load_balancer_mode = load_balancer_mode; + self + } + pub fn build(self) -> ClientOptions { self.0 } @@ -145,6 +152,7 @@ mod tests { client_keys_path: String::from(""), }) .collector(Arc::new(NopMetricsCollector {})) + .load_balancer_mode(true) .build(); assert_eq!(options.host, "test"); assert_eq!(options.port, 8888); @@ -154,5 +162,6 @@ mod tests { assert_eq!(options.heartbeat, 10000); assert_eq!(options.max_frame_size, 1); assert_eq!(options.tls.enabled, true); + assert_eq!(options.load_balancer_mode, true); } } diff --git a/src/consumer.rs b/src/consumer.rs index 59971d3f..409cd374 100644 --- a/src/consumer.rs +++ b/src/consumer.rs @@ -76,12 +76,27 @@ impl ConsumerBuilder { metadata.replicas, stream ); - client = Client::connect(ClientOptions { - host: replica.host.clone(), - port: replica.port as u16, - ..self.environment.options.client_options - }) - .await?; + let load_balancer_mode = self.environment.options.client_options.load_balancer_mode; + if load_balancer_mode { + let options = self.environment.options.client_options.clone(); + loop { + let temp_client = Client::connect(options.clone()).await?; + let mapping = temp_client.connection_properties().await; + if let Some(advertised_host) = mapping.get("advertised_host") { + if *advertised_host == replica.host.clone() { + client = temp_client; + break; + } + } + } + } else { + client = Client::connect(ClientOptions { + host: replica.host.clone(), + port: replica.port as u16, + ..self.environment.options.client_options + }) + .await?; + } } } else { return Err(ConsumerCreateError::StreamDoesNotExist { @@ -100,7 +115,6 @@ impl ConsumerBuilder { waker: AtomicWaker::new(), metrics_collector: collector, }); - let msg_handler = ConsumerMessageHandler(consumer.clone()); client.set_handler(msg_handler).await; diff --git a/src/environment.rs b/src/environment.rs index 84bd11fd..cf472729 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -121,11 +121,16 @@ impl EnvironmentBuilder { } pub fn metrics_collector( mut self, - collector: impl MetricsCollector + Send + Sync + 'static, + collector: impl MetricsCollector + 'static, ) -> EnvironmentBuilder { self.0.client_options.collector = Arc::new(collector); self } + + pub fn load_balancer_mode(mut self, load_balancer_mode: bool) -> EnvironmentBuilder { + self.0.client_options.load_balancer_mode = load_balancer_mode; + self + } } #[derive(Clone, Default)] pub struct EnvironmentOptions { diff --git a/src/producer.rs b/src/producer.rs index 39ad8bce..db8169ca 100644 --- a/src/producer.rs +++ b/src/producer.rs @@ -119,13 +119,29 @@ impl ProducerBuilder { metadata.leader, stream ); - client.close().await?; - client = Client::connect(ClientOptions { - host: metadata.leader.host.clone(), - port: metadata.leader.port as u16, - ..self.environment.options.client_options - }) - .await?; + let load_balancer_mode = self.environment.options.client_options.load_balancer_mode; + if load_balancer_mode { + // Producer must connect to leader node + let options: ClientOptions = self.environment.options.client_options.clone(); + loop { + let temp_client = Client::connect(options.clone()).await?; + let mapping = temp_client.connection_properties().await; + if let Some(advertised_host) = mapping.get("advertised_host") { + if *advertised_host == metadata.leader.host.clone() { + client = temp_client; + break; + } + } + } + } else { + client.close().await?; + client = Client::connect(ClientOptions { + host: metadata.leader.host.clone(), + port: metadata.leader.port as u16, + ..self.environment.options.client_options + }) + .await? + }; } else { return Err(ProducerCreateError::StreamDoesNotExist { stream: stream.into(),