Skip to content

refactor will_set function to match the publish function and allow the msg/payload to be encoded bytes, not just str, int or float. #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,40 +267,76 @@ def mqtt_msg(self, msg_size: int) -> None:
if msg_size < MQTT_MSG_MAX_SZ:
self._msg_size_lim = msg_size

# pylint: disable=too-many-branches, too-many-statements
def will_set(
self,
topic: Optional[str] = None,
payload: Optional[Union[int, float, str]] = None,
qos: int = 0,
topic: str,
msg: Union[str, int, float, bytes],
retain: bool = False,
qos: int = 0,
) -> None:
"""Sets the last will and testament properties. MUST be called before `connect()`.

:param str topic: MQTT Broker topic.
:param int|float|str payload: Last will disconnection payload.
payloads of type int & float are converted to a string.
:param str|int|float|bytes msg: Last will disconnection msg.
msgs of type int & float are converted to a string.
msgs of type byetes are left unchanged, as it is in the publish function.
:param int qos: Quality of Service level, defaults to
zero. Conventional options are ``0`` (send at most once), ``1``
(send at least once), or ``2`` (send exactly once).

.. note:: Only options ``1`` or ``0`` are QoS levels supported by this library.
:param bool retain: Specifies if the payload is to be retained when
:param bool retain: Specifies if the msg is to be retained when
it is published.
"""
self.logger.debug("Setting last will properties")
self._valid_qos(qos)
if self._is_connected:
raise MMQTTException("Last Will should only be called before connect().")
if payload is None:
payload = ""
if isinstance(payload, (int, float, str)):
payload = str(payload).encode()

# check topic/msg/qos kwargs
self._valid_topic(topic)
if "+" in topic or "#" in topic:
raise MMQTTException("Publish topic can not contain wildcards.")

if msg is None:
raise MMQTTException("Message can not be None.")
if isinstance(msg, (int, float)):
msg = str(msg).encode("ascii")
elif isinstance(msg, str):
msg = str(msg).encode("utf-8")
elif isinstance(msg, bytes):
pass
else:
raise MMQTTException("Invalid message data type.")
if len(msg) > MQTT_MSG_MAX_SZ:
raise MMQTTException(f"Message size larger than {MQTT_MSG_MAX_SZ} bytes.")

self._valid_qos(qos)
assert (
0 <= qos <= 1
), "Quality of Service Level 2 is unsupported by this library."

# fixed header. [3.3.1.2], [3.3.1.3]
pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1])

# variable header = 2-byte Topic length (big endian)
pub_hdr_var = bytearray(struct.pack(">H", len(topic.encode("utf-8"))))
pub_hdr_var.extend(topic.encode("utf-8")) # Topic name

remaining_length = 2 + len(msg) + len(topic.encode("utf-8"))
if qos > 0:
# packet identifier where QoS level is 1 or 2. [3.3.2.2]
remaining_length += 2
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
pub_hdr_var.append(self._pid >> 8)
pub_hdr_var.append(self._pid & 0xFF)

self._encode_remaining_length(pub_hdr_fixed, remaining_length)

self._lw_qos = qos
self._lw_topic = topic
self._lw_msg = payload
self._lw_msg = msg
self._lw_retain = retain
self.logger.debug("Last will properties successfully set")

def add_topic_callback(self, mqtt_topic: str, callback_method) -> None:
"""Registers a callback_method for a specific MQTT topic.
Expand Down
6 changes: 2 additions & 4 deletions adafruit_minimqtt/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,9 @@ def rec(node: MQTTMatcher.Node, i: int = 0):
else:
part = lst[i]
if part in node.children:
for content in rec(node.children[part], i + 1):
yield content
yield from rec(node.children[part], i + 1)
if "+" in node.children and (normal or i > 0):
for content in rec(node.children["+"], i + 1):
yield content
yield from rec(node.children["+"], i + 1)
if "#" in node.children and (normal or i > 0):
content = node.children["#"].content
if content is not None:
Expand Down