Skip to content

Commit bbd5a4a

Browse files
committed
Handle errors when reading packets
1 parent 379fd9f commit bbd5a4a

File tree

13 files changed

+195
-104
lines changed

13 files changed

+195
-104
lines changed

packets/connack.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@ func (ca *ConnackPacket) Write(w io.Writer) error {
3838
//Unpack decodes the details of a ControlPacket after the fixed
3939
//header has been read
4040
func (ca *ConnackPacket) Unpack(b io.Reader) error {
41-
ca.SessionPresent = 1&decodeByte(b) > 0
42-
ca.ReturnCode = decodeByte(b)
41+
flags, err := decodeByte(b)
42+
if err != nil {
43+
return err
44+
}
45+
ca.SessionPresent = 1&flags > 0
46+
ca.ReturnCode, err = decodeByte(b)
4347

44-
return nil
48+
return err
4549
}
4650

4751
//Details returns a Details struct containing the Qos and

packets/connect.go

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,55 @@ func (c *ConnectPacket) Write(w io.Writer) error {
6565
//Unpack decodes the details of a ControlPacket after the fixed
6666
//header has been read
6767
func (c *ConnectPacket) Unpack(b io.Reader) error {
68-
c.ProtocolName = decodeString(b)
69-
c.ProtocolVersion = decodeByte(b)
70-
options := decodeByte(b)
68+
var err error
69+
c.ProtocolName, err = decodeString(b)
70+
if err != nil {
71+
return err
72+
}
73+
c.ProtocolVersion, err = decodeByte(b)
74+
if err != nil {
75+
return err
76+
}
77+
options, err := decodeByte(b)
78+
if err != nil {
79+
return err
80+
}
7181
c.ReservedBit = 1 & options
7282
c.CleanSession = 1&(options>>1) > 0
7383
c.WillFlag = 1&(options>>2) > 0
7484
c.WillQos = 3 & (options >> 3)
7585
c.WillRetain = 1&(options>>5) > 0
7686
c.PasswordFlag = 1&(options>>6) > 0
7787
c.UsernameFlag = 1&(options>>7) > 0
78-
c.Keepalive = decodeUint16(b)
79-
c.ClientIdentifier = decodeString(b)
88+
c.Keepalive, err = decodeUint16(b)
89+
if err != nil {
90+
return err
91+
}
92+
c.ClientIdentifier, err = decodeString(b)
93+
if err != nil {
94+
return err
95+
}
8096
if c.WillFlag {
81-
c.WillTopic = decodeString(b)
82-
c.WillMessage = decodeBytes(b)
97+
c.WillTopic, err = decodeString(b)
98+
if err != nil {
99+
return err
100+
}
101+
c.WillMessage, err = decodeBytes(b)
102+
if err != nil {
103+
return err
104+
}
83105
}
84106
if c.UsernameFlag {
85-
c.Username = decodeString(b)
107+
c.Username, err = decodeString(b)
108+
if err != nil {
109+
return err
110+
}
86111
}
87112
if c.PasswordFlag {
88-
c.Password = decodeBytes(b)
113+
c.Password, err = decodeBytes(b)
114+
if err != nil {
115+
return err
116+
}
89117
}
90118

91119
return nil

packets/packets.go

Lines changed: 84 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,25 @@ var ConnErrors = map[byte]error{
9898
//to read an MQTT packet from the stream. It returns a ControlPacket
9999
//representing the decoded MQTT packet and an error. One of these returns will
100100
//always be nil, a nil ControlPacket indicating an error occurred.
101-
func ReadPacket(r io.Reader) (cp ControlPacket, err error) {
101+
func ReadPacket(r io.Reader) (ControlPacket, error) {
102102
var fh FixedHeader
103103
b := make([]byte, 1)
104104

105-
_, err = io.ReadFull(r, b)
105+
_, err := io.ReadFull(r, b)
106106
if err != nil {
107107
return nil, err
108108
}
109-
fh.unpack(b[0], r)
110-
cp = NewControlPacketWithHeader(fh)
111-
if cp == nil {
112-
return nil, errors.New("Bad data from client")
109+
110+
err = fh.unpack(b[0], r)
111+
if err != nil {
112+
return nil, err
113+
}
114+
115+
cp, err := NewControlPacketWithHeader(fh)
116+
if err != nil {
117+
return nil, err
113118
}
119+
114120
packetBytes := make([]byte, fh.RemainingLength)
115121
n, err := io.ReadFull(r, packetBytes)
116122
if err != nil {
@@ -128,79 +134,75 @@ func ReadPacket(r io.Reader) (cp ControlPacket, err error) {
128134
//by packetType, this is usually done by reference to the packet type constants
129135
//defined in packets.go. The newly created ControlPacket is empty and a pointer
130136
//is returned.
131-
func NewControlPacket(packetType byte) (cp ControlPacket) {
137+
func NewControlPacket(packetType byte) ControlPacket {
132138
switch packetType {
133139
case Connect:
134-
cp = &ConnectPacket{FixedHeader: FixedHeader{MessageType: Connect}}
140+
return &ConnectPacket{FixedHeader: FixedHeader{MessageType: Connect}}
135141
case Connack:
136-
cp = &ConnackPacket{FixedHeader: FixedHeader{MessageType: Connack}}
142+
return &ConnackPacket{FixedHeader: FixedHeader{MessageType: Connack}}
137143
case Disconnect:
138-
cp = &DisconnectPacket{FixedHeader: FixedHeader{MessageType: Disconnect}}
144+
return &DisconnectPacket{FixedHeader: FixedHeader{MessageType: Disconnect}}
139145
case Publish:
140-
cp = &PublishPacket{FixedHeader: FixedHeader{MessageType: Publish}}
146+
return &PublishPacket{FixedHeader: FixedHeader{MessageType: Publish}}
141147
case Puback:
142-
cp = &PubackPacket{FixedHeader: FixedHeader{MessageType: Puback}}
148+
return &PubackPacket{FixedHeader: FixedHeader{MessageType: Puback}}
143149
case Pubrec:
144-
cp = &PubrecPacket{FixedHeader: FixedHeader{MessageType: Pubrec}}
150+
return &PubrecPacket{FixedHeader: FixedHeader{MessageType: Pubrec}}
145151
case Pubrel:
146-
cp = &PubrelPacket{FixedHeader: FixedHeader{MessageType: Pubrel, Qos: 1}}
152+
return &PubrelPacket{FixedHeader: FixedHeader{MessageType: Pubrel, Qos: 1}}
147153
case Pubcomp:
148-
cp = &PubcompPacket{FixedHeader: FixedHeader{MessageType: Pubcomp}}
154+
return &PubcompPacket{FixedHeader: FixedHeader{MessageType: Pubcomp}}
149155
case Subscribe:
150-
cp = &SubscribePacket{FixedHeader: FixedHeader{MessageType: Subscribe, Qos: 1}}
156+
return &SubscribePacket{FixedHeader: FixedHeader{MessageType: Subscribe, Qos: 1}}
151157
case Suback:
152-
cp = &SubackPacket{FixedHeader: FixedHeader{MessageType: Suback}}
158+
return &SubackPacket{FixedHeader: FixedHeader{MessageType: Suback}}
153159
case Unsubscribe:
154-
cp = &UnsubscribePacket{FixedHeader: FixedHeader{MessageType: Unsubscribe, Qos: 1}}
160+
return &UnsubscribePacket{FixedHeader: FixedHeader{MessageType: Unsubscribe, Qos: 1}}
155161
case Unsuback:
156-
cp = &UnsubackPacket{FixedHeader: FixedHeader{MessageType: Unsuback}}
162+
return &UnsubackPacket{FixedHeader: FixedHeader{MessageType: Unsuback}}
157163
case Pingreq:
158-
cp = &PingreqPacket{FixedHeader: FixedHeader{MessageType: Pingreq}}
164+
return &PingreqPacket{FixedHeader: FixedHeader{MessageType: Pingreq}}
159165
case Pingresp:
160-
cp = &PingrespPacket{FixedHeader: FixedHeader{MessageType: Pingresp}}
161-
default:
162-
return nil
166+
return &PingrespPacket{FixedHeader: FixedHeader{MessageType: Pingresp}}
163167
}
164-
return cp
168+
return nil
165169
}
166170

167171
//NewControlPacketWithHeader is used to create a new ControlPacket of the type
168172
//specified within the FixedHeader that is passed to the function.
169173
//The newly created ControlPacket is empty and a pointer is returned.
170-
func NewControlPacketWithHeader(fh FixedHeader) (cp ControlPacket) {
174+
func NewControlPacketWithHeader(fh FixedHeader) (ControlPacket, error) {
171175
switch fh.MessageType {
172176
case Connect:
173-
cp = &ConnectPacket{FixedHeader: fh}
177+
return &ConnectPacket{FixedHeader: fh}, nil
174178
case Connack:
175-
cp = &ConnackPacket{FixedHeader: fh}
179+
return &ConnackPacket{FixedHeader: fh}, nil
176180
case Disconnect:
177-
cp = &DisconnectPacket{FixedHeader: fh}
181+
return &DisconnectPacket{FixedHeader: fh}, nil
178182
case Publish:
179-
cp = &PublishPacket{FixedHeader: fh}
183+
return &PublishPacket{FixedHeader: fh}, nil
180184
case Puback:
181-
cp = &PubackPacket{FixedHeader: fh}
185+
return &PubackPacket{FixedHeader: fh}, nil
182186
case Pubrec:
183-
cp = &PubrecPacket{FixedHeader: fh}
187+
return &PubrecPacket{FixedHeader: fh}, nil
184188
case Pubrel:
185-
cp = &PubrelPacket{FixedHeader: fh}
189+
return &PubrelPacket{FixedHeader: fh}, nil
186190
case Pubcomp:
187-
cp = &PubcompPacket{FixedHeader: fh}
191+
return &PubcompPacket{FixedHeader: fh}, nil
188192
case Subscribe:
189-
cp = &SubscribePacket{FixedHeader: fh}
193+
return &SubscribePacket{FixedHeader: fh}, nil
190194
case Suback:
191-
cp = &SubackPacket{FixedHeader: fh}
195+
return &SubackPacket{FixedHeader: fh}, nil
192196
case Unsubscribe:
193-
cp = &UnsubscribePacket{FixedHeader: fh}
197+
return &UnsubscribePacket{FixedHeader: fh}, nil
194198
case Unsuback:
195-
cp = &UnsubackPacket{FixedHeader: fh}
199+
return &UnsubackPacket{FixedHeader: fh}, nil
196200
case Pingreq:
197-
cp = &PingreqPacket{FixedHeader: fh}
201+
return &PingreqPacket{FixedHeader: fh}, nil
198202
case Pingresp:
199-
cp = &PingrespPacket{FixedHeader: fh}
200-
default:
201-
return nil
203+
return &PingrespPacket{FixedHeader: fh}, nil
202204
}
203-
return cp
205+
return nil, fmt.Errorf("unsupported packet type 0x%x", fh.MessageType)
204206
}
205207

206208
//Details struct returned by the Details() function called on
@@ -241,24 +243,34 @@ func (fh *FixedHeader) pack() bytes.Buffer {
241243
return header
242244
}
243245

244-
func (fh *FixedHeader) unpack(typeAndFlags byte, r io.Reader) {
246+
func (fh *FixedHeader) unpack(typeAndFlags byte, r io.Reader) error {
245247
fh.MessageType = typeAndFlags >> 4
246248
fh.Dup = (typeAndFlags>>3)&0x01 > 0
247249
fh.Qos = (typeAndFlags >> 1) & 0x03
248250
fh.Retain = typeAndFlags&0x01 > 0
249-
fh.RemainingLength = decodeLength(r)
251+
252+
var err error
253+
fh.RemainingLength, err = decodeLength(r)
254+
return err
250255
}
251256

252-
func decodeByte(b io.Reader) byte {
257+
func decodeByte(b io.Reader) (byte, error) {
253258
num := make([]byte, 1)
254-
b.Read(num)
255-
return num[0]
259+
_, err := b.Read(num)
260+
if err != nil {
261+
return 0, err
262+
}
263+
264+
return num[0], nil
256265
}
257266

258-
func decodeUint16(b io.Reader) uint16 {
267+
func decodeUint16(b io.Reader) (uint16, error) {
259268
num := make([]byte, 2)
260-
b.Read(num)
261-
return binary.BigEndian.Uint16(num)
269+
_, err := b.Read(num)
270+
if err != nil {
271+
return 0, err
272+
}
273+
return binary.BigEndian.Uint16(num), nil
262274
}
263275

264276
func encodeUint16(num uint16) []byte {
@@ -268,19 +280,27 @@ func encodeUint16(num uint16) []byte {
268280
}
269281

270282
func encodeString(field string) []byte {
271-
272283
return encodeBytes([]byte(field))
273284
}
274285

275-
func decodeString(b io.Reader) string {
276-
return string(decodeBytes(b))
286+
func decodeString(b io.Reader) (string, error) {
287+
buf, err := decodeBytes(b)
288+
return string(buf), err
277289
}
278290

279-
func decodeBytes(b io.Reader) []byte {
280-
fieldLength := decodeUint16(b)
291+
func decodeBytes(b io.Reader) ([]byte, error) {
292+
fieldLength, err := decodeUint16(b)
293+
if err != nil {
294+
return nil, err
295+
}
296+
281297
field := make([]byte, fieldLength)
282-
b.Read(field)
283-
return field
298+
_, err = b.Read(field)
299+
if err != nil {
300+
return nil, err
301+
}
302+
303+
return field, nil
284304
}
285305

286306
func encodeBytes(field []byte) []byte {
@@ -305,18 +325,22 @@ func encodeLength(length int) []byte {
305325
return encLength
306326
}
307327

308-
func decodeLength(r io.Reader) int {
328+
func decodeLength(r io.Reader) (int, error) {
309329
var rLength uint32
310330
var multiplier uint32
311331
b := make([]byte, 1)
312332
for multiplier < 27 { //fix: Infinite '(digit & 128) == 1' will cause the dead loop
313-
io.ReadFull(r, b)
333+
_, err := io.ReadFull(r, b)
334+
if err != nil {
335+
return 0, err
336+
}
337+
314338
digit := b[0]
315339
rLength |= uint32(digit&127) << multiplier
316340
if (digit & 128) == 0 {
317341
break
318342
}
319343
multiplier += 7
320344
}
321-
return int(rLength)
345+
return int(rLength), nil
322346
}

packets/packets_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,11 @@ func TestPackUnpackControlPackets(t *testing.T) {
192192
}
193193

194194
func TestEncoding(t *testing.T) {
195-
if res := decodeByte(bytes.NewBuffer([]byte{0x56})); res != 0x56 {
196-
t.Errorf("decodeByte([0x56]) did not return 0x56 but 0x%X", res)
195+
if res, err := decodeByte(bytes.NewBuffer([]byte{0x56})); res != 0x56 || err != nil {
196+
t.Errorf("decodeByte([0x56]) did not return (0x56, nil) but (0x%X, %v)", res, err)
197197
}
198-
if res := decodeUint16(bytes.NewBuffer([]byte{0x56, 0x78})); res != 22136 {
199-
t.Errorf("decodeUint16([0x5678]) did not return 22136 but %d", res)
198+
if res, err := decodeUint16(bytes.NewBuffer([]byte{0x56, 0x78})); res != 22136 || err != nil {
199+
t.Errorf("decodeUint16([0x5678]) did not return (22136, nil) but (%d, %v)", res, err)
200200
}
201201
if res := encodeUint16(22136); !bytes.Equal(res, []byte{0x56, 0x78}) {
202202
t.Errorf("encodeUint16(22136) did not return [0x5678] but [0x%X]", res)
@@ -208,11 +208,11 @@ func TestEncoding(t *testing.T) {
208208
"A\U0002A6D4": []byte{0x00, 0x05, 'A', 0xF0, 0xAA, 0x9B, 0x94},
209209
}
210210
for str, encoded := range strings {
211-
if res := decodeString(bytes.NewBuffer(encoded)); res != str {
212-
t.Errorf(`decodeString(%v) did not return "%s", but "%s"`, encoded, str, res)
211+
if res, err := decodeString(bytes.NewBuffer(encoded)); res != str || err != nil {
212+
t.Errorf("decodeString(%v) did not return (%q, nil), but (%q, %v)", encoded, str, res, err)
213213
}
214214
if res := encodeString(str); !bytes.Equal(res, encoded) {
215-
t.Errorf(`encodeString("%s") did not return [0x%X], but [0x%X]`, str, encoded, res)
215+
t.Errorf("encodeString(%q) did not return [0x%X], but [0x%X]", str, encoded, res)
216216
}
217217
}
218218

@@ -227,8 +227,8 @@ func TestEncoding(t *testing.T) {
227227
268435455: []byte{0xFF, 0xFF, 0xFF, 0x7F},
228228
}
229229
for length, encoded := range lengths {
230-
if res := decodeLength(bytes.NewBuffer(encoded)); res != length {
231-
t.Errorf("decodeLength([0x%X]) did not return %d, but %d", encoded, length, res)
230+
if res, err := decodeLength(bytes.NewBuffer(encoded)); res != length || err != nil {
231+
t.Errorf("decodeLength([0x%X]) did not return (%d, nil) but (%d, %v)", encoded, length, res, err)
232232
}
233233
if res := encodeLength(length); !bytes.Equal(res, encoded) {
234234
t.Errorf("encodeLength(%d) did not return [0x%X], but [0x%X]", length, encoded, res)

packets/puback.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ func (pa *PubackPacket) Write(w io.Writer) error {
3232
//Unpack decodes the details of a ControlPacket after the fixed
3333
//header has been read
3434
func (pa *PubackPacket) Unpack(b io.Reader) error {
35-
pa.MessageID = decodeUint16(b)
35+
var err error
36+
pa.MessageID, err = decodeUint16(b)
3637

37-
return nil
38+
return err
3839
}
3940

4041
//Details returns a Details struct containing the Qos and

packets/pubcomp.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ func (pc *PubcompPacket) Write(w io.Writer) error {
3232
//Unpack decodes the details of a ControlPacket after the fixed
3333
//header has been read
3434
func (pc *PubcompPacket) Unpack(b io.Reader) error {
35-
pc.MessageID = decodeUint16(b)
35+
var err error
36+
pc.MessageID, err = decodeUint16(b)
3637

37-
return nil
38+
return err
3839
}
3940

4041
//Details returns a Details struct containing the Qos and

0 commit comments

Comments
 (0)