Skip to content

Commit 0281f3f

Browse files
committed
Add MQTT Sink
1 parent ccef54a commit 0281f3f

File tree

4 files changed

+216
-1
lines changed

4 files changed

+216
-1
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ all = [
3737
"psycopg2-binary>=2.9.9,<3",
3838
"boto3>=1.35.65,<2.0",
3939
"boto3-stubs>=1.35.65,<2.0",
40-
"redis[hiredis]>=5.2.0,<6"
40+
"redis[hiredis]>=5.2.0,<6",
41+
"paho-mqtt==2.1.0"
4142
]
4243

4344
avro = ["fastavro>=1.8,<2.0"]

quixstreams/sinks/community/mqtt.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from quixstreams.sinks.base.sink import BaseSink
2+
from quixstreams.sinks.base.exceptions import SinkBackpressureError
3+
from typing import List, Tuple, Any
4+
from quixstreams.models.types import HeaderValue
5+
from datetime import datetime
6+
import json
7+
8+
try:
9+
import paho.mqtt.client as paho
10+
from paho import mqtt
11+
except ImportError as exc:
12+
raise ImportError(
13+
'Package "paho-mqtt" is missing: '
14+
"run pip install quixstreams[paho-mqtt] to fix it"
15+
) from exc
16+
17+
class MQTTSink(BaseSink):
18+
"""
19+
A sink that publishes messages to an MQTT broker.
20+
"""
21+
22+
def __init__(self,
23+
mqtt_client_id: str,
24+
mqtt_server: str,
25+
mqtt_port: int,
26+
mqtt_topic_root: str,
27+
mqtt_username: str = None,
28+
mqtt_password: str = None,
29+
mqtt_version: str = "3.1.1",
30+
tls_enabled: bool = True,
31+
qos: int = 1):
32+
"""
33+
Initialize the MQTTSink.
34+
35+
:param mqtt_client_id: MQTT client identifier.
36+
:param mqtt_server: MQTT broker server address.
37+
:param mqtt_port: MQTT broker server port.
38+
:param mqtt_topic_root: Root topic to publish messages to.
39+
:param mqtt_username: Username for MQTT broker authentication. Defaults to None
40+
:param mqtt_password: Password for MQTT broker authentication. Defaults to None
41+
:param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
42+
:param tls_enabled: Whether to use TLS encryption. Defaults to True
43+
:param qos: Quality of Service level (0, 1, or 2). Defaults to 1
44+
"""
45+
46+
super().__init__()
47+
48+
self.mqtt_version = mqtt_version
49+
self.mqtt_username = mqtt_username
50+
self.mqtt_password = mqtt_password
51+
self.mqtt_topic_root = mqtt_topic_root
52+
self.tls_enabled = tls_enabled
53+
self.qos = qos
54+
55+
self.mqtt_client = paho.Client(callback_api_version=paho.CallbackAPIVersion.VERSION2,
56+
client_id = mqtt_client_id, userdata = None, protocol = self._mqtt_protocol_version())
57+
58+
if self.tls_enabled:
59+
self.mqtt_client.tls_set(tls_version = mqtt.client.ssl.PROTOCOL_TLS) # we'll be using tls now
60+
61+
self.mqtt_client.reconnect_delay_set(5, 60)
62+
self._configure_authentication()
63+
self.mqtt_client.on_connect = self._mqtt_on_connect_cb
64+
self.mqtt_client.on_disconnect = self._mqtt_on_disconnect_cb
65+
self.mqtt_client.connect(mqtt_server, int(mqtt_port))
66+
67+
# setting callbacks for different events to see if it works, print the message etc.
68+
def _mqtt_on_connect_cb(self, client: paho.Client, userdata: any, connect_flags: paho.ConnectFlags,
69+
reason_code: paho.ReasonCode, properties: paho.Properties):
70+
if reason_code == 0:
71+
print("CONNECTED!") # required for Quix to know this has connected
72+
else:
73+
print(f"ERROR ({reason_code.value}). {reason_code.getName()}")
74+
75+
def _mqtt_on_disconnect_cb(self, client: paho.Client, userdata: any, disconnect_flags: paho.DisconnectFlags,
76+
reason_code: paho.ReasonCode, properties: paho.Properties):
77+
print(f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!")
78+
79+
def _mqtt_protocol_version(self):
80+
if self.mqtt_version == "3.1":
81+
return paho.MQTTv31
82+
elif self.mqtt_version == "3.1.1":
83+
return paho.MQTTv311
84+
elif self.mqtt_version == "5":
85+
return paho.MQTTv5
86+
else:
87+
raise ValueError(f"Unsupported MQTT version: {self.mqtt_version}")
88+
89+
def _configure_authentication(self):
90+
if self.mqtt_username:
91+
self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password)
92+
93+
def _publish_to_mqtt(self, data: str, key: bytes, timestamp: datetime, headers: List[Tuple[str, HeaderValue]]):
94+
if isinstance(data, bytes):
95+
data = data.decode('utf-8') # Decode bytes to string using utf-8
96+
97+
json_data = json.dumps(data)
98+
message_key_string = key.decode('utf-8') # Convert to string using utf-8 encoding
99+
# publish to MQTT
100+
self.mqtt_client.publish(self.mqtt_topic_root + "/" + message_key_string, payload = json_data, qos = self.qos)
101+
102+
103+
def add(self,
104+
topic: str,
105+
partition: int,
106+
offset: int,
107+
key: bytes,
108+
value: bytes,
109+
timestamp: datetime,
110+
headers: List[Tuple[str, HeaderValue]],
111+
**kwargs: Any):
112+
self._publish_to_mqtt(value, key, timestamp, headers)
113+
114+
def _construct_topic(self, key):
115+
if key:
116+
key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key)
117+
return f"{self.mqtt_topic_root}/{key_str}"
118+
else:
119+
return self.mqtt_topic_root
120+
121+
def on_paused(self, topic: str, partition: int):
122+
# not used
123+
pass
124+
125+
def flush(self, topic: str, partition: str):
126+
# not used
127+
pass
128+
129+
def cleanup(self):
130+
self.mqtt_client.loop_stop()
131+
self.mqtt_client.disconnect()
132+
133+
def __del__(self):
134+
self.cleanup()

tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ protobuf>=5.27.2
77
influxdb3-python>=0.7.0,<1.0
88
pyiceberg[pyarrow,glue]>=0.7,<0.8
99
redis[hiredis]>=5.2.0,<6
10+
paho-mqtt==2.1.0
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from unittest.mock import MagicMock, patch
2+
import pytest
3+
from datetime import datetime
4+
from quixstreams.sinks.community.mqtt import MQTTSink
5+
6+
@pytest.fixture()
7+
def mqtt_sink_factory():
8+
def factory(
9+
mqtt_client_id: str = "test_client",
10+
mqtt_server: str = "localhost",
11+
mqtt_port: int = 1883,
12+
mqtt_topic_root: str = "test/topic",
13+
mqtt_username: str = None,
14+
mqtt_password: str = None,
15+
mqtt_version: str = "3.1.1",
16+
tls_enabled: bool = True,
17+
qos: int = 1,
18+
) -> MQTTSink:
19+
with patch('paho.mqtt.client.Client') as MockClient:
20+
mock_mqtt_client = MockClient.return_value
21+
sink = MQTTSink(
22+
mqtt_client_id=mqtt_client_id,
23+
mqtt_server=mqtt_server,
24+
mqtt_port=mqtt_port,
25+
mqtt_topic_root=mqtt_topic_root,
26+
mqtt_username=mqtt_username,
27+
mqtt_password=mqtt_password,
28+
mqtt_version=mqtt_version,
29+
tls_enabled=tls_enabled,
30+
qos=qos
31+
)
32+
sink.mqtt_client = mock_mqtt_client
33+
return sink, mock_mqtt_client
34+
35+
return factory
36+
37+
class TestMQTTSink:
38+
def test_mqtt_connect(self, mqtt_sink_factory):
39+
sink, mock_mqtt_client = mqtt_sink_factory()
40+
mock_mqtt_client.connect.assert_called_once_with("localhost", 1883)
41+
42+
def test_mqtt_tls_enabled(self, mqtt_sink_factory):
43+
sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=True)
44+
mock_mqtt_client.tls_set.assert_called_once()
45+
46+
def test_mqtt_tls_disabled(self, mqtt_sink_factory):
47+
sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=False)
48+
mock_mqtt_client.tls_set.assert_not_called()
49+
50+
def test_mqtt_publish(self, mqtt_sink_factory):
51+
sink, mock_mqtt_client = mqtt_sink_factory()
52+
data = "test_data"
53+
key = b"test_key"
54+
timestamp = datetime.now()
55+
headers = []
56+
57+
sink.add(
58+
topic="test-topic",
59+
partition=0,
60+
offset=1,
61+
key=key,
62+
value=data.encode('utf-8'),
63+
timestamp=timestamp,
64+
headers=headers
65+
)
66+
67+
mock_mqtt_client.publish.assert_called_once_with(
68+
"test/topic/test_key", payload='"test_data"', qos=1
69+
)
70+
71+
def test_mqtt_authentication(self, mqtt_sink_factory):
72+
sink, mock_mqtt_client = mqtt_sink_factory(mqtt_username="user", mqtt_password="pass")
73+
mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass")
74+
75+
def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory):
76+
sink, mock_mqtt_client = mqtt_sink_factory()
77+
sink.cleanup() # Explicitly call cleanup
78+
mock_mqtt_client.loop_stop.assert_called_once()
79+
mock_mqtt_client.disconnect.assert_called_once()

0 commit comments

Comments
 (0)