Skip to content

Commit 0d6e39f

Browse files
authored
transport: Send RST stream from the server when deadline expires (#8071)
1 parent 7505bf2 commit 0d6e39f

File tree

8 files changed

+324
-49
lines changed

8 files changed

+324
-49
lines changed

internal/grpctest/grpctest.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package grpctest
2121

2222
import (
23+
"context"
2324
"reflect"
2425
"strings"
2526
"sync/atomic"
@@ -58,15 +59,19 @@ func (Tester) Setup(t *testing.T) {
5859
// completely addressed, and this can be turned back on as soon as this issue is
5960
// fixed.
6061
leakcheck.SetTrackingBufferPool(logger{t: t})
62+
leakcheck.TrackTimers()
6163
}
6264

6365
// Teardown performs a leak check.
6466
func (Tester) Teardown(t *testing.T) {
6567
leakcheck.CheckTrackingBufferPool()
68+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
69+
defer cancel()
70+
leakcheck.CheckTimers(ctx, logger{t: t})
6671
if atomic.LoadUint32(&lcFailed) == 1 {
6772
return
6873
}
69-
leakcheck.CheckGoroutines(logger{t: t}, 10*time.Second)
74+
leakcheck.CheckGoroutines(ctx, logger{t: t})
7075
if atomic.LoadUint32(&lcFailed) == 1 {
7176
t.Log("Goroutine leak check disabled for future tests")
7277
}

internal/internal.go

+13
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,13 @@ var (
259259
// SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for
260260
// testing purposes.
261261
SetBufferPoolingThresholdForTesting any // func(int)
262+
263+
// TimeAfterFunc is used to create timers. During tests the function is
264+
// replaced to track allocated timers and fail the test if a timer isn't
265+
// cancelled.
266+
TimeAfterFunc = func(d time.Duration, f func()) Timer {
267+
return time.AfterFunc(d, f)
268+
}
262269
)
263270

264271
// HealthChecker defines the signature of the client-side LB channel health
@@ -300,3 +307,9 @@ type EnforceSubConnEmbedding interface {
300307
type EnforceClientConnEmbedding interface {
301308
enforceClientConnEmbedding()
302309
}
310+
311+
// Timer is an interface to allow injecting different time.Timer implementations
312+
// during tests.
313+
type Timer interface {
314+
Stop() bool
315+
}

internal/leakcheck/leakcheck.go

+123-27
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
package leakcheck
2525

2626
import (
27+
"context"
28+
"fmt"
2729
"runtime"
2830
"runtime/debug"
2931
"slices"
@@ -53,6 +55,7 @@ func init() {
5355
}
5456

5557
var globalPool swappableBufferPool
58+
var globalTimerTracker *timerFactory
5659

5760
type swappableBufferPool struct {
5861
atomic.Pointer[mem.BufferPool]
@@ -81,7 +84,7 @@ func SetTrackingBufferPool(logger Logger) {
8184

8285
// CheckTrackingBufferPool undoes the effects of SetTrackingBufferPool, and fails
8386
// unit tests if not all buffers were returned. It is invalid to invoke this
84-
// method without previously having invoked SetTrackingBufferPool.
87+
// function without previously having invoked SetTrackingBufferPool.
8588
func CheckTrackingBufferPool() {
8689
p := (*globalPool.Load()).(*trackingBufferPool)
8790
p.lock.Lock()
@@ -148,24 +151,9 @@ type trackingBufferPool struct {
148151
func (p *trackingBufferPool) Get(length int) *[]byte {
149152
p.lock.Lock()
150153
defer p.lock.Unlock()
151-
152154
p.bufferCount++
153-
154155
buf := p.pool.Get(length)
155-
156-
var stackBuf [16]uintptr
157-
var stack []uintptr
158-
skip := 2
159-
for {
160-
n := runtime.Callers(skip, stackBuf[:])
161-
stack = append(stack, stackBuf[:n]...)
162-
if n < len(stackBuf) {
163-
break
164-
}
165-
skip += len(stackBuf)
166-
}
167-
p.allocatedBuffers[buf] = stack
168-
156+
p.allocatedBuffers[buf] = currentStack(2)
169157
return buf
170158
}
171159

@@ -257,12 +245,11 @@ type Logger interface {
257245
// CheckGoroutines looks at the currently-running goroutines and checks if there
258246
// are any interesting (created by gRPC) goroutines leaked. It waits up to 10
259247
// seconds in the error cases.
260-
func CheckGoroutines(logger Logger, timeout time.Duration) {
248+
func CheckGoroutines(ctx context.Context, logger Logger) {
261249
// Loop, waiting for goroutines to shut down.
262250
// Wait up to timeout, but finish as quickly as possible.
263-
deadline := time.Now().Add(timeout)
264251
var leaked []string
265-
for time.Now().Before(deadline) {
252+
for ctx.Err() == nil {
266253
if leaked = interestingGoroutines(); len(leaked) == 0 {
267254
return
268255
}
@@ -279,13 +266,6 @@ type LeakChecker struct {
279266
logger Logger
280267
}
281268

282-
// Check executes the leak check tests, failing the unit test if any buffer or
283-
// goroutine leaks are detected.
284-
func (lc *LeakChecker) Check() {
285-
CheckTrackingBufferPool()
286-
CheckGoroutines(lc.logger, 10*time.Second)
287-
}
288-
289269
// NewLeakChecker offers a convenient way to set up the leak checks for a
290270
// specific unit test. It can be used as follows, at the beginning of tests:
291271
//
@@ -298,3 +278,119 @@ func NewLeakChecker(logger Logger) *LeakChecker {
298278
SetTrackingBufferPool(logger)
299279
return &LeakChecker{logger: logger}
300280
}
281+
282+
type timerFactory struct {
283+
mu sync.Mutex
284+
allocatedTimers map[internal.Timer][]uintptr
285+
}
286+
287+
func (tf *timerFactory) timeAfterFunc(d time.Duration, f func()) internal.Timer {
288+
tf.mu.Lock()
289+
defer tf.mu.Unlock()
290+
ch := make(chan internal.Timer, 1)
291+
timer := time.AfterFunc(d, func() {
292+
f()
293+
tf.remove(<-ch)
294+
})
295+
ch <- timer
296+
tf.allocatedTimers[timer] = currentStack(2)
297+
return &trackingTimer{
298+
Timer: timer,
299+
parent: tf,
300+
}
301+
}
302+
303+
func (tf *timerFactory) remove(timer internal.Timer) {
304+
tf.mu.Lock()
305+
defer tf.mu.Unlock()
306+
delete(tf.allocatedTimers, timer)
307+
}
308+
309+
func (tf *timerFactory) pendingTimers() []string {
310+
tf.mu.Lock()
311+
defer tf.mu.Unlock()
312+
leaked := []string{}
313+
for _, stack := range tf.allocatedTimers {
314+
leaked = append(leaked, fmt.Sprintf("Allocated timer never cancelled:\n%s", traceToString(stack)))
315+
}
316+
return leaked
317+
}
318+
319+
type trackingTimer struct {
320+
internal.Timer
321+
parent *timerFactory
322+
}
323+
324+
func (t *trackingTimer) Stop() bool {
325+
t.parent.remove(t.Timer)
326+
return t.Timer.Stop()
327+
}
328+
329+
// TrackTimers replaces internal.TimerAfterFunc with one that tracks timer
330+
// creations, stoppages and expirations. CheckTimers should then be invoked at
331+
// the end of the test to validate that all timers created have either executed
332+
// or are cancelled.
333+
func TrackTimers() {
334+
globalTimerTracker = &timerFactory{
335+
allocatedTimers: make(map[internal.Timer][]uintptr),
336+
}
337+
internal.TimeAfterFunc = globalTimerTracker.timeAfterFunc
338+
}
339+
340+
// CheckTimers undoes the effects of TrackTimers, and fails unit tests if not
341+
// all timers were cancelled or executed. It is invalid to invoke this function
342+
// without previously having invoked TrackTimers.
343+
func CheckTimers(ctx context.Context, logger Logger) {
344+
tt := globalTimerTracker
345+
346+
// Loop, waiting for timers to be cancelled.
347+
// Wait up to timeout, but finish as quickly as possible.
348+
var leaked []string
349+
for ctx.Err() == nil {
350+
if leaked = tt.pendingTimers(); len(leaked) == 0 {
351+
return
352+
}
353+
time.Sleep(50 * time.Millisecond)
354+
}
355+
for _, g := range leaked {
356+
logger.Errorf("Leaked timers: %v", g)
357+
}
358+
359+
// Reset the internal function.
360+
internal.TimeAfterFunc = func(d time.Duration, f func()) internal.Timer {
361+
return time.AfterFunc(d, f)
362+
}
363+
}
364+
365+
func currentStack(skip int) []uintptr {
366+
var stackBuf [16]uintptr
367+
var stack []uintptr
368+
skip++
369+
for {
370+
n := runtime.Callers(skip, stackBuf[:])
371+
stack = append(stack, stackBuf[:n]...)
372+
if n < len(stackBuf) {
373+
break
374+
}
375+
skip += len(stackBuf)
376+
}
377+
return stack
378+
}
379+
380+
func traceToString(stack []uintptr) string {
381+
frames := runtime.CallersFrames(stack)
382+
var trace strings.Builder
383+
for {
384+
f, ok := frames.Next()
385+
if !ok {
386+
break
387+
}
388+
trace.WriteString(f.Function)
389+
trace.WriteString("\n\t")
390+
trace.WriteString(f.File)
391+
trace.WriteString(":")
392+
trace.WriteString(strconv.Itoa(f.Line))
393+
trace.WriteString("\n")
394+
}
395+
return trace.String()
396+
}

internal/leakcheck/leakcheck_test.go

+57-4
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@
1919
package leakcheck
2020

2121
import (
22+
"context"
2223
"fmt"
2324
"strings"
25+
"sync"
2426
"testing"
2527
"time"
28+
29+
"google.golang.org/grpc/internal"
2630
)
2731

2832
type testLogger struct {
@@ -47,12 +51,16 @@ func TestCheck(t *testing.T) {
4751
t.Error("blah")
4852
}
4953
e := &testLogger{}
50-
CheckGoroutines(e, time.Second)
54+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
55+
defer cancel()
56+
CheckGoroutines(ctx, e)
5157
if e.errorCount != leakCount {
5258
t.Errorf("CheckGoroutines found %v leaks, want %v leaks", e.errorCount, leakCount)
5359
t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n"))
5460
}
55-
CheckGoroutines(t, 3*time.Second)
61+
ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
62+
defer cancel()
63+
CheckGoroutines(ctx, t)
5664
}
5765

5866
func ignoredTestingLeak(d time.Duration) {
@@ -70,10 +78,55 @@ func TestCheckRegisterIgnore(t *testing.T) {
7078
t.Error("blah")
7179
}
7280
e := &testLogger{}
73-
CheckGoroutines(e, time.Second)
81+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
82+
defer cancel()
83+
CheckGoroutines(ctx, e)
7484
if e.errorCount != leakCount {
7585
t.Errorf("CheckGoroutines found %v leaks, want %v leaks", e.errorCount, leakCount)
7686
t.Logf("leaked goroutines:\n%v", strings.Join(e.errors, "\n"))
7787
}
78-
CheckGoroutines(t, 3*time.Second)
88+
ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
89+
defer cancel()
90+
CheckGoroutines(ctx, t)
91+
}
92+
93+
// TestTrackTimers verifies that only leaked timers are reported and expired,
94+
// stopped timers are ignored.
95+
func TestTrackTimers(t *testing.T) {
96+
TrackTimers()
97+
const leakCount = 3
98+
for i := 0; i < leakCount; i++ {
99+
internal.TimeAfterFunc(2*time.Second, func() {
100+
t.Logf("Timer %d fired.", i)
101+
})
102+
}
103+
wg := sync.WaitGroup{}
104+
// Let a couple of timers expire.
105+
for i := 0; i < 2; i++ {
106+
wg.Add(1)
107+
internal.TimeAfterFunc(time.Millisecond, func() {
108+
wg.Done()
109+
})
110+
}
111+
wg.Wait()
112+
113+
// Stop a couple of timers.
114+
for i := 0; i < leakCount; i++ {
115+
t := internal.TimeAfterFunc(time.Hour, func() {
116+
t.Error("Timer fired before test ended.")
117+
})
118+
t.Stop()
119+
}
120+
121+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
122+
defer cancel()
123+
e := &testLogger{}
124+
CheckTimers(ctx, e)
125+
if e.errorCount != leakCount {
126+
t.Errorf("CheckTimers found %v leaks, want %v leaks", e.errorCount, leakCount)
127+
t.Logf("leaked timers:\n%v", strings.Join(e.errors, "\n"))
128+
}
129+
ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
130+
defer cancel()
131+
CheckTimers(ctx, t)
79132
}

internal/transport/client_stream.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (s *ClientStream) Read(n int) (mem.BufferSlice, error) {
5959
return b, err
6060
}
6161

62-
// Close closes the stream and popagates err to any readers.
62+
// Close closes the stream and propagates err to any readers.
6363
func (s *ClientStream) Close(err error) {
6464
var (
6565
rst bool

0 commit comments

Comments
 (0)