Skip to content

Commit 669d38a

Browse files
committed
run linters
1 parent 52e71f2 commit 669d38a

File tree

2 files changed

+93
-51
lines changed

2 files changed

+93
-51
lines changed

quixstreams/sinks/community/mqtt.py

Lines changed: 80 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from quixstreams.sinks.base.sink import BaseSink
2-
from quixstreams.sinks.base.exceptions import SinkBackpressureError
3-
from typing import List, Tuple, Any
4-
from quixstreams.models.types import HeaderValue
5-
from datetime import datetime
61
import json
2+
from datetime import datetime
3+
from typing import Any, List, Tuple
4+
5+
from quixstreams.models.types import HeaderValue
6+
from quixstreams.sinks.base.sink import BaseSink
77

88
try:
99
import paho.mqtt.client as paho
@@ -14,21 +14,24 @@
1414
"run pip install quixstreams[paho-mqtt] to fix it"
1515
) from exc
1616

17+
1718
class MQTTSink(BaseSink):
1819
"""
1920
A sink that publishes messages to an MQTT broker.
2021
"""
2122

22-
def __init__(self,
23-
mqtt_client_id: str,
24-
mqtt_server: str,
25-
mqtt_port: int,
26-
mqtt_topic_root: str,
27-
mqtt_username: str = None,
28-
mqtt_password: str = None,
29-
mqtt_version: str = "3.1.1",
30-
tls_enabled: bool = True,
31-
qos: int = 1):
23+
def __init__(
24+
self,
25+
mqtt_client_id: str,
26+
mqtt_server: str,
27+
mqtt_port: int,
28+
mqtt_topic_root: str,
29+
mqtt_username: str = None,
30+
mqtt_password: str = None,
31+
mqtt_version: str = "3.1.1",
32+
tls_enabled: bool = True,
33+
qos: int = 1,
34+
):
3235
"""
3336
Initialize the MQTTSink.
3437
@@ -42,21 +45,27 @@ def __init__(self,
4245
:param tls_enabled: Whether to use TLS encryption. Defaults to True
4346
:param qos: Quality of Service level (0, 1, or 2). Defaults to 1
4447
"""
45-
48+
4649
super().__init__()
47-
50+
4851
self.mqtt_version = mqtt_version
4952
self.mqtt_username = mqtt_username
5053
self.mqtt_password = mqtt_password
5154
self.mqtt_topic_root = mqtt_topic_root
5255
self.tls_enabled = tls_enabled
5356
self.qos = qos
5457

55-
self.mqtt_client = paho.Client(callback_api_version=paho.CallbackAPIVersion.VERSION2,
56-
client_id = mqtt_client_id, userdata = None, protocol = self._mqtt_protocol_version())
58+
self.mqtt_client = paho.Client(
59+
callback_api_version=paho.CallbackAPIVersion.VERSION2,
60+
client_id=mqtt_client_id,
61+
userdata=None,
62+
protocol=self._mqtt_protocol_version(),
63+
)
5764

5865
if self.tls_enabled:
59-
self.mqtt_client.tls_set(tls_version = mqtt.client.ssl.PROTOCOL_TLS) # we'll be using tls now
66+
self.mqtt_client.tls_set(
67+
tls_version=mqtt.client.ssl.PROTOCOL_TLS
68+
) # we'll be using tls now
6069

6170
self.mqtt_client.reconnect_delay_set(5, 60)
6271
self._configure_authentication()
@@ -65,17 +74,31 @@ def __init__(self,
6574
self.mqtt_client.connect(mqtt_server, int(mqtt_port))
6675

6776
# setting callbacks for different events to see if it works, print the message etc.
68-
def _mqtt_on_connect_cb(self, client: paho.Client, userdata: any, connect_flags: paho.ConnectFlags,
69-
reason_code: paho.ReasonCode, properties: paho.Properties):
77+
def _mqtt_on_connect_cb(
78+
self,
79+
client: paho.Client,
80+
userdata: any,
81+
connect_flags: paho.ConnectFlags,
82+
reason_code: paho.ReasonCode,
83+
properties: paho.Properties,
84+
):
7085
if reason_code == 0:
71-
print("CONNECTED!") # required for Quix to know this has connected
86+
print("CONNECTED!") # required for Quix to know this has connected
7287
else:
7388
print(f"ERROR ({reason_code.value}). {reason_code.getName()}")
7489

75-
def _mqtt_on_disconnect_cb(self, client: paho.Client, userdata: any, disconnect_flags: paho.DisconnectFlags,
76-
reason_code: paho.ReasonCode, properties: paho.Properties):
77-
print(f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!")
78-
90+
def _mqtt_on_disconnect_cb(
91+
self,
92+
client: paho.Client,
93+
userdata: any,
94+
disconnect_flags: paho.DisconnectFlags,
95+
reason_code: paho.ReasonCode,
96+
properties: paho.Properties,
97+
):
98+
print(
99+
f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!"
100+
)
101+
79102
def _mqtt_protocol_version(self):
80103
if self.mqtt_version == "3.1":
81104
return paho.MQTTv31
@@ -90,38 +113,51 @@ def _configure_authentication(self):
90113
if self.mqtt_username:
91114
self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password)
92115

93-
def _publish_to_mqtt(self, data: str, key: bytes, timestamp: datetime, headers: List[Tuple[str, HeaderValue]]):
116+
def _publish_to_mqtt(
117+
self,
118+
data: str,
119+
key: bytes,
120+
timestamp: datetime,
121+
headers: List[Tuple[str, HeaderValue]],
122+
):
94123
if isinstance(data, bytes):
95-
data = data.decode('utf-8') # Decode bytes to string using utf-8
124+
data = data.decode("utf-8") # Decode bytes to string using utf-8
96125

97126
json_data = json.dumps(data)
98-
message_key_string = key.decode('utf-8') # Convert to string using utf-8 encoding
127+
message_key_string = key.decode(
128+
"utf-8"
129+
) # Convert to string using utf-8 encoding
99130
# publish to MQTT
100-
self.mqtt_client.publish(self.mqtt_topic_root + "/" + message_key_string, payload = json_data, qos = self.qos)
101-
102-
103-
def add(self,
104-
topic: str,
105-
partition: int,
106-
offset: int,
107-
key: bytes,
108-
value: bytes,
109-
timestamp: datetime,
110-
headers: List[Tuple[str, HeaderValue]],
111-
**kwargs: Any):
131+
self.mqtt_client.publish(
132+
self.mqtt_topic_root + "/" + message_key_string,
133+
payload=json_data,
134+
qos=self.qos,
135+
)
136+
137+
def add(
138+
self,
139+
topic: str,
140+
partition: int,
141+
offset: int,
142+
key: bytes,
143+
value: bytes,
144+
timestamp: datetime,
145+
headers: List[Tuple[str, HeaderValue]],
146+
**kwargs: Any,
147+
):
112148
self._publish_to_mqtt(value, key, timestamp, headers)
113149

114150
def _construct_topic(self, key):
115151
if key:
116-
key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key)
152+
key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key)
117153
return f"{self.mqtt_topic_root}/{key_str}"
118154
else:
119155
return self.mqtt_topic_root
120156

121157
def on_paused(self, topic: str, partition: int):
122158
# not used
123159
pass
124-
160+
125161
def flush(self, topic: str, partition: str):
126162
# not used
127163
pass

tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from unittest.mock import MagicMock, patch
2-
import pytest
31
from datetime import datetime
2+
from unittest.mock import patch
3+
4+
import pytest
5+
46
from quixstreams.sinks.community.mqtt import MQTTSink
57

8+
69
@pytest.fixture()
710
def mqtt_sink_factory():
811
def factory(
@@ -16,7 +19,7 @@ def factory(
1619
tls_enabled: bool = True,
1720
qos: int = 1,
1821
) -> MQTTSink:
19-
with patch('paho.mqtt.client.Client') as MockClient:
22+
with patch("paho.mqtt.client.Client") as MockClient:
2023
mock_mqtt_client = MockClient.return_value
2124
sink = MQTTSink(
2225
mqtt_client_id=mqtt_client_id,
@@ -27,13 +30,14 @@ def factory(
2730
mqtt_password=mqtt_password,
2831
mqtt_version=mqtt_version,
2932
tls_enabled=tls_enabled,
30-
qos=qos
33+
qos=qos,
3134
)
3235
sink.mqtt_client = mock_mqtt_client
3336
return sink, mock_mqtt_client
3437

3538
return factory
3639

40+
3741
class TestMQTTSink:
3842
def test_mqtt_connect(self, mqtt_sink_factory):
3943
sink, mock_mqtt_client = mqtt_sink_factory()
@@ -59,17 +63,19 @@ def test_mqtt_publish(self, mqtt_sink_factory):
5963
partition=0,
6064
offset=1,
6165
key=key,
62-
value=data.encode('utf-8'),
66+
value=data.encode("utf-8"),
6367
timestamp=timestamp,
64-
headers=headers
68+
headers=headers,
6569
)
6670

6771
mock_mqtt_client.publish.assert_called_once_with(
6872
"test/topic/test_key", payload='"test_data"', qos=1
6973
)
7074

7175
def test_mqtt_authentication(self, mqtt_sink_factory):
72-
sink, mock_mqtt_client = mqtt_sink_factory(mqtt_username="user", mqtt_password="pass")
76+
sink, mock_mqtt_client = mqtt_sink_factory(
77+
mqtt_username="user", mqtt_password="pass"
78+
)
7379
mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass")
7480

7581
def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory):

0 commit comments

Comments
 (0)