1
- from quixstreams .sinks .base .sink import BaseSink
2
- from quixstreams .sinks .base .exceptions import SinkBackpressureError
3
- from typing import List , Tuple , Any
4
- from quixstreams .models .types import HeaderValue
5
- from datetime import datetime
6
1
import json
2
+ from datetime import datetime
3
+ from typing import Any , List , Tuple
4
+
5
+ from quixstreams .models .types import HeaderValue
6
+ from quixstreams .sinks .base .sink import BaseSink
7
7
8
8
try :
9
9
import paho .mqtt .client as paho
14
14
"run pip install quixstreams[paho-mqtt] to fix it"
15
15
) from exc
16
16
17
+
17
18
class MQTTSink (BaseSink ):
18
19
"""
19
20
A sink that publishes messages to an MQTT broker.
20
21
"""
21
22
22
- def __init__ (self ,
23
- mqtt_client_id : str ,
24
- mqtt_server : str ,
25
- mqtt_port : int ,
26
- mqtt_topic_root : str ,
27
- mqtt_username : str = None ,
28
- mqtt_password : str = None ,
29
- mqtt_version : str = "3.1.1" ,
30
- tls_enabled : bool = True ,
31
- qos : int = 1 ):
23
+ def __init__ (
24
+ self ,
25
+ mqtt_client_id : str ,
26
+ mqtt_server : str ,
27
+ mqtt_port : int ,
28
+ mqtt_topic_root : str ,
29
+ mqtt_username : str = None ,
30
+ mqtt_password : str = None ,
31
+ mqtt_version : str = "3.1.1" ,
32
+ tls_enabled : bool = True ,
33
+ qos : int = 1 ,
34
+ ):
32
35
"""
33
36
Initialize the MQTTSink.
34
37
@@ -42,21 +45,27 @@ def __init__(self,
42
45
:param tls_enabled: Whether to use TLS encryption. Defaults to True
43
46
:param qos: Quality of Service level (0, 1, or 2). Defaults to 1
44
47
"""
45
-
48
+
46
49
super ().__init__ ()
47
-
50
+
48
51
self .mqtt_version = mqtt_version
49
52
self .mqtt_username = mqtt_username
50
53
self .mqtt_password = mqtt_password
51
54
self .mqtt_topic_root = mqtt_topic_root
52
55
self .tls_enabled = tls_enabled
53
56
self .qos = qos
54
57
55
- self .mqtt_client = paho .Client (callback_api_version = paho .CallbackAPIVersion .VERSION2 ,
56
- client_id = mqtt_client_id , userdata = None , protocol = self ._mqtt_protocol_version ())
58
+ self .mqtt_client = paho .Client (
59
+ callback_api_version = paho .CallbackAPIVersion .VERSION2 ,
60
+ client_id = mqtt_client_id ,
61
+ userdata = None ,
62
+ protocol = self ._mqtt_protocol_version (),
63
+ )
57
64
58
65
if self .tls_enabled :
59
- self .mqtt_client .tls_set (tls_version = mqtt .client .ssl .PROTOCOL_TLS ) # we'll be using tls now
66
+ self .mqtt_client .tls_set (
67
+ tls_version = mqtt .client .ssl .PROTOCOL_TLS
68
+ ) # we'll be using tls now
60
69
61
70
self .mqtt_client .reconnect_delay_set (5 , 60 )
62
71
self ._configure_authentication ()
@@ -65,17 +74,31 @@ def __init__(self,
65
74
self .mqtt_client .connect (mqtt_server , int (mqtt_port ))
66
75
67
76
# setting callbacks for different events to see if it works, print the message etc.
68
- def _mqtt_on_connect_cb (self , client : paho .Client , userdata : any , connect_flags : paho .ConnectFlags ,
69
- reason_code : paho .ReasonCode , properties : paho .Properties ):
77
+ def _mqtt_on_connect_cb (
78
+ self ,
79
+ client : paho .Client ,
80
+ userdata : any ,
81
+ connect_flags : paho .ConnectFlags ,
82
+ reason_code : paho .ReasonCode ,
83
+ properties : paho .Properties ,
84
+ ):
70
85
if reason_code == 0 :
71
- print ("CONNECTED!" ) # required for Quix to know this has connected
86
+ print ("CONNECTED!" ) # required for Quix to know this has connected
72
87
else :
73
88
print (f"ERROR ({ reason_code .value } ). { reason_code .getName ()} " )
74
89
75
- def _mqtt_on_disconnect_cb (self , client : paho .Client , userdata : any , disconnect_flags : paho .DisconnectFlags ,
76
- reason_code : paho .ReasonCode , properties : paho .Properties ):
77
- print (f"DISCONNECTED! Reason code ({ reason_code .value } ) { reason_code .getName ()} !" )
78
-
90
+ def _mqtt_on_disconnect_cb (
91
+ self ,
92
+ client : paho .Client ,
93
+ userdata : any ,
94
+ disconnect_flags : paho .DisconnectFlags ,
95
+ reason_code : paho .ReasonCode ,
96
+ properties : paho .Properties ,
97
+ ):
98
+ print (
99
+ f"DISCONNECTED! Reason code ({ reason_code .value } ) { reason_code .getName ()} !"
100
+ )
101
+
79
102
def _mqtt_protocol_version (self ):
80
103
if self .mqtt_version == "3.1" :
81
104
return paho .MQTTv31
@@ -90,38 +113,51 @@ def _configure_authentication(self):
90
113
if self .mqtt_username :
91
114
self .mqtt_client .username_pw_set (self .mqtt_username , self .mqtt_password )
92
115
93
- def _publish_to_mqtt (self , data : str , key : bytes , timestamp : datetime , headers : List [Tuple [str , HeaderValue ]]):
116
+ def _publish_to_mqtt (
117
+ self ,
118
+ data : str ,
119
+ key : bytes ,
120
+ timestamp : datetime ,
121
+ headers : List [Tuple [str , HeaderValue ]],
122
+ ):
94
123
if isinstance (data , bytes ):
95
- data = data .decode (' utf-8' ) # Decode bytes to string using utf-8
124
+ data = data .decode (" utf-8" ) # Decode bytes to string using utf-8
96
125
97
126
json_data = json .dumps (data )
98
- message_key_string = key .decode ('utf-8' ) # Convert to string using utf-8 encoding
127
+ message_key_string = key .decode (
128
+ "utf-8"
129
+ ) # Convert to string using utf-8 encoding
99
130
# publish to MQTT
100
- self .mqtt_client .publish (self .mqtt_topic_root + "/" + message_key_string , payload = json_data , qos = self .qos )
101
-
102
-
103
- def add (self ,
104
- topic : str ,
105
- partition : int ,
106
- offset : int ,
107
- key : bytes ,
108
- value : bytes ,
109
- timestamp : datetime ,
110
- headers : List [Tuple [str , HeaderValue ]],
111
- ** kwargs : Any ):
131
+ self .mqtt_client .publish (
132
+ self .mqtt_topic_root + "/" + message_key_string ,
133
+ payload = json_data ,
134
+ qos = self .qos ,
135
+ )
136
+
137
+ def add (
138
+ self ,
139
+ topic : str ,
140
+ partition : int ,
141
+ offset : int ,
142
+ key : bytes ,
143
+ value : bytes ,
144
+ timestamp : datetime ,
145
+ headers : List [Tuple [str , HeaderValue ]],
146
+ ** kwargs : Any ,
147
+ ):
112
148
self ._publish_to_mqtt (value , key , timestamp , headers )
113
149
114
150
def _construct_topic (self , key ):
115
151
if key :
116
- key_str = key .decode (' utf-8' ) if isinstance (key , bytes ) else str (key )
152
+ key_str = key .decode (" utf-8" ) if isinstance (key , bytes ) else str (key )
117
153
return f"{ self .mqtt_topic_root } /{ key_str } "
118
154
else :
119
155
return self .mqtt_topic_root
120
156
121
157
def on_paused (self , topic : str , partition : int ):
122
158
# not used
123
159
pass
124
-
160
+
125
161
def flush (self , topic : str , partition : str ):
126
162
# not used
127
163
pass
0 commit comments