@@ -625,7 +625,7 @@ def _connect(
625625 var_header [7 ] |= 0x4 | (self ._lw_qos & 0x1 ) << 3 | (self ._lw_qos & 0x2 ) << 3
626626 var_header [7 ] |= self ._lw_retain << 5
627627
628- self .encode_remaining_length (fixed_header , remaining_length )
628+ self ._encode_remaining_length (fixed_header , remaining_length )
629629 self .logger .debug ("Sending CONNECT to broker..." )
630630 self .logger .debug (f"Fixed Header: { fixed_header } " )
631631 self .logger .debug (f"Variable Header: { var_header } " )
@@ -663,10 +663,13 @@ def _connect(
663663 )
664664
665665 # pylint: disable=no-self-use
666- def encode_remaining_length (self , fixed_header : bytearray , remaining_length : int ):
667- """
668- Encode Remaining Length [2.2.3]
669- """
666+ def _encode_remaining_length (
667+ self , fixed_header : bytearray , remaining_length : int
668+ ) -> None :
669+ """Encode Remaining Length [2.2.3]"""
670+ if remaining_length > 268_435_455 :
671+ raise MMQTTException ("invalid remaining length" )
672+
670673 # Remaining length calculation
671674 if remaining_length > 0x7F :
672675 while remaining_length > 0 :
@@ -765,7 +768,7 @@ def publish(
765768 pub_hdr_var .append (self ._pid >> 8 )
766769 pub_hdr_var .append (self ._pid & 0xFF )
767770
768- self .encode_remaining_length (pub_hdr_fixed , remaining_length )
771+ self ._encode_remaining_length (pub_hdr_fixed , remaining_length )
769772
770773 self .logger .debug (
771774 "Sending PUBLISH\n Topic: %s\n Msg: %s\
@@ -836,7 +839,7 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N
836839 fixed_header = bytearray ([MQTT_SUB ])
837840 packet_length = 2 + (2 * len (topics )) + (1 * len (topics ))
838841 packet_length += sum (len (topic .encode ("utf-8" )) for topic , qos in topics )
839- self .encode_remaining_length (fixed_header , remaining_length = packet_length )
842+ self ._encode_remaining_length (fixed_header , remaining_length = packet_length )
840843 self .logger .debug (f"Fixed Header: { fixed_header } " )
841844 self ._sock .send (fixed_header )
842845 self ._pid = self ._pid + 1 if self ._pid < 0xFFFF else 1
@@ -864,13 +867,13 @@ def subscribe(self, topic: Optional[Union[tuple, str, list]], qos: int = 0) -> N
864867 )
865868 else :
866869 if op == 0x90 :
867- rc = self ._sock_exact_recv (3 )
868- # Check packet identifier.
869- assert rc [1 ] == var_header [0 ] and rc [2 ] == var_header [1 ]
870- remaining_len = rc [0 ] - 2
870+ remaining_len = self ._decode_remaining_length ()
871871 assert remaining_len > 0
872- rc = self ._sock_exact_recv (remaining_len )
873- for i in range (0 , remaining_len ):
872+ rc = self ._sock_exact_recv (2 )
873+ # Check packet identifier.
874+ assert rc [0 ] == var_header [0 ] and rc [1 ] == var_header [1 ]
875+ rc = self ._sock_exact_recv (remaining_len - 2 )
876+ for i in range (0 , remaining_len - 2 ):
874877 if rc [i ] not in [0 , 1 , 2 ]:
875878 raise MMQTTException (
876879 f"SUBACK Failure for topic { topics [i ][0 ]} : { hex (rc [i ])} "
@@ -915,7 +918,7 @@ def unsubscribe(self, topic: Optional[Union[str, list]]) -> None:
915918 fixed_header = bytearray ([MQTT_UNSUB ])
916919 packet_length = 2 + (2 * len (topics ))
917920 packet_length += sum (len (topic .encode ("utf-8" )) for topic in topics )
918- self .encode_remaining_length (fixed_header , remaining_length = packet_length )
921+ self ._encode_remaining_length (fixed_header , remaining_length = packet_length )
919922 self .logger .debug (f"Fixed Header: { fixed_header } " )
920923 self ._sock .send (fixed_header )
921924 self ._pid = self ._pid + 1 if self ._pid < 0xFFFF else 1
@@ -1090,7 +1093,7 @@ def _wait_for_msg(self) -> Optional[int]:
10901093 return pkt_type
10911094
10921095 # Handle only the PUBLISH packet type from now on.
1093- sz = self ._recv_len ()
1096+ sz = self ._decode_remaining_length ()
10941097 # topic length MSB & LSB
10951098 topic_len_buf = self ._sock_exact_recv (2 )
10961099 topic_len = int ((topic_len_buf [0 ] << 8 ) | topic_len_buf [1 ])
@@ -1123,11 +1126,13 @@ def _wait_for_msg(self) -> Optional[int]:
11231126
11241127 return pkt_type
11251128
1126- def _recv_len (self ) -> int :
1127- """Unpack MQTT message length. """
1129+ def _decode_remaining_length (self ) -> int :
1130+ """Decode Remaining Length [2.2.3] """
11281131 n = 0
11291132 sh = 0
11301133 while True :
1134+ if sh > 28 :
1135+ raise MMQTTException ("invalid remaining length encoding" )
11311136 b = self ._sock_exact_recv (1 )[0 ]
11321137 n |= (b & 0x7F ) << sh
11331138 if not b & 0x80 :
0 commit comments