1
1
import json
2
+ import logging
3
+ import time
2
4
from datetime import datetime
3
- from typing import Any , List , Tuple
5
+ from typing import Any , Callable , Literal , Optional , Union , get_args
4
6
5
- from quixstreams .models .types import HeaderValue
6
- from quixstreams .sinks .base .sink import BaseSink
7
+ from quixstreams .models .types import HeadersTuples
8
+ from quixstreams .sinks import (
9
+ BaseSink ,
10
+ ClientConnectFailureCallback ,
11
+ ClientConnectSuccessCallback ,
12
+ )
7
13
8
14
try :
9
15
import paho .mqtt .client as paho
10
- from paho import mqtt
11
16
except ImportError as exc :
12
17
raise ImportError (
13
18
'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it"
14
19
) from exc
15
20
16
21
22
+ logger = logging .getLogger (__name__ )
23
+
24
+ VERSION_MAP = {
25
+ "3.1" : paho .MQTTv31 ,
26
+ "3.1.1" : paho .MQTTv311 ,
27
+ "5" : paho .MQTTv5 ,
28
+ }
29
+ MQTT_SUCCESS = paho .MQTT_ERR_SUCCESS
30
+ ProtocolVersion = Literal ["3.1" , "3.1.1" , "5" ]
31
+ MqttPropertiesHandler = Union [paho .Properties , Callable [[Any ], paho .Properties ]]
32
+ RetainHandler = Union [bool , Callable [[Any ], bool ]]
33
+
34
+
17
35
class MQTTSink (BaseSink ):
18
36
"""
19
37
A sink that publishes messages to an MQTT broker.
20
38
"""
21
39
22
40
def __init__ (
23
41
self ,
24
- mqtt_client_id : str ,
25
- mqtt_server : str ,
26
- mqtt_port : int ,
27
- mqtt_topic_root : str ,
28
- mqtt_username : str = None ,
29
- mqtt_password : str = None ,
30
- mqtt_version : str = "3.1.1" ,
42
+ client_id : str ,
43
+ server : str ,
44
+ port : int ,
45
+ topic_root : str ,
46
+ username : str = None ,
47
+ password : str = None ,
48
+ version : ProtocolVersion = "3.1.1" ,
31
49
tls_enabled : bool = True ,
32
- qos : int = 1 ,
50
+ key_serializer : Callable [[Any ], str ] = bytes .decode ,
51
+ value_serializer : Callable [[Any ], str ] = json .dumps ,
52
+ qos : Literal [0 , 1 ] = 1 ,
53
+ mqtt_flush_timeout_seconds : int = 10 ,
54
+ retain : Union [bool , Callable [[Any ], bool ]] = False ,
55
+ properties : Optional [MqttPropertiesHandler ] = None ,
56
+ on_client_connect_success : Optional [ClientConnectSuccessCallback ] = None ,
57
+ on_client_connect_failure : Optional [ClientConnectFailureCallback ] = None ,
33
58
):
34
59
"""
35
60
Initialize the MQTTSink.
36
61
37
- :param mqtt_client_id: MQTT client identifier.
38
- :param mqtt_server: MQTT broker server address.
39
- :param mqtt_port: MQTT broker server port.
40
- :param mqtt_topic_root: Root topic to publish messages to.
41
- :param mqtt_username: Username for MQTT broker authentication. Defaults to None
42
- :param mqtt_password: Password for MQTT broker authentication. Defaults to None
43
- :param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
44
- :param tls_enabled: Whether to use TLS encryption. Defaults to True
45
- :param qos: Quality of Service level (0, 1, or 2). Defaults to 1
62
+ :param client_id: MQTT client identifier.
63
+ :param server: MQTT broker server address.
64
+ :param port: MQTT broker server port.
65
+ :param topic_root: Root topic to publish messages to.
66
+ :param username: Username for MQTT broker authentication. Default = None
67
+ :param password: Password for MQTT broker authentication. Default = None
68
+ :param version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
69
+ :param tls_enabled: Whether to use TLS encryption. Default = True
70
+ :param key_serializer: How to serialize the MQTT message key for producing.
71
+ :param value_serializer: How to serialize the MQTT message value for producing.
72
+ :param qos: Quality of Service level (0 or 1; 2 not yet supported) Default = 1.
73
+ :param mqtt_flush_timeout_seconds: how long to wait for publish acknowledgment
74
+ of MQTT messages before failing. Default = 10.
75
+ :param retain: Retain last message for new subscribers. Default = False.
76
+ Also accepts a callable that uses the current message value as input.
77
+ :param properties: An optional Properties instance for messages. Default = None.
78
+ Also accepts a callable that uses the current message value as input.
46
79
"""
47
-
48
- super ().__init__ ()
49
-
50
- self .mqtt_version = mqtt_version
51
- self .mqtt_username = mqtt_username
52
- self .mqtt_password = mqtt_password
53
- self .mqtt_topic_root = mqtt_topic_root
54
- self .tls_enabled = tls_enabled
55
- self .qos = qos
56
-
57
- self .mqtt_client = paho .Client (
80
+ super ().__init__ (
81
+ on_client_connect_success = on_client_connect_success ,
82
+ on_client_connect_failure = on_client_connect_failure ,
83
+ )
84
+ if qos == 2 :
85
+ raise ValueError (f"MQTT QoS level { 2 } is currently not supported." )
86
+ if not (protocol := VERSION_MAP .get (version )):
87
+ raise ValueError (
88
+ f"Invalid MQTT version { version } ; valid: { get_args (ProtocolVersion )} "
89
+ )
90
+ if properties and protocol != "5" :
91
+ raise ValueError (
92
+ "MQTT Properties can only be used with MQTT protocol version 5"
93
+ )
94
+
95
+ self ._version = version
96
+ self ._server = server
97
+ self ._port = port
98
+ self ._topic_root = topic_root
99
+ self ._key_serializer = key_serializer
100
+ self ._value_serializer = value_serializer
101
+ self ._qos = qos
102
+ self ._flush_timeout = mqtt_flush_timeout_seconds
103
+ self ._pending_acks : set [int ] = set ()
104
+ self ._retain = _get_retain_callable (retain )
105
+ self ._properties = _get_properties_callable (properties )
106
+
107
+ self ._client = paho .Client (
58
108
callback_api_version = paho .CallbackAPIVersion .VERSION2 ,
59
- client_id = mqtt_client_id ,
109
+ client_id = client_id ,
60
110
userdata = None ,
61
- protocol = self . _mqtt_protocol_version () ,
111
+ protocol = protocol ,
62
112
)
63
113
64
- if self .tls_enabled :
65
- self .mqtt_client .tls_set (
66
- tls_version = mqtt .client .ssl .PROTOCOL_TLS
67
- ) # we'll be using tls now
114
+ if username :
115
+ self ._client .username_pw_set (username , password )
116
+ if tls_enabled :
117
+ self ._client .tls_set (tls_version = paho .ssl .PROTOCOL_TLS )
118
+ self ._client .reconnect_delay_set (5 , 60 )
119
+ self ._client .on_connect = _mqtt_on_connect_cb
120
+ self ._client .on_disconnect = _mqtt_on_disconnect_cb
121
+ self ._client .on_publish = self ._on_publish_cb
122
+ self ._publish_count = 0
68
123
69
- self .mqtt_client .reconnect_delay_set (5 , 60 )
70
- self ._configure_authentication ()
71
- self .mqtt_client .on_connect = self ._mqtt_on_connect_cb
72
- self .mqtt_client .on_disconnect = self ._mqtt_on_disconnect_cb
73
- self .mqtt_client .connect (mqtt_server , int (mqtt_port ))
124
+ def setup (self ):
125
+ self ._client .connect (self ._server , self ._port )
126
+ self ._client .loop_start ()
74
127
75
- # setting callbacks for different events to see if it works, print the message etc.
76
- def _mqtt_on_connect_cb (
77
- self ,
78
- client : paho .Client ,
79
- userdata : any ,
80
- connect_flags : paho .ConnectFlags ,
81
- reason_code : paho .ReasonCode ,
82
- properties : paho .Properties ,
83
- ):
84
- if reason_code == 0 :
85
- print ("CONNECTED!" ) # required for Quix to know this has connected
86
- else :
87
- print (f"ERROR ({ reason_code .value } ). { reason_code .getName ()} " )
88
-
89
- def _mqtt_on_disconnect_cb (
128
+ def _publish_to_mqtt (
90
129
self ,
91
- client : paho .Client ,
92
- userdata : any ,
93
- disconnect_flags : paho .DisconnectFlags ,
94
- reason_code : paho .ReasonCode ,
95
- properties : paho .Properties ,
130
+ data : Any ,
131
+ topic_suffix : Any ,
96
132
):
97
- print (
98
- f"DISCONNECTED! Reason code ({ reason_code .value } ) { reason_code .getName ()} !"
133
+ properties = self ._properties
134
+ info = self ._client .publish (
135
+ f"{ self ._topic_root } /{ self ._key_serializer (topic_suffix )} " ,
136
+ payload = self ._value_serializer (data ),
137
+ qos = self ._qos ,
138
+ properties = properties (data ) if properties else None ,
139
+ retain = self ._retain (data ),
99
140
)
100
-
101
- def _mqtt_protocol_version (self ):
102
- if self .mqtt_version == "3.1" :
103
- return paho .MQTTv31
104
- elif self .mqtt_version == "3.1.1" :
105
- return paho .MQTTv311
106
- elif self .mqtt_version == "5" :
107
- return paho .MQTTv5
141
+ if self ._qos :
142
+ if info .rc != MQTT_SUCCESS :
143
+ raise MqttPublishEnqueueFailed (
144
+ f"Failed adding message to MQTT publishing queue; "
145
+ f"error code { info .rc } : { paho .error_string (info .rc )} "
146
+ )
147
+ self ._pending_acks .add (info .mid )
108
148
else :
109
- raise ValueError ( f"Unsupported MQTT version: { self .mqtt_version } " )
149
+ self ._publish_count += 1
110
150
111
- def _configure_authentication (self ):
112
- if self .mqtt_username :
113
- self .mqtt_client .username_pw_set (self .mqtt_username , self .mqtt_password )
114
-
115
- def _publish_to_mqtt (
151
+ def _on_publish_cb (
116
152
self ,
117
- data : str ,
118
- key : bytes ,
119
- timestamp : datetime ,
120
- headers : List [Tuple [str , HeaderValue ]],
153
+ client : paho .Client ,
154
+ userdata : Any ,
155
+ mid : int ,
156
+ rc : paho .ReasonCode ,
157
+ p : paho .Properties ,
121
158
):
122
- if isinstance (data , bytes ):
123
- data = data .decode ("utf-8" ) # Decode bytes to string using utf-8
124
-
125
- json_data = json .dumps (data )
126
- message_key_string = key .decode (
127
- "utf-8"
128
- ) # Convert to string using utf-8 encoding
129
- # publish to MQTT
130
- self .mqtt_client .publish (
131
- self .mqtt_topic_root + "/" + message_key_string ,
132
- payload = json_data ,
133
- qos = self .qos ,
134
- )
159
+ """
160
+ This is only triggered upon successful publish when self._qos > 0.
161
+ """
162
+ self ._publish_count += 1
163
+ self ._pending_acks .remove (mid )
135
164
136
165
def add (
137
166
self ,
@@ -141,29 +170,82 @@ def add(
141
170
key : bytes ,
142
171
value : bytes ,
143
172
timestamp : datetime ,
144
- headers : List [Tuple [str , HeaderValue ]],
145
- ** kwargs : Any ,
173
+ headers : HeadersTuples ,
146
174
):
147
- self ._publish_to_mqtt (value , key , timestamp , headers )
175
+ try :
176
+ self ._publish_to_mqtt (
177
+ value ,
178
+ key ,
179
+ )
180
+ except Exception as e :
181
+ self ._cleanup ()
182
+ raise e
183
+
184
+ def on_paused (self ):
185
+ pass
148
186
149
- def _construct_topic (self , key ):
150
- if key :
151
- key_str = key .decode ("utf-8" ) if isinstance (key , bytes ) else str (key )
152
- return f"{ self .mqtt_topic_root } /{ key_str } "
153
- else :
154
- return self .mqtt_topic_root
187
+ def flush (self ):
188
+ if self ._pending_acks :
189
+ start_time = time .monotonic ()
190
+ timeout = start_time + self ._flush_timeout
191
+ while self ._pending_acks and start_time < timeout :
192
+ logger .debug (f"Pending acks remaining: { len (self ._pending_acks )} " )
193
+ time .sleep (1 )
194
+ if self ._pending_acks :
195
+ self ._cleanup ()
196
+ raise MqttPublishAckTimeout (
197
+ f"Mqtt acknowledgement timeout of { self ._flush_timeout } s reached."
198
+ )
199
+ logger .info (f"{ self ._publish_count } MQTT messages published." )
200
+ self ._publish_count = 0
201
+
202
+ def _cleanup (self ):
203
+ self ._client .loop_stop ()
204
+ self ._client .disconnect ()
205
+
206
+
207
+ class MqttPublishEnqueueFailed (Exception ):
208
+ pass
209
+
210
+
211
+ class MqttPublishAckTimeout (Exception ):
212
+ pass
213
+
214
+
215
+ def _mqtt_on_connect_cb (
216
+ client : paho .Client ,
217
+ userdata : any ,
218
+ connect_flags : paho .ConnectFlags ,
219
+ reason_code : paho .ReasonCode ,
220
+ properties : paho .Properties ,
221
+ ):
222
+ if reason_code != 0 :
223
+ raise ConnectionError (
224
+ f"Failed to connect to MQTT broker; ERROR: ({ reason_code .value } ).{ reason_code .getName ()} "
225
+ )
155
226
156
- def on_paused (self , topic : str , partition : int ):
157
- # not used
158
- pass
159
227
160
- def flush (self , topic : str , partition : str ):
161
- # not used
162
- pass
228
+ def _mqtt_on_disconnect_cb (
229
+ client : paho .Client ,
230
+ userdata : any ,
231
+ disconnect_flags : paho .DisconnectFlags ,
232
+ reason_code : paho .ReasonCode ,
233
+ properties : paho .Properties ,
234
+ ):
235
+ logger .info (
236
+ f"DISCONNECTED! Reason code ({ reason_code .value } ) { reason_code .getName ()} !"
237
+ )
238
+
239
+
240
+ def _get_properties_callable (
241
+ properties : Optional [MqttPropertiesHandler ],
242
+ ) -> Optional [Callable [[Any ], paho .Properties ]]:
243
+ if isinstance (properties , paho .Properties ):
244
+ return lambda data : properties (data )
245
+ return properties
163
246
164
- def cleanup (self ):
165
- self .mqtt_client .loop_stop ()
166
- self .mqtt_client .disconnect ()
167
247
168
- def __del__ (self ):
169
- self .cleanup ()
248
+ def _get_retain_callable (retain : RetainHandler ) -> Callable [[Any ], bool ]:
249
+ if isinstance (retain , bool ):
250
+ return lambda data : retain
251
+ return retain
0 commit comments