Skip to content

Commit 3e58ec1

Browse files
author
cdeiters
committed
Bugfixes: Cleaning up Go-Routines Left Over in case of Bad Username and Password during login.
Adding ability to stop the messageid object when a connection is torn down, releasing the go-routine. Also, adding ability to stop the outgoing channel when no message is being sent (in the case of error during initial connection, like bad username/password). Added a timeout to all attempts to send on outgoing via oboung and oboundP, to prevent them blocking forever if outgoing is no longer running. Change-Id: I0f9c3b4154fdeda1d552f8e7c9c14a6f7620b368 Signed-off-by: Christie Deiters <cdeiters@lutron.com>
1 parent 0d6c6e7 commit 3e58ec1

File tree

6 files changed

+139
-48
lines changed

6 files changed

+139
-48
lines changed

client.go

+21-16
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,15 @@ func (c *MqttClient) Start() ([]Receipt, error) {
123123

124124
cm := newConnectMsgFromOptions(c.options)
125125
DEBUG.Println(CLI, "about to write new connect msg")
126-
c.oboundP <- cm
126+
127+
if err := sendMessageWithTimeout(c.oboundP, cm); err != nil {
128+
select {
129+
case c.errors <- err:
130+
default:
131+
// c.errors is a buffer of one, so there must already be an error closing this connection.
132+
}
133+
return nil, err
134+
}
127135

128136
rc := connect(c)
129137
if chkrc(rc) != nil {
@@ -200,12 +208,11 @@ func (c *MqttClient) disconnect() {
200208
c.connected = false
201209
dm := newDisconnectMsg()
202210

203-
// Stop all go routines except outgoing
211+
// send disconnect message
212+
sendMessageWithTimeout(c.oboundP, dm)
213+
// Stop all go routines
204214
close(c.stop)
205215

206-
// Send disconnect message and stop outgoing
207-
c.oboundP <- dm
208-
209216
DEBUG.Println(CLI, "disconnected")
210217
c.persist.Close()
211218
}
@@ -228,9 +235,7 @@ func (c *MqttClient) Publish(qos QoS, topic string, payload interface{}) <-chan
228235
r := make(chan Receipt, 1)
229236
DEBUG.Println(CLI, "sending publish message, topic:", topic)
230237

231-
select {
232-
case c.obound <- sendable{pub, r}:
233-
case <-time.After(time.Second):
238+
if err := sendSendableWithTimeout(c.obound, sendable{pub, r}); err != nil {
234239
close(r)
235240
}
236241
return r
@@ -249,13 +254,10 @@ func (c *MqttClient) PublishMessage(topic string, message *Message) <-chan Recei
249254

250255
DEBUG.Println(CLI, "sending publish message, topic:", topic)
251256

252-
select {
253-
case c.obound <- sendable{pub, r}:
254-
return r
255-
case <-time.After(time.Second):
257+
if err := sendSendableWithTimeout(c.obound, sendable{pub, r}); err != nil {
256258
close(r)
257-
return nil
258259
}
260+
return r
259261
}
260262

261263
// Start a new subscription. Provide a MessageHandler to be executed when
@@ -275,8 +277,9 @@ func (c *MqttClient) StartSubscription(callback MessageHandler, filters ...*Topi
275277
}
276278

277279
r := make(chan Receipt, 1)
278-
279-
c.obound <- sendable{submsg, r}
280+
if err := sendSendableWithTimeout(c.obound, sendable{submsg, r}); err != nil {
281+
close(r)
282+
}
280283

281284
DEBUG.Println(CLI, "exit StartSubscription")
282285
return r, nil
@@ -293,8 +296,10 @@ func (c *MqttClient) EndSubscription(topics ...string) (<-chan Receipt, error) {
293296
usmsg := newUnsubscribeMsg(topics...)
294297

295298
r := make(chan Receipt, 1)
299+
if err := sendSendableWithTimeout(c.obound, sendable{usmsg, r}); err != nil {
300+
close(r)
301+
}
296302

297-
c.obound <- sendable{usmsg, r}
298303
for _, topic := range topics {
299304
c.options.msgRouter.deleteRoute(topic)
300305
}

messageids.go

+28-11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package mqtt
1616

1717
import (
18+
"errors"
1819
"sync"
1920
)
2021

@@ -25,30 +26,41 @@ type MId uint16
2526

2627
type messageIds struct {
2728
sync.Mutex
28-
idChan chan MId
29-
index map[MId]bool
29+
idChan chan MId
30+
index map[MId]bool
31+
stopChan chan struct{}
3032
}
3133

3234
const (
3335
MId_MAX MId = 65535
3436
MId_MIN MId = 1
3537
)
3638

39+
func (mids *messageIds) stop() {
40+
close(mids.stopChan)
41+
}
42+
3743
func (mids *messageIds) generateMsgIds() {
44+
3845
mids.idChan = make(chan MId, 10)
39-
go func() {
46+
mids.stopChan = make(chan struct{})
47+
go func(mid *messageIds) {
4048
for {
41-
mids.Lock()
49+
mid.Lock()
4250
for i := MId_MIN; i < MId_MAX; i++ {
43-
if !mids.index[i] {
44-
mids.index[i] = true
45-
mids.Unlock()
46-
mids.idChan <- i
51+
if !mid.index[i] {
52+
mid.index[i] = true
53+
mid.Unlock()
54+
select {
55+
case mid.idChan <- i:
56+
case <-mid.stopChan:
57+
return
58+
}
4759
break
4860
}
4961
}
5062
}
51-
}()
63+
}(mids)
5264
}
5365

5466
func (mids *messageIds) freeId(id MId) {
@@ -58,6 +70,11 @@ func (mids *messageIds) freeId(id MId) {
5870
mids.index[id] = false
5971
}
6072

61-
func (mids *messageIds) getId() MId {
62-
return <-mids.idChan
73+
func (mids *messageIds) getId() (MId, error) {
74+
select {
75+
case i := <-mids.idChan:
76+
return i, nil
77+
case <-mids.stopChan:
78+
}
79+
return 0, errors.New("Failed to get next message id.")
6380
}

net.go

+43-9
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package mqtt
1717
import (
1818
"code.google.com/p/go.net/websocket"
1919
"crypto/tls"
20+
"errors"
2021
"io"
2122
"net"
2223
"net/url"
@@ -133,6 +134,32 @@ func incoming(c *MqttClient) {
133134
}
134135
}
135136

137+
func sendSendableWithTimeout(sChan chan sendable, s sendable) error {
138+
139+
t := time.NewTimer(time.Second)
140+
select {
141+
case sChan <- s:
142+
// stop timer so we don't leak it.
143+
t.Stop()
144+
case <-t.C:
145+
return errors.New("Timed out sending message")
146+
}
147+
return nil
148+
}
149+
150+
func sendMessageWithTimeout(mChan chan *Message, m *Message) error {
151+
152+
t := time.NewTimer(time.Second)
153+
select {
154+
case mChan <- m:
155+
// stop timer so we don't leak it.
156+
t.Stop()
157+
case <-t.C:
158+
return errors.New("Timed out sending message")
159+
}
160+
return nil
161+
}
162+
136163
// receive a Message object on obound, and then
137164
// actually send outgoing message to the wire
138165
func outgoing(c *MqttClient) {
@@ -142,12 +169,24 @@ func outgoing(c *MqttClient) {
142169
for {
143170
DEBUG.Println(NET, "outgoing waiting for an outbound message")
144171
select {
172+
case <-c.stop:
173+
// The connection has been closed, we won't be receiving any more requests.
174+
return
145175
case out := <-c.obound:
146176
msg := out.m
147177
msgtype := msg.msgType()
148178
DEBUG.Println(NET, "obound got msg to write, type:", msgtype)
149179
if msg.QoS() != QOS_ZERO && msg.MsgId() == 0 {
150-
msg.setMsgId(c.options.mids.getId())
180+
i, err := c.options.mids.getId()
181+
if err != nil {
182+
select {
183+
case c.errors <- err:
184+
default:
185+
// c.errors is a buffer of one, so there must already be an error closing this connection.
186+
}
187+
return
188+
}
189+
msg.setMsgId(i)
151190
}
152191
if out.r != nil {
153192
c.receipts.put(msg.MsgId(), out.r)
@@ -283,18 +322,12 @@ func alllogic(c *MqttClient) {
283322
id := msg.MsgId()
284323
pubrelMsg := newPubRelMsg()
285324
pubrelMsg.setMsgId(id)
286-
select {
287-
case c.obound <- sendable{pubrelMsg, nil}:
288-
case <-time.After(time.Second):
289-
}
325+
sendSendableWithTimeout(c.obound, sendable{pubrelMsg, nil})
290326
case PUBREL:
291327
DEBUG.Println(NET, "received pubrel, id:", msg.MsgId())
292328
pubcompMsg := newPubCompMsg()
293329
pubcompMsg.setMsgId(msg.MsgId())
294-
select {
295-
case c.obound <- sendable{pubcompMsg, nil}:
296-
case <-time.After(time.Second):
297-
}
330+
sendSendableWithTimeout(c.obound, sendable{pubcompMsg, nil})
298331
case PUBCOMP:
299332
DEBUG.Println(NET, "received pubcomp, id:", msg.MsgId())
300333
c.receipts.get(msg.MsgId()) <- Receipt{}
@@ -312,6 +345,7 @@ func alllogic(c *MqttClient) {
312345
// but let it know to stop anyways.
313346
close(c.options.stopRouter)
314347
close(c.stop)
348+
c.options.mids.stop()
315349
c.conn.Close()
316350

317351
// Call onConnectionLost or default error handler

ping.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,13 @@ func keepalive(c *MqttClient) {
5959
if !c.pingOutstanding {
6060
DEBUG.Println(PNG, "keepalive sending ping")
6161
ping := newPingReqMsg()
62-
c.oboundP <- ping
63-
c.pingOutstanding = true
62+
if err := sendMessageWithTimeout(c.oboundP, ping); err != nil {
63+
DEBUG.Println(PNG, "unable to send ping, disconnecting")
64+
go c.options.onconnlost(c, errors.New("unable to send ping, disconnecting"))
65+
c.disconnect()
66+
} else {
67+
c.pingOutstanding = true
68+
}
6469
} else {
6570
CRITICAL.Println(PNG, "pingresp not received, disconnecting")
6671
go c.options.onconnlost(c, errors.New("pingresp not received, disconnecting"))

state.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (c *MqttClient) resume() []Receipt {
6969
if c.receipts.get(m.MsgId()) == nil { // will be nil if client crashed
7070
c.receipts.put(m.MsgId(), make(chan Receipt, 1))
7171
}
72-
c.obound <- sendable{m, c.receipts.get(m.MsgId())}
72+
sendSendableWithTimeout(c.obound, sendable{m, c.receipts.get(m.MsgId())})
7373
}
7474
}
7575

@@ -81,7 +81,7 @@ func (c *MqttClient) resume() []Receipt {
8181
if c.receipts.get(m.MsgId()) == nil { // will be nil if client crashed
8282
c.receipts.put(m.MsgId(), make(chan Receipt, 1))
8383
}
84-
c.obound <- sendable{m, c.receipts.get(m.MsgId())}
84+
sendSendableWithTimeout(c.obound, sendable{m, c.receipts.get(m.MsgId())})
8585
}
8686
}
8787

unit_messageids_test.go

+38-8
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,29 @@ func Test_getId(t *testing.T) {
2323
mids := &messageIds{index: make(map[MId]bool)}
2424
mids.generateMsgIds()
2525

26-
i1 := mids.getId()
26+
i1, err := mids.getId()
27+
if err != nil {
28+
t.Fatalf("Failed to get id. %s", err.Error())
29+
}
2730

2831
if i1 != MId(1) {
2932
t.Fatalf("i1 was wrong: %v", i1)
3033
}
3134

32-
i2 := mids.getId()
35+
i2, err := mids.getId()
36+
if err != nil {
37+
t.Fatalf("Failed to get id. %s", err.Error())
38+
}
3339

3440
if i2 != MId(2) {
3541
t.Fatalf("i2 was wrong: %v", i2)
3642
}
3743

3844
for i := 3; i < 100; i++ {
39-
id := mids.getId()
45+
id, err := mids.getId()
46+
if err != nil {
47+
t.Fatalf("Failed to get id. %s", err.Error())
48+
}
4049
if id != MId(i) {
4150
t.Fatalf("id was wrong expected %v got %v", i, id)
4251
}
@@ -47,14 +56,20 @@ func Test_freeId(t *testing.T) {
4756
mids := &messageIds{index: make(map[MId]bool)}
4857
mids.generateMsgIds()
4958

50-
i1 := mids.getId()
59+
i1, err := mids.getId()
60+
if err != nil {
61+
t.Fatalf("Failed to get id. %s", err.Error())
62+
}
5163
mids.freeId(i1)
5264

5365
if i1 != MId(1) {
5466
t.Fatalf("i1 was wrong: %v", i1)
5567
}
5668

57-
i2 := mids.getId()
69+
i2, err := mids.getId()
70+
if err != nil {
71+
t.Fatalf("Failed to get id. %s", err.Error())
72+
}
5873
fmt.Printf("i2: %v\n", i2)
5974
}
6075

@@ -69,23 +84,38 @@ func Test_messageids_mix(t *testing.T) {
6984

7085
go func() {
7186
for i := 0; i < 10000; i++ {
72-
a <- mids.getId()
87+
id, err := mids.getId()
88+
if err != nil {
89+
t.Fatalf("Failed to get id. %s", err.Error())
90+
}
91+
a <- id
92+
7393
mids.freeId(<-b)
7494
}
7595
done <- true
7696
}()
7797

7898
go func() {
7999
for i := 0; i < 10000; i++ {
80-
b <- mids.getId()
100+
id, err := mids.getId()
101+
if err != nil {
102+
t.Fatalf("Failed to get id. %s", err.Error())
103+
}
104+
b <- id
105+
81106
mids.freeId(<-c)
82107
}
83108
done <- true
84109
}()
85110

86111
go func() {
87112
for i := 0; i < 10000; i++ {
88-
c <- mids.getId()
113+
id, err := mids.getId()
114+
if err != nil {
115+
t.Fatalf("Failed to get id. %s", err.Error())
116+
}
117+
c <- id
118+
89119
mids.freeId(<-a)
90120
}
91121
done <- true

0 commit comments

Comments
 (0)