Skip to content

Commit 279387e

Browse files
committed
add test case for long list of topics
this uncovered a bug in SUBACK processing
1 parent dceca0c commit 279387e

File tree

2 files changed

+91
-19
lines changed

2 files changed

+91
-19
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

+22-17
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def _connect(
625625
var_header[7] |= 0x4 | (self._lw_qos & 0x1) << 3 | (self._lw_qos & 0x2) << 3
626626
var_header[7] |= self._lw_retain << 5
627627

628-
self.encode_remaining_length(fixed_header, remaining_length)
628+
self._encode_remaining_length(fixed_header, remaining_length)
629629
self.logger.debug("Sending CONNECT to broker...")
630630
self.logger.debug(f"Fixed Header: {fixed_header}")
631631
self.logger.debug(f"Variable Header: {var_header}")
@@ -663,10 +663,13 @@ def _connect(
663663
)
664664

665665
# pylint: disable=no-self-use
666-
def encode_remaining_length(self, fixed_header: bytearray, remaining_length: int):
667-
"""
668-
Encode Remaining Length [2.2.3]
669-
"""
666+
def _encode_remaining_length(
667+
self, fixed_header: bytearray, remaining_length: int
668+
) -> None:
669+
"""Encode Remaining Length [2.2.3]"""
670+
if remaining_length > 268_435_455:
671+
raise MMQTTException("invalid remaining length")
672+
670673
# Remaining length calculation
671674
if remaining_length > 0x7F:
672675
while remaining_length > 0:
@@ -765,7 +768,7 @@ def publish(
765768
pub_hdr_var.append(self._pid >> 8)
766769
pub_hdr_var.append(self._pid & 0xFF)
767770

768-
self.encode_remaining_length(pub_hdr_fixed, remaining_length)
771+
self._encode_remaining_length(pub_hdr_fixed, remaining_length)
769772

770773
self.logger.debug(
771774
"Sending PUBLISH\nTopic: %s\nMsg: %s\
@@ -836,7 +839,7 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N
836839
fixed_header = bytearray([MQTT_SUB])
837840
packet_length = 2 + (2 * len(topics)) + (1 * len(topics))
838841
packet_length += sum(len(topic.encode("utf-8")) for topic, qos in topics)
839-
self.encode_remaining_length(fixed_header, remaining_length=packet_length)
842+
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
840843
self.logger.debug(f"Fixed Header: {fixed_header}")
841844
self._sock.send(fixed_header)
842845
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
@@ -864,13 +867,13 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N
864867
)
865868
else:
866869
if op == 0x90:
867-
rc = self._sock_exact_recv(3)
868-
# Check packet identifier.
869-
assert rc[1] == var_header[0] and rc[2] == var_header[1]
870-
remaining_len = rc[0] - 2
870+
remaining_len = self._decode_remaining_length()
871871
assert remaining_len > 0
872-
rc = self._sock_exact_recv(remaining_len)
873-
for i in range(0, remaining_len):
872+
rc = self._sock_exact_recv(2)
873+
# Check packet identifier.
874+
assert rc[0] == var_header[0] and rc[1] == var_header[1]
875+
rc = self._sock_exact_recv(remaining_len - 2)
876+
for i in range(0, remaining_len - 2):
874877
if rc[i] not in [0, 1, 2]:
875878
raise MMQTTException(
876879
f"SUBACK Failure for topic {topics[i][0]}: {hex(rc[i])}"
@@ -915,7 +918,7 @@ def unsubscribe(self, topic: Optional[Union[str, list]]) -> None:
915918
fixed_header = bytearray([MQTT_UNSUB])
916919
packet_length = 2 + (2 * len(topics))
917920
packet_length += sum(len(topic.encode("utf-8")) for topic in topics)
918-
self.encode_remaining_length(fixed_header, remaining_length=packet_length)
921+
self._encode_remaining_length(fixed_header, remaining_length=packet_length)
919922
self.logger.debug(f"Fixed Header: {fixed_header}")
920923
self._sock.send(fixed_header)
921924
self._pid = self._pid + 1 if self._pid < 0xFFFF else 1
@@ -1090,7 +1093,7 @@ def _wait_for_msg(self) -> Optional[int]:
10901093
return pkt_type
10911094

10921095
# Handle only the PUBLISH packet type from now on.
1093-
sz = self._recv_len()
1096+
sz = self._decode_remaining_length()
10941097
# topic length MSB & LSB
10951098
topic_len_buf = self._sock_exact_recv(2)
10961099
topic_len = int((topic_len_buf[0] << 8) | topic_len_buf[1])
@@ -1123,11 +1126,13 @@ def _wait_for_msg(self) -> Optional[int]:
11231126

11241127
return pkt_type
11251128

1126-
def _recv_len(self) -> int:
1127-
"""Unpack MQTT message length."""
1129+
def _decode_remaining_length(self) -> int:
1130+
"""Decode Remaining Length [2.2.3]"""
11281131
n = 0
11291132
sh = 0
11301133
while True:
1134+
if sh > 28:
1135+
raise MMQTTException("invalid remaining length encoding")
11311136
b = self._sock_exact_recv(1)[0]
11321137
n |= (b & 0x7F) << sh
11331138
if not b & 0x80:

tests/test_subscribe.py

+69-2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,29 @@ def handle_subscribe(client, user_data, topic, qos):
4949
]
5050
),
5151
),
52+
# same as before but with tuple
53+
(
54+
("foo/bar", 0),
55+
bytearray([0x90, 0x03, 0x00, 0x01, 0x00]), # SUBACK
56+
bytearray(
57+
[
58+
0x82, # fixed header
59+
0x0C, # remaining length
60+
0x00,
61+
0x01, # message ID
62+
0x00,
63+
0x07, # topic length
64+
0x66, # topic
65+
0x6F,
66+
0x6F,
67+
0x2F,
68+
0x62,
69+
0x61,
70+
0x72,
71+
0x00, # QoS
72+
]
73+
),
74+
),
5275
# remaining length is encoded as 2 bytes due to long topic name.
5376
(
5477
"f" + "o" * 257,
@@ -113,13 +136,52 @@ def handle_subscribe(client, user_data, topic, qos):
113136
]
114137
),
115138
),
139+
# use list of topics for more coverage. If the range was (1, 10000), that would be
140+
# long enough to use 3 bytes for remaining length, however that would make the test
141+
# run for many minutes even on modern systems, so 1001 is used instead.
142+
# This results in 2 bytes for the remaining length.
143+
(
144+
[(f"foo/bar{x:04}", 0) for x in range(1, 1001)],
145+
bytearray(
146+
[
147+
0x90,
148+
0xEA, # remaining length
149+
0x07,
150+
0x00, # message ID
151+
0x01,
152+
]
153+
+ [0x00] * 1000 # success for all topics
154+
),
155+
bytearray(
156+
[
157+
0x82, # fixed header
158+
0xB2, # remaining length
159+
0x6D,
160+
0x00, # message ID
161+
0x01,
162+
]
163+
+ sum(
164+
[
165+
[0x00, 0x0B] + list(f"foo/bar{x:04}".encode("ascii")) + [0x00]
166+
for x in range(1, 1001)
167+
],
168+
[],
169+
)
170+
),
171+
),
116172
]
117173

118174

119175
@pytest.mark.parametrize(
120176
"topic,to_send,exp_recv",
121177
testdata,
122-
ids=["short_topic", "long_topic", "publish_first"],
178+
ids=[
179+
"short_topic",
180+
"short_topic_tuple",
181+
"long_topic",
182+
"publish_first",
183+
"topic_list_long",
184+
],
123185
)
124186
def test_subscribe(topic, to_send, exp_recv) -> None:
125187
"""
@@ -157,5 +219,10 @@ def test_subscribe(topic, to_send, exp_recv) -> None:
157219
logger.info(f"subscribing to {topic}")
158220
mqtt_client.subscribe(topic)
159221

160-
assert topic in subscribed_topics
222+
if isinstance(topic, str):
223+
assert topic in subscribed_topics
224+
elif isinstance(topic, list):
225+
for topic_name, _ in topic:
226+
assert topic_name in subscribed_topics
161227
assert mocket.sent == exp_recv
228+
assert len(mocket._to_send) == 0

0 commit comments

Comments
 (0)