From abc19373c8fb6fcc1c6b5a46019ffae0f4bcd3af Mon Sep 17 00:00:00 2001 From: SteveRosam Date: Wed, 27 Nov 2024 16:23:58 +0000 Subject: [PATCH 01/10] Add MQTT Sink --- pyproject.toml | 3 +- quixstreams/sinks/community/mqtt.py | 134 ++++++++++++++++++ .../test_community/test_mqtt_sink.py | 79 +++++++++++ 3 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 quixstreams/sinks/community/mqtt.py create mode 100644 tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py diff --git a/pyproject.toml b/pyproject.toml index b95560611..d583cb672 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,8 @@ all = [ "pymongo>=4.11,<5", "pandas>=1.0.0,<3.0", "elasticsearch>=8.17,<9", - "influxdb>=5.3,<6" + "influxdb>=5.3,<6", + "paho-mqtt==2.1.0" ] avro = ["fastavro>=1.8,<2.0"] diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py new file mode 100644 index 000000000..ff624ac2a --- /dev/null +++ b/quixstreams/sinks/community/mqtt.py @@ -0,0 +1,134 @@ +from quixstreams.sinks.base.sink import BaseSink +from quixstreams.sinks.base.exceptions import SinkBackpressureError +from typing import List, Tuple, Any +from quixstreams.models.types import HeaderValue +from datetime import datetime +import json + +try: + import paho.mqtt.client as paho + from paho import mqtt +except ImportError as exc: + raise ImportError( + 'Package "paho-mqtt" is missing: ' + "run pip install quixstreams[paho-mqtt] to fix it" + ) from exc + +class MQTTSink(BaseSink): + """ + A sink that publishes messages to an MQTT broker. + """ + + def __init__(self, + mqtt_client_id: str, + mqtt_server: str, + mqtt_port: int, + mqtt_topic_root: str, + mqtt_username: str = None, + mqtt_password: str = None, + mqtt_version: str = "3.1.1", + tls_enabled: bool = True, + qos: int = 1): + """ + Initialize the MQTTSink. + + :param mqtt_client_id: MQTT client identifier. + :param mqtt_server: MQTT broker server address. + :param mqtt_port: MQTT broker server port. + :param mqtt_topic_root: Root topic to publish messages to. + :param mqtt_username: Username for MQTT broker authentication. Defaults to None + :param mqtt_password: Password for MQTT broker authentication. Defaults to None + :param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1 + :param tls_enabled: Whether to use TLS encryption. Defaults to True + :param qos: Quality of Service level (0, 1, or 2). Defaults to 1 + """ + + super().__init__() + + self.mqtt_version = mqtt_version + self.mqtt_username = mqtt_username + self.mqtt_password = mqtt_password + self.mqtt_topic_root = mqtt_topic_root + self.tls_enabled = tls_enabled + self.qos = qos + + self.mqtt_client = paho.Client(callback_api_version=paho.CallbackAPIVersion.VERSION2, + client_id = mqtt_client_id, userdata = None, protocol = self._mqtt_protocol_version()) + + if self.tls_enabled: + self.mqtt_client.tls_set(tls_version = mqtt.client.ssl.PROTOCOL_TLS) # we'll be using tls now + + self.mqtt_client.reconnect_delay_set(5, 60) + self._configure_authentication() + self.mqtt_client.on_connect = self._mqtt_on_connect_cb + self.mqtt_client.on_disconnect = self._mqtt_on_disconnect_cb + self.mqtt_client.connect(mqtt_server, int(mqtt_port)) + + # setting callbacks for different events to see if it works, print the message etc. + def _mqtt_on_connect_cb(self, client: paho.Client, userdata: any, connect_flags: paho.ConnectFlags, + reason_code: paho.ReasonCode, properties: paho.Properties): + if reason_code == 0: + print("CONNECTED!") # required for Quix to know this has connected + else: + print(f"ERROR ({reason_code.value}). {reason_code.getName()}") + + def _mqtt_on_disconnect_cb(self, client: paho.Client, userdata: any, disconnect_flags: paho.DisconnectFlags, + reason_code: paho.ReasonCode, properties: paho.Properties): + print(f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!") + + def _mqtt_protocol_version(self): + if self.mqtt_version == "3.1": + return paho.MQTTv31 + elif self.mqtt_version == "3.1.1": + return paho.MQTTv311 + elif self.mqtt_version == "5": + return paho.MQTTv5 + else: + raise ValueError(f"Unsupported MQTT version: {self.mqtt_version}") + + def _configure_authentication(self): + if self.mqtt_username: + self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password) + + def _publish_to_mqtt(self, data: str, key: bytes, timestamp: datetime, headers: List[Tuple[str, HeaderValue]]): + if isinstance(data, bytes): + data = data.decode('utf-8') # Decode bytes to string using utf-8 + + json_data = json.dumps(data) + message_key_string = key.decode('utf-8') # Convert to string using utf-8 encoding + # publish to MQTT + self.mqtt_client.publish(self.mqtt_topic_root + "/" + message_key_string, payload = json_data, qos = self.qos) + + + def add(self, + topic: str, + partition: int, + offset: int, + key: bytes, + value: bytes, + timestamp: datetime, + headers: List[Tuple[str, HeaderValue]], + **kwargs: Any): + self._publish_to_mqtt(value, key, timestamp, headers) + + def _construct_topic(self, key): + if key: + key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key) + return f"{self.mqtt_topic_root}/{key_str}" + else: + return self.mqtt_topic_root + + def on_paused(self, topic: str, partition: int): + # not used + pass + + def flush(self, topic: str, partition: str): + # not used + pass + + def cleanup(self): + self.mqtt_client.loop_stop() + self.mqtt_client.disconnect() + + def __del__(self): + self.cleanup() \ No newline at end of file diff --git a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py new file mode 100644 index 000000000..6f9e80d25 --- /dev/null +++ b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py @@ -0,0 +1,79 @@ +from unittest.mock import MagicMock, patch +import pytest +from datetime import datetime +from quixstreams.sinks.community.mqtt import MQTTSink + +@pytest.fixture() +def mqtt_sink_factory(): + def factory( + mqtt_client_id: str = "test_client", + mqtt_server: str = "localhost", + mqtt_port: int = 1883, + mqtt_topic_root: str = "test/topic", + mqtt_username: str = None, + mqtt_password: str = None, + mqtt_version: str = "3.1.1", + tls_enabled: bool = True, + qos: int = 1, + ) -> MQTTSink: + with patch('paho.mqtt.client.Client') as MockClient: + mock_mqtt_client = MockClient.return_value + sink = MQTTSink( + mqtt_client_id=mqtt_client_id, + mqtt_server=mqtt_server, + mqtt_port=mqtt_port, + mqtt_topic_root=mqtt_topic_root, + mqtt_username=mqtt_username, + mqtt_password=mqtt_password, + mqtt_version=mqtt_version, + tls_enabled=tls_enabled, + qos=qos + ) + sink.mqtt_client = mock_mqtt_client + return sink, mock_mqtt_client + + return factory + +class TestMQTTSink: + def test_mqtt_connect(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory() + mock_mqtt_client.connect.assert_called_once_with("localhost", 1883) + + def test_mqtt_tls_enabled(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=True) + mock_mqtt_client.tls_set.assert_called_once() + + def test_mqtt_tls_disabled(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=False) + mock_mqtt_client.tls_set.assert_not_called() + + def test_mqtt_publish(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory() + data = "test_data" + key = b"test_key" + timestamp = datetime.now() + headers = [] + + sink.add( + topic="test-topic", + partition=0, + offset=1, + key=key, + value=data.encode('utf-8'), + timestamp=timestamp, + headers=headers + ) + + mock_mqtt_client.publish.assert_called_once_with( + "test/topic/test_key", payload='"test_data"', qos=1 + ) + + def test_mqtt_authentication(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory(mqtt_username="user", mqtt_password="pass") + mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass") + + def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory() + sink.cleanup() # Explicitly call cleanup + mock_mqtt_client.loop_stop.assert_called_once() + mock_mqtt_client.disconnect.assert_called_once() \ No newline at end of file From 204d3e4a66c52e90fc48751d493c10407f300831 Mon Sep 17 00:00:00 2001 From: SteveRosam Date: Wed, 27 Nov 2024 16:31:34 +0000 Subject: [PATCH 02/10] Add new line --- .../test_sinks/test_community/test_mqtt_sink.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py index 6f9e80d25..4efe1d694 100644 --- a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py +++ b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py @@ -76,4 +76,4 @@ def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory): sink, mock_mqtt_client = mqtt_sink_factory() sink.cleanup() # Explicitly call cleanup mock_mqtt_client.loop_stop.assert_called_once() - mock_mqtt_client.disconnect.assert_called_once() \ No newline at end of file + mock_mqtt_client.disconnect.assert_called_once() From 9e7bb4c4dc6356ac3a264f27fcf7fa8ccf7344e4 Mon Sep 17 00:00:00 2001 From: SteveRosam Date: Wed, 27 Nov 2024 16:33:36 +0000 Subject: [PATCH 03/10] EoF New Line --- quixstreams/sinks/community/mqtt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py index ff624ac2a..b4d305651 100644 --- a/quixstreams/sinks/community/mqtt.py +++ b/quixstreams/sinks/community/mqtt.py @@ -131,4 +131,4 @@ def cleanup(self): self.mqtt_client.disconnect() def __del__(self): - self.cleanup() \ No newline at end of file + self.cleanup() From 8066f1d0fcf8bb8526ef66927311375d979bd781 Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Thu, 28 Nov 2024 10:08:32 +0100 Subject: [PATCH 04/10] run linters --- quixstreams/sinks/community/mqtt.py | 124 +++++++++++------- .../test_community/test_mqtt_sink.py | 20 ++- 2 files changed, 93 insertions(+), 51 deletions(-) diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py index b4d305651..51ab840a7 100644 --- a/quixstreams/sinks/community/mqtt.py +++ b/quixstreams/sinks/community/mqtt.py @@ -1,9 +1,9 @@ -from quixstreams.sinks.base.sink import BaseSink -from quixstreams.sinks.base.exceptions import SinkBackpressureError -from typing import List, Tuple, Any -from quixstreams.models.types import HeaderValue -from datetime import datetime import json +from datetime import datetime +from typing import Any, List, Tuple + +from quixstreams.models.types import HeaderValue +from quixstreams.sinks.base.sink import BaseSink try: import paho.mqtt.client as paho @@ -14,21 +14,24 @@ "run pip install quixstreams[paho-mqtt] to fix it" ) from exc + class MQTTSink(BaseSink): """ A sink that publishes messages to an MQTT broker. """ - def __init__(self, - mqtt_client_id: str, - mqtt_server: str, - mqtt_port: int, - mqtt_topic_root: str, - mqtt_username: str = None, - mqtt_password: str = None, - mqtt_version: str = "3.1.1", - tls_enabled: bool = True, - qos: int = 1): + def __init__( + self, + mqtt_client_id: str, + mqtt_server: str, + mqtt_port: int, + mqtt_topic_root: str, + mqtt_username: str = None, + mqtt_password: str = None, + mqtt_version: str = "3.1.1", + tls_enabled: bool = True, + qos: int = 1, + ): """ Initialize the MQTTSink. @@ -42,9 +45,9 @@ def __init__(self, :param tls_enabled: Whether to use TLS encryption. Defaults to True :param qos: Quality of Service level (0, 1, or 2). Defaults to 1 """ - + super().__init__() - + self.mqtt_version = mqtt_version self.mqtt_username = mqtt_username self.mqtt_password = mqtt_password @@ -52,11 +55,17 @@ def __init__(self, self.tls_enabled = tls_enabled self.qos = qos - self.mqtt_client = paho.Client(callback_api_version=paho.CallbackAPIVersion.VERSION2, - client_id = mqtt_client_id, userdata = None, protocol = self._mqtt_protocol_version()) + self.mqtt_client = paho.Client( + callback_api_version=paho.CallbackAPIVersion.VERSION2, + client_id=mqtt_client_id, + userdata=None, + protocol=self._mqtt_protocol_version(), + ) if self.tls_enabled: - self.mqtt_client.tls_set(tls_version = mqtt.client.ssl.PROTOCOL_TLS) # we'll be using tls now + self.mqtt_client.tls_set( + tls_version=mqtt.client.ssl.PROTOCOL_TLS + ) # we'll be using tls now self.mqtt_client.reconnect_delay_set(5, 60) self._configure_authentication() @@ -65,17 +74,31 @@ def __init__(self, self.mqtt_client.connect(mqtt_server, int(mqtt_port)) # setting callbacks for different events to see if it works, print the message etc. - def _mqtt_on_connect_cb(self, client: paho.Client, userdata: any, connect_flags: paho.ConnectFlags, - reason_code: paho.ReasonCode, properties: paho.Properties): + def _mqtt_on_connect_cb( + self, + client: paho.Client, + userdata: any, + connect_flags: paho.ConnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, + ): if reason_code == 0: - print("CONNECTED!") # required for Quix to know this has connected + print("CONNECTED!") # required for Quix to know this has connected else: print(f"ERROR ({reason_code.value}). {reason_code.getName()}") - def _mqtt_on_disconnect_cb(self, client: paho.Client, userdata: any, disconnect_flags: paho.DisconnectFlags, - reason_code: paho.ReasonCode, properties: paho.Properties): - print(f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!") - + def _mqtt_on_disconnect_cb( + self, + client: paho.Client, + userdata: any, + disconnect_flags: paho.DisconnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, + ): + print( + f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!" + ) + def _mqtt_protocol_version(self): if self.mqtt_version == "3.1": return paho.MQTTv31 @@ -90,30 +113,43 @@ def _configure_authentication(self): if self.mqtt_username: self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password) - def _publish_to_mqtt(self, data: str, key: bytes, timestamp: datetime, headers: List[Tuple[str, HeaderValue]]): + def _publish_to_mqtt( + self, + data: str, + key: bytes, + timestamp: datetime, + headers: List[Tuple[str, HeaderValue]], + ): if isinstance(data, bytes): - data = data.decode('utf-8') # Decode bytes to string using utf-8 + data = data.decode("utf-8") # Decode bytes to string using utf-8 json_data = json.dumps(data) - message_key_string = key.decode('utf-8') # Convert to string using utf-8 encoding + message_key_string = key.decode( + "utf-8" + ) # Convert to string using utf-8 encoding # publish to MQTT - self.mqtt_client.publish(self.mqtt_topic_root + "/" + message_key_string, payload = json_data, qos = self.qos) - - - def add(self, - topic: str, - partition: int, - offset: int, - key: bytes, - value: bytes, - timestamp: datetime, - headers: List[Tuple[str, HeaderValue]], - **kwargs: Any): + self.mqtt_client.publish( + self.mqtt_topic_root + "/" + message_key_string, + payload=json_data, + qos=self.qos, + ) + + def add( + self, + topic: str, + partition: int, + offset: int, + key: bytes, + value: bytes, + timestamp: datetime, + headers: List[Tuple[str, HeaderValue]], + **kwargs: Any, + ): self._publish_to_mqtt(value, key, timestamp, headers) def _construct_topic(self, key): if key: - key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key) + key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key) return f"{self.mqtt_topic_root}/{key_str}" else: return self.mqtt_topic_root @@ -121,7 +157,7 @@ def _construct_topic(self, key): def on_paused(self, topic: str, partition: int): # not used pass - + def flush(self, topic: str, partition: str): # not used pass diff --git a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py index 4efe1d694..05b6b332b 100644 --- a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py +++ b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py @@ -1,8 +1,11 @@ -from unittest.mock import MagicMock, patch -import pytest from datetime import datetime +from unittest.mock import patch + +import pytest + from quixstreams.sinks.community.mqtt import MQTTSink + @pytest.fixture() def mqtt_sink_factory(): def factory( @@ -16,7 +19,7 @@ def factory( tls_enabled: bool = True, qos: int = 1, ) -> MQTTSink: - with patch('paho.mqtt.client.Client') as MockClient: + with patch("paho.mqtt.client.Client") as MockClient: mock_mqtt_client = MockClient.return_value sink = MQTTSink( mqtt_client_id=mqtt_client_id, @@ -27,13 +30,14 @@ def factory( mqtt_password=mqtt_password, mqtt_version=mqtt_version, tls_enabled=tls_enabled, - qos=qos + qos=qos, ) sink.mqtt_client = mock_mqtt_client return sink, mock_mqtt_client return factory + class TestMQTTSink: def test_mqtt_connect(self, mqtt_sink_factory): sink, mock_mqtt_client = mqtt_sink_factory() @@ -59,9 +63,9 @@ def test_mqtt_publish(self, mqtt_sink_factory): partition=0, offset=1, key=key, - value=data.encode('utf-8'), + value=data.encode("utf-8"), timestamp=timestamp, - headers=headers + headers=headers, ) mock_mqtt_client.publish.assert_called_once_with( @@ -69,7 +73,9 @@ def test_mqtt_publish(self, mqtt_sink_factory): ) def test_mqtt_authentication(self, mqtt_sink_factory): - sink, mock_mqtt_client = mqtt_sink_factory(mqtt_username="user", mqtt_password="pass") + sink, mock_mqtt_client = mqtt_sink_factory( + mqtt_username="user", mqtt_password="pass" + ) mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass") def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory): From cbc2ebd35e7c938d4fdadba333e889a13d0d71f2 Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Thu, 28 Nov 2024 10:27:01 +0100 Subject: [PATCH 05/10] requirements --- conda/post-link.sh | 3 ++- pyproject.toml | 3 ++- quixstreams/sinks/community/mqtt.py | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/conda/post-link.sh b/conda/post-link.sh index 9d19e640d..70198ade0 100644 --- a/conda/post-link.sh +++ b/conda/post-link.sh @@ -8,4 +8,5 @@ $PREFIX/bin/pip install \ 'redis[hiredis]>=5.2.0,<6' \ 'confluent-kafka[avro,json,protobuf,schemaregistry]>=2.8.2,<2.10' \ 'influxdb>=5.3,<6' \ -'jsonpath_ng>=1.7.0,<2' +'jsonpath_ng>=1.7.0,<2' \ +'paho-mqtt>=2.1.0,<3' diff --git a/pyproject.toml b/pyproject.toml index d583cb672..7c5916ce3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ all = [ "pandas>=1.0.0,<3.0", "elasticsearch>=8.17,<9", "influxdb>=5.3,<6", - "paho-mqtt==2.1.0" + "paho-mqtt>=2.1.0,<3" ] avro = ["fastavro>=1.8,<2.0"] @@ -62,6 +62,7 @@ neo4j = ["neo4j>=5.27.0,<6"] mongodb = ["pymongo>=4.11,<5"] pandas = ["pandas>=1.0.0,<3.0"] elasticsearch = ["elasticsearch>=8.17,<9"] +mqtt = ["paho-mqtt>=2.1.0,<3"] # AWS dependencies are separated by service to support # different requirements in the future. diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py index 51ab840a7..51a2284a0 100644 --- a/quixstreams/sinks/community/mqtt.py +++ b/quixstreams/sinks/community/mqtt.py @@ -10,8 +10,7 @@ from paho import mqtt except ImportError as exc: raise ImportError( - 'Package "paho-mqtt" is missing: ' - "run pip install quixstreams[paho-mqtt] to fix it" + 'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it" ) from exc From 32c5f63450781ca75f1bb7792c637ce10852b9bc Mon Sep 17 00:00:00 2001 From: Tim Sawicki Date: Thu, 3 Jul 2025 11:51:17 -0400 Subject: [PATCH 06/10] update to latest sink patterns and overhaul functionality --- quixstreams/sinks/community/mqtt.py | 314 ++++++++++++++++++---------- 1 file changed, 198 insertions(+), 116 deletions(-) diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py index 51a2284a0..b1adb67b5 100644 --- a/quixstreams/sinks/community/mqtt.py +++ b/quixstreams/sinks/community/mqtt.py @@ -1,19 +1,37 @@ import json +import logging +import time from datetime import datetime -from typing import Any, List, Tuple +from typing import Any, Callable, Literal, Optional, Union, get_args -from quixstreams.models.types import HeaderValue -from quixstreams.sinks.base.sink import BaseSink +from quixstreams.models.types import HeadersTuples +from quixstreams.sinks import ( + BaseSink, + ClientConnectFailureCallback, + ClientConnectSuccessCallback, +) try: import paho.mqtt.client as paho - from paho import mqtt except ImportError as exc: raise ImportError( 'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it" ) from exc +logger = logging.getLogger(__name__) + +VERSION_MAP = { + "3.1": paho.MQTTv31, + "3.1.1": paho.MQTTv311, + "5": paho.MQTTv5, +} +MQTT_SUCCESS = paho.MQTT_ERR_SUCCESS +ProtocolVersion = Literal["3.1", "3.1.1", "5"] +MqttPropertiesHandler = Union[paho.Properties, Callable[[Any], paho.Properties]] +RetainHandler = Union[bool, Callable[[Any], bool]] + + class MQTTSink(BaseSink): """ A sink that publishes messages to an MQTT broker. @@ -21,117 +39,128 @@ class MQTTSink(BaseSink): def __init__( self, - mqtt_client_id: str, - mqtt_server: str, - mqtt_port: int, - mqtt_topic_root: str, - mqtt_username: str = None, - mqtt_password: str = None, - mqtt_version: str = "3.1.1", + client_id: str, + server: str, + port: int, + topic_root: str, + username: str = None, + password: str = None, + version: ProtocolVersion = "3.1.1", tls_enabled: bool = True, - qos: int = 1, + key_serializer: Callable[[Any], str] = bytes.decode, + value_serializer: Callable[[Any], str] = json.dumps, + qos: Literal[0, 1] = 1, + mqtt_flush_timeout_seconds: int = 10, + retain: Union[bool, Callable[[Any], bool]] = False, + properties: Optional[MqttPropertiesHandler] = None, + on_client_connect_success: Optional[ClientConnectSuccessCallback] = None, + on_client_connect_failure: Optional[ClientConnectFailureCallback] = None, ): """ Initialize the MQTTSink. - :param mqtt_client_id: MQTT client identifier. - :param mqtt_server: MQTT broker server address. - :param mqtt_port: MQTT broker server port. - :param mqtt_topic_root: Root topic to publish messages to. - :param mqtt_username: Username for MQTT broker authentication. Defaults to None - :param mqtt_password: Password for MQTT broker authentication. Defaults to None - :param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1 - :param tls_enabled: Whether to use TLS encryption. Defaults to True - :param qos: Quality of Service level (0, 1, or 2). Defaults to 1 + :param client_id: MQTT client identifier. + :param server: MQTT broker server address. + :param port: MQTT broker server port. + :param topic_root: Root topic to publish messages to. + :param username: Username for MQTT broker authentication. Default = None + :param password: Password for MQTT broker authentication. Default = None + :param version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1 + :param tls_enabled: Whether to use TLS encryption. Default = True + :param key_serializer: How to serialize the MQTT message key for producing. + :param value_serializer: How to serialize the MQTT message value for producing. + :param qos: Quality of Service level (0 or 1; 2 not yet supported) Default = 1. + :param mqtt_flush_timeout_seconds: how long to wait for publish acknowledgment + of MQTT messages before failing. Default = 10. + :param retain: Retain last message for new subscribers. Default = False. + Also accepts a callable that uses the current message value as input. + :param properties: An optional Properties instance for messages. Default = None. + Also accepts a callable that uses the current message value as input. """ - - super().__init__() - - self.mqtt_version = mqtt_version - self.mqtt_username = mqtt_username - self.mqtt_password = mqtt_password - self.mqtt_topic_root = mqtt_topic_root - self.tls_enabled = tls_enabled - self.qos = qos - - self.mqtt_client = paho.Client( + super().__init__( + on_client_connect_success=on_client_connect_success, + on_client_connect_failure=on_client_connect_failure, + ) + if qos == 2: + raise ValueError(f"MQTT QoS level {2} is currently not supported.") + if not (protocol := VERSION_MAP.get(version)): + raise ValueError( + f"Invalid MQTT version {version}; valid: {get_args(ProtocolVersion)}" + ) + if properties and protocol != "5": + raise ValueError( + "MQTT Properties can only be used with MQTT protocol version 5" + ) + + self._version = version + self._server = server + self._port = port + self._topic_root = topic_root + self._key_serializer = key_serializer + self._value_serializer = value_serializer + self._qos = qos + self._flush_timeout = mqtt_flush_timeout_seconds + self._pending_acks: set[int] = set() + self._retain = _get_retain_callable(retain) + self._properties = _get_properties_callable(properties) + + self._client = paho.Client( callback_api_version=paho.CallbackAPIVersion.VERSION2, - client_id=mqtt_client_id, + client_id=client_id, userdata=None, - protocol=self._mqtt_protocol_version(), + protocol=protocol, ) - if self.tls_enabled: - self.mqtt_client.tls_set( - tls_version=mqtt.client.ssl.PROTOCOL_TLS - ) # we'll be using tls now + if username: + self._client.username_pw_set(username, password) + if tls_enabled: + self._client.tls_set(tls_version=paho.ssl.PROTOCOL_TLS) + self._client.reconnect_delay_set(5, 60) + self._client.on_connect = _mqtt_on_connect_cb + self._client.on_disconnect = _mqtt_on_disconnect_cb + self._client.on_publish = self._on_publish_cb + self._publish_count = 0 - self.mqtt_client.reconnect_delay_set(5, 60) - self._configure_authentication() - self.mqtt_client.on_connect = self._mqtt_on_connect_cb - self.mqtt_client.on_disconnect = self._mqtt_on_disconnect_cb - self.mqtt_client.connect(mqtt_server, int(mqtt_port)) + def setup(self): + self._client.connect(self._server, self._port) + self._client.loop_start() - # setting callbacks for different events to see if it works, print the message etc. - def _mqtt_on_connect_cb( - self, - client: paho.Client, - userdata: any, - connect_flags: paho.ConnectFlags, - reason_code: paho.ReasonCode, - properties: paho.Properties, - ): - if reason_code == 0: - print("CONNECTED!") # required for Quix to know this has connected - else: - print(f"ERROR ({reason_code.value}). {reason_code.getName()}") - - def _mqtt_on_disconnect_cb( + def _publish_to_mqtt( self, - client: paho.Client, - userdata: any, - disconnect_flags: paho.DisconnectFlags, - reason_code: paho.ReasonCode, - properties: paho.Properties, + data: Any, + topic_suffix: Any, ): - print( - f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!" + properties = self._properties + info = self._client.publish( + f"{self._topic_root}/{self._key_serializer(topic_suffix)}", + payload=self._value_serializer(data), + qos=self._qos, + properties=properties(data) if properties else None, + retain=self._retain(data), ) - - def _mqtt_protocol_version(self): - if self.mqtt_version == "3.1": - return paho.MQTTv31 - elif self.mqtt_version == "3.1.1": - return paho.MQTTv311 - elif self.mqtt_version == "5": - return paho.MQTTv5 + if self._qos: + if info.rc != MQTT_SUCCESS: + raise MqttPublishEnqueueFailed( + f"Failed adding message to MQTT publishing queue; " + f"error code {info.rc}: {paho.error_string(info.rc)}" + ) + self._pending_acks.add(info.mid) else: - raise ValueError(f"Unsupported MQTT version: {self.mqtt_version}") + self._publish_count += 1 - def _configure_authentication(self): - if self.mqtt_username: - self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password) - - def _publish_to_mqtt( + def _on_publish_cb( self, - data: str, - key: bytes, - timestamp: datetime, - headers: List[Tuple[str, HeaderValue]], + client: paho.Client, + userdata: Any, + mid: int, + rc: paho.ReasonCode, + p: paho.Properties, ): - if isinstance(data, bytes): - data = data.decode("utf-8") # Decode bytes to string using utf-8 - - json_data = json.dumps(data) - message_key_string = key.decode( - "utf-8" - ) # Convert to string using utf-8 encoding - # publish to MQTT - self.mqtt_client.publish( - self.mqtt_topic_root + "/" + message_key_string, - payload=json_data, - qos=self.qos, - ) + """ + This is only triggered upon successful publish when self._qos > 0. + """ + self._publish_count += 1 + self._pending_acks.remove(mid) def add( self, @@ -141,29 +170,82 @@ def add( key: bytes, value: bytes, timestamp: datetime, - headers: List[Tuple[str, HeaderValue]], - **kwargs: Any, + headers: HeadersTuples, ): - self._publish_to_mqtt(value, key, timestamp, headers) + try: + self._publish_to_mqtt( + value, + key, + ) + except Exception as e: + self._cleanup() + raise e + + def on_paused(self): + pass - def _construct_topic(self, key): - if key: - key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key) - return f"{self.mqtt_topic_root}/{key_str}" - else: - return self.mqtt_topic_root + def flush(self): + if self._pending_acks: + start_time = time.monotonic() + timeout = start_time + self._flush_timeout + while self._pending_acks and start_time < timeout: + logger.debug(f"Pending acks remaining: {len(self._pending_acks)}") + time.sleep(1) + if self._pending_acks: + self._cleanup() + raise MqttPublishAckTimeout( + f"Mqtt acknowledgement timeout of {self._flush_timeout}s reached." + ) + logger.info(f"{self._publish_count} MQTT messages published.") + self._publish_count = 0 + + def _cleanup(self): + self._client.loop_stop() + self._client.disconnect() + + +class MqttPublishEnqueueFailed(Exception): + pass + + +class MqttPublishAckTimeout(Exception): + pass + + +def _mqtt_on_connect_cb( + client: paho.Client, + userdata: any, + connect_flags: paho.ConnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, +): + if reason_code != 0: + raise ConnectionError( + f"Failed to connect to MQTT broker; ERROR: ({reason_code.value}).{reason_code.getName()}" + ) - def on_paused(self, topic: str, partition: int): - # not used - pass - def flush(self, topic: str, partition: str): - # not used - pass +def _mqtt_on_disconnect_cb( + client: paho.Client, + userdata: any, + disconnect_flags: paho.DisconnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, +): + logger.info( + f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!" + ) + + +def _get_properties_callable( + properties: Optional[MqttPropertiesHandler], +) -> Optional[Callable[[Any], paho.Properties]]: + if isinstance(properties, paho.Properties): + return lambda data: properties(data) + return properties - def cleanup(self): - self.mqtt_client.loop_stop() - self.mqtt_client.disconnect() - def __del__(self): - self.cleanup() +def _get_retain_callable(retain: RetainHandler) -> Callable[[Any], bool]: + if isinstance(retain, bool): + return lambda data: retain + return retain From 97781ac29cc376d48c8ad86cdb4b25049a947489 Mon Sep 17 00:00:00 2001 From: Tim Sawicki Date: Thu, 3 Jul 2025 11:54:55 -0400 Subject: [PATCH 07/10] tiny format tweak --- quixstreams/sinks/community/mqtt.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py index b1adb67b5..88cc10afa 100644 --- a/quixstreams/sinks/community/mqtt.py +++ b/quixstreams/sinks/community/mqtt.py @@ -173,17 +173,11 @@ def add( headers: HeadersTuples, ): try: - self._publish_to_mqtt( - value, - key, - ) + self._publish_to_mqtt(value, key) except Exception as e: self._cleanup() raise e - def on_paused(self): - pass - def flush(self): if self._pending_acks: start_time = time.monotonic() @@ -199,6 +193,9 @@ def flush(self): logger.info(f"{self._publish_count} MQTT messages published.") self._publish_count = 0 + def on_paused(self): + pass + def _cleanup(self): self._client.loop_stop() self._client.disconnect() From 692d6d510bcc5a051b9a3fb60e3f960516aeb4c6 Mon Sep 17 00:00:00 2001 From: Tim Sawicki Date: Thu, 3 Jul 2025 11:58:04 -0400 Subject: [PATCH 08/10] add missing docstring --- quixstreams/sinks/community/mqtt.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py index 88cc10afa..7b90a7554 100644 --- a/quixstreams/sinks/community/mqtt.py +++ b/quixstreams/sinks/community/mqtt.py @@ -76,6 +76,12 @@ def __init__( Also accepts a callable that uses the current message value as input. :param properties: An optional Properties instance for messages. Default = None. Also accepts a callable that uses the current message value as input. + :param on_client_connect_success: An optional callback made after successful + client authentication, primarily for additional logging. + :param on_client_connect_failure: An optional callback made after failed + client authentication (which should raise an Exception). + Callback should accept the raised Exception as an argument. + Callback must resolve (or propagate/re-raise) the Exception. """ super().__init__( on_client_connect_success=on_client_connect_success, From 39753eec9af183ac4a2540c713eeb9947c143e2c Mon Sep 17 00:00:00 2001 From: Tim Sawicki Date: Thu, 3 Jul 2025 12:33:24 -0400 Subject: [PATCH 09/10] fix current tests --- .../test_community/test_mqtt_sink.py | 61 +++++++++++++------ 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py index 05b6b332b..81a6632d7 100644 --- a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py +++ b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional from unittest.mock import patch import pytest @@ -9,26 +10,26 @@ @pytest.fixture() def mqtt_sink_factory(): def factory( - mqtt_client_id: str = "test_client", - mqtt_server: str = "localhost", - mqtt_port: int = 1883, - mqtt_topic_root: str = "test/topic", - mqtt_username: str = None, - mqtt_password: str = None, - mqtt_version: str = "3.1.1", + client_id: str = "test_client", + server: str = "localhost", + port: int = 1883, + username: Optional[str] = None, + password: Optional[str] = None, + topic_root: str = "test/topic", + version: str = "3.1.1", tls_enabled: bool = True, qos: int = 1, ) -> MQTTSink: with patch("paho.mqtt.client.Client") as MockClient: mock_mqtt_client = MockClient.return_value sink = MQTTSink( - mqtt_client_id=mqtt_client_id, - mqtt_server=mqtt_server, - mqtt_port=mqtt_port, - mqtt_topic_root=mqtt_topic_root, - mqtt_username=mqtt_username, - mqtt_password=mqtt_password, - mqtt_version=mqtt_version, + client_id=client_id, + server=server, + port=port, + topic_root=topic_root, + username=username, + password=password, + version=version, tls_enabled=tls_enabled, qos=qos, ) @@ -41,6 +42,7 @@ def factory( class TestMQTTSink: def test_mqtt_connect(self, mqtt_sink_factory): sink, mock_mqtt_client = mqtt_sink_factory() + sink.setup() mock_mqtt_client.connect.assert_called_once_with("localhost", 1883) def test_mqtt_tls_enabled(self, mqtt_sink_factory): @@ -58,28 +60,47 @@ def test_mqtt_publish(self, mqtt_sink_factory): timestamp = datetime.now() headers = [] + class MockInfo: + def __init__(self): + self.rc = 0 + self.mid = 123 + + mock_mqtt_client.publish.return_value = MockInfo() sink.add( topic="test-topic", partition=0, offset=1, key=key, - value=data.encode("utf-8"), + value=data, timestamp=timestamp, headers=headers, ) mock_mqtt_client.publish.assert_called_once_with( - "test/topic/test_key", payload='"test_data"', qos=1 + "test/topic/test_key", + payload='"test_data"', + qos=1, + retain=False, + properties=None, ) def test_mqtt_authentication(self, mqtt_sink_factory): - sink, mock_mqtt_client = mqtt_sink_factory( - mqtt_username="user", mqtt_password="pass" - ) + sink, mock_mqtt_client = mqtt_sink_factory(username="user", password="pass") mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass") def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory): sink, mock_mqtt_client = mqtt_sink_factory() - sink.cleanup() # Explicitly call cleanup + mock_mqtt_client.publish.side_effect = ConnectionError("publish error") + with pytest.raises(ConnectionError): + sink.add( + topic="test-topic", + partition=0, + offset=1, + key=b"key", + value="data", + timestamp=12345, + headers=(), + ) + mock_mqtt_client.loop_stop.assert_called_once() mock_mqtt_client.disconnect.assert_called_once() From 04610871e6bbad54d412d24517b8421ff721605c Mon Sep 17 00:00:00 2001 From: Tim Sawicki Date: Thu, 3 Jul 2025 12:38:37 -0400 Subject: [PATCH 10/10] add missing test requirement --- tests/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/requirements.txt b/tests/requirements.txt index b764aca0e..923594cb8 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -10,3 +10,4 @@ redis[hiredis]>=5.2.0,<6 pandas>=1.0.0,<3.0 psycopg2-binary>=2.9,<3 types-psycopg2>=2.9,<3 +paho-mqtt>=2.1.0,<3