1
+ from quixstreams .sinks .base .sink import BaseSink
2
+ from quixstreams .sinks .base .exceptions import SinkBackpressureError
3
+ from typing import List , Tuple , Any
4
+ from quixstreams .models .types import HeaderValue
5
+ from datetime import datetime
6
+ import json
7
+
8
+ try :
9
+ import paho .mqtt .client as paho
10
+ from paho import mqtt
11
+ except ImportError as exc :
12
+ raise ImportError (
13
+ 'Package "paho-mqtt" is missing: '
14
+ "run pip install quixstreams[paho-mqtt] to fix it"
15
+ ) from exc
16
+
17
+ class MQTTSink (BaseSink ):
18
+ """
19
+ A sink that publishes messages to an MQTT broker.
20
+ """
21
+
22
+ def __init__ (self ,
23
+ mqtt_client_id : str ,
24
+ mqtt_server : str ,
25
+ mqtt_port : int ,
26
+ mqtt_topic_root : str ,
27
+ mqtt_username : str = None ,
28
+ mqtt_password : str = None ,
29
+ mqtt_version : str = "3.1.1" ,
30
+ tls_enabled : bool = True ,
31
+ qos : int = 1 ):
32
+ """
33
+ Initialize the MQTTSink.
34
+
35
+ :param mqtt_client_id: MQTT client identifier.
36
+ :param mqtt_server: MQTT broker server address.
37
+ :param mqtt_port: MQTT broker server port.
38
+ :param mqtt_topic_root: Root topic to publish messages to.
39
+ :param mqtt_username: Username for MQTT broker authentication. Defaults to None
40
+ :param mqtt_password: Password for MQTT broker authentication. Defaults to None
41
+ :param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1
42
+ :param tls_enabled: Whether to use TLS encryption. Defaults to True
43
+ :param qos: Quality of Service level (0, 1, or 2). Defaults to 1
44
+ """
45
+
46
+ super ().__init__ ()
47
+
48
+ self .mqtt_version = mqtt_version
49
+ self .mqtt_username = mqtt_username
50
+ self .mqtt_password = mqtt_password
51
+ self .mqtt_topic_root = mqtt_topic_root
52
+ self .tls_enabled = tls_enabled
53
+ self .qos = qos
54
+
55
+ self .mqtt_client = paho .Client (callback_api_version = paho .CallbackAPIVersion .VERSION2 ,
56
+ client_id = mqtt_client_id , userdata = None , protocol = self ._mqtt_protocol_version ())
57
+
58
+ if self .tls_enabled :
59
+ self .mqtt_client .tls_set (tls_version = mqtt .client .ssl .PROTOCOL_TLS ) # we'll be using tls now
60
+
61
+ self .mqtt_client .reconnect_delay_set (5 , 60 )
62
+ self ._configure_authentication ()
63
+ self .mqtt_client .on_connect = self ._mqtt_on_connect_cb
64
+ self .mqtt_client .on_disconnect = self ._mqtt_on_disconnect_cb
65
+ self .mqtt_client .connect (mqtt_server , int (mqtt_port ))
66
+
67
+ # setting callbacks for different events to see if it works, print the message etc.
68
+ def _mqtt_on_connect_cb (self , client : paho .Client , userdata : any , connect_flags : paho .ConnectFlags ,
69
+ reason_code : paho .ReasonCode , properties : paho .Properties ):
70
+ if reason_code == 0 :
71
+ print ("CONNECTED!" ) # required for Quix to know this has connected
72
+ else :
73
+ print (f"ERROR ({ reason_code .value } ). { reason_code .getName ()} " )
74
+
75
+ def _mqtt_on_disconnect_cb (self , client : paho .Client , userdata : any , disconnect_flags : paho .DisconnectFlags ,
76
+ reason_code : paho .ReasonCode , properties : paho .Properties ):
77
+ print (f"DISCONNECTED! Reason code ({ reason_code .value } ) { reason_code .getName ()} !" )
78
+
79
+ def _mqtt_protocol_version (self ):
80
+ if self .mqtt_version == "3.1" :
81
+ return paho .MQTTv31
82
+ elif self .mqtt_version == "3.1.1" :
83
+ return paho .MQTTv311
84
+ elif self .mqtt_version == "5" :
85
+ return paho .MQTTv5
86
+ else :
87
+ raise ValueError (f"Unsupported MQTT version: { self .mqtt_version } " )
88
+
89
+ def _configure_authentication (self ):
90
+ if self .mqtt_username :
91
+ self .mqtt_client .username_pw_set (self .mqtt_username , self .mqtt_password )
92
+
93
+ def _publish_to_mqtt (self , data : str , key : bytes , timestamp : datetime , headers : List [Tuple [str , HeaderValue ]]):
94
+ if isinstance (data , bytes ):
95
+ data = data .decode ('utf-8' ) # Decode bytes to string using utf-8
96
+
97
+ json_data = json .dumps (data )
98
+ message_key_string = key .decode ('utf-8' ) # Convert to string using utf-8 encoding
99
+ # publish to MQTT
100
+ self .mqtt_client .publish (self .mqtt_topic_root + "/" + message_key_string , payload = json_data , qos = self .qos )
101
+
102
+
103
+ def add (self ,
104
+ topic : str ,
105
+ partition : int ,
106
+ offset : int ,
107
+ key : bytes ,
108
+ value : bytes ,
109
+ timestamp : datetime ,
110
+ headers : List [Tuple [str , HeaderValue ]],
111
+ ** kwargs : Any ):
112
+ self ._publish_to_mqtt (value , key , timestamp , headers )
113
+
114
+ def _construct_topic (self , key ):
115
+ if key :
116
+ key_str = key .decode ('utf-8' ) if isinstance (key , bytes ) else str (key )
117
+ return f"{ self .mqtt_topic_root } /{ key_str } "
118
+ else :
119
+ return self .mqtt_topic_root
120
+
121
+ def on_paused (self , topic : str , partition : int ):
122
+ # not used
123
+ pass
124
+
125
+ def flush (self , topic : str , partition : str ):
126
+ # not used
127
+ pass
128
+
129
+ def cleanup (self ):
130
+ self .mqtt_client .loop_stop ()
131
+ self .mqtt_client .disconnect ()
132
+
133
+ def __del__ (self ):
134
+ self .cleanup ()
0 commit comments