diff --git a/Cargo.lock b/Cargo.lock index cbd7d2f79c0..4a96cfee57e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4897,6 +4897,7 @@ dependencies = [ name = "tedge_mqtt_ext" version = "1.5.1" dependencies = [ + "anyhow", "assert-json-diff", "async-trait", "futures", diff --git a/crates/common/mqtt_channel/src/tests.rs b/crates/common/mqtt_channel/src/tests.rs index 6c2029b5646..b18a4c6cec7 100644 --- a/crates/common/mqtt_channel/src/tests.rs +++ b/crates/common/mqtt_channel/src/tests.rs @@ -3,6 +3,7 @@ use futures::SinkExt; use futures::StreamExt; use std::convert::TryInto; use std::time::Duration; +use tokio::time::timeout; const TIMEOUT: Duration = Duration::from_millis(1000); @@ -568,3 +569,49 @@ async fn assert_payload_received(con: &mut Connection, payload: &'static str) { not_msg => panic!("Expected message to be received, got {not_msg:?}"), } } + +#[tokio::test] +async fn connections_from_cloned_configs_are_independent() -> Result<(), anyhow::Error> { + // This test arose from an issue with dynamic subscriptions where + // subscriptions were shared between different MQTT channel instances + let broker = mqtt_tests::test_mqtt_broker(); + let mqtt_config = Config::default().with_port(broker.port); + let mqtt_config_cloned = mqtt_config.clone(); + + let topic = uniquify!("a/test/topic"); + let other_topic = uniquify!("different/test/topic"); + let mqtt_config = mqtt_config.with_subscriptions(TopicFilter::new_unchecked(topic)); + let mqtt_config_cloned = + mqtt_config_cloned.with_subscriptions(TopicFilter::new_unchecked(other_topic)); + + let mut con = Connection::new(&mqtt_config).await?; + let mut other_con = Connection::new(&mqtt_config_cloned).await?; + + // Any messages published on that topic ... + broker.publish(topic, "original topic message").await?; + broker.publish(other_topic, "other topic message").await?; + + // ... must be received by the client + assert_eq!( + MaybeMessage::Next(message(topic, "original topic message")), + next_message(&mut con.received).await + ); + assert_eq!( + MaybeMessage::Next(message(other_topic, "other topic message")), + next_message(&mut other_con.received).await + ); + + assert!( + timeout(Duration::from_millis(200), next_message(&mut con.received)) + .await + .is_err() + ); + assert!(timeout( + Duration::from_millis(200), + next_message(&mut other_con.received) + ) + .await + .is_err()); + + Ok(()) +} diff --git a/crates/common/mqtt_channel/src/topics.rs b/crates/common/mqtt_channel/src/topics.rs index 35885660eee..3cfd69ac8a3 100644 --- a/crates/common/mqtt_channel/src/topics.rs +++ b/crates/common/mqtt_channel/src/topics.rs @@ -191,6 +191,10 @@ impl TopicFilter { None } } + + pub fn is_empty(&self) -> bool { + self.patterns.is_empty() + } } impl TryInto for &str { diff --git a/crates/core/tedge_actors/src/test_helpers.rs b/crates/core/tedge_actors/src/test_helpers.rs index 5348a798564..71e29e6c456 100644 --- a/crates/core/tedge_actors/src/test_helpers.rs +++ b/crates/core/tedge_actors/src/test_helpers.rs @@ -259,7 +259,6 @@ where } } - #[allow(clippy::needless_collect)] // To avoid issues with Send constraints async fn assert_received(&mut self, expected: Samples) where Samples: IntoIterator + Send, diff --git a/crates/core/tedge_agent/src/agent.rs b/crates/core/tedge_agent/src/agent.rs index 72829937879..510235fdc74 100644 --- a/crates/core/tedge_agent/src/agent.rs +++ b/crates/core/tedge_agent/src/agent.rs @@ -58,7 +58,6 @@ use tedge_log_manager::LogManagerConfig; use tedge_log_manager::LogManagerOptions; use tedge_mqtt_ext::MqttActorBuilder; use tedge_mqtt_ext::MqttConfig; -use tedge_mqtt_ext::MqttDynamicConnector; use tedge_mqtt_ext::TopicFilter; use tedge_script_ext::ScriptActor; use tedge_signal_ext::SignalActor; @@ -388,7 +387,6 @@ impl Agent { let entity_store_server = EntityStoreServer::new( entity_store, mqtt_schema.clone(), - Box::new(MqttDynamicConnector::new(self.config.mqtt_config)), &mut mqtt_actor_builder, self.config.entity_auto_register, ); diff --git a/crates/core/tedge_agent/src/entity_manager/server.rs b/crates/core/tedge_agent/src/entity_manager/server.rs index 5bdb3e2525c..950b33febac 100644 --- a/crates/core/tedge_agent/src/entity_manager/server.rs +++ b/crates/core/tedge_agent/src/entity_manager/server.rs @@ -1,10 +1,22 @@ +use std::convert::Infallible; + use async_trait::async_trait; use serde_json::Map; use serde_json::Value; +use tedge_actors::Actor; +use tedge_actors::Builder; use tedge_actors::LoggingSender; +use tedge_actors::MappingSender; +use tedge_actors::MessageReceiver; use tedge_actors::MessageSink; +use tedge_actors::MessageSource; +use tedge_actors::NoConfig; +use tedge_actors::NoMessage; +use tedge_actors::RuntimeError; use tedge_actors::Sender; use tedge_actors::Server; +use tedge_actors::SimpleMessageBox; +use tedge_actors::SimpleMessageBoxBuilder; use tedge_api::entity::EntityMetadata; use tedge_api::entity_store; use tedge_api::entity_store::EntityRegistrationMessage; @@ -18,11 +30,10 @@ use tedge_api::mqtt_topics::EntityTopicId; use tedge_api::mqtt_topics::MqttSchema; use tedge_api::pending_entity_store::RegisteredEntityData; use tedge_api::EntityStore; -use tedge_mqtt_ext::MqttConnector; +use tedge_mqtt_ext::DynSubscriptionsInner; use tedge_mqtt_ext::MqttMessage; +use tedge_mqtt_ext::MqttRequest; use tedge_mqtt_ext::TopicFilter; -use tokio::time::timeout; -use tokio::time::Duration; use tracing::error; #[derive(Debug)] @@ -56,27 +67,94 @@ pub enum EntityStoreResponse { pub struct EntityStoreServer { entity_store: EntityStore, mqtt_schema: MqttSchema, - mqtt_connector: Box, mqtt_publisher: LoggingSender, entity_auto_register: bool, + retain_requests: SimpleMessageBox, +} + +struct DeregistrationActorBuilder { + mqtt_publish: LoggingSender, + messages: SimpleMessageBoxBuilder, +} + +impl Builder for DeregistrationActorBuilder { + type Error = Infallible; + + fn try_build(self) -> Result { + Ok(DeregistrationActor { + mqtt_publish: self.mqtt_publish, + messages: self.messages.build(), + }) + } +} + +struct DeregistrationActor { + messages: SimpleMessageBox, + mqtt_publish: LoggingSender, +} + +#[async_trait::async_trait] +impl Actor for DeregistrationActor { + fn name(&self) -> &str { + todo!() + } + + async fn run(mut self) -> Result<(), RuntimeError> { + while let Ok(Some(msg)) = self.messages.try_recv().await { + if msg.retain && !msg.payload.as_bytes().is_empty() { + let clear_msg = MqttMessage::new(&msg.topic, "").with_retain(); + self.mqtt_publish.send(clear_msg).await.unwrap(); + } + } + Ok(()) + } +} + +impl MessageSink for DeregistrationActorBuilder { + fn get_sender(&self) -> tedge_actors::DynSender { + self.messages.get_sender() + } } impl EntityStoreServer { - pub fn new( + pub fn new( entity_store: EntityStore, mqtt_schema: MqttSchema, - mqtt_connector: Box, - mqtt_actor: &mut impl MessageSink, + mqtt_actor: &mut M, entity_auto_register: bool, - ) -> Self { - let mqtt_publisher = LoggingSender::new("MqttPublisher".into(), mqtt_actor.get_sender()); + ) -> Self + where + M: MessageSink + + for<'a> MessageSource, + { + let mqtt_publisher = LoggingSender::new( + "MqttPublisher".into(), + Box::new(MappingSender::new(mqtt_actor.get_sender(), |msg| { + [MqttRequest::Publish(msg)] + })), + ); + let mut retain_requests = SimpleMessageBoxBuilder::new("DeregistrationClient", 16); + let mut dyn_subs = DynSubscriptionsInner::new(TopicFilter::empty()); + let messages = SimpleMessageBoxBuilder::new("DeregistrationActor", 16); + let dereg_actor_builder = DeregistrationActorBuilder { + mqtt_publish: mqtt_publisher.clone(), + messages, + }; + mqtt_actor.connect_sink(&mut dyn_subs, &dereg_actor_builder); + let client_id = dyn_subs.client_id(); + mqtt_actor.connect_mapped_source(NoConfig, &mut retain_requests, move |topics| { + [MqttRequest::RetrieveRetain(client_id, topics)] + }); + + // TODO - no don't do this! + tokio::spawn(dereg_actor_builder.build().run()); Self { entity_store, mqtt_schema, - mqtt_connector, mqtt_publisher, entity_auto_register, + retain_requests: retain_requests.build(), } } @@ -321,6 +399,10 @@ impl EntityStoreServer { return deleted; } + if deleted.is_empty() { + return vec![]; + } + let mut topics = TopicFilter::empty(); for entity in deleted.iter() { for channel_filter in [ @@ -340,33 +422,7 @@ impl EntityStoreServer { } } - // A single connection to retrieve all retained metadata messages for all deleted entities - match self.mqtt_connector.connect(topics.clone()).await { - Ok(mut connection) => { - while let Ok(Some(message)) = - timeout(Duration::from_secs(1), connection.next_message()).await - { - if message.retain - && !message.payload_bytes().is_empty() - && topics.accept(&message) - { - let clear_msg = MqttMessage::new(&message.topic, "").with_retain(); - if let Err(err) = self.mqtt_publisher.send(clear_msg).await { - error!( - "Failed to clear retained message on topic {} while de-registering {} due to: {err}", - topic_id, - message.topic - ); - } - } - } - - connection.disconnect().await; - } - Err(err) => { - error!("Failed to create MQTT connection for clearing entity data: {err}"); - } - } + self.retain_requests.send(topics).await.unwrap(); // Clear the entity metadata of all deleted entities bottom up for entity in deleted.iter().rev() { diff --git a/crates/core/tedge_agent/src/entity_manager/tests.rs b/crates/core/tedge_agent/src/entity_manager/tests.rs index eccb270f2ac..f52cb1f34fc 100644 --- a/crates/core/tedge_agent/src/entity_manager/tests.rs +++ b/crates/core/tedge_agent/src/entity_manager/tests.rs @@ -6,10 +6,6 @@ use crate::entity_manager::tests::model::Command; use crate::entity_manager::tests::model::Commands; use crate::entity_manager::tests::model::Protocol::HTTP; use crate::entity_manager::tests::model::Protocol::MQTT; -use async_trait::async_trait; -use futures::channel::mpsc::UnboundedReceiver; -use futures::channel::mpsc::UnboundedSender; -use futures::StreamExt; use proptest::proptest; use serde_json::json; use std::collections::HashSet; @@ -20,15 +16,14 @@ use tedge_api::entity::EntityMetadata; use tedge_api::entity::EntityType; use tedge_api::mqtt_topics::EntityTopicId; use tedge_mqtt_ext::test_helpers::assert_received_contains_str; -use tedge_mqtt_ext::MqttConnection; -use tedge_mqtt_ext::MqttConnector; -use tedge_mqtt_ext::MqttError; +use tedge_mqtt_ext::ClientId; use tedge_mqtt_ext::MqttMessage; +use tedge_mqtt_ext::MqttRequest; use tedge_mqtt_ext::TopicFilter; #[tokio::test] async fn new_entity_store() { - let (mut entity_store, _mqtt_output, _mqtt_sender) = entity::server("device-under-test"); + let (mut entity_store, _mqtt_output) = entity::server("device-under-test"); assert_eq!( entity::get(&mut entity_store, "device/main//").await, @@ -74,7 +69,7 @@ async fn removing_a_child_using_mqtt() { #[tokio::test] async fn twin_fragment_updates_published_to_mqtt() { - let (mut entity_store, mut mqtt_box, _mqtt_sender) = entity::server("device-under-test"); + let (mut entity_store, mut mqtt_box) = entity::server("device-under-test"); entity::set_twin_fragments( &mut entity_store, "device/main//", @@ -92,7 +87,7 @@ async fn twin_fragment_updates_published_to_mqtt() { #[tokio::test] async fn delete_entity_clears_retained_data() { - let (mut entity_store, mut mqtt_box, mut mqtt_sender) = entity::server("device-under-test"); + let (mut entity_store, mut mqtt_box) = entity::server("device-under-test"); entity::create_entity( &mut entity_store, "device/child0//", @@ -130,7 +125,7 @@ async fn delete_entity_clears_retained_data() { true, ), ] { - mqtt_sender + mqtt_box .send(MqttMessage::from((topic, &payload.to_string())).with_retain_flag(retain)) .await .unwrap(); @@ -140,24 +135,45 @@ async fn delete_entity_clears_retained_data() { .await .unwrap(); + let expected_topics = [ + "te/device/child0///m/+/meta", + "te/device/child0///e/+/meta", + "te/device/child0///a/+/meta", + "te/device/child0///a/+", + "te/device/child0///twin/+", + "te/device/child0///cmd/+/+", + "te/device/child0///cmd/+", + "te/device/child0///status/health", + ]; + let mut topic_filter = TopicFilter::empty(); + for topic in expected_topics { + topic_filter.add_unchecked(topic); + } + mqtt_box + .assert_received([MqttRequest::RetrieveRetain(ClientId(0), topic_filter)]) + .await; + // Assert that all retained messages for the entity are cleared, // excluding the non-retained `temp` measurement and `temp_change` event mqtt_box - .assert_received([ - MqttMessage::from(("te/device/child0///twin/x", "")).with_retain(), - MqttMessage::from(("te/device/child0///twin/y", "")).with_retain(), - MqttMessage::from(("te/device/child0///a/high_temp", "")).with_retain(), - MqttMessage::from(("te/device/child0///cmd/restart", "")).with_retain(), - MqttMessage::from(("te/device/child0///cmd/restart/123", "")).with_retain(), - MqttMessage::from(("te/device/child0///status/health", "")).with_retain(), - MqttMessage::from(("te/device/child0//", "")).with_retain(), - ]) + .assert_received_unordered( + [ + MqttMessage::from(("te/device/child0///twin/x", "")).with_retain(), + MqttMessage::from(("te/device/child0///twin/y", "")).with_retain(), + MqttMessage::from(("te/device/child0///a/high_temp", "")).with_retain(), + MqttMessage::from(("te/device/child0///cmd/restart", "")).with_retain(), + MqttMessage::from(("te/device/child0///cmd/restart/123", "")).with_retain(), + MqttMessage::from(("te/device/child0///status/health", "")).with_retain(), + MqttMessage::from(("te/device/child0//", "")).with_retain(), + ] + .map(MqttRequest::Publish), + ) .await; } #[tokio::test] async fn delete_entity_tree_clears_entities_bottom_up() { - let (mut entity_store, mut mqtt_box, _mqtt_sender) = entity::server("device-under-test"); + let (mut entity_store, mut mqtt_box) = entity::server("device-under-test"); for entity in [ ("device/child0//", EntityType::ChildDevice, None), ("device/child1//", EntityType::ChildDevice, None), @@ -188,19 +204,62 @@ async fn delete_entity_tree_clears_entities_bottom_up() { entity::delete_entity(&mut entity_store, "device/child0//") .await .unwrap(); + let expected_topics = [ + "te/device/child0///m/+/meta", + "te/device/child0///e/+/meta", + "te/device/child0///a/+/meta", + "te/device/child0///a/+", + "te/device/child0///twin/+", + "te/device/child0///cmd/+/+", + "te/device/child0///cmd/+", + "te/device/child0///status/health", + "te/device/child00///m/+/meta", + "te/device/child00///e/+/meta", + "te/device/child00///a/+/meta", + "te/device/child00///a/+", + "te/device/child00///twin/+", + "te/device/child00///cmd/+/+", + "te/device/child00///cmd/+", + "te/device/child00///status/health", + "te/device/child000///m/+/meta", + "te/device/child000///e/+/meta", + "te/device/child000///a/+/meta", + "te/device/child000///a/+", + "te/device/child000///twin/+", + "te/device/child000///cmd/+/+", + "te/device/child000///cmd/+", + "te/device/child000///status/health", + "te/device/child000/service/service0/m/+/meta", + "te/device/child000/service/service0/e/+/meta", + "te/device/child000/service/service0/a/+/meta", + "te/device/child000/service/service0/a/+", + "te/device/child000/service/service0/twin/+", + "te/device/child000/service/service0/cmd/+/+", + "te/device/child000/service/service0/cmd/+", + "te/device/child000/service/service0/status/health", + ]; + let mut topic_filter = TopicFilter::empty(); + for topic in expected_topics { + topic_filter.add_unchecked(topic); + } + mqtt_box + .assert_received([MqttRequest::RetrieveRetain(ClientId(0), topic_filter)]) + .await; mqtt_box .assert_received([ - MqttMessage::from(("te/device/child000/service/service0", "")).with_retain(), - MqttMessage::from(("te/device/child000//", "")).with_retain(), - MqttMessage::from(("te/device/child00//", "")).with_retain(), - MqttMessage::from(("te/device/child0//", "")).with_retain(), + MqttRequest::Publish( + MqttMessage::from(("te/device/child000/service/service0", "")).with_retain(), + ), + MqttRequest::Publish(MqttMessage::from(("te/device/child000//", "")).with_retain()), + MqttRequest::Publish(MqttMessage::from(("te/device/child00//", "")).with_retain()), + MqttRequest::Publish(MqttMessage::from(("te/device/child0//", "")).with_retain()), ]) .await; } #[tokio::test] async fn clear_entity_twin_data() { - let (mut entity_store, _mqtt_box, _mqtt_sender) = entity::server("device-under-test"); + let (mut entity_store, _mqtt_box) = entity::server("device-under-test"); entity_store .process_mqtt_message(MqttMessage::from(("te/device/main///twin/x", "9")).with_retain()) @@ -232,7 +291,7 @@ proptest! { } async fn check_registrations(registrations: Commands) { - let (mut entity_store, _mqtt_output, _mqtt_sender) = entity::server("device-under-test"); + let (mut entity_store, _mqtt_output) = entity::server("device-under-test"); let mut state = model::State::new(); for Command { protocol, action } in registrations.0 { @@ -279,7 +338,7 @@ proptest! { } async fn check_registrations_from_user_pov(registrations: Commands) { - let (mut entity_store, _mqtt_output, _mqtt_sender) = entity::server("device-under-test"); + let (mut entity_store, _mqtt_output) = entity::server("device-under-test"); // Trigger all operations over HTTP to avoid pending entities (which are not visible to the user) for action in registrations.0.into_iter().map(|c| c.action) { @@ -327,13 +386,13 @@ mod entity { use crate::entity_manager::server::EntityStoreRequest; use crate::entity_manager::server::EntityStoreResponse; use crate::entity_manager::server::EntityStoreServer; - use crate::entity_manager::tests::TestMqttConnector; - use futures::channel::mpsc::UnboundedSender; use serde_json::Map; use serde_json::Value; use std::str::FromStr; use tedge_actors::Builder; - use tedge_actors::NoMessage; + use tedge_actors::MessageSink; + use tedge_actors::MessageSource; + use tedge_actors::NoConfig; use tedge_actors::Server; use tedge_actors::SimpleMessageBox; use tedge_actors::SimpleMessageBoxBuilder; @@ -343,7 +402,9 @@ mod entity { use tedge_api::mqtt_topics::EntityTopicId; use tedge_api::mqtt_topics::MqttSchema; use tedge_api::EntityStore; + use tedge_mqtt_ext::DynSubscriptionsInner; use tedge_mqtt_ext::MqttMessage; + use tedge_mqtt_ext::MqttRequest; use tempfile::TempDir; pub async fn get( @@ -414,8 +475,7 @@ mod entity { device_id: &str, ) -> ( EntityStoreServer, - SimpleMessageBox, - UnboundedSender, + SimpleMessageBox, ) { let mqtt_schema = MqttSchema::default(); let main_device = EntityRegistrationMessage::main_device(Some(device_id.to_string())); @@ -432,71 +492,40 @@ mod entity { ) .unwrap(); - let mqtt_connector = TestMqttConnector::new(); - let mqtt_sender = mqtt_connector.get_message_sender(); - let mut mqtt_actor = SimpleMessageBoxBuilder::new("MQTT", 64); + let mqtt_actor = SimpleMessageBoxBuilder::new("MQTT", 64); + let mut actor_builder = TestMqttActorBuilder { + messages: mqtt_actor, + }; let server = EntityStoreServer::new( entity_store, mqtt_schema, - Box::new(mqtt_connector), - &mut mqtt_actor, + &mut actor_builder, entity_auto_register, ); - let mqtt_output = mqtt_actor.build(); - (server, mqtt_output, mqtt_sender) + let mqtt_output = actor_builder.messages.build(); + (server, mqtt_output) } -} -#[derive(Debug)] -struct TestMqttConnector { - sender: UnboundedSender, - receiver: UnboundedReceiver, -} - -impl TestMqttConnector { - pub fn new() -> Self { - let (sender, receiver) = futures::channel::mpsc::unbounded(); - TestMqttConnector { sender, receiver } - } - - pub fn get_message_sender(&self) -> UnboundedSender { - self.sender.clone() + struct TestMqttActorBuilder { + messages: SimpleMessageBoxBuilder, } -} -#[async_trait] -impl MqttConnector for TestMqttConnector { - async fn connect( - &mut self, - _topics: TopicFilter, - ) -> Result, MqttError> { - let (mut sender, receiver) = futures::channel::mpsc::unbounded(); - while let Ok(Some(msg)) = self.receiver.try_next() { - sender.send(msg).await.unwrap(); + impl MessageSource for TestMqttActorBuilder { + fn connect_sink( + &mut self, + config: &mut DynSubscriptionsInner, + peer: &impl MessageSink, + ) { + config.set_client_id(0); + self.messages.connect_sink(NoConfig, peer); } - Ok(Box::new(TestMqttConnection::new(receiver))) - } -} - -struct TestMqttConnection { - receiver: UnboundedReceiver, -} - -impl TestMqttConnection { - pub fn new(receiver: UnboundedReceiver) -> Self { - TestMqttConnection { receiver } } -} -#[async_trait] -impl MqttConnection for TestMqttConnection { - async fn next_message(&mut self) -> Option { - self.receiver.next().await - } - - async fn disconnect(self: Box) { - // Do nothing + impl MessageSink for TestMqttActorBuilder { + fn get_sender(&self) -> tedge_actors::DynSender { + self.messages.get_sender() + } } } diff --git a/crates/extensions/tedge_mqtt_ext/Cargo.toml b/crates/extensions/tedge_mqtt_ext/Cargo.toml index df78f046939..38d69cfa4a8 100644 --- a/crates/extensions/tedge_mqtt_ext/Cargo.toml +++ b/crates/extensions/tedge_mqtt_ext/Cargo.toml @@ -12,9 +12,10 @@ repository = { workspace = true } [features] # No features on by default default = [] -test-helpers = ["dep:assert-json-diff"] +test-helpers = ["dep:assert-json-diff", "dep:anyhow"] [dependencies] +anyhow = { workspace = true, optional = true } assert-json-diff = { workspace = true, optional = true } async-trait = { workspace = true } mqtt_channel = { workspace = true } diff --git a/crates/extensions/tedge_mqtt_ext/src/lib.rs b/crates/extensions/tedge_mqtt_ext/src/lib.rs index 2be5b2f3902..fc3e19d6f82 100644 --- a/crates/extensions/tedge_mqtt_ext/src/lib.rs +++ b/crates/extensions/tedge_mqtt_ext/src/lib.rs @@ -5,7 +5,6 @@ mod tests; pub mod trie; use async_trait::async_trait; -use mqtt_channel::Connection; pub use mqtt_channel::DebugPayload; pub use mqtt_channel::MqttError; pub use mqtt_channel::MqttMessage; @@ -16,6 +15,9 @@ use mqtt_channel::SubscriberOps; pub use mqtt_channel::Topic; pub use mqtt_channel::TopicFilter; use std::convert::Infallible; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Duration; use tedge_actors::fan_in_message_type; use tedge_actors::futures::channel::mpsc; use tedge_actors::Actor; @@ -24,6 +26,7 @@ use tedge_actors::ChannelError; use tedge_actors::ClientMessageBox; use tedge_actors::CloneSender; use tedge_actors::DynSender; +use tedge_actors::MappingSender; use tedge_actors::MessageReceiver; use tedge_actors::MessageSink; use tedge_actors::MessageSource; @@ -36,6 +39,7 @@ use tedge_actors::Server; use tedge_actors::ServerActorBuilder; use tedge_actors::ServerConfig; use trie::MqtTrie; +use trie::RankTopicFilter; use trie::SubscriptionDiff; pub type MqttConfig = mqtt_channel::Config; @@ -43,55 +47,172 @@ pub type MqttConfig = mqtt_channel::Config; pub struct MqttActorBuilder { mqtt_config: mqtt_channel::Config, input_receiver: InputCombiner, - pub_or_sub_sender: PubOrSubSender, - publish_sender: mpsc::Sender, + request_sender: mpsc::Sender, subscriber_addresses: Vec>, signal_sender: mpsc::Sender, trie: TrieService, current_id: usize, subscription_diff: SubscriptionDiff, + dynamic_connect_sender: + mpsc::Sender<(InsertRequest, Box + 'static>)>, + dynamic_connect_receiver: + mpsc::Receiver<(InsertRequest, Box + 'static>)>, +} + +impl MqttRequest { + pub fn subscribe(client_id: ClientId, diff: SubscriptionDiff) -> Self { + MqttRequest::Subscribe(SubscriptionRequest { diff, client_id }) + } } struct InputCombiner { - publish_receiver: mpsc::Receiver, - subscription_request_receiver: mpsc::Receiver, signal_receiver: mpsc::Receiver, + request_receiver: mpsc::Receiver, } -#[derive(Debug)] -pub enum PublishOrSubscribe { +impl MessageSource for MqttActorBuilder { + fn connect_sink( + &mut self, + subscriptions: DynSubscriptions, + peer: &impl MessageSink, + ) { + let client_id = self.connect_id_sink(subscriptions.init_topics(), peer); + subscriptions.set_client_id(client_id); + } +} + +impl MessageSource for MqttActorBuilder { + fn connect_sink( + &mut self, + subscriptions: &mut DynSubscriptionsInner, + peer: &impl MessageSink, + ) { + let client_id = self.connect_id_sink(subscriptions.init_topics.clone(), peer); + subscriptions.client_id = Some(client_id); + } +} + +#[derive(Clone)] +pub struct DynSubscriptions { + inner: Arc>, +} +pub struct DynSubscriptionsInner { + init_topics: TopicFilter, + client_id: Option, +} + +impl DynSubscriptionsInner { + pub fn new(init_topics: TopicFilter) -> Self { + DynSubscriptionsInner { + init_topics, + client_id: None, + } + } + + pub fn client_id(&self) -> ClientId { + self.client_id.unwrap() + } + + #[cfg(feature = "test-helpers")] + pub fn set_client_id(&mut self, value: usize) { + self.client_id = Some(ClientId(value)) + } +} + +#[cfg(feature = "test-helpers")] +impl TryFrom for MqttMessage { + type Error = anyhow::Error; + fn try_from(value: MqttRequest) -> Result { + if let MqttRequest::Publish(msg) = value { + Ok(msg) + } else { + Err(anyhow::anyhow!("{value:?} is not an MQTT message!")) + } + } +} + +impl DynSubscriptions { + pub fn new(init_topics: TopicFilter) -> Self { + let inner = DynSubscriptionsInner { + init_topics, + client_id: None, + }; + DynSubscriptions { + inner: Arc::new(Mutex::new(inner)), + } + } + + fn set_client_id(&self, client_id: ClientId) { + let mut inner = self.inner.lock().unwrap(); + inner.client_id = Some(client_id); + } + + fn init_topics(&self) -> TopicFilter { + self.inner.lock().unwrap().init_topics.clone() + } + + /// Return the client id + /// + /// Panic if not properly registered as a sink of the MqttActorBuilder + pub fn client_id(&self) -> ClientId { + self.inner.lock().unwrap().client_id.unwrap() + } +} + +#[derive(Debug, Eq, PartialEq)] +pub enum MqttRequest { Publish(MqttMessage), Subscribe(SubscriptionRequest), + RetrieveRetain(ClientId, TopicFilter), +} + +#[derive(Clone)] +pub struct DynamicMqttClientHandle { + current_id: Arc>, + tx: mpsc::Sender<(InsertRequest, Box + 'static>)>, +} + +impl DynamicMqttClientHandle { + pub async fn connect_sink_dynamic( + &mut self, + topics: TopicFilter, + peer: &impl MessageSink, + ) -> ClientId { + let mut current_id = self.current_id.lock().await; + let client_id = ClientId(*current_id); + self.tx + .send((InsertRequest { client_id, topics }, peer.get_sender())) + .await + .unwrap(); + *current_id += 1; + client_id + } } impl InputCombiner { pub fn close_input(&mut self) { - self.publish_receiver.close(); - self.subscription_request_receiver.close(); + self.request_receiver.close(); self.signal_receiver.close(); } } #[async_trait] -impl MessageReceiver for InputCombiner { - async fn try_recv(&mut self) -> Result, RuntimeRequest> { +impl MessageReceiver for InputCombiner { + async fn try_recv(&mut self) -> Result, RuntimeRequest> { tokio::select! { biased; Some(runtime_request) = self.signal_receiver.next() => { Err(runtime_request) } - Some(message) = self.publish_receiver.next() => { - Ok(Some(PublishOrSubscribe::Publish(message))) - } - Some(request) = self.subscription_request_receiver.next() => { - Ok(Some(PublishOrSubscribe::Subscribe(request))) + Some(request) = self.request_receiver.next() => { + Ok(Some(request)) } else => Ok(None) } } - async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { match self.try_recv().await { Ok(Some(message)) => Some(message), _ => None, @@ -105,30 +226,26 @@ impl MessageReceiver for InputCombiner { impl MqttActorBuilder { pub fn new(config: mqtt_channel::Config) -> Self { - let (publish_sender, publish_receiver) = mpsc::channel(10); - let (subscription_request_sender, subscription_request_receiver) = mpsc::channel(10); + let (request_sender, request_receiver) = mpsc::channel(10); let (signal_sender, signal_receiver) = mpsc::channel(10); + let (dynamic_connect_sender, dynamic_connect_receiver) = mpsc::channel(10); let trie = TrieService::new(MqtTrie::default()); - let pub_or_sub_sender = PubOrSubSender { - subscription_request_sender, - publish_sender: publish_sender.clone(), - }; let input_receiver = InputCombiner { - publish_receiver, signal_receiver, - subscription_request_receiver, + request_receiver, }; MqttActorBuilder { mqtt_config: config, input_receiver, - publish_sender, subscriber_addresses: Vec::new(), signal_sender, - pub_or_sub_sender, + request_sender, trie, subscription_diff: SubscriptionDiff::empty(), current_id: 0, + dynamic_connect_sender, + dynamic_connect_receiver, } } @@ -139,12 +256,18 @@ impl MqttActorBuilder { tracing::info!(target: "MQTT sub", "{pattern}"); } - let mqtt_config = self.mqtt_config.with_subscriptions(topic_filter); + let mqtt_config = self.mqtt_config.clone().with_subscriptions(topic_filter); MqttActor::new( mqtt_config, + self.mqtt_config, self.input_receiver, self.subscriber_addresses, self.trie.builder(), + DynamicMqttClientHandle { + current_id: Arc::new(tokio::sync::Mutex::new(self.current_id)), + tx: self.dynamic_connect_sender, + }, + self.dynamic_connect_receiver, ) } } @@ -187,7 +310,10 @@ impl MqttActorBuilder { impl MessageSink for MqttActorBuilder { fn get_sender(&self) -> DynSender { - self.publish_sender.clone().into() + MappingSender::new(self.request_sender.clone().into(), |msg| { + [MqttRequest::Publish(msg)] + }) + .into() } } @@ -208,11 +334,20 @@ impl TrieService { } } +#[cfg(feature = "test-helpers")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ClientId(pub usize); +#[cfg(not(feature = "test-helpers"))] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ClientId(usize); type MatchRequest = String; +struct InsertRequest { + client_id: ClientId, + topics: TopicFilter, +} + fan_in_message_type!(TrieRequest[SubscriptionRequest, MatchRequest]: Clone, Debug); #[derive(Debug)] @@ -257,29 +392,9 @@ pub struct SubscriptionRequest { client_id: ClientId, } -#[derive(Clone, Debug)] -struct PubOrSubSender { - publish_sender: mpsc::Sender, - subscription_request_sender: mpsc::Sender, -} - -#[async_trait] -impl Sender for PubOrSubSender { - async fn send(&mut self, message: PublishOrSubscribe) -> Result<(), ChannelError> { - match message { - PublishOrSubscribe::Publish(msg) => { - Sender::<_>::send(&mut self.publish_sender, msg).await - } - PublishOrSubscribe::Subscribe(sub) => { - Sender::<_>::send(&mut self.subscription_request_sender, sub).await - } - } - } -} - -impl MessageSink for MqttActorBuilder { - fn get_sender(&self) -> DynSender { - self.pub_or_sub_sender.clone().into() +impl MessageSink for MqttActorBuilder { + fn get_sender(&self) -> DynSender { + self.request_sender.clone().into() } } @@ -303,6 +418,7 @@ impl Builder for MqttActorBuilder { pub struct FromPeers { input_receiver: InputCombiner, + base_config: mqtt_channel::Config, subscriptions: ClientMessageBox, } @@ -312,28 +428,51 @@ pub struct ToPeers { } impl FromPeers { + async fn try_recv( + &mut self, + rx_to_peers: &mut mpsc::UnboundedReceiver, + ) -> Result, RuntimeRequest> { + tokio::select! { + msg = self.input_receiver.try_recv() => msg, + msg = rx_to_peers.next() => Ok(msg), + } + } + async fn relay_messages_to( &mut self, outgoing_mqtt: &mut mpsc::UnboundedSender, + tx_to_peers: &mut mpsc::UnboundedSender<(ClientId, MqttMessage)>, client: impl SubscriberOps + Clone + Send + 'static, + rx_to_peers: &mut mpsc::UnboundedReceiver, ) -> Result<(), RuntimeError> { - while let Ok(Some(message)) = self.try_recv().await { + while let Ok(Some(message)) = self.try_recv(rx_to_peers).await { match message { - PublishOrSubscribe::Publish(message) => { + MqttRequest::Publish(message) => { tracing::debug!(target: "MQTT pub", "{message}"); SinkExt::send(outgoing_mqtt, message) .await .map_err(Box::new)?; } - PublishOrSubscribe::Subscribe(request) => { + MqttRequest::Subscribe(request) => { let TrieResponse::Diff(diff) = self .subscriptions - .await_response(TrieRequest::SubscriptionRequest(request)) + .await_response(TrieRequest::SubscriptionRequest(request.clone())) .await .map_err(Box::new)? else { unreachable!("Subscription request always returns diff") }; + let overlapping_subscriptions = request + .diff + .subscribe + .iter() + .filter(|s| { + !diff + .subscribe + .iter() + .any(|s2| RankTopicFilter(s2) >= RankTopicFilter(s)) + }) + .collect::>(); let client = client.clone(); tokio::spawn(async move { // We're running outside the main task, so we can't return an error @@ -345,6 +484,18 @@ impl FromPeers { client.unsubscribe_many(diff.unsubscribe).await.unwrap(); } }); + let mut tf = TopicFilter::empty(); + for sub in overlapping_subscriptions { + tf.add_unchecked(sub); + } + if !tf.is_empty() { + self.forward_retain_messages_to(tx_to_peers.clone(), tf, request.client_id); + } + } + MqttRequest::RetrieveRetain(client_id, topics) => { + // We don't need to create a long-lived subscription, just + // forward the retain messages for these topics + self.forward_retain_messages_to(tx_to_peers.clone(), topics, client_id); } } } @@ -355,28 +506,74 @@ impl FromPeers { // Then, publish all the messages awaiting to be sent over MQTT while let Some(message) = self.recv().await { match message { - PublishOrSubscribe::Publish(message) => { + MqttRequest::Publish(message) => { tracing::debug!(target: "MQTT pub", "{message}"); SinkExt::send(outgoing_mqtt, message) .await .map_err(Box::new)?; } // No point creating subscriptions at this point - PublishOrSubscribe::Subscribe(_) => (), + MqttRequest::Subscribe(_) => (), + MqttRequest::RetrieveRetain(_, _) => (), } } Ok(()) } + + fn forward_retain_messages_to( + &self, + mut sender: mpsc::UnboundedSender<(ClientId, MqttMessage)>, + topics: TopicFilter, + client_id: ClientId, + ) { + let dynamic_connection_config = self.base_config.clone().with_subscriptions(topics); + tokio::spawn(async move { + let mut conn = mqtt_channel::Connection::new(&dynamic_connection_config) + .await + .unwrap(); + while let Ok(msg) = + tokio::time::timeout(Duration::from_secs(10), conn.received.next()).await + { + if let Some(msg) = msg { + if msg.retain { + SinkExt::send(&mut sender, (client_id, msg)).await.unwrap(); + } + } + } + conn.close().await; + }); + } } impl ToPeers { async fn relay_messages_from( mut self, incoming_mqtt: &mut mpsc::UnboundedReceiver, + rx_from_peers: &mut mpsc::UnboundedReceiver<(ClientId, MqttMessage)>, + dynamic_connection_request: &mut mpsc::Receiver<( + InsertRequest, + Box + 'static>, + )>, + tx_from_peers: &mut mpsc::UnboundedSender, ) -> Result<(), RuntimeError> { - while let Some(message) = incoming_mqtt.next().await { - tracing::debug!(target: "MQTT recv", "{message}"); - self.send(message).await?; + loop { + tokio::select! { + message = incoming_mqtt.next() => { + let Some(message) = message else { break }; + tracing::debug!(target: "MQTT recv", "{message}"); + self.send(message).await?; + } + message = rx_from_peers.next() => { + let Some((client, message)) = message else { break }; + tracing::debug!(target: "MQTT recv", "{message}"); + self.sender_by_id(client).send(message.clone()).await?; + } + message = dynamic_connection_request.next() => { + let Some((insert_req, sender)) = message else { continue }; + self.peer_senders.push(sender); + SinkExt::send(tx_from_peers, MqttRequest::Subscribe(SubscriptionRequest { diff: SubscriptionDiff { subscribe: insert_req.topics.patterns().iter().cloned().collect(), unsubscribe: <_>::default() }, client_id: insert_req.client_id })).await?; + } + }; } Ok(()) } @@ -390,19 +587,23 @@ impl ToPeers { unreachable!("MatchRequest always returns Matched") }; for client in matches { - self.peer_senders[client.0].send(message.clone()).await?; + self.sender_by_id(client).send(message.clone()).await?; } Ok(()) } + + fn sender_by_id(&mut self, id: ClientId) -> &mut Box> { + &mut self.peer_senders[id.0] + } } #[async_trait] -impl MessageReceiver for FromPeers { - async fn try_recv(&mut self) -> Result, RuntimeRequest> { +impl MessageReceiver for FromPeers { + async fn try_recv(&mut self) -> Result, RuntimeRequest> { self.input_receiver.try_recv().await } - async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { self.input_receiver.recv().await } @@ -416,19 +617,29 @@ pub struct MqttActor { from_peers: FromPeers, to_peers: ToPeers, trie_service: ServerActorBuilder, + dynamic_client_handle: DynamicMqttClientHandle, + dynamic_connect_receiver: + mpsc::Receiver<(InsertRequest, Box + 'static>)>, } impl MqttActor { fn new( mqtt_config: mqtt_channel::Config, + base_config: mqtt_channel::Config, input_receiver: InputCombiner, peer_senders: Vec>, mut trie_service: ServerActorBuilder, + dynamic_client_handle: DynamicMqttClientHandle, + dynamic_connect_receiver: mpsc::Receiver<( + InsertRequest, + Box + 'static>, + )>, ) -> Self { MqttActor { mqtt_config, from_peers: FromPeers { input_receiver, + base_config, subscriptions: ClientMessageBox::new(&mut trie_service), }, to_peers: ToPeers { @@ -436,8 +647,14 @@ impl MqttActor { subscriptions: ClientMessageBox::new(&mut trie_service), }, trie_service, + dynamic_client_handle, + dynamic_connect_receiver, } } + + pub fn dynamic_client_handle(&self) -> DynamicMqttClientHandle { + self.dynamic_client_handle.clone() + } } #[async_trait] @@ -456,70 +673,29 @@ impl Actor for MqttActor { return Ok(()) } }; + let (mut to_peer, mut from_peer) = mpsc::unbounded(); + let (mut to_from_peer, mut from_to_peer) = mpsc::unbounded(); tokio::spawn(async move { self.trie_service.run().await }); tedge_utils::futures::select( - self.from_peers - .relay_messages_to(&mut mqtt_client.published, mqtt_client.subscriptions), - self.to_peers.relay_messages_from(&mut mqtt_client.received), + self.from_peers.relay_messages_to( + &mut mqtt_client.published, + &mut to_peer, + mqtt_client.subscriptions, + &mut from_to_peer, + ), + self.to_peers.relay_messages_from( + &mut mqtt_client.received, + &mut from_peer, + &mut self.dynamic_connect_receiver, + &mut to_from_peer, + ), ) .await } } -#[async_trait] -pub trait MqttConnector: Send { - async fn connect(&mut self, topics: TopicFilter) -> Result, MqttError>; -} - -#[async_trait] -pub trait MqttConnection: Send { - async fn next_message(&mut self) -> Option; - - async fn disconnect(self: Box); -} - -pub struct MqttConnectionImpl { - connection: Connection, -} - -impl MqttConnectionImpl { - fn new(connection: Connection) -> Self { - Self { connection } - } -} - -#[async_trait] -impl MqttConnection for MqttConnectionImpl { - async fn next_message(&mut self) -> Option { - self.connection.received.next().await - } - - async fn disconnect(self: Box) { - self.connection.close().await; - } -} - -pub struct MqttDynamicConnector { - base_mqtt_config: MqttConfig, -} - -impl MqttDynamicConnector { - pub fn new(base_mqtt_config: MqttConfig) -> Self { - Self { base_mqtt_config } - } -} - -#[async_trait] -impl MqttConnector for MqttDynamicConnector { - async fn connect(&mut self, topics: TopicFilter) -> Result, MqttError> { - let mqtt_config = self.base_mqtt_config.clone().with_subscriptions(topics); - let connection = mqtt_channel::Connection::new(&mqtt_config).await?; - Ok(Box::new(MqttConnectionImpl::new(connection))) - } -} - #[cfg(test)] mod unit_tests { use std::collections::HashMap; @@ -559,6 +735,8 @@ mod unit_tests { .subscribe_client .assert_subscribed_to(["a/b".into()]) .await; + + actor.close().await; } #[tokio::test] @@ -577,6 +755,8 @@ mod unit_tests { .subscribe_client .assert_unsubscribed_from(["a/b".into()]) .await; + + actor.close().await; } #[tokio::test] @@ -595,6 +775,8 @@ mod unit_tests { .subscribe_client .assert_subscribed_to(["#".into()]) .await; + + actor.close().await; } #[tokio::test] @@ -612,7 +794,9 @@ mod unit_tests { &Topic::new("a/b").unwrap(), "test message" )) - ) + ); + + actor.close().await; } #[tokio::test] @@ -631,33 +815,133 @@ mod unit_tests { .unwrap() .try_next() .is_err()); + + actor.close().await; + } + + #[tokio::test] + async fn publishes_messages_to_dynamically_subscribed_clients() { + let mut actor = MqttActorTest::new(&[]); + + let client_id = actor + .connect_dynamic(TopicFilter::new_unchecked("b/c")) + .await; + + actor + .subscribe_client + .assert_subscribed_to(["b/c".into()]) + .await; + + actor.receive("b/c", "test message").await; + + assert_eq!( + actor.next_message_for(client_id).await, + MqttMessage::new(&Topic::new("b/c").unwrap(), "test message") + ); + + actor.close().await; + } + + #[tokio::test] + async fn publishes_messages_only_to_subscribed_dynamic_client() { + let mut actor = MqttActorTest::new(&[]); + + let client_id = actor + .connect_dynamic(TopicFilter::new_unchecked("a/b")) + .await; + let client_id_2 = actor + .connect_dynamic(TopicFilter::new_unchecked("b/c")) + .await; + + actor + .subscribe_client + .assert_subscribed_to(["a/b".into()]) + .await; + actor + .subscribe_client + .assert_subscribed_to(["b/c".into()]) + .await; + + actor.receive("b/c", "test message").await; + + assert_eq!( + actor.next_message_for(client_id_2).await, + MqttMessage::new(&Topic::new("b/c").unwrap(), "test message") + ); + assert!(actor + .sent_to_clients + .get_mut(&client_id) + .unwrap() + .try_next() + .is_err()); + + actor.close().await; + } + + #[tokio::test] + async fn publishes_messages_separately_to_dynamic_and_non_dynamic_clients() { + let mut actor = MqttActorTest::new(&[("a/b", 0)]); + + let static_id = 0; + let dynamic_id = actor + .connect_dynamic(TopicFilter::new_unchecked("b/c")) + .await; + + actor + .subscribe_client + .assert_subscribed_to(["b/c".into()]) + .await; + + actor.receive("a/b", "test message").await; + actor.receive("b/c", "test message").await; + + assert_eq!( + actor.next_message_for(static_id).await, + MqttMessage::new(&Topic::new("a/b").unwrap(), "test message") + ); + assert_eq!( + actor.next_message_for(dynamic_id).await, + MqttMessage::new(&Topic::new("b/c").unwrap(), "test message") + ); + + actor.close().await; } struct MqttActorTest { subscribe_client: MockSubscriberOps, - sub_tx: mpsc::Sender, - pub_tx: mpsc::Sender, + req_tx: mpsc::Sender, sent_to_channel: mpsc::UnboundedReceiver, sent_to_clients: HashMap>, inject_received_message: mpsc::UnboundedSender, + from_peers: Option>>, + to_peers: Option>>, + waited: bool, + dyn_connect: DynamicMqttClientHandle, + } + + impl Drop for MqttActorTest { + fn drop(&mut self) { + if !std::thread::panicking() && !self.waited { + panic!("Call `MqttActorTest::close` at the end of the test") + } + } } impl MqttActorTest { pub fn new(default_subscriptions: &[(&str, usize)]) -> Self { - let (pub_tx, pub_rx) = mpsc::channel(10); - let (sub_tx, sub_rx) = mpsc::channel(10); + let (req_tx, req_rx) = mpsc::channel(10); let (_sig_tx, sig_rx) = mpsc::channel(10); let (mut outgoing_mqtt, sent_messages) = mpsc::unbounded(); let (inject_received_message, mut incoming_messages) = mpsc::unbounded(); let input_combiner = InputCombiner { - publish_receiver: pub_rx, - subscription_request_receiver: sub_rx, signal_receiver: sig_rx, + request_receiver: req_rx, }; let mut ts = TrieService::with_default_subscriptions(default_subscriptions); let mut fp = FromPeers { input_receiver: input_combiner, + base_config: <_>::default(), subscriptions: ClientMessageBox::new(&mut ts), }; let mut sent_to_clients = HashMap::new(); @@ -676,32 +960,89 @@ mod unit_tests { }; tokio::spawn(async move { ts.build().run().await }); + let (mut tx, mut rx) = mpsc::unbounded(); + let (mut tx2, mut rx2) = mpsc::unbounded(); + let (dyn_connect_tx, mut dyn_connect_rx) = mpsc::channel(10); + let subscribe_client = MockSubscriberOps::default(); - { + let from_peers = { let client = subscribe_client.clone(); - tokio::spawn(async move { fp.relay_messages_to(&mut outgoing_mqtt, client).await }); - } - tokio::spawn(async move { tp.relay_messages_from(&mut incoming_messages).await }); + tokio::spawn(async move { + fp.relay_messages_to(&mut outgoing_mqtt, &mut tx, client, &mut rx2) + .await + }) + }; + let to_peers = tokio::spawn(async move { + tp.relay_messages_from( + &mut incoming_messages, + &mut rx, + &mut dyn_connect_rx, + &mut tx2, + ) + .await + }); Self { subscribe_client, - sub_tx, - pub_tx, + req_tx, sent_to_clients, sent_to_channel: sent_messages, inject_received_message, + from_peers: Some(from_peers), + to_peers: Some(to_peers), + waited: false, + dyn_connect: DynamicMqttClientHandle { + current_id: Arc::new(tokio::sync::Mutex::new( + max_client_id.map_or(0, |&max| max + 1), + )), + tx: dyn_connect_tx, + }, } } + pub async fn connect_dynamic(&mut self, topics: TopicFilter) -> usize { + struct ChannelSink(mpsc::Sender); + + impl MessageSink for ChannelSink { + fn get_sender(&self) -> DynSender { + Box::new(self.0.clone()) + } + } + + let (tx, rx) = mpsc::channel(10); + let dyn_client_id = self + .dyn_connect + .connect_sink_dynamic(topics, &ChannelSink(tx)) + .await; + self.sent_to_clients.insert(dyn_client_id.0, rx); + dyn_client_id.0 + } + + /// Closes the channels associated with this actor and waits for both + /// loops to finish executing + /// + /// This allows the `SubscriberOps::drop` implementation to reliably + /// flag any unasserted communication + pub async fn close(mut self) { + self.req_tx.close_channel(); + self.inject_received_message.close_channel(); + self.from_peers.take().unwrap().await.unwrap().unwrap(); + self.to_peers.take().unwrap().await.unwrap().unwrap(); + self.waited = true; + } + /// Simulates a client sending a subscription request to the mqtt actor pub async fn send_sub(&mut self, req: SubscriptionRequest) { - SinkExt::send(&mut self.sub_tx, req).await.unwrap(); + SinkExt::send(&mut self.req_tx, MqttRequest::Subscribe(req)) + .await + .unwrap(); } + /// Simulates a client sending a publish request to the mqtt actor pub async fn publish(&mut self, topic: &str, payload: &str) { SinkExt::send( - &mut self.pub_tx, - MqttMessage::new(&Topic::new(topic).unwrap(), payload), + &mut self.req_tx, + MqttRequest::Publish(MqttMessage::new(&Topic::new(topic).unwrap(), payload)), ) .await .unwrap(); @@ -803,6 +1144,9 @@ mod unit_tests { if std::thread::panicking() { return; } + if Arc::strong_count(&self.subscribe_many) > 1 { + return; + } let subscribe = self.subscribe_many.lock().unwrap().clone(); let unsubscribe = self.unsubscribe_many.lock().unwrap().clone(); if !subscribe.is_empty() { diff --git a/crates/extensions/tedge_mqtt_ext/src/test_helpers.rs b/crates/extensions/tedge_mqtt_ext/src/test_helpers.rs index 44870f5b1f9..46f1a3b9aa3 100644 --- a/crates/extensions/tedge_mqtt_ext/src/test_helpers.rs +++ b/crates/extensions/tedge_mqtt_ext/src/test_helpers.rs @@ -8,7 +8,8 @@ pub async fn assert_received_contains_str<'a, M, I>( messages: &mut dyn MessageReceiver, expected: I, ) where - M: Into, + M: TryInto, + M::Error: std::fmt::Debug, I: IntoIterator, { for expected_msg in expected.into_iter() { @@ -19,7 +20,7 @@ pub async fn assert_received_contains_str<'a, M, I>( expected_msg ); let message = message.unwrap(); - assert_message_contains_str(&message.into(), expected_msg); + assert_message_contains_str(&message.try_into().unwrap(), expected_msg); } } diff --git a/crates/extensions/tedge_mqtt_ext/src/tests.rs b/crates/extensions/tedge_mqtt_ext/src/tests.rs index 1a224907863..33bdfb875e8 100644 --- a/crates/extensions/tedge_mqtt_ext/src/tests.rs +++ b/crates/extensions/tedge_mqtt_ext/src/tests.rs @@ -197,8 +197,8 @@ async fn dynamic_subscriptions() { let mqtt_config = MqttConfig::default().with_port(broker.port); let mut mqtt = MqttActorBuilder::new(mqtt_config); - let mut client_0 = SimpleMessageBoxBuilder::<_, PublishOrSubscribe>::new("dyn-subscriber", 16); - let mut client_1 = SimpleMessageBoxBuilder::<_, PublishOrSubscribe>::new("dyn-subscriber1", 16); + let mut client_0 = SimpleMessageBoxBuilder::<_, MqttRequest>::new("dyn-subscriber", 16); + let mut client_1 = SimpleMessageBoxBuilder::<_, MqttRequest>::new("dyn-subscriber1", 16); let client_id_0 = mqtt.connect_id_sink(TopicFilter::new_unchecked("a/b"), &client_0); let _client_id_1 = mqtt.connect_id_sink(TopicFilter::new_unchecked("a/+"), &client_1); client_0.connect_sink(NoConfig, &mqtt); @@ -210,14 +210,21 @@ async fn dynamic_subscriptions() { let msg = MqttMessage::new(&Topic::new_unchecked("a/b"), "hello"); client_0 - .send(PublishOrSubscribe::Publish(msg.clone())) + .send(MqttRequest::Publish(msg.clone())) .await .unwrap(); assert_eq!(timeout(client_0.recv()).await.unwrap(), msg); assert_eq!(timeout(client_1.recv()).await.unwrap(), msg); + // Send the messages as retain so we don't have a race for the subscription + let msg = MqttMessage::new(&Topic::new_unchecked("b/c"), "hello").with_retain(); + client_0 + .send(MqttRequest::Publish(msg.clone())) + .await + .unwrap(); + client_0 - .send(PublishOrSubscribe::Subscribe(SubscriptionRequest { + .send(MqttRequest::Subscribe(SubscriptionRequest { diff: SubscriptionDiff { subscribe: ["b/c".into()].into(), unsubscribe: [].into(), @@ -227,18 +234,12 @@ async fn dynamic_subscriptions() { .await .unwrap(); - // Send the messages as retain so we don't have a race for the subscription - let msg = MqttMessage::new(&Topic::new_unchecked("b/c"), "hello").with_retain(); - client_0 - .send(PublishOrSubscribe::Publish(msg.clone())) - .await - .unwrap(); assert_eq!(timeout(client_0.recv()).await.unwrap(), msg); // Verify that messages aren't sent to clients let msg = MqttMessage::new(&Topic::new_unchecked("a/c"), "hello"); client_0 - .send(PublishOrSubscribe::Publish(msg.clone())) + .send(MqttRequest::Publish(msg.clone())) .await .unwrap(); assert_eq!(timeout(client_1.recv()).await.unwrap(), msg); @@ -249,6 +250,115 @@ async fn dynamic_subscriptions() { ); } +#[tokio::test] +async fn dynamic_subscribers_receive_retain_messages() { + let broker = mqtt_tests::test_mqtt_broker(); + let mqtt_config = MqttConfig::default().with_port(broker.port); + let mut mqtt = MqttActorBuilder::new(mqtt_config); + + broker + .publish_with_opts("a/b", "retain", QoS::AtLeastOnce, true) + .await + .unwrap(); + broker + .publish_with_opts("b/c", "retain", QoS::AtLeastOnce, true) + .await + .unwrap(); + + let mut client_0 = SimpleMessageBoxBuilder::<_, MqttRequest>::new("dyn-subscriber", 16); + let mut client_1 = SimpleMessageBoxBuilder::<_, MqttRequest>::new("dyn-subscriber1", 16); + let _client_id_0 = mqtt.connect_id_sink(TopicFilter::new_unchecked("a/b"), &client_0); + let client_id_1 = mqtt.connect_id_sink(TopicFilter::empty(), &client_1); + client_0.connect_sink(NoConfig, &mqtt); + client_1.connect_sink(NoConfig, &mqtt); + let mqtt = mqtt.build(); + tokio::spawn(async move { mqtt.run().await.unwrap() }); + let mut client_0 = client_0.build(); + let mut client_1 = client_1.build(); + + let msg = MqttMessage::new(&Topic::new_unchecked("a/b"), "retain").with_retain(); + let msg2 = MqttMessage::new(&Topic::new_unchecked("b/c"), "retain").with_retain(); + + // client_0 receives retain message upon subscribing to "a/b" + assert_eq!(timeout(client_0.recv()).await.unwrap(), msg); + + client_1 + .send(MqttRequest::Subscribe(SubscriptionRequest { + diff: SubscriptionDiff { + subscribe: ["a/b".into(), "b/c".into()].into(), + unsubscribe: [].into(), + }, + client_id: client_id_1, + })) + .await + .unwrap(); + + // client_1 should receive both "a/b" and "b/c" retain messages upon subscribing + let recv = timeout(client_1.recv()).await.unwrap(); + let recv2 = timeout(client_1.recv()).await.unwrap(); + + // Retain message should not be redelivered to client_0 + assert!( + tokio::time::timeout(Duration::from_millis(200), client_0.recv()) + .await + .is_err() + ); + + if recv.topic.name == "a/b" { + assert_eq!(recv, msg); + assert_eq!(recv2, msg2); + } else { + assert_eq!(recv, msg2); + assert_eq!(recv2, msg); + } +} + +#[tokio::test] +async fn dynamic_subscribers_receive_retain_messages_when_upgrading_topic() { + let broker = mqtt_tests::test_mqtt_broker(); + let mqtt_config = MqttConfig::default().with_port(broker.port); + let mut mqtt = MqttActorBuilder::new(mqtt_config); + + broker + .publish_with_opts("a/b", "retain", QoS::AtLeastOnce, true) + .await + .unwrap(); + + let mut client_0 = SimpleMessageBoxBuilder::<_, MqttRequest>::new("dyn-subscriber", 16); + let mut client_1 = SimpleMessageBoxBuilder::<_, MqttRequest>::new("dyn-subscriber1", 16); + let _client_id_0 = mqtt.connect_id_sink(TopicFilter::new_unchecked("a/b"), &client_0); + let client_id_1 = mqtt.connect_id_sink(TopicFilter::empty(), &client_1); + client_0.connect_sink(NoConfig, &mqtt); + client_1.connect_sink(NoConfig, &mqtt); + let mqtt = mqtt.build(); + tokio::spawn(async move { mqtt.run().await.unwrap() }); + let mut client_0 = client_0.build(); + let mut client_1 = client_1.build(); + + let msg = MqttMessage::new(&Topic::new_unchecked("a/b"), "retain").with_retain(); + + // client_0 receives retain message upon subscribing to "a/b" + assert_eq!(timeout(client_0.recv()).await.unwrap(), msg); + + client_1 + .send(MqttRequest::Subscribe(SubscriptionRequest { + diff: SubscriptionDiff { + subscribe: ["a/+".into()].into(), + unsubscribe: [].into(), + }, + client_id: client_id_1, + })) + .await + .unwrap(); + + // client_1 should receive both "a/b" and "b/c" retain messages upon subscribing + let recv = timeout(client_1.recv()).await.unwrap(); + assert_eq!(recv, msg); + + // Retain message might be redelivered to client_0, but this is + // implementation dependent, don't assert either way +} + async fn timeout(fut: impl Future) -> T { tokio::time::timeout(Duration::from_secs(1), fut) .await diff --git a/crates/extensions/tedge_mqtt_ext/src/trie.rs b/crates/extensions/tedge_mqtt_ext/src/trie.rs index 4e1d1f47e6e..c6432e15a72 100644 --- a/crates/extensions/tedge_mqtt_ext/src/trie.rs +++ b/crates/extensions/tedge_mqtt_ext/src/trie.rs @@ -193,7 +193,7 @@ impl SubscriptionDiff { /// "a/+" does not compare to "a/b/c" /// "a/+/c" does not compare to "a/b/+" /// "a/b" does not compare to "c/d" -struct RankTopicFilter<'a>(&'a str); +pub(crate) struct RankTopicFilter<'a>(pub &'a str); impl PartialOrd for RankTopicFilter<'_> { fn partial_cmp(&self, other: &Self) -> Option {