24
24
package leakcheck
25
25
26
26
import (
27
+ "context"
28
+ "fmt"
27
29
"runtime"
28
30
"runtime/debug"
29
31
"slices"
@@ -53,6 +55,7 @@ func init() {
53
55
}
54
56
55
57
var globalPool swappableBufferPool
58
+ var globalTimerTracker * timerFactory
56
59
57
60
type swappableBufferPool struct {
58
61
atomic.Pointer [mem.BufferPool ]
@@ -81,7 +84,7 @@ func SetTrackingBufferPool(logger Logger) {
81
84
82
85
// CheckTrackingBufferPool undoes the effects of SetTrackingBufferPool, and fails
83
86
// unit tests if not all buffers were returned. It is invalid to invoke this
84
- // method without previously having invoked SetTrackingBufferPool.
87
+ // function without previously having invoked SetTrackingBufferPool.
85
88
func CheckTrackingBufferPool () {
86
89
p := (* globalPool .Load ()).(* trackingBufferPool )
87
90
p .lock .Lock ()
@@ -148,24 +151,9 @@ type trackingBufferPool struct {
148
151
func (p * trackingBufferPool ) Get (length int ) * []byte {
149
152
p .lock .Lock ()
150
153
defer p .lock .Unlock ()
151
-
152
154
p .bufferCount ++
153
-
154
155
buf := p .pool .Get (length )
155
-
156
- var stackBuf [16 ]uintptr
157
- var stack []uintptr
158
- skip := 2
159
- for {
160
- n := runtime .Callers (skip , stackBuf [:])
161
- stack = append (stack , stackBuf [:n ]... )
162
- if n < len (stackBuf ) {
163
- break
164
- }
165
- skip += len (stackBuf )
166
- }
167
- p .allocatedBuffers [buf ] = stack
168
-
156
+ p .allocatedBuffers [buf ] = currentStack (2 )
169
157
return buf
170
158
}
171
159
@@ -257,12 +245,11 @@ type Logger interface {
257
245
// CheckGoroutines looks at the currently-running goroutines and checks if there
258
246
// are any interesting (created by gRPC) goroutines leaked. It waits up to 10
259
247
// seconds in the error cases.
260
- func CheckGoroutines (logger Logger , timeout time. Duration ) {
248
+ func CheckGoroutines (ctx context. Context , logger Logger ) {
261
249
// Loop, waiting for goroutines to shut down.
262
250
// Wait up to timeout, but finish as quickly as possible.
263
- deadline := time .Now ().Add (timeout )
264
251
var leaked []string
265
- for time . Now (). Before ( deadline ) {
252
+ for ctx . Err () == nil {
266
253
if leaked = interestingGoroutines (); len (leaked ) == 0 {
267
254
return
268
255
}
@@ -279,13 +266,6 @@ type LeakChecker struct {
279
266
logger Logger
280
267
}
281
268
282
- // Check executes the leak check tests, failing the unit test if any buffer or
283
- // goroutine leaks are detected.
284
- func (lc * LeakChecker ) Check () {
285
- CheckTrackingBufferPool ()
286
- CheckGoroutines (lc .logger , 10 * time .Second )
287
- }
288
-
289
269
// NewLeakChecker offers a convenient way to set up the leak checks for a
290
270
// specific unit test. It can be used as follows, at the beginning of tests:
291
271
//
@@ -298,3 +278,119 @@ func NewLeakChecker(logger Logger) *LeakChecker {
298
278
SetTrackingBufferPool (logger )
299
279
return & LeakChecker {logger : logger }
300
280
}
281
+
282
+ type timerFactory struct {
283
+ mu sync.Mutex
284
+ allocatedTimers map [internal.Timer ][]uintptr
285
+ }
286
+
287
+ func (tf * timerFactory ) timeAfterFunc (d time.Duration , f func ()) internal.Timer {
288
+ tf .mu .Lock ()
289
+ defer tf .mu .Unlock ()
290
+ ch := make (chan internal.Timer , 1 )
291
+ timer := time .AfterFunc (d , func () {
292
+ f ()
293
+ tf .remove (<- ch )
294
+ })
295
+ ch <- timer
296
+ tf .allocatedTimers [timer ] = currentStack (2 )
297
+ return & trackingTimer {
298
+ Timer : timer ,
299
+ parent : tf ,
300
+ }
301
+ }
302
+
303
+ func (tf * timerFactory ) remove (timer internal.Timer ) {
304
+ tf .mu .Lock ()
305
+ defer tf .mu .Unlock ()
306
+ delete (tf .allocatedTimers , timer )
307
+ }
308
+
309
+ func (tf * timerFactory ) pendingTimers () []string {
310
+ tf .mu .Lock ()
311
+ defer tf .mu .Unlock ()
312
+ leaked := []string {}
313
+ for _ , stack := range tf .allocatedTimers {
314
+ leaked = append (leaked , fmt .Sprintf ("Allocated timer never cancelled:\n %s" , traceToString (stack )))
315
+ }
316
+ return leaked
317
+ }
318
+
319
+ type trackingTimer struct {
320
+ internal.Timer
321
+ parent * timerFactory
322
+ }
323
+
324
+ func (t * trackingTimer ) Stop () bool {
325
+ t .parent .remove (t .Timer )
326
+ return t .Timer .Stop ()
327
+ }
328
+
329
+ // TrackTimers replaces internal.TimerAfterFunc with one that tracks timer
330
+ // creations, stoppages and expirations. CheckTimers should then be invoked at
331
+ // the end of the test to validate that all timers created have either executed
332
+ // or are cancelled.
333
+ func TrackTimers () {
334
+ globalTimerTracker = & timerFactory {
335
+ allocatedTimers : make (map [internal.Timer ][]uintptr ),
336
+ }
337
+ internal .TimeAfterFunc = globalTimerTracker .timeAfterFunc
338
+ }
339
+
340
+ // CheckTimers undoes the effects of TrackTimers, and fails unit tests if not
341
+ // all timers were cancelled or executed. It is invalid to invoke this function
342
+ // without previously having invoked TrackTimers.
343
+ func CheckTimers (ctx context.Context , logger Logger ) {
344
+ tt := globalTimerTracker
345
+
346
+ // Loop, waiting for timers to be cancelled.
347
+ // Wait up to timeout, but finish as quickly as possible.
348
+ var leaked []string
349
+ for ctx .Err () == nil {
350
+ if leaked = tt .pendingTimers (); len (leaked ) == 0 {
351
+ return
352
+ }
353
+ time .Sleep (50 * time .Millisecond )
354
+ }
355
+ for _ , g := range leaked {
356
+ logger .Errorf ("Leaked timers: %v" , g )
357
+ }
358
+
359
+ // Reset the internal function.
360
+ internal .TimeAfterFunc = func (d time.Duration , f func ()) internal.Timer {
361
+ return time .AfterFunc (d , f )
362
+ }
363
+ }
364
+
365
+ func currentStack (skip int ) []uintptr {
366
+ var stackBuf [16 ]uintptr
367
+ var stack []uintptr
368
+ skip ++
369
+ for {
370
+ n := runtime .Callers (skip , stackBuf [:])
371
+ stack = append (stack , stackBuf [:n ]... )
372
+ if n < len (stackBuf ) {
373
+ break
374
+ }
375
+ skip += len (stackBuf )
376
+ }
377
+ return stack
378
+ }
379
+
380
+ func traceToString (stack []uintptr ) string {
381
+ frames := runtime .CallersFrames (stack )
382
+ var trace strings.Builder
383
+ for {
384
+ f , ok := frames .Next ()
385
+ if ! ok {
386
+ break
387
+ }
388
+ trace .WriteString (f .Function )
389
+ trace .WriteString ("\n \t " )
390
+ trace .WriteString (f .File )
391
+ trace .WriteString (":" )
392
+ trace .WriteString (strconv .Itoa (f .Line ))
393
+ trace .WriteString ("\n " )
394
+ }
395
+ return trace .String ()
396
+ }
0 commit comments