Skip to content

Commit bcb8c1b

Browse files
authored
Merge pull request #3661 from jarhodes314/refactor/dynamic-mqtt-subs
refactor: support dynamic MQTT subscriptions in `mqtt-channel`/`tedge-mqtt-ext`
2 parents 734af48 + 1b249b3 commit bcb8c1b

File tree

9 files changed

+2825
-60
lines changed

9 files changed

+2825
-60
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/common/mqtt_channel/src/connection.rs

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::MqttError;
44
use crate::MqttMessage;
55
use crate::PubChannel;
66
use crate::SubChannel;
7+
use crate::TopicFilter;
78
use futures::channel::mpsc;
89
use futures::channel::oneshot;
910
use futures::SinkExt;
@@ -16,10 +17,12 @@ use rumqttc::Event;
1617
use rumqttc::EventLoop;
1718
use rumqttc::Outgoing;
1819
use rumqttc::Packet;
20+
use rumqttc::SubscribeFilter;
1921
use std::collections::HashSet;
2022
use std::sync::atomic::AtomicUsize;
2123
use std::sync::atomic::Ordering;
2224
use std::sync::Arc;
25+
use std::sync::Mutex;
2326
use std::time::Duration;
2427
use tokio::sync::OwnedSemaphorePermit;
2528
use tokio::sync::Semaphore;
@@ -38,6 +41,76 @@ pub struct Connection {
3841

3942
/// A channel to notify that all the published messages have been actually published.
4043
pub pub_done: oneshot::Receiver<()>,
44+
45+
pub subscriptions: SubscriberHandle,
46+
}
47+
48+
#[derive(Clone)]
49+
/// A client for changing the subscribed topics
50+
pub struct SubscriberHandle {
51+
client: AsyncClient,
52+
pub(crate) subscriptions: Arc<Mutex<TopicFilter>>,
53+
}
54+
55+
impl SubscriberHandle {
56+
pub fn new(client: AsyncClient, subscriptions: Arc<Mutex<TopicFilter>>) -> Self {
57+
Self {
58+
client,
59+
subscriptions,
60+
}
61+
}
62+
}
63+
64+
#[async_trait::async_trait]
65+
pub trait SubscriberOps {
66+
async fn subscribe_many(
67+
&self,
68+
topics: impl IntoIterator<Item = String> + Send,
69+
) -> Result<(), MqttError>;
70+
async fn unsubscribe_many(
71+
&self,
72+
topics: impl IntoIterator<Item = String> + Send,
73+
) -> Result<(), MqttError>;
74+
}
75+
76+
#[async_trait::async_trait]
77+
impl SubscriberOps for SubscriberHandle {
78+
async fn subscribe_many(
79+
&self,
80+
topics: impl IntoIterator<Item = String> + Send,
81+
) -> Result<(), MqttError> {
82+
let topics = topics.into_iter().collect::<Vec<_>>();
83+
{
84+
let mut subs = self.subscriptions.lock().unwrap();
85+
for topic in &topics {
86+
subs.add(topic)?;
87+
}
88+
}
89+
self.client
90+
.subscribe_many(topics.into_iter().map(|path| SubscribeFilter {
91+
path,
92+
qos: rumqttc::QoS::AtLeastOnce,
93+
}))
94+
.await?;
95+
Ok(())
96+
}
97+
98+
async fn unsubscribe_many(
99+
&self,
100+
topics: impl IntoIterator<Item = String> + Send,
101+
) -> Result<(), MqttError> {
102+
let topics = topics.into_iter().collect::<Vec<_>>();
103+
{
104+
let mut subs = self.subscriptions.lock().unwrap();
105+
for topic in &topics {
106+
subs.remove(topic);
107+
}
108+
}
109+
for topic in topics {
110+
self.client.unsubscribe(topic).await?;
111+
}
112+
Ok(())
113+
}
41114
}
42115

43116
impl Connection {
@@ -91,9 +164,15 @@ impl Connection {
91164
let (published_sender, published_receiver) = mpsc::unbounded();
92165
let (error_sender, error_receiver) = mpsc::unbounded();
93166
let (pub_done_sender, pub_done_receiver) = oneshot::channel();
167+
let subscriptions = Arc::new(Mutex::new(config.subscriptions.clone()));
94168

95-
let (mqtt_client, event_loop) =
96-
Connection::open(config, received_sender.clone(), error_sender.clone()).await?;
169+
let (mqtt_client, event_loop) = Connection::open(
170+
config,
171+
received_sender.clone(),
172+
error_sender.clone(),
173+
subscriptions.clone(),
174+
)
175+
.await?;
97176
let permits = Arc::new(Semaphore::new(1));
98177
let permit = permits.clone().acquire_owned().await.unwrap();
99178
let pub_count = Arc::new(AtomicUsize::new(0));
@@ -106,9 +185,10 @@ impl Connection {
106185
pub_done_sender,
107186
permits,
108187
pub_count.clone(),
188+
subscriptions.clone(),
109189
));
110190
tokio::spawn(Connection::sender_loop(
111-
mqtt_client,
191+
mqtt_client.clone(),
112192
published_receiver,
113193
error_sender,
114194
config.last_will_message.clone(),
@@ -121,6 +201,7 @@ impl Connection {
121201
published: published_sender,
122202
errors: error_receiver,
123203
pub_done: pub_done_receiver,
204+
subscriptions: SubscriberHandle::new(mqtt_client, subscriptions),
124205
})
125206
}
126207

@@ -133,6 +214,7 @@ impl Connection {
133214
config: &Config,
134215
mut message_sender: mpsc::UnboundedSender<MqttMessage>,
135216
mut error_sender: mpsc::UnboundedSender<MqttError>,
217+
subscriptions: Arc<Mutex<TopicFilter>>,
136218
) -> Result<(AsyncClient, EventLoop), MqttError> {
137219
const INSECURE_MQTT_PORT: u16 = 1883;
138220
const SECURE_MQTT_PORT: u16 = 8883;
@@ -160,7 +242,7 @@ impl Connection {
160242
};
161243
info!(target: "MQTT", "Connection established");
162244

163-
let subscriptions = config.subscriptions.filters();
245+
let subscriptions = subscriptions.lock().unwrap().filters();
164246

165247
// Need check here otherwise it will hang waiting for a SubAck, and none will come when there is no subscription.
166248
if subscriptions.is_empty() {
@@ -217,6 +299,7 @@ impl Connection {
217299
done: oneshot::Sender<()>,
218300
permits: Arc<Semaphore>,
219301
pub_count: Arc<AtomicUsize>,
302+
subscriptions: Arc<Mutex<TopicFilter>>,
220303
) -> Result<(), MqttError> {
221304
let mut triggered_disconnect = false;
222305
let mut disconnect_permit = None;
@@ -289,7 +372,7 @@ impl Connection {
289372
// If session_name is not provided or if the broker session persistence
290373
// is not enabled or working, then re-subscribe
291374

292-
let subscriptions = config.subscriptions.filters();
375+
let subscriptions = subscriptions.lock().unwrap().filters();
293376
// Need check here otherwise it will hang waiting for a SubAck, and none will come when there is no subscription.
294377
if subscriptions.is_empty() {
295378
break;

crates/common/mqtt_channel/src/tests.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,3 +474,97 @@ async fn test_max_packet_size_validation() -> Result<(), anyhow::Error> {
474474

475475
Ok(())
476476
}
477+
478+
#[tokio::test]
479+
async fn dynamic_subscriptions() {
480+
// Given an MQTT broker
481+
let broker = mqtt_tests::test_mqtt_broker();
482+
let mqtt_config = Config::default().with_port(broker.port);
483+
484+
// A client subscribes to a topic on connect
485+
let topic = uniquify!("a/test/topic");
486+
let topic2 = uniquify!("a/test/topic/2");
487+
488+
// Publish a retain message before any client connects
489+
broker
490+
.publish_with_opts(topic2, "msg 0", QoS::AtLeastOnce, true)
491+
.await
492+
.unwrap();
493+
494+
let mqtt_config = mqtt_config.with_session_name(uniquify!("test_client"));
495+
let mut con = Connection::new(&mqtt_config).await.unwrap();
496+
497+
assert_eq!(
498+
*con.subscriptions.subscriptions.lock().unwrap(),
499+
TopicFilter::empty()
500+
);
501+
con.subscriptions
502+
.subscribe_many([topic.to_owned()])
503+
.await
504+
.unwrap();
505+
506+
// Assert we have added the newly subscribed topic to the list we need to
507+
// resubscribe to if the connection drops
508+
assert_eq!(
509+
*con.subscriptions.subscriptions.lock().unwrap(),
510+
TopicFilter::new_unchecked(topic)
511+
);
512+
513+
broker
514+
.publish_with_opts(topic, "msg 1", QoS::AtLeastOnce, true)
515+
.await
516+
.unwrap();
517+
518+
// Assert just against the payload since the retain flag may be true or
519+
// false depending on the ordering of the publish/subscribe messages (retain
520+
// flag is only set on an incoming message if the message was published
521+
// before we subscribe)
522+
assert_payload_received(&mut con, "msg 1").await;
523+
524+
// Unsubscribe from one topic and subscribe to another
525+
con.subscriptions
526+
.unsubscribe_many([topic.to_owned()])
527+
.await
528+
.unwrap();
529+
con.subscriptions
530+
.subscribe_many([topic2.to_owned()])
531+
.await
532+
.unwrap();
533+
534+
// Check we've updated the subscription list with both those changes
535+
assert_eq!(
536+
*con.subscriptions.subscriptions.lock().unwrap(),
537+
TopicFilter::new_unchecked(topic2)
538+
);
539+
540+
// We expect now to receive the retained message that has been published before the connection created
541+
// on the topic2 which the connection only subscribed to now.
542+
assert_payload_received(&mut con, "msg 0").await;
543+
544+
// Wait for the new subscription to be enacted
545+
broker
546+
.publish_with_opts(topic2, "msg 2", QoS::AtLeastOnce, true)
547+
.await
548+
.unwrap();
549+
550+
assert_payload_received(&mut con, "msg 2").await;
551+
552+
// At this point, we are unsubscribed from topic, and therefore we shouldn't receive the messages
553+
broker
554+
.publish_with_opts(topic, "msg 3", QoS::AtLeastOnce, true)
555+
.await
556+
.unwrap();
557+
broker
558+
.publish_with_opts(topic2, "msg 4", QoS::AtLeastOnce, true)
559+
.await
560+
.unwrap();
561+
562+
assert_payload_received(&mut con, "msg 4").await;
563+
}
564+
565+
async fn assert_payload_received(con: &mut Connection, payload: &'static str) {
566+
match next_message(&mut con.received).await {
567+
MaybeMessage::Next(msg) => assert_eq!(msg.payload.as_str().unwrap(), payload),
568+
not_msg => panic!("Expected message to be received, got {not_msg:?}"),
569+
}
570+
}

crates/common/mqtt_channel/src/topics.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ impl TopicFilter {
169169
}
170170

171171
/// The list of `SubscribeFilter` expected by `mqttc`
172-
pub(crate) fn filters(&self) -> Vec<SubscribeFilter> {
172+
pub fn filters(&self) -> Vec<SubscribeFilter> {
173173
let qos = self.qos;
174174
self.patterns
175175
.iter()
@@ -183,6 +183,14 @@ impl TopicFilter {
183183
pub fn patterns(&self) -> &Vec<String> {
184184
&self.patterns
185185
}
186+
187+
pub fn remove(&mut self, topic: &str) -> Option<String> {
188+
if let Some((index, _)) = self.patterns.iter().enumerate().find(|(_, p)| *p == topic) {
189+
Some(self.patterns.swap_remove(index))
190+
} else {
191+
None
192+
}
193+
}
186194
}
187195

188196
impl TryInto<Topic> for &str {

crates/extensions/tedge_mqtt_ext/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ tracing = { workspace = true }
2727
[dev-dependencies]
2828
futures = { workspace = true }
2929
mqtt_tests = { path = "../../tests/mqtt_tests" }
30+
proptest = { workspace = true }
3031

3132
[lints]
3233
workspace = true
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Seeds for failure cases proptest has generated in the past. It is
2+
# automatically read and these particular cases re-run before any
3+
# novel cases are generated.
4+
#
5+
# It is recommended to check this file in to source control so that
6+
# everyone who runs the test benefits from these saved cases.
7+
cc c2af50542642d11276b3414120dba5d319b7e79bc5c4b593b0f0e3e551bb216e # shrinks to subscriptions = ["/#", "+/+"]
8+
cc 35ff8c20ac897723b2f45738d88b5c3fdfb3e5f1dbdbfd3a170a70913ebc368d # shrinks to subscriptions = ["+/a/a", "/#"]
9+
cc c305db36e62ac105b2011d441fca4e88224159738d637aae88679fdeb967de3d # shrinks to subscriptions = ["a/+/a"], unsubscriptions = ["a"]
10+
cc 3faad8e1ecbe0502bf89b9dfed357f774e1912712f66a63943df74a6310f93f1 # shrinks to subscriptions = ["+"], unsubscriptions = ["+/#"]
11+
cc 8509be48aeb210d202da60447c6fbe0b99d491808274aefe12a2a0e3828426cd # shrinks to subscriptions = [], unsubscriptions = ["+/+/a/+", "a/#"]
12+
cc e055cb4e6e9c214f73ee66f2080422dbf5ba5fffca89008a304081ceb942e0c2 # shrinks to subscriptions = [], unsubscriptions = ["+/+", "/#"]
13+
cc 0951d4466b6ca4c30105fee3c2d5fcc047f0133be17b69ee9b12a6ba7358f48a # shrinks to subscriptions = [], unsubscriptions = ["c/+", "c/#"]
14+
cc 77a623802e2f52b94bf7fcf10040927a15037ada1641a3209594985fadf76d14 # shrinks to subscriptions = ["a/+/a"], unsubscriptions = ["+/a/+"]
15+
cc ab19cb84992e7141ae210779904cd516408022c7bc85937942030ebec26ebef1 # shrinks to subscriptions = ["+/+/a/+", "a/a/+/a"]
16+
cc 5637b36f5f378ae9bcc4ed1d442f93368bcd7cc03cfdb052cc98ab942c4aa29d # shrinks to subscriptions = ["c/+/+", "c/a/+", "+/#"]
17+
cc 5255ea6f0768fd462c21e020972c4221a202345c0b898b502012f557d1542b76 # shrinks to subscriptions = ["/#", "+/a/+"]
18+
cc 9638071b32ed1148858f0b0a693cac06b2857b38a8181a56154429b181a61fc1 # shrinks to subscriptions = ["+", "", "/#"]
19+
cc 09fc525d653ff597e455644305a5bcf0a9143b5b50c5e0b668ab69b72a8094f4 # shrinks to subscriptions = ["/#", "", "+"]
20+
cc 72dbc7091f98279af02f2b825dcd9440526b9d475a0a33222e2a69389581af6e # shrinks to subscriptions = ["c/a", "+/+", "c/+"]
21+
cc e6fe7dfd3c0771d88dc0e1d8da448d9ec3e5402059423622b36d2dc6020bd607 # shrinks to subscriptions = ["b/a", "+/#", "b/+"]
22+
cc 7fcaadecf97f9f29c2b119f46c53992ff8b687dee107eec3e6bb775d4cf467e3 # shrinks to subscriptions = ["+/+/a", "+/+/#", "+/#"]
23+
cc 6d35905cd0547e000e5a89edfe24818f4fa2e2947ff8a40ad19f7cc89f75bb0d # shrinks to subscriptions = ["+/a/#", "b/a/+", "b/+/+"]
24+
cc 3368368c515c9bbb99afc5c7f7079f3930cd56049c3e164cb9be5ea916e3d183 # shrinks to subscriptions = ["c/+", "c/c", "+/c"]
25+
cc 83857d88a8be534674f2326f109b667f7076aecb6b8c8bcc0b3d282959fb3a1e # shrinks to subscriptions = ["a/+", "+/+", "+/+/#", "+/+"]
26+
cc 321d11e2a9be7ffbd83775f65e48763715103d79f1cf6d83938199b25fc66f99 # shrinks to subscriptions = ["+/c/a/+", "b/c/a/c", "b/+/+/c"]

0 commit comments

Comments
 (0)