Skip to content

Commit 03ef4c5

Browse files
committed
refactor: handle retain messages with dynamic subscriptions
Signed-off-by: James Rhodes <jarhodes314@gmail.com>
1 parent bcb8c1b commit 03ef4c5

File tree

3 files changed

+154
-16
lines changed

3 files changed

+154
-16
lines changed

crates/extensions/tedge_mqtt_ext/src/lib.rs

Lines changed: 109 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use mqtt_channel::SubscriberOps;
1616
pub use mqtt_channel::Topic;
1717
pub use mqtt_channel::TopicFilter;
1818
use std::convert::Infallible;
19+
use std::time::Duration;
1920
use tedge_actors::fan_in_message_type;
2021
use tedge_actors::futures::channel::mpsc;
2122
use tedge_actors::Actor;
@@ -36,6 +37,7 @@ use tedge_actors::Server;
3637
use tedge_actors::ServerActorBuilder;
3738
use tedge_actors::ServerConfig;
3839
use trie::MqtTrie;
40+
use trie::RankTopicFilter;
3941
use trie::SubscriptionDiff;
4042

4143
pub type MqttConfig = mqtt_channel::Config;
@@ -139,9 +141,10 @@ impl MqttActorBuilder {
139141
tracing::info!(target: "MQTT sub", "{pattern}");
140142
}
141143

142-
let mqtt_config = self.mqtt_config.with_subscriptions(topic_filter);
144+
let mqtt_config = self.mqtt_config.clone().with_subscriptions(topic_filter);
143145
MqttActor::new(
144146
mqtt_config,
147+
self.mqtt_config,
145148
self.input_receiver,
146149
self.subscriber_addresses,
147150
self.trie.builder(),
@@ -303,6 +306,7 @@ impl Builder<MqttActor> for MqttActorBuilder {
303306

304307
pub struct FromPeers {
305308
input_receiver: InputCombiner,
309+
base_config: mqtt_channel::Config,
306310
subscriptions: ClientMessageBox<TrieRequest, TrieResponse>,
307311
}
308312

@@ -315,25 +319,41 @@ impl FromPeers {
315319
async fn relay_messages_to(
316320
&mut self,
317321
outgoing_mqtt: &mut mpsc::UnboundedSender<MqttMessage>,
322+
tx_to_peers: &mut mpsc::UnboundedSender<(ClientId, MqttMessage)>,
318323
client: impl SubscriberOps + Clone + Send + 'static,
319324
) -> Result<(), RuntimeError> {
320325
while let Ok(Some(message)) = self.try_recv().await {
321326
match message {
322327
PublishOrSubscribe::Publish(message) => {
323328
tracing::debug!(target: "MQTT pub", "{message}");
324-
SinkExt::send(outgoing_mqtt, message)
329+
SinkExt::send(outgoing_mqtt, message)
325330
.await
326331
.map_err(Box::new)?;
327332
}
328333
PublishOrSubscribe::Subscribe(request) => {
329334
let TrieResponse::Diff(diff) = self
330335
.subscriptions
331-
.await_response(TrieRequest::SubscriptionRequest(request))
336+
.await_response(TrieRequest::SubscriptionRequest(request.clone()))
332337
.await
333338
.map_err(Box::new)?
334339
else {
335340
unreachable!("Subscription request always returns diff")
336341
};
342+
let overlapping_subscriptions = request
343+
.diff
344+
.subscribe
345+
.iter()
346+
.filter(|s| {
347+
!diff
348+
.subscribe
349+
.iter()
350+
.any(|s2| RankTopicFilter(s2) >= RankTopicFilter(s))
351+
})
352+
.collect::<Vec<_>>();
353+
let mut tf = TopicFilter::empty();
354+
for sub in overlapping_subscriptions {
355+
tf.add_unchecked(&sub);
356+
}
337357
let client = client.clone();
338358
tokio::spawn(async move {
339359
// We're running outside the main task, so we can't return an error
@@ -345,6 +365,18 @@ impl FromPeers {
345365
client.unsubscribe_many(diff.unsubscribe).await.unwrap();
346366
}
347367
});
368+
let dynamic_connection_config = self.base_config.clone().with_subscriptions(tf);
369+
let mut sender = tx_to_peers.clone();
370+
tokio::spawn(async move {
371+
let mut conn = mqtt_channel::Connection::new(&dynamic_connection_config).await.unwrap();
372+
while let Ok(msg) = tokio::time::timeout(Duration::from_secs(10), conn.received.next()).await {
373+
if let Some(msg) = msg {
374+
if msg.retain {
375+
SinkExt::send(&mut sender, (request.client_id, msg)).await.unwrap();
376+
}
377+
}
378+
}
379+
});
348380
}
349381
}
350382
}
@@ -373,10 +405,21 @@ impl ToPeers {
373405
async fn relay_messages_from(
374406
mut self,
375407
incoming_mqtt: &mut mpsc::UnboundedReceiver<MqttMessage>,
408+
rx_from_peers: &mut mpsc::UnboundedReceiver<(ClientId, MqttMessage)>,
376409
) -> Result<(), RuntimeError> {
377-
while let Some(message) = incoming_mqtt.next().await {
378-
tracing::debug!(target: "MQTT recv", "{message}");
379-
self.send(message).await?;
410+
loop {
411+
tokio::select! {
412+
message = incoming_mqtt.next() => {
413+
let Some(message) = message else { break };
414+
tracing::debug!(target: "MQTT recv", "{message}");
415+
self.send(message).await?;
416+
}
417+
message = rx_from_peers.next() => {
418+
let Some((client, message)) = message else { break };
419+
tracing::debug!(target: "MQTT recv", "{message}");
420+
self.sender_by_id(client).send(message.clone()).await?;
421+
}
422+
};
380423
}
381424
Ok(())
382425
}
@@ -390,10 +433,14 @@ impl ToPeers {
390433
unreachable!("MatchRequest always returns Matched")
391434
};
392435
for client in matches {
393-
self.peer_senders[client.0].send(message.clone()).await?;
436+
self.sender_by_id(client).send(message.clone()).await?;
394437
}
395438
Ok(())
396439
}
440+
441+
fn sender_by_id(&mut self, id: ClientId) -> &mut Box<dyn CloneSender<MqttMessage>> {
442+
&mut self.peer_senders[id.0]
443+
}
397444
}
398445

399446
#[async_trait]
@@ -421,6 +468,7 @@ pub struct MqttActor {
421468
impl MqttActor {
422469
fn new(
423470
mqtt_config: mqtt_channel::Config,
471+
base_config: mqtt_channel::Config,
424472
input_receiver: InputCombiner,
425473
peer_senders: Vec<DynSender<MqttMessage>>,
426474
mut trie_service: ServerActorBuilder<TrieService, Sequential>,
@@ -429,6 +477,7 @@ impl MqttActor {
429477
mqtt_config,
430478
from_peers: FromPeers {
431479
input_receiver,
480+
base_config,
432481
subscriptions: ClientMessageBox::new(&mut trie_service),
433482
},
434483
to_peers: ToPeers {
@@ -456,13 +505,14 @@ impl Actor for MqttActor {
456505
return Ok(())
457506
}
458507
};
508+
let (mut to_peer, mut from_peer) = mpsc::unbounded();
459509

460510
tokio::spawn(async move { self.trie_service.run().await });
461511

462512
tedge_utils::futures::select(
463513
self.from_peers
464-
.relay_messages_to(&mut mqtt_client.published, mqtt_client.subscriptions),
465-
self.to_peers.relay_messages_from(&mut mqtt_client.received),
514+
.relay_messages_to(&mut mqtt_client.published, &mut to_peer, mqtt_client.subscriptions),
515+
self.to_peers.relay_messages_from(&mut mqtt_client.received, &mut from_peer),
466516
)
467517
.await
468518
}
@@ -559,6 +609,8 @@ mod unit_tests {
559609
.subscribe_client
560610
.assert_subscribed_to(["a/b".into()])
561611
.await;
612+
613+
actor.close().await;
562614
}
563615

564616
#[tokio::test]
@@ -572,11 +624,13 @@ mod unit_tests {
572624
id: 0
573625
))
574626
.await;
575-
627+
576628
actor
577629
.subscribe_client
578630
.assert_unsubscribed_from(["a/b".into()])
579631
.await;
632+
633+
actor.close().await;
580634
}
581635

582636
#[tokio::test]
@@ -595,6 +649,8 @@ mod unit_tests {
595649
.subscribe_client
596650
.assert_subscribed_to(["#".into()])
597651
.await;
652+
653+
actor.close().await;
598654
}
599655

600656
#[tokio::test]
@@ -612,7 +668,9 @@ mod unit_tests {
612668
&Topic::new("a/b").unwrap(),
613669
"test message"
614670
))
615-
)
671+
);
672+
673+
actor.close().await;
616674
}
617675

618676
#[tokio::test]
@@ -631,6 +689,8 @@ mod unit_tests {
631689
.unwrap()
632690
.try_next()
633691
.is_err());
692+
693+
actor.close().await;
634694
}
635695

636696
struct MqttActorTest {
@@ -640,6 +700,17 @@ mod unit_tests {
640700
sent_to_channel: mpsc::UnboundedReceiver<MqttMessage>,
641701
sent_to_clients: HashMap<usize, mpsc::Receiver<MqttMessage>>,
642702
inject_received_message: mpsc::UnboundedSender<MqttMessage>,
703+
from_peers: Option<tokio::task::JoinHandle<Result<(), RuntimeError>>>,
704+
to_peers: Option<tokio::task::JoinHandle<Result<(), RuntimeError>>>,
705+
waited: bool,
706+
}
707+
708+
impl Drop for MqttActorTest {
709+
fn drop(&mut self) {
710+
if !self.waited {
711+
panic!("Call `MqttActorTest::close` at the end of the test")
712+
}
713+
}
643714
}
644715

645716
impl MqttActorTest {
@@ -658,6 +729,7 @@ mod unit_tests {
658729
let mut ts = TrieService::with_default_subscriptions(default_subscriptions);
659730
let mut fp = FromPeers {
660731
input_receiver: input_combiner,
732+
base_config: <_>::default(),
661733
subscriptions: ClientMessageBox::new(&mut ts),
662734
};
663735
let mut sent_to_clients = HashMap::new();
@@ -676,12 +748,13 @@ mod unit_tests {
676748
};
677749
tokio::spawn(async move { ts.build().run().await });
678750

751+
let (mut tx, mut rx) = mpsc::unbounded();
679752
let subscribe_client = MockSubscriberOps::default();
680-
{
753+
let from_peers = {
681754
let client = subscribe_client.clone();
682-
tokio::spawn(async move { fp.relay_messages_to(&mut outgoing_mqtt, client).await });
683-
}
684-
tokio::spawn(async move { tp.relay_messages_from(&mut incoming_messages).await });
755+
tokio::spawn(async move { fp.relay_messages_to(&mut outgoing_mqtt, &mut tx, client).await })
756+
};
757+
let to_peers = tokio::spawn(async move { tp.relay_messages_from(&mut incoming_messages, &mut rx).await });
685758

686759
Self {
687760
subscribe_client,
@@ -690,14 +763,32 @@ mod unit_tests {
690763
sent_to_clients,
691764
sent_to_channel: sent_messages,
692765
inject_received_message,
766+
from_peers: Some(from_peers),
767+
to_peers: Some(to_peers),
768+
waited: false,
693769
}
694770
}
695771

772+
/// Closes the channels associated with this actor and waits for both
773+
/// loops to finish executing
774+
///
775+
/// This allows the `SubscriberOps::drop` implementation to reliably
776+
/// flag any unasserted communication
777+
pub async fn close(mut self) {
778+
self.pub_tx.close_channel();
779+
self.sub_tx.close_channel();
780+
self.inject_received_message.close_channel();
781+
self.from_peers.take().unwrap().await.unwrap().unwrap();
782+
self.to_peers.take().unwrap().await.unwrap().unwrap();
783+
self.waited = true;
784+
}
785+
696786
/// Simulates a client sending a subscription request to the mqtt actor
697787
pub async fn send_sub(&mut self, req: SubscriptionRequest) {
698788
SinkExt::send(&mut self.sub_tx, req).await.unwrap();
699789
}
700790

791+
/// Simulates a client sending a publish request to the mqtt actor
701792
pub async fn publish(&mut self, topic: &str, payload: &str) {
702793
SinkExt::send(
703794
&mut self.pub_tx,
@@ -803,6 +894,9 @@ mod unit_tests {
803894
if std::thread::panicking() {
804895
return;
805896
}
897+
if Arc::strong_count(&self.subscribe_many) > 1 {
898+
return;
899+
}
806900
let subscribe = self.subscribe_many.lock().unwrap().clone();
807901
let unsubscribe = self.unsubscribe_many.lock().unwrap().clone();
808902
if !subscribe.is_empty() {

crates/extensions/tedge_mqtt_ext/src/tests.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,50 @@ async fn dynamic_subscriptions() {
249249
);
250250
}
251251

252+
#[tokio::test]
253+
async fn dynamic_subscribers_receive_retain_messages() {
254+
let broker = mqtt_tests::test_mqtt_broker();
255+
let mqtt_config = MqttConfig::default().with_port(broker.port);
256+
let mut mqtt = MqttActorBuilder::new(mqtt_config);
257+
258+
broker.publish_with_opts("a/b", "retain", QoS::AtLeastOnce, true).await.unwrap();
259+
broker.publish_with_opts("b/c", "retain", QoS::AtLeastOnce, true).await.unwrap();
260+
261+
let mut client_0 = SimpleMessageBoxBuilder::<_, PublishOrSubscribe>::new("dyn-subscriber", 16);
262+
let mut client_1 = SimpleMessageBoxBuilder::<_, PublishOrSubscribe>::new("dyn-subscriber1", 16);
263+
let _client_id_0 = mqtt.connect_id_sink(TopicFilter::new_unchecked("a/b"), &client_0);
264+
let client_id_1 = mqtt.connect_id_sink(TopicFilter::empty(), &client_1);
265+
client_0.connect_sink(NoConfig, &mqtt);
266+
client_1.connect_sink(NoConfig, &mqtt);
267+
let mqtt = mqtt.build();
268+
tokio::spawn(async move { mqtt.run().await.unwrap() });
269+
let mut client_0 = client_0.build();
270+
let mut client_1 = client_1.build();
271+
272+
let msg = MqttMessage::new(&Topic::new_unchecked("a/b"), "retain").with_retain();
273+
let msg2 = MqttMessage::new(&Topic::new_unchecked("b/c"), "retain").with_retain();
274+
275+
// client_0 receives retain message upon subscribing to "a/b"
276+
assert_eq!(timeout(client_0.recv()).await.unwrap(), msg);
277+
278+
client_1.send(PublishOrSubscribe::Subscribe(SubscriptionRequest { diff: SubscriptionDiff { subscribe: ["a/b".into(), "b/c".into()].into(), unsubscribe: [].into() }, client_id: client_id_1 })).await.unwrap();
279+
280+
// client_1 should receive both "a/b" and "b/c" retain messages upon subscribing
281+
let recv = timeout(client_1.recv()).await.unwrap();
282+
let recv2 = timeout(client_1.recv()).await.unwrap();
283+
284+
// Retain message should not be redelivered to client_0
285+
assert!(tokio::time::timeout(Duration::from_millis(200), client_0.recv()).await.is_err());
286+
287+
if recv.topic.name == "a/b" {
288+
assert_eq!(recv, msg);
289+
assert_eq!(recv2, msg2);
290+
} else {
291+
assert_eq!(recv, msg2);
292+
assert_eq!(recv2, msg);
293+
}
294+
}
295+
252296
async fn timeout<T>(fut: impl Future<Output = T>) -> T {
253297
tokio::time::timeout(Duration::from_secs(1), fut)
254298
.await

crates/extensions/tedge_mqtt_ext/src/trie.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ impl SubscriptionDiff {
193193
/// "a/+" does not compare to "a/b/c"
194194
/// "a/+/c" does not compare to "a/b/+"
195195
/// "a/b" does not compare to "c/d"
196-
struct RankTopicFilter<'a>(&'a str);
196+
pub(crate) struct RankTopicFilter<'a>(pub &'a str);
197197

198198
impl PartialOrd for RankTopicFilter<'_> {
199199
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {

0 commit comments

Comments
 (0)