Skip to content

Commit 32c5f63

Browse files
committed
update to latest sink patterns and overhaul functionality
1 parent cbc2ebd commit 32c5f63

File tree

1 file changed

+198
-116
lines changed

1 file changed

+198
-116
lines changed

quixstreams/sinks/community/mqtt.py

Lines changed: 198 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,137 +1,166 @@
11
import json
2+
import logging
3+
import time
24
from datetime import datetime
3-
from typing import Any, List, Tuple
5+
from typing import Any, Callable, Literal, Optional, Union, get_args
46

5-
from quixstreams.models.types import HeaderValue
6-
from quixstreams.sinks.base.sink import BaseSink
7+
from quixstreams.models.types import HeadersTuples
8+
from quixstreams.sinks import (
9+
BaseSink,
10+
ClientConnectFailureCallback,
11+
ClientConnectSuccessCallback,
12+
)
713

814
try:
915
import paho.mqtt.client as paho
10-
from paho import mqtt
1116
except ImportError as exc:
1217
raise ImportError(
1318
'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it"
1419
) from exc
1520

1621

22+
logger = logging.getLogger(__name__)
23+
24+
VERSION_MAP = {
25+
"3.1": paho.MQTTv31,
26+
"3.1.1": paho.MQTTv311,
27+
"5": paho.MQTTv5,
28+
}
29+
MQTT_SUCCESS = paho.MQTT_ERR_SUCCESS
30+
ProtocolVersion = Literal["3.1", "3.1.1", "5"]
31+
MqttPropertiesHandler = Union[paho.Properties, Callable[[Any], paho.Properties]]
32+
RetainHandler = Union[bool, Callable[[Any], bool]]
33+
34+
1735
class MQTTSink(BaseSink):
1836
"""
1937
A sink that publishes messages to an MQTT broker.
2038
"""
2139

2240
def __init__(
2341
self,
24-
mqtt_client_id: str,
25-
mqtt_server: str,
26-
mqtt_port: int,
27-
mqtt_topic_root: str,
28-
mqtt_username: str = None,
29-
mqtt_password: str = None,
30-
mqtt_version: str = "3.1.1",
42+
client_id: str,
43+
server: str,
44+
port: int,
45+
topic_root: str,
46+
username: str = None,
47+
password: str = None,
48+
version: ProtocolVersion = "3.1.1",
3149
tls_enabled: bool = True,
32-
qos: int = 1,
50+
key_serializer: Callable[[Any], str] = bytes.decode,
51+
value_serializer: Callable[[Any], str] = json.dumps,
52+
qos: Literal[0, 1] = 1,
53+
mqtt_flush_timeout_seconds: int = 10,
54+
retain: Union[bool, Callable[[Any], bool]] = False,
55+
properties: Optional[MqttPropertiesHandler] = None,
56+
on_client_connect_success: Optional[ClientConnectSuccessCallback] = None,
57+
on_client_connect_failure: Optional[ClientConnectFailureCallback] = None,
3358
):
3459
"""
3560
Initialize the MQTTSink.
3661
37-
:param mqtt_client_id: MQTT client identifier.
38-
:param mqtt_server: MQTT broker server address.
39-
:param mqtt_port: MQTT broker server port.
40-
:param mqtt_topic_root: Root topic to publish messages to.
41-
:param mqtt_username: Username for MQTT broker authentication. Defaults to None
42-
:param mqtt_password: Password for MQTT broker authentication. Defaults to None
43-
:param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
44-
:param tls_enabled: Whether to use TLS encryption. Defaults to True
45-
:param qos: Quality of Service level (0, 1, or 2). Defaults to 1
62+
:param client_id: MQTT client identifier.
63+
:param server: MQTT broker server address.
64+
:param port: MQTT broker server port.
65+
:param topic_root: Root topic to publish messages to.
66+
:param username: Username for MQTT broker authentication. Default = None
67+
:param password: Password for MQTT broker authentication. Default = None
68+
:param version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
69+
:param tls_enabled: Whether to use TLS encryption. Default = True
70+
:param key_serializer: How to serialize the MQTT message key for producing.
71+
:param value_serializer: How to serialize the MQTT message value for producing.
72+
:param qos: Quality of Service level (0 or 1; 2 not yet supported) Default = 1.
73+
:param mqtt_flush_timeout_seconds: how long to wait for publish acknowledgment
74+
of MQTT messages before failing. Default = 10.
75+
:param retain: Retain last message for new subscribers. Default = False.
76+
Also accepts a callable that uses the current message value as input.
77+
:param properties: An optional Properties instance for messages. Default = None.
78+
Also accepts a callable that uses the current message value as input.
4679
"""
47-
48-
super().__init__()
49-
50-
self.mqtt_version = mqtt_version
51-
self.mqtt_username = mqtt_username
52-
self.mqtt_password = mqtt_password
53-
self.mqtt_topic_root = mqtt_topic_root
54-
self.tls_enabled = tls_enabled
55-
self.qos = qos
56-
57-
self.mqtt_client = paho.Client(
80+
super().__init__(
81+
on_client_connect_success=on_client_connect_success,
82+
on_client_connect_failure=on_client_connect_failure,
83+
)
84+
if qos == 2:
85+
raise ValueError(f"MQTT QoS level {2} is currently not supported.")
86+
if not (protocol := VERSION_MAP.get(version)):
87+
raise ValueError(
88+
f"Invalid MQTT version {version}; valid: {get_args(ProtocolVersion)}"
89+
)
90+
if properties and protocol != "5":
91+
raise ValueError(
92+
"MQTT Properties can only be used with MQTT protocol version 5"
93+
)
94+
95+
self._version = version
96+
self._server = server
97+
self._port = port
98+
self._topic_root = topic_root
99+
self._key_serializer = key_serializer
100+
self._value_serializer = value_serializer
101+
self._qos = qos
102+
self._flush_timeout = mqtt_flush_timeout_seconds
103+
self._pending_acks: set[int] = set()
104+
self._retain = _get_retain_callable(retain)
105+
self._properties = _get_properties_callable(properties)
106+
107+
self._client = paho.Client(
58108
callback_api_version=paho.CallbackAPIVersion.VERSION2,
59-
client_id=mqtt_client_id,
109+
client_id=client_id,
60110
userdata=None,
61-
protocol=self._mqtt_protocol_version(),
111+
protocol=protocol,
62112
)
63113

64-
if self.tls_enabled:
65-
self.mqtt_client.tls_set(
66-
tls_version=mqtt.client.ssl.PROTOCOL_TLS
67-
) # we'll be using tls now
114+
if username:
115+
self._client.username_pw_set(username, password)
116+
if tls_enabled:
117+
self._client.tls_set(tls_version=paho.ssl.PROTOCOL_TLS)
118+
self._client.reconnect_delay_set(5, 60)
119+
self._client.on_connect = _mqtt_on_connect_cb
120+
self._client.on_disconnect = _mqtt_on_disconnect_cb
121+
self._client.on_publish = self._on_publish_cb
122+
self._publish_count = 0
68123

69-
self.mqtt_client.reconnect_delay_set(5, 60)
70-
self._configure_authentication()
71-
self.mqtt_client.on_connect = self._mqtt_on_connect_cb
72-
self.mqtt_client.on_disconnect = self._mqtt_on_disconnect_cb
73-
self.mqtt_client.connect(mqtt_server, int(mqtt_port))
124+
def setup(self):
125+
self._client.connect(self._server, self._port)
126+
self._client.loop_start()
74127

75-
# setting callbacks for different events to see if it works, print the message etc.
76-
def _mqtt_on_connect_cb(
77-
self,
78-
client: paho.Client,
79-
userdata: any,
80-
connect_flags: paho.ConnectFlags,
81-
reason_code: paho.ReasonCode,
82-
properties: paho.Properties,
83-
):
84-
if reason_code == 0:
85-
print("CONNECTED!") # required for Quix to know this has connected
86-
else:
87-
print(f"ERROR ({reason_code.value}). {reason_code.getName()}")
88-
89-
def _mqtt_on_disconnect_cb(
128+
def _publish_to_mqtt(
90129
self,
91-
client: paho.Client,
92-
userdata: any,
93-
disconnect_flags: paho.DisconnectFlags,
94-
reason_code: paho.ReasonCode,
95-
properties: paho.Properties,
130+
data: Any,
131+
topic_suffix: Any,
96132
):
97-
print(
98-
f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!"
133+
properties = self._properties
134+
info = self._client.publish(
135+
f"{self._topic_root}/{self._key_serializer(topic_suffix)}",
136+
payload=self._value_serializer(data),
137+
qos=self._qos,
138+
properties=properties(data) if properties else None,
139+
retain=self._retain(data),
99140
)
100-
101-
def _mqtt_protocol_version(self):
102-
if self.mqtt_version == "3.1":
103-
return paho.MQTTv31
104-
elif self.mqtt_version == "3.1.1":
105-
return paho.MQTTv311
106-
elif self.mqtt_version == "5":
107-
return paho.MQTTv5
141+
if self._qos:
142+
if info.rc != MQTT_SUCCESS:
143+
raise MqttPublishEnqueueFailed(
144+
f"Failed adding message to MQTT publishing queue; "
145+
f"error code {info.rc}: {paho.error_string(info.rc)}"
146+
)
147+
self._pending_acks.add(info.mid)
108148
else:
109-
raise ValueError(f"Unsupported MQTT version: {self.mqtt_version}")
149+
self._publish_count += 1
110150

111-
def _configure_authentication(self):
112-
if self.mqtt_username:
113-
self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password)
114-
115-
def _publish_to_mqtt(
151+
def _on_publish_cb(
116152
self,
117-
data: str,
118-
key: bytes,
119-
timestamp: datetime,
120-
headers: List[Tuple[str, HeaderValue]],
153+
client: paho.Client,
154+
userdata: Any,
155+
mid: int,
156+
rc: paho.ReasonCode,
157+
p: paho.Properties,
121158
):
122-
if isinstance(data, bytes):
123-
data = data.decode("utf-8") # Decode bytes to string using utf-8
124-
125-
json_data = json.dumps(data)
126-
message_key_string = key.decode(
127-
"utf-8"
128-
) # Convert to string using utf-8 encoding
129-
# publish to MQTT
130-
self.mqtt_client.publish(
131-
self.mqtt_topic_root + "/" + message_key_string,
132-
payload=json_data,
133-
qos=self.qos,
134-
)
159+
"""
160+
This is only triggered upon successful publish when self._qos > 0.
161+
"""
162+
self._publish_count += 1
163+
self._pending_acks.remove(mid)
135164

136165
def add(
137166
self,
@@ -141,29 +170,82 @@ def add(
141170
key: bytes,
142171
value: bytes,
143172
timestamp: datetime,
144-
headers: List[Tuple[str, HeaderValue]],
145-
**kwargs: Any,
173+
headers: HeadersTuples,
146174
):
147-
self._publish_to_mqtt(value, key, timestamp, headers)
175+
try:
176+
self._publish_to_mqtt(
177+
value,
178+
key,
179+
)
180+
except Exception as e:
181+
self._cleanup()
182+
raise e
183+
184+
def on_paused(self):
185+
pass
148186

149-
def _construct_topic(self, key):
150-
if key:
151-
key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key)
152-
return f"{self.mqtt_topic_root}/{key_str}"
153-
else:
154-
return self.mqtt_topic_root
187+
def flush(self):
188+
if self._pending_acks:
189+
start_time = time.monotonic()
190+
timeout = start_time + self._flush_timeout
191+
while self._pending_acks and start_time < timeout:
192+
logger.debug(f"Pending acks remaining: {len(self._pending_acks)}")
193+
time.sleep(1)
194+
if self._pending_acks:
195+
self._cleanup()
196+
raise MqttPublishAckTimeout(
197+
f"Mqtt acknowledgement timeout of {self._flush_timeout}s reached."
198+
)
199+
logger.info(f"{self._publish_count} MQTT messages published.")
200+
self._publish_count = 0
201+
202+
def _cleanup(self):
203+
self._client.loop_stop()
204+
self._client.disconnect()
205+
206+
207+
class MqttPublishEnqueueFailed(Exception):
208+
pass
209+
210+
211+
class MqttPublishAckTimeout(Exception):
212+
pass
213+
214+
215+
def _mqtt_on_connect_cb(
216+
client: paho.Client,
217+
userdata: any,
218+
connect_flags: paho.ConnectFlags,
219+
reason_code: paho.ReasonCode,
220+
properties: paho.Properties,
221+
):
222+
if reason_code != 0:
223+
raise ConnectionError(
224+
f"Failed to connect to MQTT broker; ERROR: ({reason_code.value}).{reason_code.getName()}"
225+
)
155226

156-
def on_paused(self, topic: str, partition: int):
157-
# not used
158-
pass
159227

160-
def flush(self, topic: str, partition: str):
161-
# not used
162-
pass
228+
def _mqtt_on_disconnect_cb(
229+
client: paho.Client,
230+
userdata: any,
231+
disconnect_flags: paho.DisconnectFlags,
232+
reason_code: paho.ReasonCode,
233+
properties: paho.Properties,
234+
):
235+
logger.info(
236+
f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!"
237+
)
238+
239+
240+
def _get_properties_callable(
241+
properties: Optional[MqttPropertiesHandler],
242+
) -> Optional[Callable[[Any], paho.Properties]]:
243+
if isinstance(properties, paho.Properties):
244+
return lambda data: properties(data)
245+
return properties
163246

164-
def cleanup(self):
165-
self.mqtt_client.loop_stop()
166-
self.mqtt_client.disconnect()
167247

168-
def __del__(self):
169-
self.cleanup()
248+
def _get_retain_callable(retain: RetainHandler) -> Callable[[Any], bool]:
249+
if isinstance(retain, bool):
250+
return lambda data: retain
251+
return retain

0 commit comments

Comments
 (0)