@@ -625,7 +625,7 @@ def _connect(
625
625
var_header [7 ] |= 0x4 | (self ._lw_qos & 0x1 ) << 3 | (self ._lw_qos & 0x2 ) << 3
626
626
var_header [7 ] |= self ._lw_retain << 5
627
627
628
- self .encode_remaining_length (fixed_header , remaining_length )
628
+ self ._encode_remaining_length (fixed_header , remaining_length )
629
629
self .logger .debug ("Sending CONNECT to broker..." )
630
630
self .logger .debug (f"Fixed Header: { fixed_header } " )
631
631
self .logger .debug (f"Variable Header: { var_header } " )
@@ -663,10 +663,13 @@ def _connect(
663
663
)
664
664
665
665
# pylint: disable=no-self-use
666
- def encode_remaining_length (self , fixed_header : bytearray , remaining_length : int ):
667
- """
668
- Encode Remaining Length [2.2.3]
669
- """
666
+ def _encode_remaining_length (
667
+ self , fixed_header : bytearray , remaining_length : int
668
+ ) -> None :
669
+ """Encode Remaining Length [2.2.3]"""
670
+ if remaining_length > 268_435_455 :
671
+ raise MMQTTException ("invalid remaining length" )
672
+
670
673
# Remaining length calculation
671
674
if remaining_length > 0x7F :
672
675
while remaining_length > 0 :
@@ -765,7 +768,7 @@ def publish(
765
768
pub_hdr_var .append (self ._pid >> 8 )
766
769
pub_hdr_var .append (self ._pid & 0xFF )
767
770
768
- self .encode_remaining_length (pub_hdr_fixed , remaining_length )
771
+ self ._encode_remaining_length (pub_hdr_fixed , remaining_length )
769
772
770
773
self .logger .debug (
771
774
"Sending PUBLISH\n Topic: %s\n Msg: %s\
@@ -836,7 +839,7 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N
836
839
fixed_header = bytearray ([MQTT_SUB ])
837
840
packet_length = 2 + (2 * len (topics )) + (1 * len (topics ))
838
841
packet_length += sum (len (topic .encode ("utf-8" )) for topic , qos in topics )
839
- self .encode_remaining_length (fixed_header , remaining_length = packet_length )
842
+ self ._encode_remaining_length (fixed_header , remaining_length = packet_length )
840
843
self .logger .debug (f"Fixed Header: { fixed_header } " )
841
844
self ._sock .send (fixed_header )
842
845
self ._pid = self ._pid + 1 if self ._pid < 0xFFFF else 1
@@ -864,13 +867,13 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N
864
867
)
865
868
else :
866
869
if op == 0x90 :
867
- rc = self ._sock_exact_recv (3 )
868
- # Check packet identifier.
869
- assert rc [1 ] == var_header [0 ] and rc [2 ] == var_header [1 ]
870
- remaining_len = rc [0 ] - 2
870
+ remaining_len = self ._decode_remaining_length ()
871
871
assert remaining_len > 0
872
- rc = self ._sock_exact_recv (remaining_len )
873
- for i in range (0 , remaining_len ):
872
+ rc = self ._sock_exact_recv (2 )
873
+ # Check packet identifier.
874
+ assert rc [0 ] == var_header [0 ] and rc [1 ] == var_header [1 ]
875
+ rc = self ._sock_exact_recv (remaining_len - 2 )
876
+ for i in range (0 , remaining_len - 2 ):
874
877
if rc [i ] not in [0 , 1 , 2 ]:
875
878
raise MMQTTException (
876
879
f"SUBACK Failure for topic { topics [i ][0 ]} : { hex (rc [i ])} "
@@ -915,7 +918,7 @@ def unsubscribe(self, topic: Optional[Union[str, list]]) -> None:
915
918
fixed_header = bytearray ([MQTT_UNSUB ])
916
919
packet_length = 2 + (2 * len (topics ))
917
920
packet_length += sum (len (topic .encode ("utf-8" )) for topic in topics )
918
- self .encode_remaining_length (fixed_header , remaining_length = packet_length )
921
+ self ._encode_remaining_length (fixed_header , remaining_length = packet_length )
919
922
self .logger .debug (f"Fixed Header: { fixed_header } " )
920
923
self ._sock .send (fixed_header )
921
924
self ._pid = self ._pid + 1 if self ._pid < 0xFFFF else 1
@@ -1090,7 +1093,7 @@ def _wait_for_msg(self) -> Optional[int]:
1090
1093
return pkt_type
1091
1094
1092
1095
# Handle only the PUBLISH packet type from now on.
1093
- sz = self ._recv_len ()
1096
+ sz = self ._decode_remaining_length ()
1094
1097
# topic length MSB & LSB
1095
1098
topic_len_buf = self ._sock_exact_recv (2 )
1096
1099
topic_len = int ((topic_len_buf [0 ] << 8 ) | topic_len_buf [1 ])
@@ -1123,11 +1126,13 @@ def _wait_for_msg(self) -> Optional[int]:
1123
1126
1124
1127
return pkt_type
1125
1128
1126
- def _recv_len (self ) -> int :
1127
- """Unpack MQTT message length. """
1129
+ def _decode_remaining_length (self ) -> int :
1130
+ """Decode Remaining Length [2.2.3] """
1128
1131
n = 0
1129
1132
sh = 0
1130
1133
while True :
1134
+ if sh > 28 :
1135
+ raise MMQTTException ("invalid remaining length encoding" )
1131
1136
b = self ._sock_exact_recv (1 )[0 ]
1132
1137
n |= (b & 0x7F ) << sh
1133
1138
if not b & 0x80 :
0 commit comments