@@ -16,6 +16,7 @@ use mqtt_channel::SubscriberOps;
16
16
pub use mqtt_channel:: Topic ;
17
17
pub use mqtt_channel:: TopicFilter ;
18
18
use std:: convert:: Infallible ;
19
+ use std:: time:: Duration ;
19
20
use tedge_actors:: fan_in_message_type;
20
21
use tedge_actors:: futures:: channel:: mpsc;
21
22
use tedge_actors:: Actor ;
@@ -36,6 +37,7 @@ use tedge_actors::Server;
36
37
use tedge_actors:: ServerActorBuilder ;
37
38
use tedge_actors:: ServerConfig ;
38
39
use trie:: MqtTrie ;
40
+ use trie:: RankTopicFilter ;
39
41
use trie:: SubscriptionDiff ;
40
42
41
43
pub type MqttConfig = mqtt_channel:: Config ;
@@ -139,9 +141,10 @@ impl MqttActorBuilder {
139
141
tracing:: info!( target: "MQTT sub" , "{pattern}" ) ;
140
142
}
141
143
142
- let mqtt_config = self . mqtt_config . with_subscriptions ( topic_filter) ;
144
+ let mqtt_config = self . mqtt_config . clone ( ) . with_subscriptions ( topic_filter) ;
143
145
MqttActor :: new (
144
146
mqtt_config,
147
+ self . mqtt_config ,
145
148
self . input_receiver ,
146
149
self . subscriber_addresses ,
147
150
self . trie . builder ( ) ,
@@ -303,6 +306,7 @@ impl Builder<MqttActor> for MqttActorBuilder {
303
306
304
307
pub struct FromPeers {
305
308
input_receiver : InputCombiner ,
309
+ base_config : mqtt_channel:: Config ,
306
310
subscriptions : ClientMessageBox < TrieRequest , TrieResponse > ,
307
311
}
308
312
@@ -315,25 +319,41 @@ impl FromPeers {
315
319
async fn relay_messages_to (
316
320
& mut self ,
317
321
outgoing_mqtt : & mut mpsc:: UnboundedSender < MqttMessage > ,
322
+ tx_to_peers : & mut mpsc:: UnboundedSender < ( ClientId , MqttMessage ) > ,
318
323
client : impl SubscriberOps + Clone + Send + ' static ,
319
324
) -> Result < ( ) , RuntimeError > {
320
325
while let Ok ( Some ( message) ) = self . try_recv ( ) . await {
321
326
match message {
322
327
PublishOrSubscribe :: Publish ( message) => {
323
328
tracing:: debug!( target: "MQTT pub" , "{message}" ) ;
324
- SinkExt :: send ( outgoing_mqtt, message)
329
+ SinkExt :: send ( outgoing_mqtt, message)
325
330
. await
326
331
. map_err ( Box :: new) ?;
327
332
}
328
333
PublishOrSubscribe :: Subscribe ( request) => {
329
334
let TrieResponse :: Diff ( diff) = self
330
335
. subscriptions
331
- . await_response ( TrieRequest :: SubscriptionRequest ( request) )
336
+ . await_response ( TrieRequest :: SubscriptionRequest ( request. clone ( ) ) )
332
337
. await
333
338
. map_err ( Box :: new) ?
334
339
else {
335
340
unreachable ! ( "Subscription request always returns diff" )
336
341
} ;
342
+ let overlapping_subscriptions = request
343
+ . diff
344
+ . subscribe
345
+ . iter ( )
346
+ . filter ( |s| {
347
+ !diff
348
+ . subscribe
349
+ . iter ( )
350
+ . any ( |s2| RankTopicFilter ( s2) >= RankTopicFilter ( s) )
351
+ } )
352
+ . collect :: < Vec < _ > > ( ) ;
353
+ let mut tf = TopicFilter :: empty ( ) ;
354
+ for sub in overlapping_subscriptions {
355
+ tf. add_unchecked ( & sub) ;
356
+ }
337
357
let client = client. clone ( ) ;
338
358
tokio:: spawn ( async move {
339
359
// We're running outside the main task, so we can't return an error
@@ -345,6 +365,18 @@ impl FromPeers {
345
365
client. unsubscribe_many ( diff. unsubscribe ) . await . unwrap ( ) ;
346
366
}
347
367
} ) ;
368
+ let dynamic_connection_config = self . base_config . clone ( ) . with_subscriptions ( tf) ;
369
+ let mut sender = tx_to_peers. clone ( ) ;
370
+ tokio:: spawn ( async move {
371
+ let mut conn = mqtt_channel:: Connection :: new ( & dynamic_connection_config) . await . unwrap ( ) ;
372
+ while let Ok ( msg) = tokio:: time:: timeout ( Duration :: from_secs ( 10 ) , conn. received . next ( ) ) . await {
373
+ if let Some ( msg) = msg {
374
+ if msg. retain {
375
+ SinkExt :: send ( & mut sender, ( request. client_id , msg) ) . await . unwrap ( ) ;
376
+ }
377
+ }
378
+ }
379
+ } ) ;
348
380
}
349
381
}
350
382
}
@@ -373,10 +405,21 @@ impl ToPeers {
373
405
async fn relay_messages_from (
374
406
mut self ,
375
407
incoming_mqtt : & mut mpsc:: UnboundedReceiver < MqttMessage > ,
408
+ rx_from_peers : & mut mpsc:: UnboundedReceiver < ( ClientId , MqttMessage ) > ,
376
409
) -> Result < ( ) , RuntimeError > {
377
- while let Some ( message) = incoming_mqtt. next ( ) . await {
378
- tracing:: debug!( target: "MQTT recv" , "{message}" ) ;
379
- self . send ( message) . await ?;
410
+ loop {
411
+ tokio:: select! {
412
+ message = incoming_mqtt. next( ) => {
413
+ let Some ( message) = message else { break } ;
414
+ tracing:: debug!( target: "MQTT recv" , "{message}" ) ;
415
+ self . send( message) . await ?;
416
+ }
417
+ message = rx_from_peers. next( ) => {
418
+ let Some ( ( client, message) ) = message else { break } ;
419
+ tracing:: debug!( target: "MQTT recv" , "{message}" ) ;
420
+ self . sender_by_id( client) . send( message. clone( ) ) . await ?;
421
+ }
422
+ } ;
380
423
}
381
424
Ok ( ( ) )
382
425
}
@@ -390,10 +433,14 @@ impl ToPeers {
390
433
unreachable ! ( "MatchRequest always returns Matched" )
391
434
} ;
392
435
for client in matches {
393
- self . peer_senders [ client. 0 ] . send ( message. clone ( ) ) . await ?;
436
+ self . sender_by_id ( client) . send ( message. clone ( ) ) . await ?;
394
437
}
395
438
Ok ( ( ) )
396
439
}
440
+
441
+ fn sender_by_id ( & mut self , id : ClientId ) -> & mut Box < dyn CloneSender < MqttMessage > > {
442
+ & mut self . peer_senders [ id. 0 ]
443
+ }
397
444
}
398
445
399
446
#[ async_trait]
@@ -421,6 +468,7 @@ pub struct MqttActor {
421
468
impl MqttActor {
422
469
fn new (
423
470
mqtt_config : mqtt_channel:: Config ,
471
+ base_config : mqtt_channel:: Config ,
424
472
input_receiver : InputCombiner ,
425
473
peer_senders : Vec < DynSender < MqttMessage > > ,
426
474
mut trie_service : ServerActorBuilder < TrieService , Sequential > ,
@@ -429,6 +477,7 @@ impl MqttActor {
429
477
mqtt_config,
430
478
from_peers : FromPeers {
431
479
input_receiver,
480
+ base_config,
432
481
subscriptions : ClientMessageBox :: new ( & mut trie_service) ,
433
482
} ,
434
483
to_peers : ToPeers {
@@ -456,13 +505,14 @@ impl Actor for MqttActor {
456
505
return Ok ( ( ) )
457
506
}
458
507
} ;
508
+ let ( mut to_peer, mut from_peer) = mpsc:: unbounded ( ) ;
459
509
460
510
tokio:: spawn ( async move { self . trie_service . run ( ) . await } ) ;
461
511
462
512
tedge_utils:: futures:: select (
463
513
self . from_peers
464
- . relay_messages_to ( & mut mqtt_client. published , mqtt_client. subscriptions ) ,
465
- self . to_peers . relay_messages_from ( & mut mqtt_client. received ) ,
514
+ . relay_messages_to ( & mut mqtt_client. published , & mut to_peer , mqtt_client. subscriptions ) ,
515
+ self . to_peers . relay_messages_from ( & mut mqtt_client. received , & mut from_peer ) ,
466
516
)
467
517
. await
468
518
}
@@ -559,6 +609,8 @@ mod unit_tests {
559
609
. subscribe_client
560
610
. assert_subscribed_to ( [ "a/b" . into ( ) ] )
561
611
. await ;
612
+
613
+ actor. close ( ) . await ;
562
614
}
563
615
564
616
#[ tokio:: test]
@@ -572,11 +624,13 @@ mod unit_tests {
572
624
id: 0
573
625
) )
574
626
. await ;
575
-
627
+
576
628
actor
577
629
. subscribe_client
578
630
. assert_unsubscribed_from ( [ "a/b" . into ( ) ] )
579
631
. await ;
632
+
633
+ actor. close ( ) . await ;
580
634
}
581
635
582
636
#[ tokio:: test]
@@ -595,6 +649,8 @@ mod unit_tests {
595
649
. subscribe_client
596
650
. assert_subscribed_to ( [ "#" . into ( ) ] )
597
651
. await ;
652
+
653
+ actor. close ( ) . await ;
598
654
}
599
655
600
656
#[ tokio:: test]
@@ -612,7 +668,9 @@ mod unit_tests {
612
668
& Topic :: new( "a/b" ) . unwrap( ) ,
613
669
"test message"
614
670
) )
615
- )
671
+ ) ;
672
+
673
+ actor. close ( ) . await ;
616
674
}
617
675
618
676
#[ tokio:: test]
@@ -631,6 +689,8 @@ mod unit_tests {
631
689
. unwrap( )
632
690
. try_next( )
633
691
. is_err( ) ) ;
692
+
693
+ actor. close ( ) . await ;
634
694
}
635
695
636
696
struct MqttActorTest {
@@ -640,6 +700,17 @@ mod unit_tests {
640
700
sent_to_channel : mpsc:: UnboundedReceiver < MqttMessage > ,
641
701
sent_to_clients : HashMap < usize , mpsc:: Receiver < MqttMessage > > ,
642
702
inject_received_message : mpsc:: UnboundedSender < MqttMessage > ,
703
+ from_peers : Option < tokio:: task:: JoinHandle < Result < ( ) , RuntimeError > > > ,
704
+ to_peers : Option < tokio:: task:: JoinHandle < Result < ( ) , RuntimeError > > > ,
705
+ waited : bool ,
706
+ }
707
+
708
+ impl Drop for MqttActorTest {
709
+ fn drop ( & mut self ) {
710
+ if !self . waited {
711
+ panic ! ( "Call `MqttActorTest::close` at the end of the test" )
712
+ }
713
+ }
643
714
}
644
715
645
716
impl MqttActorTest {
@@ -658,6 +729,7 @@ mod unit_tests {
658
729
let mut ts = TrieService :: with_default_subscriptions ( default_subscriptions) ;
659
730
let mut fp = FromPeers {
660
731
input_receiver : input_combiner,
732
+ base_config : <_ >:: default ( ) ,
661
733
subscriptions : ClientMessageBox :: new ( & mut ts) ,
662
734
} ;
663
735
let mut sent_to_clients = HashMap :: new ( ) ;
@@ -676,12 +748,13 @@ mod unit_tests {
676
748
} ;
677
749
tokio:: spawn ( async move { ts. build ( ) . run ( ) . await } ) ;
678
750
751
+ let ( mut tx, mut rx) = mpsc:: unbounded ( ) ;
679
752
let subscribe_client = MockSubscriberOps :: default ( ) ;
680
- {
753
+ let from_peers = {
681
754
let client = subscribe_client. clone ( ) ;
682
- tokio:: spawn ( async move { fp. relay_messages_to ( & mut outgoing_mqtt, client) . await } ) ;
683
- }
684
- tokio:: spawn ( async move { tp. relay_messages_from ( & mut incoming_messages) . await } ) ;
755
+ tokio:: spawn ( async move { fp. relay_messages_to ( & mut outgoing_mqtt, & mut tx , client) . await } )
756
+ } ;
757
+ let to_peers = tokio:: spawn ( async move { tp. relay_messages_from ( & mut incoming_messages, & mut rx ) . await } ) ;
685
758
686
759
Self {
687
760
subscribe_client,
@@ -690,14 +763,32 @@ mod unit_tests {
690
763
sent_to_clients,
691
764
sent_to_channel : sent_messages,
692
765
inject_received_message,
766
+ from_peers : Some ( from_peers) ,
767
+ to_peers : Some ( to_peers) ,
768
+ waited : false ,
693
769
}
694
770
}
695
771
772
+ /// Closes the channels associated with this actor and waits for both
773
+ /// loops to finish executing
774
+ ///
775
+ /// This allows the `SubscriberOps::drop` implementation to reliably
776
+ /// flag any unasserted communication
777
+ pub async fn close ( mut self ) {
778
+ self . pub_tx . close_channel ( ) ;
779
+ self . sub_tx . close_channel ( ) ;
780
+ self . inject_received_message . close_channel ( ) ;
781
+ self . from_peers . take ( ) . unwrap ( ) . await . unwrap ( ) . unwrap ( ) ;
782
+ self . to_peers . take ( ) . unwrap ( ) . await . unwrap ( ) . unwrap ( ) ;
783
+ self . waited = true ;
784
+ }
785
+
696
786
/// Simulates a client sending a subscription request to the mqtt actor
697
787
pub async fn send_sub ( & mut self , req : SubscriptionRequest ) {
698
788
SinkExt :: send ( & mut self . sub_tx , req) . await . unwrap ( ) ;
699
789
}
700
790
791
+ /// Simulates a client sending a publish request to the mqtt actor
701
792
pub async fn publish ( & mut self , topic : & str , payload : & str ) {
702
793
SinkExt :: send (
703
794
& mut self . pub_tx ,
@@ -803,6 +894,9 @@ mod unit_tests {
803
894
if std:: thread:: panicking ( ) {
804
895
return ;
805
896
}
897
+ if Arc :: strong_count ( & self . subscribe_many ) > 1 {
898
+ return ;
899
+ }
806
900
let subscribe = self . subscribe_many . lock ( ) . unwrap ( ) . clone ( ) ;
807
901
let unsubscribe = self . unsubscribe_many . lock ( ) . unwrap ( ) . clone ( ) ;
808
902
if !subscribe. is_empty ( ) {
0 commit comments