diff --git a/context/context.go b/context/context.go index db1c95fab1..d3cb951752 100644 --- a/context/context.go +++ b/context/context.go @@ -6,7 +6,7 @@ // cancellation signals, and other request-scoped values across API boundaries // and between processes. // As of Go 1.7 this package is available in the standard library under the -// name [context], and migrating to it can be done automatically with [go fix]. +// name [context]. // // Incoming requests to a server should create a [Context], and outgoing // calls to servers should accept a Context. The chain of function @@ -38,8 +38,6 @@ // // See https://go.dev/blog/context for example code for a server that uses // Contexts. -// -// [go fix]: https://go.dev/cmd/go#hdr-Update_packages_to_use_new_APIs package context import ( @@ -51,36 +49,37 @@ import ( // API boundaries. // // Context's methods may be called by multiple goroutines simultaneously. +// +//go:fix inline type Context = context.Context // Canceled is the error returned by [Context.Err] when the context is canceled // for some reason other than its deadline passing. +// +//go:fix inline var Canceled = context.Canceled // DeadlineExceeded is the error returned by [Context.Err] when the context is canceled // due to its deadline passing. +// +//go:fix inline var DeadlineExceeded = context.DeadlineExceeded // Background returns a non-nil, empty Context. It is never canceled, has no // values, and has no deadline. It is typically used by the main function, // initialization, and tests, and as the top-level Context for incoming // requests. -func Background() Context { - return background -} +// +//go:fix inline +func Background() Context { return context.Background() } // TODO returns a non-nil, empty Context. Code should use context.TODO when // it's unclear which Context to use or it is not yet available (because the // surrounding function has not yet been extended to accept a Context // parameter). -func TODO() Context { - return todo -} - -var ( - background = context.Background() - todo = context.TODO() -) +// +//go:fix inline +func TODO() Context { return context.TODO() } // A CancelFunc tells an operation to abandon its work. // A CancelFunc does not wait for the work to stop. @@ -95,6 +94,8 @@ type CancelFunc = context.CancelFunc // // Canceling this context releases resources associated with it, so code should // call cancel as soon as the operations running in this [Context] complete. +// +//go:fix inline func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { return context.WithCancel(parent) } @@ -108,6 +109,8 @@ func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { // // Canceling this context releases resources associated with it, so code should // call cancel as soon as the operations running in this [Context] complete. +// +//go:fix inline func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) { return context.WithDeadline(parent, d) } @@ -122,6 +125,8 @@ func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) { // defer cancel() // releases resources if slowOperation completes before timeout elapses // return slowOperation(ctx) // } +// +//go:fix inline func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { return context.WithTimeout(parent, timeout) } @@ -139,6 +144,8 @@ func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { // interface{}, context keys often have concrete type // struct{}. Alternatively, exported context key variables' static // type should be a pointer or interface. +// +//go:fix inline func WithValue(parent Context, key, val interface{}) Context { return context.WithValue(parent, key, val) } diff --git a/go.mod b/go.mod index 39cac244a1..944cfb6dfa 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,10 @@ module golang.org/x/net -go 1.23.0 +go 1.24.0 require ( - golang.org/x/crypto v0.41.0 - golang.org/x/sys v0.35.0 - golang.org/x/term v0.34.0 - golang.org/x/text v0.28.0 + golang.org/x/crypto v0.42.0 + golang.org/x/sys v0.36.0 + golang.org/x/term v0.35.0 + golang.org/x/text v0.29.0 ) diff --git a/go.sum b/go.sum index 1ce4678e25..fc842c25ab 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ -golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= -golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= -golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= +golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= +golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index f9e9a2fdaa..de2b93501e 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -5,6 +5,8 @@ // Infrastructure for testing ClientConn.RoundTrip. // Put actual tests in transport_test.go. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -17,6 +19,7 @@ import ( "reflect" "sync/atomic" "testing" + "testing/synctest" "time" "golang.org/x/net/http2/hpack" @@ -24,7 +27,8 @@ import ( ) // TestTestClientConn demonstrates usage of testClientConn. -func TestTestClientConn(t *testing.T) { +func TestTestClientConn(t *testing.T) { synctestTest(t, testTestClientConn) } +func testTestClientConn(t testing.TB) { // newTestClientConn creates a *ClientConn and surrounding test infrastructure. tc := newTestClientConn(t) @@ -91,12 +95,11 @@ func TestTestClientConn(t *testing.T) { // testClientConn manages synchronization, so tests can generally be written as // a linear sequence of actions and validations without additional synchronization. type testClientConn struct { - t *testing.T + t testing.TB - tr *Transport - fr *Framer - cc *ClientConn - group *synctestGroup + tr *Transport + fr *Framer + cc *ClientConn testConnFramer encbuf bytes.Buffer @@ -107,12 +110,11 @@ type testClientConn struct { netconn *synctestNetConn } -func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn { +func newTestClientConnFromClientConn(t testing.TB, cc *ClientConn) *testClientConn { tc := &testClientConn{ - t: t, - tr: cc.t, - cc: cc, - group: cc.t.transportTestHooks.group.(*synctestGroup), + t: t, + tr: cc.t, + cc: cc, } // srv is the side controlled by the test. @@ -121,7 +123,7 @@ func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientCo // If cc.tconn is nil, we're being called with a new conn created by the // Transport's client pool. This path skips dialing the server, and we // create a test connection pair here. - cc.tconn, srv = synctestNetPipe(tc.group) + cc.tconn, srv = synctestNetPipe() } else { // If cc.tconn is non-nil, we're in a test which provides a conn to the // Transport via a TLSNextProto hook. Extract the test connection pair. @@ -133,7 +135,7 @@ func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientCo srv = cc.tconn.(*synctestNetConn).peer } - srv.SetReadDeadline(tc.group.Now()) + srv.SetReadDeadline(time.Now()) srv.autoWait = true tc.netconn = srv tc.enc = hpack.NewEncoder(&tc.encbuf) @@ -163,7 +165,7 @@ func (tc *testClientConn) readClientPreface() { } } -func newTestClientConn(t *testing.T, opts ...any) *testClientConn { +func newTestClientConn(t testing.TB, opts ...any) *testClientConn { t.Helper() tt := newTestTransport(t, opts...) @@ -176,18 +178,6 @@ func newTestClientConn(t *testing.T, opts ...any) *testClientConn { return tt.getConn() } -// sync waits for the ClientConn under test to reach a stable state, -// with all goroutines blocked on some input. -func (tc *testClientConn) sync() { - tc.group.Wait() -} - -// advance advances synthetic time by a duration. -func (tc *testClientConn) advance(d time.Duration) { - tc.group.AdvanceTime(d) - tc.sync() -} - // hasFrame reports whether a frame is available to be read. func (tc *testClientConn) hasFrame() bool { return len(tc.netconn.Peek()) > 0 @@ -204,6 +194,13 @@ func (tc *testClientConn) closeWrite() { tc.netconn.Close() } +// closeWrite causes the net.Conn used by the ClientConn to return a error +// from Write calls. +func (tc *testClientConn) closeWriteWithError(err error) { + tc.netconn.loc.setReadError(io.EOF) + tc.netconn.loc.setWriteError(err) +} + // testRequestBody is a Request.Body for use in tests. type testRequestBody struct { tc *testClientConn @@ -258,17 +255,17 @@ func (b *testRequestBody) Close() error { // writeBytes adds n arbitrary bytes to the body. func (b *testRequestBody) writeBytes(n int) { - defer b.tc.sync() + defer synctest.Wait() b.gate.Lock() defer b.unlock() b.bytes += n b.checkWrite() - b.tc.sync() + synctest.Wait() } // Write adds bytes to the body. func (b *testRequestBody) Write(p []byte) (int, error) { - defer b.tc.sync() + defer synctest.Wait() b.gate.Lock() defer b.unlock() n, err := b.buf.Write(p) @@ -287,7 +284,7 @@ func (b *testRequestBody) checkWrite() { // closeWithError sets an error which will be returned by Read. func (b *testRequestBody) closeWithError(err error) { - defer b.tc.sync() + defer synctest.Wait() b.gate.Lock() defer b.unlock() b.err = err @@ -304,13 +301,12 @@ func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip { } tc.roundtrips = append(tc.roundtrips, rt) go func() { - tc.group.Join() defer close(rt.donec) rt.resp, rt.respErr = tc.cc.roundTrip(req, func(cs *clientStream) { rt.id.Store(cs.ID) }) }() - tc.sync() + synctest.Wait() tc.t.Cleanup(func() { if !rt.done() { @@ -366,7 +362,7 @@ func (tc *testClientConn) inflowWindow(streamID uint32) int32 { // testRoundTrip manages a RoundTrip in progress. type testRoundTrip struct { - t *testing.T + t testing.TB resp *http.Response respErr error donec chan struct{} @@ -396,6 +392,7 @@ func (rt *testRoundTrip) done() bool { func (rt *testRoundTrip) result() (*http.Response, error) { t := rt.t t.Helper() + synctest.Wait() select { case <-rt.donec: default: @@ -494,19 +491,16 @@ func diffHeaders(got, want http.Header) string { // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling // should use testClientConn instead. type testTransport struct { - t *testing.T - tr *Transport - group *synctestGroup + t testing.TB + tr *Transport ccs []*testClientConn } -func newTestTransport(t *testing.T, opts ...any) *testTransport { +func newTestTransport(t testing.TB, opts ...any) *testTransport { tt := &testTransport{ - t: t, - group: newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)), + t: t, } - tt.group.Join() tr := &Transport{} for _, o := range opts { @@ -525,7 +519,6 @@ func newTestTransport(t *testing.T, opts ...any) *testTransport { tt.tr = tr tr.transportTestHooks = &transportTestHooks{ - group: tt.group, newclientconn: func(cc *ClientConn) { tc := newTestClientConnFromClientConn(t, cc) tt.ccs = append(tt.ccs, tc) @@ -533,25 +526,15 @@ func newTestTransport(t *testing.T, opts ...any) *testTransport { } t.Cleanup(func() { - tt.sync() + synctest.Wait() if len(tt.ccs) > 0 { t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs)) } - tt.group.Close(t) }) return tt } -func (tt *testTransport) sync() { - tt.group.Wait() -} - -func (tt *testTransport) advance(d time.Duration) { - tt.group.AdvanceTime(d) - tt.sync() -} - func (tt *testTransport) hasConn() bool { return len(tt.ccs) > 0 } @@ -563,9 +546,9 @@ func (tt *testTransport) getConn() *testClientConn { } tc := tt.ccs[0] tt.ccs = tt.ccs[1:] - tc.sync() + synctest.Wait() tc.readClientPreface() - tc.sync() + synctest.Wait() return tc } @@ -575,11 +558,10 @@ func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip { donec: make(chan struct{}), } go func() { - tt.group.Join() defer close(rt.donec) rt.resp, rt.respErr = tt.tr.RoundTrip(req) }() - tt.sync() + synctest.Wait() tt.t.Cleanup(func() { if !rt.done() { diff --git a/http2/config.go b/http2/config.go index ca645d9a1a..02fe0c2d48 100644 --- a/http2/config.go +++ b/http2/config.go @@ -55,7 +55,7 @@ func configFromServer(h1 *http.Server, h2 *Server) http2Config { PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites, CountError: h2.CountError, } - fillNetHTTPServerConfig(&conf, h1) + fillNetHTTPConfig(&conf, h1.HTTP2) setConfigDefaults(&conf, true) return conf } @@ -81,7 +81,7 @@ func configFromTransport(h2 *Transport) http2Config { } if h2.t1 != nil { - fillNetHTTPTransportConfig(&conf, h2.t1) + fillNetHTTPConfig(&conf, h2.t1.HTTP2) } setConfigDefaults(&conf, false) return conf @@ -120,3 +120,45 @@ func adjustHTTP1MaxHeaderSize(n int64) int64 { const typicalHeaders = 10 // conservative return n + typicalHeaders*perFieldOverhead } + +func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) { + if h2 == nil { + return + } + if h2.MaxConcurrentStreams != 0 { + conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) + } + if h2.MaxEncoderHeaderTableSize != 0 { + conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize) + } + if h2.MaxDecoderHeaderTableSize != 0 { + conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize) + } + if h2.MaxConcurrentStreams != 0 { + conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) + } + if h2.MaxReadFrameSize != 0 { + conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize) + } + if h2.MaxReceiveBufferPerConnection != 0 { + conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection) + } + if h2.MaxReceiveBufferPerStream != 0 { + conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream) + } + if h2.SendPingTimeout != 0 { + conf.SendPingTimeout = h2.SendPingTimeout + } + if h2.PingTimeout != 0 { + conf.PingTimeout = h2.PingTimeout + } + if h2.WriteByteTimeout != 0 { + conf.WriteByteTimeout = h2.WriteByteTimeout + } + if h2.PermitProhibitedCipherSuites { + conf.PermitProhibitedCipherSuites = true + } + if h2.CountError != nil { + conf.CountError = h2.CountError + } +} diff --git a/http2/config_go124.go b/http2/config_go124.go deleted file mode 100644 index 5b516c55ff..0000000000 --- a/http2/config_go124.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.24 - -package http2 - -import "net/http" - -// fillNetHTTPServerConfig sets fields in conf from srv.HTTP2. -func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) { - fillNetHTTPConfig(conf, srv.HTTP2) -} - -// fillNetHTTPTransportConfig sets fields in conf from tr.HTTP2. -func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) { - fillNetHTTPConfig(conf, tr.HTTP2) -} - -func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) { - if h2 == nil { - return - } - if h2.MaxConcurrentStreams != 0 { - conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) - } - if h2.MaxEncoderHeaderTableSize != 0 { - conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize) - } - if h2.MaxDecoderHeaderTableSize != 0 { - conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize) - } - if h2.MaxConcurrentStreams != 0 { - conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams) - } - if h2.MaxReadFrameSize != 0 { - conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize) - } - if h2.MaxReceiveBufferPerConnection != 0 { - conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection) - } - if h2.MaxReceiveBufferPerStream != 0 { - conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream) - } - if h2.SendPingTimeout != 0 { - conf.SendPingTimeout = h2.SendPingTimeout - } - if h2.PingTimeout != 0 { - conf.PingTimeout = h2.PingTimeout - } - if h2.WriteByteTimeout != 0 { - conf.WriteByteTimeout = h2.WriteByteTimeout - } - if h2.PermitProhibitedCipherSuites { - conf.PermitProhibitedCipherSuites = true - } - if h2.CountError != nil { - conf.CountError = h2.CountError - } -} diff --git a/http2/config_pre_go124.go b/http2/config_pre_go124.go deleted file mode 100644 index 060fd6c64c..0000000000 --- a/http2/config_pre_go124.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.24 - -package http2 - -import "net/http" - -// Pre-Go 1.24 fallback. -// The Server.HTTP2 and Transport.HTTP2 config fields were added in Go 1.24. - -func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {} - -func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {} diff --git a/http2/config_test.go b/http2/config_test.go index b8e7a7b043..88e05e0aa4 100644 --- a/http2/config_test.go +++ b/http2/config_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.24 +//go:build go1.25 || goexperiment.synctest package http2 @@ -12,7 +12,8 @@ import ( "time" ) -func TestConfigServerSettings(t *testing.T) { +func TestConfigServerSettings(t *testing.T) { synctestTest(t, testConfigServerSettings) } +func testConfigServerSettings(t testing.TB) { config := &http.HTTP2Config{ MaxConcurrentStreams: 1, MaxDecoderHeaderTableSize: 1<<20 + 2, @@ -37,7 +38,8 @@ func TestConfigServerSettings(t *testing.T) { }) } -func TestConfigTransportSettings(t *testing.T) { +func TestConfigTransportSettings(t *testing.T) { synctestTest(t, testConfigTransportSettings) } +func testConfigTransportSettings(t testing.TB) { config := &http.HTTP2Config{ MaxConcurrentStreams: 1, // ignored by Transport MaxDecoderHeaderTableSize: 1<<20 + 2, @@ -60,7 +62,8 @@ func TestConfigTransportSettings(t *testing.T) { tc.wantWindowUpdate(0, uint32(config.MaxReceiveBufferPerConnection)) } -func TestConfigPingTimeoutServer(t *testing.T) { +func TestConfigPingTimeoutServer(t *testing.T) { synctestTest(t, testConfigPingTimeoutServer) } +func testConfigPingTimeoutServer(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(s *Server) { s.ReadIdleTimeout = 2 * time.Second @@ -68,13 +71,14 @@ func TestConfigPingTimeoutServer(t *testing.T) { }) st.greet() - st.advance(2 * time.Second) + time.Sleep(2 * time.Second) _ = readFrame[*PingFrame](t, st) - st.advance(3 * time.Second) + time.Sleep(3 * time.Second) st.wantClosed() } -func TestConfigPingTimeoutTransport(t *testing.T) { +func TestConfigPingTimeoutTransport(t *testing.T) { synctestTest(t, testConfigPingTimeoutTransport) } +func testConfigPingTimeoutTransport(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.ReadIdleTimeout = 2 * time.Second tr.PingTimeout = 3 * time.Second @@ -85,9 +89,9 @@ func TestConfigPingTimeoutTransport(t *testing.T) { rt := tc.roundTrip(req) tc.wantFrameType(FrameHeaders) - tc.advance(2 * time.Second) + time.Sleep(2 * time.Second) tc.wantFrameType(FramePing) - tc.advance(3 * time.Second) + time.Sleep(3 * time.Second) err := rt.err() if err == nil { t.Fatalf("expected connection to close") diff --git a/http2/connframes_test.go b/http2/connframes_test.go index 2c4532571a..e3f8a96e52 100644 --- a/http2/connframes_test.go +++ b/http2/connframes_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( diff --git a/http2/frame_test.go b/http2/frame_test.go index 68505317e1..dfeff53a8f 100644 --- a/http2/frame_test.go +++ b/http2/frame_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( diff --git a/http2/gotrack.go b/http2/gotrack.go index 9933c9f8c7..9921ca096d 100644 --- a/http2/gotrack.go +++ b/http2/gotrack.go @@ -15,21 +15,32 @@ import ( "runtime" "strconv" "sync" + "sync/atomic" ) var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" +// Setting DebugGoroutines to false during a test to disable goroutine debugging +// results in race detector complaints when a test leaves goroutines running before +// returning. Tests shouldn't do this, of course, but when they do it generally shows +// up as infrequent, hard-to-debug flakes. (See #66519.) +// +// Disable goroutine debugging during individual tests with an atomic bool. +// (Note that it's safe to enable/disable debugging mid-test, so the actual race condition +// here is harmless.) +var disableDebugGoroutines atomic.Bool + type goroutineLock uint64 func newGoroutineLock() goroutineLock { - if !DebugGoroutines { + if !DebugGoroutines || disableDebugGoroutines.Load() { return 0 } return goroutineLock(curGoroutineID()) } func (g goroutineLock) check() { - if !DebugGoroutines { + if !DebugGoroutines || disableDebugGoroutines.Load() { return } if curGoroutineID() != uint64(g) { @@ -38,7 +49,7 @@ func (g goroutineLock) check() { } func (g goroutineLock) checkNotOn() { - if !DebugGoroutines { + if !DebugGoroutines || disableDebugGoroutines.Load() { return } if curGoroutineID() == uint64(g) { diff --git a/http2/http2.go b/http2/http2.go index ea5ae629fd..6878f8ecc9 100644 --- a/http2/http2.go +++ b/http2/http2.go @@ -15,7 +15,6 @@ package http2 // import "golang.org/x/net/http2" import ( "bufio" - "context" "crypto/tls" "errors" "fmt" @@ -255,15 +254,13 @@ func (cw closeWaiter) Wait() { // idle memory usage with many connections. type bufferedWriter struct { _ incomparable - group synctestGroupInterface // immutable - conn net.Conn // immutable - bw *bufio.Writer // non-nil when data is buffered - byteTimeout time.Duration // immutable, WriteByteTimeout + conn net.Conn // immutable + bw *bufio.Writer // non-nil when data is buffered + byteTimeout time.Duration // immutable, WriteByteTimeout } -func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter { +func newBufferedWriter(conn net.Conn, timeout time.Duration) *bufferedWriter { return &bufferedWriter{ - group: group, conn: conn, byteTimeout: timeout, } @@ -314,24 +311,18 @@ func (w *bufferedWriter) Flush() error { type bufferedWriterTimeoutWriter bufferedWriter func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) { - return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p) + return writeWithByteTimeout(w.conn, w.byteTimeout, p) } // writeWithByteTimeout writes to conn. // If more than timeout passes without any bytes being written to the connection, // the write fails. -func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) { +func writeWithByteTimeout(conn net.Conn, timeout time.Duration, p []byte) (n int, err error) { if timeout <= 0 { return conn.Write(p) } for { - var now time.Time - if group == nil { - now = time.Now() - } else { - now = group.Now() - } - conn.SetWriteDeadline(now.Add(timeout)) + conn.SetWriteDeadline(time.Now().Add(timeout)) nn, err := conn.Write(p[n:]) n += nn if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) { @@ -417,14 +408,3 @@ func (s *sorter) SortStrings(ss []string) { // makes that struct also non-comparable, and generally doesn't add // any size (as long as it's first). type incomparable [0]func() - -// synctestGroupInterface is the methods of synctestGroup used by Server and Transport. -// It's defined as an interface here to let us keep synctestGroup entirely test-only -// and not a part of non-test builds. -type synctestGroupInterface interface { - Join() - Now() time.Time - NewTimer(d time.Duration) timer - AfterFunc(d time.Duration, f func()) timer - ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) -} diff --git a/http2/http2_test.go b/http2/http2_test.go index c7774133a7..5fec656401 100644 --- a/http2/http2_test.go +++ b/http2/http2_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -68,7 +70,7 @@ func (w twriter) Write(p []byte) (n int, err error) { } // like encodeHeader, but don't add implicit pseudo headers. -func encodeHeaderNoImplicit(t *testing.T, headers ...string) []byte { +func encodeHeaderNoImplicit(t testing.TB, headers ...string) []byte { var buf bytes.Buffer enc := hpack.NewEncoder(&buf) for len(headers) > 0 { @@ -81,35 +83,6 @@ func encodeHeaderNoImplicit(t *testing.T, headers ...string) []byte { return buf.Bytes() } -type puppetCommand struct { - fn func(w http.ResponseWriter, r *http.Request) - done chan<- bool -} - -type handlerPuppet struct { - ch chan puppetCommand -} - -func newHandlerPuppet() *handlerPuppet { - return &handlerPuppet{ - ch: make(chan puppetCommand), - } -} - -func (p *handlerPuppet) act(w http.ResponseWriter, r *http.Request) { - for cmd := range p.ch { - cmd.fn(w, r) - cmd.done <- true - } -} - -func (p *handlerPuppet) done() { close(p.ch) } -func (p *handlerPuppet) do(fn func(http.ResponseWriter, *http.Request)) { - done := make(chan bool) - p.ch <- puppetCommand{fn, done} - <-done -} - func cleanDate(res *http.Response) { if d := res.Header["Date"]; len(d) == 1 { d[0] = "XXX" @@ -285,7 +258,7 @@ func TestNoUnicodeStrings(t *testing.T) { } // setForTest sets *p = v, and restores its original value in t.Cleanup. -func setForTest[T any](t *testing.T, p *T, v T) { +func setForTest[T any](t testing.TB, p *T, v T) { orig := *p t.Cleanup(func() { *p = orig @@ -300,3 +273,11 @@ func must[T any](v T, err error) T { } return v } + +// synctestSubtest starts a subtest and runs f in a synctest bubble within it. +func synctestSubtest(t *testing.T, name string, f func(testing.TB)) { + t.Helper() + t.Run(name, func(t *testing.T) { + synctestTest(t, f) + }) +} diff --git a/http2/netconn_test.go b/http2/netconn_test.go index 5a1759579e..ffa87ec7ab 100644 --- a/http2/netconn_test.go +++ b/http2/netconn_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -14,6 +16,7 @@ import ( "net/netip" "os" "sync" + "testing/synctest" "time" ) @@ -23,13 +26,13 @@ import ( // Unlike net.Pipe, the connection is not synchronous. // Writes are made to a buffer, and return immediately. // By default, the buffer size is unlimited. -func synctestNetPipe(group *synctestGroup) (r, w *synctestNetConn) { +func synctestNetPipe() (r, w *synctestNetConn) { s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000")) s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001")) s1 := newSynctestNetConnHalf(s1addr) s2 := newSynctestNetConnHalf(s2addr) - r = &synctestNetConn{group: group, loc: s1, rem: s2} - w = &synctestNetConn{group: group, loc: s2, rem: s1} + r = &synctestNetConn{loc: s1, rem: s2} + w = &synctestNetConn{loc: s2, rem: s1} r.peer = w w.peer = r return r, w @@ -37,8 +40,6 @@ func synctestNetPipe(group *synctestGroup) (r, w *synctestNetConn) { // A synctestNetConn is one endpoint of the connection created by synctestNetPipe. type synctestNetConn struct { - group *synctestGroup - // local and remote connection halves. // Each half contains a buffer. // Reads pull from the local buffer, and writes push to the remote buffer. @@ -54,7 +55,7 @@ type synctestNetConn struct { // Read reads data from the connection. func (c *synctestNetConn) Read(b []byte) (n int, err error) { if c.autoWait { - c.group.Wait() + synctest.Wait() } return c.loc.read(b) } @@ -63,7 +64,7 @@ func (c *synctestNetConn) Read(b []byte) (n int, err error) { // without consuming its contents. func (c *synctestNetConn) Peek() []byte { if c.autoWait { - c.group.Wait() + synctest.Wait() } return c.loc.peek() } @@ -71,7 +72,7 @@ func (c *synctestNetConn) Peek() []byte { // Write writes data to the connection. func (c *synctestNetConn) Write(b []byte) (n int, err error) { if c.autoWait { - defer c.group.Wait() + defer synctest.Wait() } return c.rem.write(b) } @@ -79,7 +80,7 @@ func (c *synctestNetConn) Write(b []byte) (n int, err error) { // IsClosedByPeer reports whether the peer has closed its end of the connection. func (c *synctestNetConn) IsClosedByPeer() bool { if c.autoWait { - c.group.Wait() + synctest.Wait() } return c.loc.isClosedByPeer() } @@ -89,7 +90,7 @@ func (c *synctestNetConn) Close() error { c.loc.setWriteError(errors.New("connection closed by peer")) c.rem.setReadError(io.EOF) if c.autoWait { - c.group.Wait() + synctest.Wait() } return nil } @@ -113,13 +114,13 @@ func (c *synctestNetConn) SetDeadline(t time.Time) error { // SetReadDeadline sets the read deadline for the connection. func (c *synctestNetConn) SetReadDeadline(t time.Time) error { - c.loc.rctx.setDeadline(c.group, t) + c.loc.rctx.setDeadline(t) return nil } // SetWriteDeadline sets the write deadline for the connection. func (c *synctestNetConn) SetWriteDeadline(t time.Time) error { - c.rem.wctx.setDeadline(c.group, t) + c.rem.wctx.setDeadline(t) return nil } @@ -305,7 +306,7 @@ type deadlineContext struct { mu sync.Mutex ctx context.Context cancel context.CancelCauseFunc - timer timer + timer *time.Timer } // context returns a Context which expires when the deadline does. @@ -319,7 +320,7 @@ func (t *deadlineContext) context() context.Context { } // setDeadline sets the current deadline. -func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) { +func (t *deadlineContext) setDeadline(deadline time.Time) { t.mu.Lock() defer t.mu.Unlock() // If t.ctx is non-nil and t.cancel is nil, then t.ctx was canceled @@ -335,7 +336,7 @@ func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) // No deadline. return } - if !deadline.After(group.Now()) { + if !deadline.After(time.Now()) { // Deadline has already expired. t.cancel(os.ErrDeadlineExceeded) t.cancel = nil @@ -343,11 +344,11 @@ func (t *deadlineContext) setDeadline(group *synctestGroup, deadline time.Time) } if t.timer != nil { // Reuse existing deadline timer. - t.timer.Reset(deadline.Sub(group.Now())) + t.timer.Reset(deadline.Sub(time.Now())) return } // Create a new timer to cancel the context at the deadline. - t.timer = group.AfterFunc(deadline.Sub(group.Now()), func() { + t.timer = time.AfterFunc(deadline.Sub(time.Now()), func() { t.mu.Lock() defer t.mu.Unlock() t.cancel(os.ErrDeadlineExceeded) diff --git a/http2/server.go b/http2/server.go index 51fca38f61..64085f6e16 100644 --- a/http2/server.go +++ b/http2/server.go @@ -176,39 +176,6 @@ type Server struct { // so that we don't embed a Mutex in this struct, which will make the // struct non-copyable, which might break some callers. state *serverInternalState - - // Synchronization group used for testing. - // Outside of tests, this is nil. - group synctestGroupInterface -} - -func (s *Server) markNewGoroutine() { - if s.group != nil { - s.group.Join() - } -} - -func (s *Server) now() time.Time { - if s.group != nil { - return s.group.Now() - } - return time.Now() -} - -// newTimer creates a new time.Timer, or a synthetic timer in tests. -func (s *Server) newTimer(d time.Duration) timer { - if s.group != nil { - return s.group.NewTimer(d) - } - return timeTimer{time.NewTimer(d)} -} - -// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. -func (s *Server) afterFunc(d time.Duration, f func()) timer { - if s.group != nil { - return s.group.AfterFunc(d, f) - } - return timeTimer{time.AfterFunc(d, f)} } type serverInternalState struct { @@ -423,6 +390,9 @@ func (o *ServeConnOpts) handler() http.Handler { // // The opts parameter is optional. If nil, default values are used. func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { + if opts == nil { + opts = &ServeConnOpts{} + } s.serveConn(c, opts, nil) } @@ -438,7 +408,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon conn: c, baseCtx: baseCtx, remoteAddrStr: c.RemoteAddr().String(), - bw: newBufferedWriter(s.group, c, conf.WriteByteTimeout), + bw: newBufferedWriter(c, conf.WriteByteTimeout), handler: opts.handler(), streams: make(map[uint32]*stream), readFrameCh: make(chan readFrameResult), @@ -638,11 +608,11 @@ type serverConn struct { pingSent bool sentPingData [8]byte goAwayCode ErrCode - shutdownTimer timer // nil until used - idleTimer timer // nil if unused + shutdownTimer *time.Timer // nil until used + idleTimer *time.Timer // nil if unused readIdleTimeout time.Duration pingTimeout time.Duration - readIdleTimer timer // nil if unused + readIdleTimer *time.Timer // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -687,12 +657,12 @@ type stream struct { flow outflow // limits writing from Handler to client inflow inflow // what the client is allowed to POST/etc to us state streamState - resetQueued bool // RST_STREAM queued for write; set by sc.resetStream - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - readDeadline timer // nil if unused - writeDeadline timer // nil if unused - closeErr error // set before cw is closed + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + readDeadline *time.Timer // nil if unused + writeDeadline *time.Timer // nil if unused + closeErr error // set before cw is closed trailer http.Header // accumulated trailers reqTrailer http.Header // handler's Request.Trailer @@ -848,7 +818,6 @@ type readFrameResult struct { // consumer is done with the frame. // It's run on its own goroutine. func (sc *serverConn) readFrames() { - sc.srv.markNewGoroutine() gate := make(chan struct{}) gateDone := func() { gate <- struct{}{} } for { @@ -881,7 +850,6 @@ type frameWriteResult struct { // At most one goroutine can be running writeFrameAsync at a time per // serverConn. func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) { - sc.srv.markNewGoroutine() var err error if wd == nil { err = wr.write.writeFrame(sc) @@ -965,22 +933,22 @@ func (sc *serverConn) serve(conf http2Config) { sc.setConnState(http.StateIdle) if sc.srv.IdleTimeout > 0 { - sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) + sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() } if conf.SendPingTimeout > 0 { sc.readIdleTimeout = conf.SendPingTimeout - sc.readIdleTimer = sc.srv.afterFunc(conf.SendPingTimeout, sc.onReadIdleTimer) + sc.readIdleTimer = time.AfterFunc(conf.SendPingTimeout, sc.onReadIdleTimer) defer sc.readIdleTimer.Stop() } go sc.readFrames() // closed by defer sc.conn.Close above - settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer) + settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer) defer settingsTimer.Stop() - lastFrameTime := sc.srv.now() + lastFrameTime := time.Now() loopNum := 0 for { loopNum++ @@ -994,7 +962,7 @@ func (sc *serverConn) serve(conf http2Config) { case res := <-sc.wroteFrameCh: sc.wroteFrame(res) case res := <-sc.readFrameCh: - lastFrameTime = sc.srv.now() + lastFrameTime = time.Now() // Process any written frames before reading new frames from the client since a // written frame could have triggered a new stream to be started. if sc.writingFrameAsync { @@ -1077,7 +1045,7 @@ func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) { } pingAt := lastFrameReadTime.Add(sc.readIdleTimeout) - now := sc.srv.now() + now := time.Now() if pingAt.After(now) { // We received frames since arming the ping timer. // Reset it for the next possible timeout. @@ -1141,10 +1109,10 @@ func (sc *serverConn) readPreface() error { errc <- nil } }() - timer := sc.srv.newTimer(prefaceTimeout) // TODO: configurable on *Server? + timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server? defer timer.Stop() select { - case <-timer.C(): + case <-timer.C: return errPrefaceTimeout case err := <-errc: if err == nil { @@ -1160,6 +1128,21 @@ var errChanPool = sync.Pool{ New: func() interface{} { return make(chan error, 1) }, } +func getErrChan() chan error { + if inTests { + // Channels cannot be reused across synctest tests. + return make(chan error, 1) + } else { + return errChanPool.Get().(chan error) + } +} + +func putErrChan(ch chan error) { + if !inTests { + errChanPool.Put(ch) + } +} + var writeDataPool = sync.Pool{ New: func() interface{} { return new(writeData) }, } @@ -1167,7 +1150,7 @@ var writeDataPool = sync.Pool{ // writeDataFromHandler writes DATA response frames from a handler on // the given stream. func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error { - ch := errChanPool.Get().(chan error) + ch := getErrChan() writeArg := writeDataPool.Get().(*writeData) *writeArg = writeData{stream.id, data, endStream} err := sc.writeFrameFromHandler(FrameWriteRequest{ @@ -1199,7 +1182,7 @@ func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStrea return errStreamClosed } } - errChanPool.Put(ch) + putErrChan(ch) if frameWriteDone { writeDataPool.Put(writeArg) } @@ -1513,7 +1496,7 @@ func (sc *serverConn) goAway(code ErrCode) { func (sc *serverConn) shutDownIn(d time.Duration) { sc.serveG.check() - sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer) + sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) } func (sc *serverConn) resetStream(se StreamError) { @@ -2118,7 +2101,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // (in Go 1.8), though. That's a more sane option anyway. if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) - st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout) + st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } return sc.scheduleHandler(id, rw, req, handler) @@ -2216,7 +2199,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.initialStreamRecvWindowSize) if sc.hs.WriteTimeout > 0 { - st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) + st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } sc.streams[id] = st @@ -2405,7 +2388,6 @@ func (sc *serverConn) handlerDone() { // Run on its own goroutine. func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { - sc.srv.markNewGoroutine() defer sc.sendServeMsg(handlerDoneMsg) didPanic := true defer func() { @@ -2454,7 +2436,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro // waiting for this frame to be written, so an http.Flush mid-handler // writes out the correct value of keys, before a handler later potentially // mutates it. - errc = errChanPool.Get().(chan error) + errc = getErrChan() } if err := sc.writeFrameFromHandler(FrameWriteRequest{ write: headerData, @@ -2466,7 +2448,7 @@ func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) erro if errc != nil { select { case err := <-errc: - errChanPool.Put(errc) + putErrChan(errc) return err case <-sc.doneServing: return errClientDisconnected @@ -2573,7 +2555,7 @@ func (b *requestBody) Read(p []byte) (n int, err error) { if err == io.EOF { b.sawEOF = true } - if b.conn == nil && inTests { + if b.conn == nil { return } b.conn.noteBodyReadFromHandler(b.stream, n, err) @@ -2702,7 +2684,7 @@ func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) { var date string if _, ok := rws.snapHeader["Date"]; !ok { // TODO(bradfitz): be faster here, like net/http? measure. - date = rws.conn.srv.now().UTC().Format(http.TimeFormat) + date = time.Now().UTC().Format(http.TimeFormat) } for _, v := range rws.snapHeader["Trailer"] { @@ -2824,7 +2806,7 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() { func (w *responseWriter) SetReadDeadline(deadline time.Time) error { st := w.rws.stream - if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) { + if !deadline.IsZero() && deadline.Before(time.Now()) { // If we're setting a deadline in the past, reset the stream immediately // so writes after SetWriteDeadline returns will fail. st.onReadTimeout() @@ -2840,9 +2822,9 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error { if deadline.IsZero() { st.readDeadline = nil } else if st.readDeadline == nil { - st.readDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onReadTimeout) + st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout) } else { - st.readDeadline.Reset(deadline.Sub(sc.srv.now())) + st.readDeadline.Reset(deadline.Sub(time.Now())) } }) return nil @@ -2850,7 +2832,7 @@ func (w *responseWriter) SetReadDeadline(deadline time.Time) error { func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { st := w.rws.stream - if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) { + if !deadline.IsZero() && deadline.Before(time.Now()) { // If we're setting a deadline in the past, reset the stream immediately // so writes after SetWriteDeadline returns will fail. st.onWriteTimeout() @@ -2866,9 +2848,9 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error { if deadline.IsZero() { st.writeDeadline = nil } else if st.writeDeadline == nil { - st.writeDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onWriteTimeout) + st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout) } else { - st.writeDeadline.Reset(deadline.Sub(sc.srv.now())) + st.writeDeadline.Reset(deadline.Sub(time.Now())) } }) return nil @@ -3147,7 +3129,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error { method: opts.Method, url: u, header: cloneHeader(opts.Header), - done: errChanPool.Get().(chan error), + done: getErrChan(), } select { @@ -3164,7 +3146,7 @@ func (w *responseWriter) Push(target string, opts *http.PushOptions) error { case <-st.cw: return errStreamClosed case err := <-msg.done: - errChanPool.Put(msg.done) + putErrChan(msg.done) return err } } diff --git a/http2/server_push_test.go b/http2/server_push_test.go index 69e4c3b12d..ea0a1b260c 100644 --- a/http2/server_push_test.go +++ b/http2/server_push_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -10,14 +12,14 @@ import ( "io" "net/http" "reflect" - "runtime" "strconv" - "sync" "testing" + "testing/synctest" "time" ) -func TestServer_Push_Success(t *testing.T) { +func TestServer_Push_Success(t *testing.T) { synctestTest(t, testServer_Push_Success) } +func testServer_Push_Success(t testing.TB) { const ( mainBody = "index page" pushedBody = "pushed page" @@ -242,7 +244,8 @@ func TestServer_Push_Success(t *testing.T) { } } -func TestServer_Push_SuccessNoRace(t *testing.T) { +func TestServer_Push_SuccessNoRace(t *testing.T) { synctestTest(t, testServer_Push_SuccessNoRace) } +func testServer_Push_SuccessNoRace(t testing.TB) { // Regression test for issue #18326. Ensure the request handler can mutate // pushed request headers without racing with the PUSH_PROMISE write. errc := make(chan error, 2) @@ -287,6 +290,9 @@ func TestServer_Push_SuccessNoRace(t *testing.T) { } func TestServer_Push_RejectRecursivePush(t *testing.T) { + synctestTest(t, testServer_Push_RejectRecursivePush) +} +func testServer_Push_RejectRecursivePush(t testing.TB) { // Expect two requests, but might get three if there's a bug and the second push succeeds. errc := make(chan error, 3) handler := func(w http.ResponseWriter, r *http.Request) error { @@ -323,6 +329,11 @@ func TestServer_Push_RejectRecursivePush(t *testing.T) { } func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) { + synctestTest(t, func(t testing.TB) { + testServer_Push_RejectSingleRequest_Bubble(t, doPush, settings...) + }) +} +func testServer_Push_RejectSingleRequest_Bubble(t testing.TB, doPush func(http.Pusher, *http.Request) error, settings ...Setting) { // Expect one request, but might get two if there's a bug and the push succeeds. errc := make(chan error, 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -426,6 +437,9 @@ func TestServer_Push_RejectForbiddenHeader(t *testing.T) { } func TestServer_Push_StateTransitions(t *testing.T) { + synctestTest(t, testServer_Push_StateTransitions) +} +func testServer_Push_StateTransitions(t testing.TB) { const body = "foo" gotPromise := make(chan bool) @@ -479,7 +493,9 @@ func TestServer_Push_StateTransitions(t *testing.T) { } func TestServer_Push_RejectAfterGoAway(t *testing.T) { - var readyOnce sync.Once + synctestTest(t, testServer_Push_RejectAfterGoAway) +} +func testServer_Push_RejectAfterGoAway(t testing.TB) { ready := make(chan struct{}) errc := make(chan error, 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -495,30 +511,15 @@ func TestServer_Push_RejectAfterGoAway(t *testing.T) { // Send GOAWAY and wait for it to be processed. st.fr.WriteGoAway(1, ErrCodeNo, nil) - go func() { - for { - select { - case <-ready: - return - default: - if runtime.GOARCH == "wasm" { - // Work around https://go.dev/issue/65178 to avoid goroutine starvation. - runtime.Gosched() - } - } - st.sc.serveMsgCh <- func(loopNum int) { - if !st.sc.pushEnabled { - readyOnce.Do(func() { close(ready) }) - } - } - } - }() + synctest.Wait() + close(ready) if err := <-errc; err != nil { t.Error(err) } } -func TestServer_Push_Underflow(t *testing.T) { +func TestServer_Push_Underflow(t *testing.T) { synctestTest(t, testServer_Push_Underflow) } +func testServer_Push_Underflow(t testing.TB) { // Test for #63511: Send several requests which generate PUSH_PROMISE responses, // verify they all complete successfully. st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { diff --git a/http2/server_test.go b/http2/server_test.go index b27a127a5e..71287d1e56 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -26,6 +28,7 @@ import ( "strings" "sync" "testing" + "testing/synctest" "time" "golang.org/x/net/http2/hpack" @@ -67,7 +70,6 @@ func (sb *safeBuffer) Len() int { type serverTester struct { cc net.Conn // client conn t testing.TB - group *synctestGroup h1server *http.Server h2server *Server serverLogBuf safeBuffer // logger for httptest.Server @@ -76,6 +78,9 @@ type serverTester struct { sc *serverConn testConnFramer + callsMu sync.Mutex + calls []*serverHandlerCall + // If http2debug!=2, then we capture Frame debug logs that will be written // to t.Log after a test fails. The read and write logs use separate locks // and buffers so we don't accidentally introduce synchronization between @@ -149,15 +154,9 @@ var optQuiet = func(server *http.Server) { func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester { t.Helper() - g := newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)) - t.Cleanup(func() { - g.Close(t) - }) h1server := &http.Server{} - h2server := &Server{ - group: g, - } + h2server := &Server{} tlsState := tls.ConnectionState{ Version: tls.VersionTLS13, ServerName: "go.dev", @@ -177,14 +176,13 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} } ConfigureServer(h1server, h2server) - cli, srv := synctestNetPipe(g) - cli.SetReadDeadline(g.Now()) + cli, srv := synctestNetPipe() + cli.SetReadDeadline(time.Now()) cli.autoWait = true st := &serverTester{ t: t, cc: cli, - group: g, h1server: h1server, h2server: h2server, } @@ -193,14 +191,17 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} h1server.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags) } + if handler == nil { + handler = serverTesterHandler{st}.ServeHTTP + } + t.Cleanup(func() { st.Close() - g.AdvanceTime(goAwayTimeout) // give server time to shut down + time.Sleep(goAwayTimeout) // give server time to shut down }) connc := make(chan *serverConn) go func() { - g.Join() h2server.serveConn(&netConnWithConnectionState{ Conn: srv, state: tlsState, @@ -219,7 +220,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} fr: NewFramer(st.cc, st.cc), dec: hpack.NewDecoder(initialHeaderTableSize, nil), } - g.Wait() + synctest.Wait() return st } @@ -232,6 +233,50 @@ func (c *netConnWithConnectionState) ConnectionState() tls.ConnectionState { return c.state } +type serverTesterHandler struct { + st *serverTester +} + +func (h serverTesterHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + call := &serverHandlerCall{ + w: w, + req: req, + ch: make(chan func()), + } + h.st.t.Cleanup(call.exit) + h.st.callsMu.Lock() + h.st.calls = append(h.st.calls, call) + h.st.callsMu.Unlock() + for f := range call.ch { + f() + } +} + +// serverHandlerCall is a call to the server handler's ServeHTTP method. +type serverHandlerCall struct { + w http.ResponseWriter + req *http.Request + closeOnce sync.Once + ch chan func() +} + +// do executes f in the handler's goroutine. +func (call *serverHandlerCall) do(f func(http.ResponseWriter, *http.Request)) { + donec := make(chan struct{}) + call.ch <- func() { + defer close(donec) + f(call.w, call.req) + } + <-donec +} + +// exit causes the handler to return. +func (call *serverHandlerCall) exit() { + call.closeOnce.Do(func() { + close(call.ch) + }) +} + // newServerTesterWithRealConn creates a test server listening on a localhost port. // Mostly superseded by newServerTester, which creates a test server using a fake // net.Conn and synthetic time. This function is still around because some benchmarks @@ -333,14 +378,13 @@ func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts .. // sync waits for all goroutines to idle. func (st *serverTester) sync() { - if st.group != nil { - st.group.Wait() - } + synctest.Wait() } // advance advances synthetic time by a duration. func (st *serverTester) advance(d time.Duration) { - st.group.AdvanceTime(d) + time.Sleep(d) + synctest.Wait() } func (st *serverTester) authority() string { @@ -357,6 +401,19 @@ func (st *serverTester) addLogFilter(phrase string) { st.logFilter = append(st.logFilter, phrase) } +func (st *serverTester) nextHandlerCall() *serverHandlerCall { + st.t.Helper() + synctest.Wait() + st.callsMu.Lock() + defer st.callsMu.Unlock() + if len(st.calls) == 0 { + st.t.Fatal("expected server handler call, got none") + } + call := st.calls[0] + st.calls = st.calls[1:] + return call +} + func (st *serverTester) stream(id uint32) *stream { ch := make(chan *stream, 1) st.sc.serveMsgCh <- func(int) { @@ -383,23 +440,6 @@ func (st *serverTester) loopNum() int { return <-lastc } -// awaitIdle heuristically awaits for the server conn's select loop to be idle. -// The heuristic is that the server connection's serve loop must schedule -// 50 times in a row without any channel sends or receives occurring. -func (st *serverTester) awaitIdle() { - remain := 50 - last := st.loopNum() - for remain > 0 { - n := st.loopNum() - if n == last+1 { - remain-- - } else { - remain = 50 - } - last = n - } -} - func (st *serverTester) Close() { if st.t.Failed() { st.frameReadLogMu.Lock() @@ -591,30 +631,23 @@ func (st *serverTester) bodylessReq1(headers ...string) { }) } -func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) { +func (st *serverTester) wantConnFlowControlConsumed(consumed int32) { conf := configFromServer(st.sc.hs, st.sc.srv) - var initial int32 - if streamID == 0 { - initial = conf.MaxUploadBufferPerConnection - } else { - initial = conf.MaxUploadBufferPerStream - } donec := make(chan struct{}) st.sc.sendServeMsg(func(sc *serverConn) { defer close(donec) var avail int32 - if streamID == 0 { - avail = sc.inflow.avail + sc.inflow.unsent - } else { - } + initial := conf.MaxUploadBufferPerConnection + avail = sc.inflow.avail + sc.inflow.unsent if got, want := initial-avail, consumed; got != want { - st.t.Errorf("stream %v flow control consumed: %v, want %v", streamID, got, want) + st.t.Errorf("connection flow control consumed: %v, want %v", got, want) } }) <-donec } -func TestServer(t *testing.T) { +func TestServer(t *testing.T) { synctestTest(t, testServer) } +func testServer(t testing.TB) { gotReq := make(chan bool, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Foo", "Bar") @@ -633,7 +666,8 @@ func TestServer(t *testing.T) { <-gotReq } -func TestServer_Request_Get(t *testing.T) { +func TestServer_Request_Get(t *testing.T) { synctestTest(t, testServer_Request_Get) } +func testServer_Request_Get(t testing.TB) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -673,6 +707,9 @@ func TestServer_Request_Get(t *testing.T) { } func TestServer_Request_Get_PathSlashes(t *testing.T) { + synctestTest(t, testServer_Request_Get_PathSlashes) +} +func testServer_Request_Get_PathSlashes(t testing.TB) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -695,6 +732,9 @@ func TestServer_Request_Get_PathSlashes(t *testing.T) { // zero? func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) { + synctestTest(t, testServer_Request_Post_NoContentLength_EndStream) +} +func testServer_Request_Post_NoContentLength_EndStream(t testing.TB) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -716,6 +756,9 @@ func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) { } func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_ImmediateEOF) +} +func testServer_Request_Post_Body_ImmediateEOF(t testing.TB) { testBodyContents(t, -1, "", func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -728,6 +771,9 @@ func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) { } func TestServer_Request_Post_Body_OneData(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_OneData) +} +func testServer_Request_Post_Body_OneData(t testing.TB) { const content = "Some content" testBodyContents(t, -1, content, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -741,6 +787,9 @@ func TestServer_Request_Post_Body_OneData(t *testing.T) { } func TestServer_Request_Post_Body_TwoData(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_TwoData) +} +func testServer_Request_Post_Body_TwoData(t testing.TB) { const content = "Some content" testBodyContents(t, -1, content, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -755,6 +804,9 @@ func TestServer_Request_Post_Body_TwoData(t *testing.T) { } func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_ContentLength_Correct) +} +func testServer_Request_Post_Body_ContentLength_Correct(t testing.TB) { const content = "Some content" testBodyContents(t, int64(len(content)), content, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -771,6 +823,9 @@ func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) { } func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_ContentLength_TooLarge) +} +func testServer_Request_Post_Body_ContentLength_TooLarge(t testing.TB) { testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes", func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -787,6 +842,9 @@ func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) { } func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) { + synctestTest(t, testServer_Request_Post_Body_ContentLength_TooSmall) +} +func testServer_Request_Post_Body_ContentLength_TooSmall(t testing.TB) { testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes", func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -802,11 +860,11 @@ func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) { // Return flow control bytes back, since the data handler closed // the stream. st.wantRSTStream(1, ErrCodeProtocol) - st.wantFlowControlConsumed(0, 0) + st.wantConnFlowControlConsumed(0) }) } -func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) { +func testBodyContents(t testing.TB, wantContentLength int64, wantBody string, write func(st *serverTester)) { testServerRequest(t, write, func(r *http.Request) { if r.Method != "POST" { t.Errorf("Method = %q; want POST", r.Method) @@ -827,7 +885,7 @@ func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, wr }) } -func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) { +func testBodyContentsFail(t testing.TB, wantContentLength int64, wantReadError string, write func(st *serverTester)) { testServerRequest(t, write, func(r *http.Request) { if r.Method != "POST" { t.Errorf("Method = %q; want POST", r.Method) @@ -850,7 +908,8 @@ func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError s } // Using a Host header, instead of :authority -func TestServer_Request_Get_Host(t *testing.T) { +func TestServer_Request_Get_Host(t *testing.T) { synctestTest(t, testServer_Request_Get_Host) } +func testServer_Request_Get_Host(t testing.TB) { const host = "example.com" testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -868,6 +927,9 @@ func TestServer_Request_Get_Host(t *testing.T) { // Using an :authority pseudo-header, instead of Host func TestServer_Request_Get_Authority(t *testing.T) { + synctestTest(t, testServer_Request_Get_Authority) +} +func testServer_Request_Get_Authority(t testing.TB) { const host = "example.com" testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -884,6 +946,9 @@ func TestServer_Request_Get_Authority(t *testing.T) { } func TestServer_Request_WithContinuation(t *testing.T) { + synctestTest(t, testServer_Request_WithContinuation) +} +func testServer_Request_WithContinuation(t testing.TB) { wantHeader := http.Header{ "Foo-One": []string{"value-one"}, "Foo-Two": []string{"value-two"}, @@ -931,7 +996,8 @@ func TestServer_Request_WithContinuation(t *testing.T) { } // Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field") -func TestServer_Request_CookieConcat(t *testing.T) { +func TestServer_Request_CookieConcat(t *testing.T) { synctestTest(t, testServer_Request_CookieConcat) } +func testServer_Request_CookieConcat(t testing.TB) { const host = "example.com" testServerRequest(t, func(st *serverTester) { st.bodylessReq1( @@ -1053,17 +1119,19 @@ func TestServer_Request_Reject_Authority_Userinfo(t *testing.T) { } func testRejectRequest(t *testing.T, send func(*serverTester)) { - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - t.Error("server request made it to handler; should've been rejected") - }) - defer st.Close() + synctestTest(t, func(t testing.TB) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + t.Error("server request made it to handler; should've been rejected") + }) + defer st.Close() - st.greet() - send(st) - st.wantRSTStream(1, ErrCodeProtocol) + st.greet() + send(st) + st.wantRSTStream(1, ErrCodeProtocol) + }) } -func newServerTesterForError(t *testing.T) *serverTester { +func newServerTesterForError(t testing.TB) *serverTester { t.Helper() st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { t.Error("server request made it to handler; should've been rejected") @@ -1076,22 +1144,28 @@ func newServerTesterForError(t *testing.T) *serverTester { // HEADERS or PRIORITY on a stream in this state MUST be treated as a // connection error (Section 5.4.1) of type PROTOCOL_ERROR." func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) { + synctestTest(t, testRejectFrameOnIdle_WindowUpdate) +} +func testRejectFrameOnIdle_WindowUpdate(t testing.TB) { st := newServerTesterForError(t) st.fr.WriteWindowUpdate(123, 456) st.wantGoAway(123, ErrCodeProtocol) } -func TestRejectFrameOnIdle_Data(t *testing.T) { +func TestRejectFrameOnIdle_Data(t *testing.T) { synctestTest(t, testRejectFrameOnIdle_Data) } +func testRejectFrameOnIdle_Data(t testing.TB) { st := newServerTesterForError(t) st.fr.WriteData(123, true, nil) st.wantGoAway(123, ErrCodeProtocol) } -func TestRejectFrameOnIdle_RSTStream(t *testing.T) { +func TestRejectFrameOnIdle_RSTStream(t *testing.T) { synctestTest(t, testRejectFrameOnIdle_RSTStream) } +func testRejectFrameOnIdle_RSTStream(t testing.TB) { st := newServerTesterForError(t) st.fr.WriteRSTStream(123, ErrCodeCancel) st.wantGoAway(123, ErrCodeProtocol) } -func TestServer_Request_Connect(t *testing.T) { +func TestServer_Request_Connect(t *testing.T) { synctestTest(t, testServer_Request_Connect) } +func testServer_Request_Connect(t testing.TB) { testServerRequest(t, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1116,6 +1190,9 @@ func TestServer_Request_Connect(t *testing.T) { } func TestServer_Request_Connect_InvalidPath(t *testing.T) { + synctestTest(t, testServer_Request_Connect_InvalidPath) +} +func testServer_Request_Connect_InvalidPath(t testing.TB) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1131,6 +1208,9 @@ func TestServer_Request_Connect_InvalidPath(t *testing.T) { } func TestServer_Request_Connect_InvalidScheme(t *testing.T) { + synctestTest(t, testServer_Request_Connect_InvalidScheme) +} +func testServer_Request_Connect_InvalidScheme(t testing.TB) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1145,7 +1225,8 @@ func TestServer_Request_Connect_InvalidScheme(t *testing.T) { }) } -func TestServer_Ping(t *testing.T) { +func TestServer_Ping(t *testing.T) { synctestTest(t, testServer_Ping) } +func testServer_Ping(t testing.TB) { st := newServerTester(t, nil) defer st.Close() st.greet() @@ -1185,6 +1266,9 @@ func (l *filterListener) Accept() (net.Conn, error) { } func TestServer_MaxQueuedControlFrames(t *testing.T) { + synctestTest(t, testServer_MaxQueuedControlFrames) +} +func testServer_MaxQueuedControlFrames(t testing.TB) { // Goroutine debugging makes this test very slow. disableGoroutineTracking(t) @@ -1201,7 +1285,7 @@ func TestServer_MaxQueuedControlFrames(t *testing.T) { pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} st.fr.WritePing(false, pingData) } - st.group.Wait() + synctest.Wait() // Unblock the server. // It should have closed the connection after exceeding the control frame limit. @@ -1217,7 +1301,8 @@ func TestServer_MaxQueuedControlFrames(t *testing.T) { st.wantClosed() } -func TestServer_RejectsLargeFrames(t *testing.T) { +func TestServer_RejectsLargeFrames(t *testing.T) { synctestTest(t, testServer_RejectsLargeFrames) } +func testServer_RejectsLargeFrames(t testing.TB) { if runtime.GOOS == "windows" || runtime.GOOS == "plan9" || runtime.GOOS == "zos" { t.Skip("see golang.org/issue/13434, golang.org/issue/37321") } @@ -1236,20 +1321,19 @@ func TestServer_RejectsLargeFrames(t *testing.T) { } func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { + synctestTest(t, testServer_Handler_Sends_WindowUpdate) +} +func testServer_Handler_Sends_WindowUpdate(t testing.TB) { // Need to set this to at least twice the initial window size, // or st.greet gets stuck waiting for a WINDOW_UPDATE. // // This also needs to be less than MAX_FRAME_SIZE. const windowSize = 65535 * 2 - puppet := newHandlerPuppet() - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - puppet.act(w, r) - }, func(s *Server) { + st := newServerTester(t, nil, func(s *Server) { s.MaxUploadBufferPerConnection = windowSize s.MaxUploadBufferPerStream = windowSize }) defer st.Close() - defer puppet.done() st.greet() st.writeHeaders(HeadersFrameParam{ @@ -1258,13 +1342,14 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { EndStream: false, // data coming EndHeaders: true, }) + call := st.nextHandlerCall() // Write less than half the max window of data and consume it. // The server doesn't return flow control yet, buffering the 1024 bytes to // combine with a future update. data := make([]byte, windowSize) st.writeData(1, false, data[:1024]) - puppet.do(readBodyHandler(t, string(data[:1024]))) + call.do(readBodyHandler(t, string(data[:1024]))) // Write up to the window limit. // The server returns the buffered credit. @@ -1273,7 +1358,7 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { st.wantWindowUpdate(1, 1024) // The handler consumes the data and the server returns credit. - puppet.do(readBodyHandler(t, string(data[1024:]))) + call.do(readBodyHandler(t, string(data[1024:]))) st.wantWindowUpdate(0, windowSize-1024) st.wantWindowUpdate(1, windowSize-1024) } @@ -1281,16 +1366,15 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { // the version of the TestServer_Handler_Sends_WindowUpdate with padding. // See golang.org/issue/16556 func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) { + synctestTest(t, testServer_Handler_Sends_WindowUpdate_Padding) +} +func testServer_Handler_Sends_WindowUpdate_Padding(t testing.TB) { const windowSize = 65535 * 2 - puppet := newHandlerPuppet() - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - puppet.act(w, r) - }, func(s *Server) { + st := newServerTester(t, nil, func(s *Server) { s.MaxUploadBufferPerConnection = windowSize s.MaxUploadBufferPerStream = windowSize }) defer st.Close() - defer puppet.done() st.greet() st.writeHeaders(HeadersFrameParam{ @@ -1299,6 +1383,7 @@ func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) { EndStream: false, EndHeaders: true, }) + call := st.nextHandlerCall() // Write half a window of data, with some padding. // The server doesn't return the padding yet, buffering the 5 bytes to combine @@ -1310,12 +1395,15 @@ func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) { // The handler consumes the body. // The server returns flow control for the body and padding // (4 bytes of padding + 1 byte of length). - puppet.do(readBodyHandler(t, string(data))) + call.do(readBodyHandler(t, string(data))) st.wantWindowUpdate(0, uint32(len(data)+1+len(pad))) st.wantWindowUpdate(1, uint32(len(data)+1+len(pad))) } func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) { + synctestTest(t, testServer_Send_GoAway_After_Bogus_WindowUpdate) +} +func testServer_Send_GoAway_After_Bogus_WindowUpdate(t testing.TB) { st := newServerTester(t, nil) defer st.Close() st.greet() @@ -1326,6 +1414,9 @@ func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) { } func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) { + synctestTest(t, testServer_Send_RstStream_After_Bogus_WindowUpdate) +} +func testServer_Send_RstStream_After_Bogus_WindowUpdate(t testing.TB) { inHandler := make(chan bool) blockHandler := make(chan bool) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -1352,7 +1443,7 @@ func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) { // testServerPostUnblock sends a hanging POST with unsent data to handler, // then runs fn once in the handler, and verifies that the error returned from // handler is acceptable. It fails if takes over 5 seconds for handler to exit. -func testServerPostUnblock(t *testing.T, +func testServerPostUnblock(t testing.TB, handler func(http.ResponseWriter, *http.Request) error, fn func(*serverTester), checkErr func(error), @@ -1380,6 +1471,9 @@ func testServerPostUnblock(t *testing.T, } func TestServer_RSTStream_Unblocks_Read(t *testing.T) { + synctestTest(t, testServer_RSTStream_Unblocks_Read) +} +func testServer_RSTStream_Unblocks_Read(t testing.TB) { testServerPostUnblock(t, func(w http.ResponseWriter, r *http.Request) (err error) { _, err = r.Body.Read(make([]byte, 1)) @@ -1407,11 +1501,11 @@ func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) { n = 5 } for i := 0; i < n; i++ { - testServer_RSTStream_Unblocks_Header_Write(t) + synctestTest(t, testServer_RSTStream_Unblocks_Header_Write) } } -func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) { +func testServer_RSTStream_Unblocks_Header_Write(t testing.TB) { inHandler := make(chan bool, 1) unblockHandler := make(chan bool, 1) headerWritten := make(chan bool, 1) @@ -1440,12 +1534,15 @@ func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) { t.Fatal(err) } wroteRST <- true - st.awaitIdle() + synctest.Wait() <-headerWritten unblockHandler <- true } func TestServer_DeadConn_Unblocks_Read(t *testing.T) { + synctestTest(t, testServer_DeadConn_Unblocks_Read) +} +func testServer_DeadConn_Unblocks_Read(t testing.TB) { testServerPostUnblock(t, func(w http.ResponseWriter, r *http.Request) (err error) { _, err = r.Body.Read(make([]byte, 1)) @@ -1466,6 +1563,9 @@ var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error { } func TestServer_CloseNotify_After_RSTStream(t *testing.T) { + synctestTest(t, testServer_CloseNotify_After_RSTStream) +} +func testServer_CloseNotify_After_RSTStream(t testing.TB) { testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { t.Fatal(err) @@ -1474,6 +1574,9 @@ func TestServer_CloseNotify_After_RSTStream(t *testing.T) { } func TestServer_CloseNotify_After_ConnClose(t *testing.T) { + synctestTest(t, testServer_CloseNotify_After_ConnClose) +} +func testServer_CloseNotify_After_ConnClose(t testing.TB) { testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil) } @@ -1481,13 +1584,17 @@ func TestServer_CloseNotify_After_ConnClose(t *testing.T) { // problem that's unrelated to them explicitly canceling it (which is // TestServer_CloseNotify_After_RSTStream above) func TestServer_CloseNotify_After_StreamError(t *testing.T) { + synctestTest(t, testServer_CloseNotify_After_StreamError) +} +func testServer_CloseNotify_After_StreamError(t testing.TB) { testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { // data longer than declared Content-Length => stream error st.writeData(1, true, []byte("1234")) }, nil, "content-length", "3") } -func TestServer_StateTransitions(t *testing.T) { +func TestServer_StateTransitions(t *testing.T) { synctestTest(t, testServer_StateTransitions) } +func testServer_StateTransitions(t testing.TB) { var st *serverTester inHandler := make(chan bool) writeData := make(chan bool) @@ -1544,6 +1651,9 @@ func TestServer_StateTransitions(t *testing.T) { // test HEADERS w/o EndHeaders + another HEADERS (should get rejected) func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersNoEnd_Then_Headers) +} +func testServer_Rejects_HeadersNoEnd_Then_Headers(t testing.TB) { st := newServerTesterForError(t) st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1562,6 +1672,9 @@ func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) { // test HEADERS w/o EndHeaders + PING (should get rejected) func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersNoEnd_Then_Ping) +} +func testServer_Rejects_HeadersNoEnd_Then_Ping(t testing.TB) { st := newServerTesterForError(t) st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1577,6 +1690,9 @@ func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) { // test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected) func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersEnd_Then_Continuation) +} +func testServer_Rejects_HeadersEnd_Then_Continuation(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optQuiet) st.greet() st.writeHeaders(HeadersFrameParam{ @@ -1597,6 +1713,9 @@ func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) { // test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream) +} +func testServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t testing.TB) { st := newServerTesterForError(t) st.writeHeaders(HeadersFrameParam{ StreamID: 1, @@ -1611,7 +1730,8 @@ func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) } // No HEADERS on stream 0. -func TestServer_Rejects_Headers0(t *testing.T) { +func TestServer_Rejects_Headers0(t *testing.T) { synctestTest(t, testServer_Rejects_Headers0) } +func testServer_Rejects_Headers0(t testing.TB) { st := newServerTesterForError(t) st.fr.AllowIllegalWrites = true st.writeHeaders(HeadersFrameParam{ @@ -1625,6 +1745,9 @@ func TestServer_Rejects_Headers0(t *testing.T) { // No CONTINUATION on stream 0. func TestServer_Rejects_Continuation0(t *testing.T) { + synctestTest(t, testServer_Rejects_Continuation0) +} +func testServer_Rejects_Continuation0(t testing.TB) { st := newServerTesterForError(t) st.fr.AllowIllegalWrites = true if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil { @@ -1634,7 +1757,8 @@ func TestServer_Rejects_Continuation0(t *testing.T) { } // No PRIORITY on stream 0. -func TestServer_Rejects_Priority0(t *testing.T) { +func TestServer_Rejects_Priority0(t *testing.T) { synctestTest(t, testServer_Rejects_Priority0) } +func testServer_Rejects_Priority0(t testing.TB) { st := newServerTesterForError(t) st.fr.AllowIllegalWrites = true st.writePriority(0, PriorityParam{StreamDep: 1}) @@ -1643,6 +1767,9 @@ func TestServer_Rejects_Priority0(t *testing.T) { // No HEADERS frame with a self-dependence. func TestServer_Rejects_HeadersSelfDependence(t *testing.T) { + synctestTest(t, testServer_Rejects_HeadersSelfDependence) +} +func testServer_Rejects_HeadersSelfDependence(t testing.TB) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.fr.AllowIllegalWrites = true st.writeHeaders(HeadersFrameParam{ @@ -1657,13 +1784,17 @@ func TestServer_Rejects_HeadersSelfDependence(t *testing.T) { // No PRIORITY frame with a self-dependence. func TestServer_Rejects_PrioritySelfDependence(t *testing.T) { + synctestTest(t, testServer_Rejects_PrioritySelfDependence) +} +func testServer_Rejects_PrioritySelfDependence(t testing.TB) { testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) { st.fr.AllowIllegalWrites = true st.writePriority(1, PriorityParam{StreamDep: 1}) }) } -func TestServer_Rejects_PushPromise(t *testing.T) { +func TestServer_Rejects_PushPromise(t *testing.T) { synctestTest(t, testServer_Rejects_PushPromise) } +func testServer_Rejects_PushPromise(t testing.TB) { st := newServerTesterForError(t) pp := PushPromiseParam{ StreamID: 1, @@ -1677,7 +1808,7 @@ func TestServer_Rejects_PushPromise(t *testing.T) { // testServerRejectsStream tests that the server sends a RST_STREAM with the provided // error code after a client sends a bogus request. -func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) { +func testServerRejectsStream(t testing.TB, code ErrCode, writeReq func(*serverTester)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) defer st.Close() st.greet() @@ -1688,7 +1819,7 @@ func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTe // testServerRequest sets up an idle HTTP/2 connection and lets you // write a single request with writeReq, and then verify that the // *http.Request is built correctly in checkReq. -func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) { +func testServerRequest(t testing.TB, writeReq func(*serverTester), checkReq func(*http.Request)) { gotReq := make(chan bool, 1) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { if r.Body == nil { @@ -1706,7 +1837,8 @@ func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func func getSlash(st *serverTester) { st.bodylessReq1() } -func TestServer_Response_NoData(t *testing.T) { +func TestServer_Response_NoData(t *testing.T) { synctestTest(t, testServer_Response_NoData) } +func testServer_Response_NoData(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { // Nothing. return nil @@ -1720,6 +1852,9 @@ func TestServer_Response_NoData(t *testing.T) { } func TestServer_Response_NoData_Header_FooBar(t *testing.T) { + synctestTest(t, testServer_Response_NoData_Header_FooBar) +} +func testServer_Response_NoData_Header_FooBar(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Foo-Bar", "some-value") return nil @@ -1740,6 +1875,9 @@ func TestServer_Response_NoData_Header_FooBar(t *testing.T) { // Reject content-length headers containing a sign. // See https://golang.org/issue/39017 func TestServerIgnoresContentLengthSignWhenWritingChunks(t *testing.T) { + synctestTest(t, testServerIgnoresContentLengthSignWhenWritingChunks) +} +func testServerIgnoresContentLengthSignWhenWritingChunks(t testing.TB) { tests := []struct { name string cl string @@ -1827,7 +1965,7 @@ func TestServerRejectsContentLengthWithSignNewRequests(t *testing.T) { for _, tt := range tests { tt := tt - t.Run(tt.name, func(t *testing.T) { + synctestSubtest(t, tt.name, func(t testing.TB) { writeReq := func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ StreamID: 1, // clients send odd numbers @@ -1848,6 +1986,9 @@ func TestServerRejectsContentLengthWithSignNewRequests(t *testing.T) { } func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) { + synctestTest(t, testServer_Response_Data_Sniff_DoesntOverride) +} +func testServer_Response_Data_Sniff_DoesntOverride(t testing.TB) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Content-Type", "foo/bar") @@ -1873,6 +2014,9 @@ func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) { } func TestServer_Response_TransferEncoding_chunked(t *testing.T) { + synctestTest(t, testServer_Response_TransferEncoding_chunked) +} +func testServer_Response_TransferEncoding_chunked(t testing.TB) { const msg = "hi" testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Transfer-Encoding", "chunked") // should be stripped @@ -1894,6 +2038,9 @@ func TestServer_Response_TransferEncoding_chunked(t *testing.T) { // Header accessed only after the initial write. func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) { + synctestTest(t, testServer_Response_Data_IgnoreHeaderAfterWrite_After) +} +func testServer_Response_Data_IgnoreHeaderAfterWrite_After(t testing.TB) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { io.WriteString(w, msg) @@ -1915,6 +2062,9 @@ func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) { // Header accessed before the initial write and later mutated. func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) { + synctestTest(t, testServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite) +} +func testServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t testing.TB) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("foo", "proper value") @@ -1937,6 +2087,9 @@ func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) { } func TestServer_Response_Data_SniffLenType(t *testing.T) { + synctestTest(t, testServer_Response_Data_SniffLenType) +} +func testServer_Response_Data_SniffLenType(t testing.TB) { const msg = "this is HTML." testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { io.WriteString(w, msg) @@ -1961,6 +2114,9 @@ func TestServer_Response_Data_SniffLenType(t *testing.T) { } func TestServer_Response_Header_Flush_MidWrite(t *testing.T) { + synctestTest(t, testServer_Response_Header_Flush_MidWrite) +} +func testServer_Response_Header_Flush_MidWrite(t testing.TB) { const msg = "this is HTML" const msg2 = ", and this is the next chunk" testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { @@ -1992,7 +2148,8 @@ func TestServer_Response_Header_Flush_MidWrite(t *testing.T) { }) } -func TestServer_Response_LargeWrite(t *testing.T) { +func TestServer_Response_LargeWrite(t *testing.T) { synctestTest(t, testServer_Response_LargeWrite) } +func testServer_Response_LargeWrite(t testing.TB) { const size = 1 << 20 const maxFrameSize = 16 << 10 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { @@ -2058,6 +2215,9 @@ func TestServer_Response_LargeWrite(t *testing.T) { // Test that the handler can't write more than the client allows func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) { + synctestTest(t, testServer_Response_LargeWrite_FlowControlled) +} +func testServer_Response_LargeWrite_FlowControlled(t testing.TB) { // Make these reads. Before each read, the client adds exactly enough // flow-control to satisfy the read. Numbers chosen arbitrarily. reads := []int{123, 1, 13, 127} @@ -2112,6 +2272,9 @@ func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) { // Test that the handler blocked in a Write is unblocked if the server sends a RST_STREAM. func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) { + synctestTest(t, testServer_Response_RST_Unblocks_LargeWrite) +} +func testServer_Response_RST_Unblocks_LargeWrite(t testing.TB) { const size = 1 << 20 const maxFrameSize = 16 << 10 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { @@ -2144,6 +2307,9 @@ func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) { } func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) { + synctestTest(t, testServer_Response_Empty_Data_Not_FlowControlled) +} +func testServer_Response_Empty_Data_Not_FlowControlled(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.(http.Flusher).Flush() // Nothing; send empty DATA @@ -2171,6 +2337,9 @@ func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) { } func TestServer_Response_Automatic100Continue(t *testing.T) { + synctestTest(t, testServer_Response_Automatic100Continue) +} +func testServer_Response_Automatic100Continue(t testing.TB) { const msg = "foo" const reply = "bar" testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { @@ -2222,6 +2391,9 @@ func TestServer_Response_Automatic100Continue(t *testing.T) { } func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) { + synctestTest(t, testServer_HandlerWriteErrorOnDisconnect) +} +func testServer_HandlerWriteErrorOnDisconnect(t testing.TB) { errc := make(chan error, 1) testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { p := []byte("some data.\n") @@ -2250,16 +2422,17 @@ func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) { } func TestServer_Rejects_Too_Many_Streams(t *testing.T) { - const testPath = "/some/path" - + synctestTest(t, testServer_Rejects_Too_Many_Streams) +} +func testServer_Rejects_Too_Many_Streams(t testing.TB) { inHandler := make(chan uint32) leaveHandler := make(chan bool) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - id := w.(*responseWriter).rws.stream.id - inHandler <- id - if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath { - t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath) + var streamID uint32 + if _, err := fmt.Sscanf(r.URL.Path, "/%d", &streamID); err != nil { + t.Errorf("parsing %q: %v", r.URL.Path, err) } + inHandler <- streamID <-leaveHandler }) defer st.Close() @@ -2274,12 +2447,14 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) { defer func() { nextStreamID += 2 }() return nextStreamID } - sendReq := func(id uint32, headers ...string) { + sendReq := func(id uint32) { st.writeHeaders(HeadersFrameParam{ - StreamID: id, - BlockFragment: st.encodeHeader(headers...), - EndStream: true, - EndHeaders: true, + StreamID: id, + BlockFragment: st.encodeHeader( + ":path", fmt.Sprintf("/%v", id), + ), + EndStream: true, + EndHeaders: true, }) } for i := 0; i < defaultMaxStreams; i++ { @@ -2296,7 +2471,7 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) { // (It's also sent as a CONTINUATION, to verify we still track the decoder context, // even if we're rejecting it) rejectID := streamID() - headerBlock := st.encodeHeader(":path", testPath) + headerBlock := st.encodeHeader(":path", fmt.Sprintf("/%v", rejectID)) frag1, frag2 := headerBlock[:3], headerBlock[3:] st.writeHeaders(HeadersFrameParam{ StreamID: rejectID, @@ -2320,7 +2495,7 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) { // And now another stream should be able to start: goodID := streamID() - sendReq(goodID, ":path", testPath) + sendReq(goodID) if got := <-inHandler; got != goodID { t.Errorf("Got stream %d; want %d", got, goodID) } @@ -2328,6 +2503,9 @@ func TestServer_Rejects_Too_Many_Streams(t *testing.T) { // So many response headers that the server needs to use CONTINUATION frames: func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) { + synctestTest(t, testServer_Response_ManyHeaders_With_Continuation) +} +func testServer_Response_ManyHeaders_With_Continuation(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { h := w.Header() for i := 0; i < 5000; i++ { @@ -2362,6 +2540,9 @@ func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) { // defer sc.closeAllStreamsOnConnClose) when the serverConn serve loop // ended. func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { + synctestTest(t, testServer_NoCrash_HandlerClose_Then_ClientClose) +} +func testServer_NoCrash_HandlerClose_Then_ClientClose(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { // nothing return nil @@ -2396,7 +2577,7 @@ func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { // We should have our flow control bytes back, // since the handler didn't get them. - st.wantFlowControlConsumed(0, 0) + st.wantConnFlowControlConsumed(0) // Set up a bunch of machinery to record the panic we saw // previously. @@ -2416,7 +2597,7 @@ func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { // Now force the serve loop to end, via closing the connection. st.cc.Close() - <-st.sc.doneServing + synctest.Wait() panMu.Lock() got := panicVal @@ -2431,17 +2612,20 @@ func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) } func testRejectTLS(t *testing.T, version uint16) { - st := newServerTester(t, nil, func(state *tls.ConnectionState) { - // As of 1.18 the default minimum Go TLS version is - // 1.2. In order to test rejection of lower versions, - // manually set the version to 1.0 - state.Version = version + synctestTest(t, func(t testing.TB) { + st := newServerTester(t, nil, func(state *tls.ConnectionState) { + // As of 1.18 the default minimum Go TLS version is + // 1.2. In order to test rejection of lower versions, + // manually set the version to 1.0 + state.Version = version + }) + defer st.Close() + st.wantGoAway(0, ErrCodeInadequateSecurity) }) - defer st.Close() - st.wantGoAway(0, ErrCodeInadequateSecurity) } -func TestServer_Rejects_TLSBadCipher(t *testing.T) { +func TestServer_Rejects_TLSBadCipher(t *testing.T) { synctestTest(t, testServer_Rejects_TLSBadCipher) } +func testServer_Rejects_TLSBadCipher(t testing.TB) { st := newServerTester(t, nil, func(state *tls.ConnectionState) { state.Version = tls.VersionTLS12 state.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA @@ -2451,6 +2635,9 @@ func TestServer_Rejects_TLSBadCipher(t *testing.T) { } func TestServer_Advertises_Common_Cipher(t *testing.T) { + synctestTest(t, testServer_Advertises_Common_Cipher) +} +func testServer_Advertises_Common_Cipher(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { }, func(srv *http.Server) { // Have the server configured with no specific cipher suites. @@ -2508,7 +2695,7 @@ func testServerResponse(t testing.TB, // readBodyHandler returns an http Handler func that reads len(want) // bytes from r.Body and fails t if the contents read were not // the value of want. -func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) { +func readBodyHandler(t testing.TB, want string) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { buf := make([]byte, len(want)) _, err := io.ReadFull(r.Body, buf) @@ -2523,6 +2710,9 @@ func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *h } func TestServer_MaxDecoderHeaderTableSize(t *testing.T) { + synctestTest(t, testServer_MaxDecoderHeaderTableSize) +} +func testServer_MaxDecoderHeaderTableSize(t testing.TB) { wantHeaderTableSize := uint32(initialHeaderTableSize * 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) { s.MaxDecoderHeaderTableSize = wantHeaderTableSize @@ -2546,6 +2736,9 @@ func TestServer_MaxDecoderHeaderTableSize(t *testing.T) { } func TestServer_MaxEncoderHeaderTableSize(t *testing.T) { + synctestTest(t, testServer_MaxEncoderHeaderTableSize) +} +func testServer_MaxEncoderHeaderTableSize(t testing.TB) { wantHeaderTableSize := uint32(initialHeaderTableSize / 2) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) { s.MaxEncoderHeaderTableSize = wantHeaderTableSize @@ -2560,7 +2753,8 @@ func TestServer_MaxEncoderHeaderTableSize(t *testing.T) { } // Issue 12843 -func TestServerDoS_MaxHeaderListSize(t *testing.T) { +func TestServerDoS_MaxHeaderListSize(t *testing.T) { synctestTest(t, testServerDoS_MaxHeaderListSize) } +func testServerDoS_MaxHeaderListSize(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) defer st.Close() @@ -2630,6 +2824,9 @@ func TestServerDoS_MaxHeaderListSize(t *testing.T) { } func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) { + synctestTest(t, testServer_Response_Stream_With_Missing_Trailer) +} +func testServer_Response_Stream_With_Missing_Trailer(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Trailer", "test-trailer") return nil @@ -2647,7 +2844,8 @@ func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) { }) } -func TestCompressionErrorOnWrite(t *testing.T) { +func TestCompressionErrorOnWrite(t *testing.T) { synctestTest(t, testCompressionErrorOnWrite) } +func testCompressionErrorOnWrite(t testing.TB) { const maxStrLen = 8 << 10 var serverConfig *http.Server st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -2709,7 +2907,8 @@ func TestCompressionErrorOnWrite(t *testing.T) { st.wantGoAway(3, ErrCodeCompression) } -func TestCompressionErrorOnClose(t *testing.T) { +func TestCompressionErrorOnClose(t *testing.T) { synctestTest(t, testCompressionErrorOnClose) } +func testCompressionErrorOnClose(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // No response body. }) @@ -2729,7 +2928,8 @@ func TestCompressionErrorOnClose(t *testing.T) { } // test that a server handler can read trailers from a client -func TestServerReadsTrailers(t *testing.T) { +func TestServerReadsTrailers(t *testing.T) { synctestTest(t, testServerReadsTrailers) } +func testServerReadsTrailers(t testing.TB) { const testBody = "some test body" writeReq := func(st *serverTester) { st.writeHeaders(HeadersFrameParam{ @@ -2780,10 +2980,18 @@ func TestServerReadsTrailers(t *testing.T) { } // test that a server handler can send trailers -func TestServerWritesTrailers_WithFlush(t *testing.T) { testServerWritesTrailers(t, true) } -func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) } +func TestServerWritesTrailers_WithFlush(t *testing.T) { + synctestTest(t, func(t testing.TB) { + testServerWritesTrailers(t, true) + }) +} +func TestServerWritesTrailers_WithoutFlush(t *testing.T) { + synctestTest(t, func(t testing.TB) { + testServerWritesTrailers(t, false) + }) +} -func testServerWritesTrailers(t *testing.T, withFlush bool) { +func testServerWritesTrailers(t testing.TB, withFlush bool) { // See https://httpwg.github.io/specs/rfc7540.html#rfc.section.8.1.3 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") @@ -2851,6 +3059,9 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) { } func TestServerWritesUndeclaredTrailers(t *testing.T) { + synctestTest(t, testServerWritesUndeclaredTrailers) +} +func testServerWritesUndeclaredTrailers(t testing.TB) { const trailer = "Trailer-Header" const value = "hi1" ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { @@ -2876,6 +3087,9 @@ func TestServerWritesUndeclaredTrailers(t *testing.T) { // validate transmitted header field names & values // golang.org/issue/14048 func TestServerDoesntWriteInvalidHeaders(t *testing.T) { + synctestTest(t, testServerDoesntWriteInvalidHeaders) +} +func testServerDoesntWriteInvalidHeaders(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Add("OK1", "x") w.Header().Add("Bad:Colon", "x") // colon (non-token byte) in key @@ -3054,7 +3268,8 @@ func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) { // go-fuzz bug, originally reported at https://github.com/bradfitz/http2/issues/53 // Verify we don't hang. -func TestIssue53(t *testing.T) { +func TestIssue53(t *testing.T) { synctestTest(t, testIssue53) } +func testIssue53(t testing.TB) { const data = "PRI * HTTP/2.0\r\n\r\nSM" + "\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad" s := &http.Server{ @@ -3109,28 +3324,53 @@ func (c *issue53Conn) SetDeadline(t time.Time) error { return nil } func (c *issue53Conn) SetReadDeadline(t time.Time) error { return nil } func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil } +// TestServeConnNilOpts ensures that Server.ServeConn(conn, nil) works. +// // golang.org/issue/33839 -func TestServeConnOptsNilReceiverBehavior(t *testing.T) { - defer func() { - if r := recover(); r != nil { - t.Errorf("got a panic that should not happen: %v", r) - } - }() +func TestServeConnNilOpts(t *testing.T) { synctestTest(t, testServeConnNilOpts) } +func testServeConnNilOpts(t testing.TB) { + // A nil ServeConnOpts uses http.DefaultServeMux as the handler. + var gotRequest string + var mux http.ServeMux + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + gotRequest = r.URL.Path + }) + setForTest(t, &http.DefaultServeMux, &mux) + + srvConn, cliConn := net.Pipe() + defer srvConn.Close() + defer cliConn.Close() + + s2 := &Server{} + go s2.ServeConn(srvConn, nil) + + fr := NewFramer(cliConn, cliConn) + io.WriteString(cliConn, ClientPreface) + fr.WriteSettings() + fr.WriteSettingsAck() + var henc hpackEncoder + const reqPath = "/request" + fr.WriteHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: henc.encodeHeaderRaw(t, + ":method", "GET", + ":path", reqPath, + ":scheme", "https", + ":authority", "foo.com", + ), + EndStream: true, + EndHeaders: true, + }) - var o *ServeConnOpts - if o.context() == nil { - t.Error("o.context should not return nil") - } - if o.baseConfig() == nil { - t.Error("o.baseConfig should not return nil") - } - if o.handler() == nil { - t.Error("o.handler should not return nil") + synctest.Wait() + if got, want := gotRequest, reqPath; got != want { + t.Errorf("got request: %q, want %q", got, want) } } // golang.org/issue/12895 -func TestConfigureServer(t *testing.T) { +func TestConfigureServer(t *testing.T) { synctestTest(t, testConfigureServer) } +func testConfigureServer(t testing.TB) { tests := []struct { name string tlsConfig *tls.Config @@ -3202,6 +3442,9 @@ func TestConfigureServer(t *testing.T) { } func TestServerNoAutoContentLengthOnHead(t *testing.T) { + synctestTest(t, testServerNoAutoContentLengthOnHead) +} +func testServerNoAutoContentLengthOnHead(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { // No response body. (or smaller than one frame) }) @@ -3224,6 +3467,9 @@ func TestServerNoAutoContentLengthOnHead(t *testing.T) { // golang.org/issue/13495 func TestServerNoDuplicateContentType(t *testing.T) { + synctestTest(t, testServerNoDuplicateContentType) +} +func testServerNoDuplicateContentType(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header()["Content-Type"] = []string{""} fmt.Fprintf(w, "hi") @@ -3248,6 +3494,9 @@ func TestServerNoDuplicateContentType(t *testing.T) { } func TestServerContentLengthCanBeDisabled(t *testing.T) { + synctestTest(t, testServerContentLengthCanBeDisabled) +} +func testServerContentLengthCanBeDisabled(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Header()["Content-Length"] = nil fmt.Fprintf(w, "OK") @@ -3271,9 +3520,10 @@ func TestServerContentLengthCanBeDisabled(t *testing.T) { } func disableGoroutineTracking(t testing.TB) { - old := DebugGoroutines - DebugGoroutines = false - t.Cleanup(func() { DebugGoroutines = old }) + disableDebugGoroutines.Store(true) + t.Cleanup(func() { + disableDebugGoroutines.Store(false) + }) } func BenchmarkServer_GetRequest(b *testing.B) { @@ -3349,7 +3599,8 @@ func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs } // golang.org/issue/12737 -- handle any net.Conn, not just // *tls.Conn. -func TestServerHandleCustomConn(t *testing.T) { +func TestServerHandleCustomConn(t *testing.T) { synctestTest(t, testServerHandleCustomConn) } +func testServerHandleCustomConn(t testing.TB) { var s Server c1, c2 := net.Pipe() clientDone := make(chan struct{}) @@ -3414,7 +3665,8 @@ func TestServerHandleCustomConn(t *testing.T) { } // golang.org/issue/14214 -func TestServer_Rejects_ConnHeaders(t *testing.T) { +func TestServer_Rejects_ConnHeaders(t *testing.T) { synctestTest(t, testServer_Rejects_ConnHeaders) } +func testServer_Rejects_ConnHeaders(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { t.Error("should not get to Handler") }) @@ -3438,7 +3690,7 @@ type hpackEncoder struct { buf bytes.Buffer } -func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte { +func (he *hpackEncoder) encodeHeaderRaw(t testing.TB, headers ...string) []byte { if len(headers)%2 == 1 { panic("odd number of kv args") } @@ -3457,7 +3709,8 @@ func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte return he.buf.Bytes() } -func TestCheckValidHTTP2Request(t *testing.T) { +func TestCheckValidHTTP2Request(t *testing.T) { synctestTest(t, testCheckValidHTTP2Request) } +func testCheckValidHTTP2Request(t testing.TB) { tests := []struct { h http.Header want error @@ -3501,6 +3754,9 @@ func TestCheckValidHTTP2Request(t *testing.T) { // golang.org/issue/14030 func TestExpect100ContinueAfterHandlerWrites(t *testing.T) { + synctestTest(t, testExpect100ContinueAfterHandlerWrites) +} +func testExpect100ContinueAfterHandlerWrites(t testing.TB) { const msg = "Hello" const msg2 = "World" @@ -3578,7 +3834,7 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) { }, }, } { - t.Run(tt.name, func(t *testing.T) { + synctestSubtest(t, tt.name, func(t testing.TB) { unblock := make(chan bool, 1) defer close(unblock) @@ -3618,6 +3874,9 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) { } func TestServerReturnsStreamAndConnFlowControlOnBodyClose(t *testing.T) { + synctestTest(t, testServerReturnsStreamAndConnFlowControlOnBodyClose) +} +func testServerReturnsStreamAndConnFlowControlOnBodyClose(t testing.TB) { unblockHandler := make(chan struct{}) defer close(unblockHandler) @@ -3649,7 +3908,8 @@ func TestServerReturnsStreamAndConnFlowControlOnBodyClose(t *testing.T) { }) } -func TestServerIdleTimeout(t *testing.T) { +func TestServerIdleTimeout(t *testing.T) { synctestTest(t, testServerIdleTimeout) } +func testServerIdleTimeout(t testing.TB) { if testing.Short() { t.Skip("skipping in short mode") } @@ -3666,6 +3926,9 @@ func TestServerIdleTimeout(t *testing.T) { } func TestServerIdleTimeout_AfterRequest(t *testing.T) { + synctestTest(t, testServerIdleTimeout_AfterRequest) +} +func testServerIdleTimeout_AfterRequest(t testing.TB) { if testing.Short() { t.Skip("skipping in short mode") } @@ -3676,7 +3939,7 @@ func TestServerIdleTimeout_AfterRequest(t *testing.T) { var st *serverTester st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - st.group.Sleep(requestTimeout) + time.Sleep(requestTimeout) }, func(h2s *Server) { h2s.IdleTimeout = idleTimeout }) @@ -3702,28 +3965,46 @@ func TestServerIdleTimeout_AfterRequest(t *testing.T) { // grpc-go closes the Request.Body currently with a Read. // Verify that it doesn't race. // See https://github.com/grpc/grpc-go/pull/938 -func TestRequestBodyReadCloseRace(t *testing.T) { - for i := 0; i < 100; i++ { - body := &requestBody{ - pipe: &pipe{ - b: new(bytes.Buffer), - }, - } - body.pipe.CloseWithError(io.EOF) +func TestRequestBodyReadCloseRace(t *testing.T) { synctestTest(t, testRequestBodyReadCloseRace) } +func testRequestBodyReadCloseRace(t testing.TB) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + go r.Body.Close() + io.Copy(io.Discard, r.Body) + }) + st.greet() + + data := make([]byte, 1024) + for i := range 100 { + streamID := uint32(1 + (i * 2)) // clients send odd numbers + st.writeHeaders(HeadersFrameParam{ + StreamID: streamID, + BlockFragment: st.encodeHeader(), + EndHeaders: true, + }) + st.writeData(1, false, data) - done := make(chan bool, 1) - buf := make([]byte, 10) - go func() { - time.Sleep(1 * time.Millisecond) - body.Close() - done <- true - }() - body.Read(buf) - <-done + for { + // Look for a RST_STREAM frame. + // Skip over anything else (HEADERS and WINDOW_UPDATE). + fr := st.readFrame() + if fr == nil { + t.Fatalf("got no RSTStreamFrame, want one") + } + rst, ok := fr.(*RSTStreamFrame) + if !ok { + continue + } + // We can get NO or STREAM_CLOSED depending on scheduling. + if rst.ErrCode != ErrCodeNo && rst.ErrCode != ErrCodeStreamClosed { + t.Fatalf("got RSTStreamFrame with error code %v, want ErrCodeNo or ErrCodeStreamClosed", rst.ErrCode) + } + break + } } } -func TestIssue20704Race(t *testing.T) { +func TestIssue20704Race(t *testing.T) { synctestTest(t, testIssue20704Race) } +func testIssue20704Race(t testing.TB) { if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" { t.Skip("skipping in short mode") } @@ -3756,7 +4037,8 @@ func TestIssue20704Race(t *testing.T) { } } -func TestServer_Rejects_TooSmall(t *testing.T) { +func TestServer_Rejects_TooSmall(t *testing.T) { synctestTest(t, testServer_Rejects_TooSmall) } +func testServer_Rejects_TooSmall(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { io.ReadAll(r.Body) return nil @@ -3772,13 +4054,16 @@ func TestServer_Rejects_TooSmall(t *testing.T) { }) st.writeData(1, true, []byte("12345")) st.wantRSTStream(1, ErrCodeProtocol) - st.wantFlowControlConsumed(0, 0) + st.wantConnFlowControlConsumed(0) }) } // Tests that a handler setting "Connection: close" results in a GOAWAY being sent, // and the connection still completing. func TestServerHandlerConnectionClose(t *testing.T) { + synctestTest(t, testServerHandlerConnectionClose) +} +func testServerHandlerConnectionClose(t testing.TB) { unblockHandler := make(chan bool, 1) testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Connection", "close") @@ -3872,6 +4157,9 @@ func TestServerHandlerConnectionClose(t *testing.T) { } func TestServer_Headers_HalfCloseRemote(t *testing.T) { + synctestTest(t, testServer_Headers_HalfCloseRemote) +} +func testServer_Headers_HalfCloseRemote(t testing.TB) { var st *serverTester writeData := make(chan bool) writeHeaders := make(chan bool) @@ -3919,7 +4207,8 @@ func TestServer_Headers_HalfCloseRemote(t *testing.T) { st.wantRSTStream(1, ErrCodeStreamClosed) } -func TestServerGracefulShutdown(t *testing.T) { +func TestServerGracefulShutdown(t *testing.T) { synctestTest(t, testServerGracefulShutdown) } +func testServerGracefulShutdown(t testing.TB) { handlerDone := make(chan struct{}) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { <-handlerDone @@ -4016,7 +4305,7 @@ func TestContentEncodingNoSniffing(t *testing.T) { } for _, tt := range resps { - t.Run(tt.name, func(t *testing.T) { + synctestSubtest(t, tt.name, func(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { if tt.contentEncoding != nil { w.Header().Set("Content-Encoding", tt.contentEncoding.(string)) @@ -4055,10 +4344,12 @@ func TestContentEncodingNoSniffing(t *testing.T) { } func TestServerWindowUpdateOnBodyClose(t *testing.T) { + synctestTest(t, testServerWindowUpdateOnBodyClose) +} +func testServerWindowUpdateOnBodyClose(t testing.TB) { const windowSize = 65535 * 2 content := make([]byte, windowSize) - blockCh := make(chan bool) - errc := make(chan error, 1) + errc := make(chan error) st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { buf := make([]byte, 4) n, err := io.ReadFull(r.Body, buf) @@ -4070,8 +4361,7 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) { errc <- fmt.Errorf("too few bytes read: %d", n) return } - blockCh <- true - <-blockCh + r.Body.Close() errc <- nil }, func(s *Server) { s.MaxUploadBufferPerConnection = windowSize @@ -4090,9 +4380,9 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) { EndHeaders: true, }) st.writeData(1, false, content[:windowSize/2]) - <-blockCh - st.stream(1).body.CloseWithError(io.EOF) - blockCh <- true + if err := <-errc; err != nil { + t.Fatal(err) + } // Wait for flow control credit for the portion of the request written so far. increments := windowSize / 2 @@ -4112,13 +4402,12 @@ func TestServerWindowUpdateOnBodyClose(t *testing.T) { // Writing data after the stream is reset immediately returns flow control credit. st.writeData(1, false, content[windowSize/2:]) st.wantWindowUpdate(0, windowSize/2) - - if err := <-errc; err != nil { - t.Error(err) - } } func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) { + synctestTest(t, testNoErrorLoggedOnPostAfterGOAWAY) +} +func testNoErrorLoggedOnPostAfterGOAWAY(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) defer st.Close() @@ -4151,7 +4440,8 @@ func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) { } } -func TestServerSendsProcessing(t *testing.T) { +func TestServerSendsProcessing(t *testing.T) { synctestTest(t, testServerSendsProcessing) } +func testServerSendsProcessing(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { w.WriteHeader(http.StatusProcessing) w.Write([]byte("stuff")) @@ -4178,7 +4468,8 @@ func TestServerSendsProcessing(t *testing.T) { }) } -func TestServerSendsEarlyHints(t *testing.T) { +func TestServerSendsEarlyHints(t *testing.T) { synctestTest(t, testServerSendsEarlyHints) } +func testServerSendsEarlyHints(t testing.TB) { testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { h := w.Header() h.Add("Content-Length", "123") @@ -4234,7 +4525,8 @@ func TestServerSendsEarlyHints(t *testing.T) { }) } -func TestProtocolErrorAfterGoAway(t *testing.T) { +func TestProtocolErrorAfterGoAway(t *testing.T) { synctestTest(t, testProtocolErrorAfterGoAway) } +func testProtocolErrorAfterGoAway(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.Copy(io.Discard, r.Body) }) @@ -4279,7 +4571,7 @@ func TestServerInitialFlowControlWindow(t *testing.T) { // test this case, but we currently do not. 65535 * 2, } { - t.Run(fmt.Sprint(want), func(t *testing.T) { + synctestSubtest(t, fmt.Sprint(want), func(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(s *Server) { @@ -4322,7 +4614,8 @@ func TestServerInitialFlowControlWindow(t *testing.T) { // TestCanonicalHeaderCacheGrowth verifies that the canonical header cache // size is capped to a reasonable level. -func TestCanonicalHeaderCacheGrowth(t *testing.T) { +func TestCanonicalHeaderCacheGrowth(t *testing.T) { synctestTest(t, testCanonicalHeaderCacheGrowth) } +func testCanonicalHeaderCacheGrowth(t testing.TB) { for _, size := range []int{1, (1 << 20) - 10} { base := strings.Repeat("X", size) sc := &serverConn{ @@ -4355,6 +4648,9 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) { // Terminating the request stream on the client causes Write to return. // We should not access the slice after this point. func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) { + synctestTest(t, testServerWriteDoesNotRetainBufferAfterReturn) +} +func testServerWriteDoesNotRetainBufferAfterReturn(t testing.TB) { donec := make(chan struct{}) ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { defer close(donec) @@ -4390,6 +4686,9 @@ func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) { // Shutting down the Server causes Write to return. // We should not access the slice after this point. func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) { + synctestTest(t, testServerWriteDoesNotRetainBufferAfterServerClose) +} +func testServerWriteDoesNotRetainBufferAfterServerClose(t testing.TB) { donec := make(chan struct{}, 1) ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { donec <- struct{}{} @@ -4422,7 +4721,8 @@ func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) { <-donec } -func TestServerMaxHandlerGoroutines(t *testing.T) { +func TestServerMaxHandlerGoroutines(t *testing.T) { synctestTest(t, testServerMaxHandlerGoroutines) } +func testServerMaxHandlerGoroutines(t testing.TB) { const maxHandlers = 10 handlerc := make(chan chan bool) donec := make(chan struct{}) @@ -4522,7 +4822,8 @@ func TestServerMaxHandlerGoroutines(t *testing.T) { } } -func TestServerContinuationFlood(t *testing.T) { +func TestServerContinuationFlood(t *testing.T) { synctestTest(t, testServerContinuationFlood) } +func testServerContinuationFlood(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { fmt.Println(r.Header) }, func(s *http.Server) { @@ -4575,6 +4876,9 @@ func TestServerContinuationFlood(t *testing.T) { } func TestServerContinuationAfterInvalidHeader(t *testing.T) { + synctestTest(t, testServerContinuationAfterInvalidHeader) +} +func testServerContinuationAfterInvalidHeader(t testing.TB) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { fmt.Println(r.Header) }) @@ -4613,6 +4917,9 @@ func TestServerContinuationAfterInvalidHeader(t *testing.T) { } func TestServerUpgradeRequestPrefaceFailure(t *testing.T) { + synctestTest(t, testServerUpgradeRequestPrefaceFailure) +} +func testServerUpgradeRequestPrefaceFailure(t testing.TB) { // An h2c upgrade request fails when the client preface is not as expected. s2 := &Server{ // Setting IdleTimeout triggers #67168. @@ -4633,7 +4940,8 @@ func TestServerUpgradeRequestPrefaceFailure(t *testing.T) { } // Issue 67036: A stream error should result in the handler's request context being canceled. -func TestServerRequestCancelOnError(t *testing.T) { +func TestServerRequestCancelOnError(t *testing.T) { synctestTest(t, testServerRequestCancelOnError) } +func testServerRequestCancelOnError(t testing.TB) { recvc := make(chan struct{}) // handler has started donec := make(chan struct{}) // handler has finished st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { @@ -4667,6 +4975,9 @@ func TestServerRequestCancelOnError(t *testing.T) { } func TestServerSetReadWriteDeadlineRace(t *testing.T) { + synctestTest(t, testServerSetReadWriteDeadlineRace) +} +func testServerSetReadWriteDeadlineRace(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { ctl := http.NewResponseController(w) ctl.SetReadDeadline(time.Now().Add(3600 * time.Second)) @@ -4679,7 +4990,8 @@ func TestServerSetReadWriteDeadlineRace(t *testing.T) { resp.Body.Close() } -func TestServerWriteByteTimeout(t *testing.T) { +func TestServerWriteByteTimeout(t *testing.T) { synctestTest(t, testServerWriteByteTimeout) } +func testServerWriteByteTimeout(t testing.TB) { const timeout = 1 * time.Second st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { w.Write(make([]byte, 100)) @@ -4711,7 +5023,8 @@ func TestServerWriteByteTimeout(t *testing.T) { st.wantClosed() } -func TestServerPingSent(t *testing.T) { +func TestServerPingSent(t *testing.T) { synctestTest(t, testServerPingSent) } +func testServerPingSent(t testing.TB) { const readIdleTimeout = 15 * time.Second st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(s *Server) { @@ -4731,7 +5044,8 @@ func TestServerPingSent(t *testing.T) { st.wantClosed() } -func TestServerPingResponded(t *testing.T) { +func TestServerPingResponded(t *testing.T) { synctestTest(t, testServerPingResponded) } +func testServerPingResponded(t testing.TB) { const readIdleTimeout = 15 * time.Second st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, func(s *Server) { @@ -4753,3 +5067,60 @@ func TestServerPingResponded(t *testing.T) { st.advance(2 * time.Second) st.wantIdle() } + +// golang.org/issue/15425: test that a handler closing the request +// body doesn't terminate the stream to the peer. (It just stops +// readability from the handler's side, and eventually the client +// runs out of flow control tokens) +func TestServerSendDataAfterRequestBodyClose(t *testing.T) { + synctestTest(t, testServerSendDataAfterRequestBodyClose) +} +func testServerSendDataAfterRequestBodyClose(t testing.TB) { + st := newServerTester(t, nil) + st.greet() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: false, + EndHeaders: true, + }) + + // Handler starts writing the response body. + call := st.nextHandlerCall() + call.do(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("one")) + http.NewResponseController(w).Flush() + }) + st.wantFrameType(FrameHeaders) + st.wantData(wantData{ + streamID: 1, + endStream: false, + data: []byte("one"), + }) + st.wantIdle() + + // Handler closes the request body. + // This is not observable by the client. + call.do(func(w http.ResponseWriter, req *http.Request) { + req.Body.Close() + }) + st.wantIdle() + + // The client can still send request data, which is discarded. + st.writeData(1, false, []byte("client-sent data")) + st.wantIdle() + + // Handler can still write more response body, + // which is sent to the client. + call.do(func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("two")) + http.NewResponseController(w).Flush() + }) + st.wantData(wantData{ + streamID: 1, + endStream: false, + data: []byte("two"), + }) + st.wantIdle() +} diff --git a/http2/sync_test.go b/http2/sync_test.go deleted file mode 100644 index 6687202d2c..0000000000 --- a/http2/sync_test.go +++ /dev/null @@ -1,329 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http2 - -import ( - "context" - "fmt" - "runtime" - "strconv" - "strings" - "sync" - "testing" - "time" -) - -// A synctestGroup synchronizes between a set of cooperating goroutines. -type synctestGroup struct { - mu sync.Mutex - gids map[int]bool - now time.Time - timers map[*fakeTimer]struct{} -} - -type goroutine struct { - id int - parent int - state string - syscall bool -} - -// newSynctest creates a new group with the synthetic clock set the provided time. -func newSynctest(now time.Time) *synctestGroup { - return &synctestGroup{ - gids: map[int]bool{ - currentGoroutine(): true, - }, - now: now, - } -} - -// Join adds the current goroutine to the group. -func (g *synctestGroup) Join() { - g.mu.Lock() - defer g.mu.Unlock() - g.gids[currentGoroutine()] = true -} - -// Count returns the number of goroutines in the group. -func (g *synctestGroup) Count() int { - gs := stacks(true) - count := 0 - for _, gr := range gs { - if !g.gids[gr.id] && !g.gids[gr.parent] { - continue - } - count++ - } - return count -} - -// Close calls t.Fatal if the group contains any running goroutines. -func (g *synctestGroup) Close(t testing.TB) { - if count := g.Count(); count != 1 { - buf := make([]byte, 16*1024) - n := runtime.Stack(buf, true) - t.Logf("stacks:\n%s", buf[:n]) - t.Fatalf("%v goroutines still running after test completed, expect 1", count) - } -} - -// Wait blocks until every goroutine in the group and their direct children are idle. -func (g *synctestGroup) Wait() { - for i := 0; ; i++ { - if g.idle() { - return - } - runtime.Gosched() - if runtime.GOOS == "js" { - // When GOOS=js, we appear to need to time.Sleep to make progress - // on some syscalls. In particular, without this sleep - // writing to stdout (including via t.Log) can block forever. - for range 10 { - time.Sleep(1) - } - } - } -} - -func (g *synctestGroup) idle() bool { - gs := stacks(true) - g.mu.Lock() - defer g.mu.Unlock() - for _, gr := range gs[1:] { - if !g.gids[gr.id] && !g.gids[gr.parent] { - continue - } - if gr.syscall { - return false - } - // From runtime/runtime2.go. - switch gr.state { - case "IO wait": - case "chan receive (nil chan)": - case "chan send (nil chan)": - case "select": - case "select (no cases)": - case "chan receive": - case "chan send": - case "sync.Cond.Wait": - default: - return false - } - } - return true -} - -func currentGoroutine() int { - s := stacks(false) - return s[0].id -} - -func stacks(all bool) []goroutine { - buf := make([]byte, 16*1024) - for { - n := runtime.Stack(buf, all) - if n < len(buf) { - buf = buf[:n] - break - } - buf = make([]byte, len(buf)*2) - } - - var goroutines []goroutine - for _, gs := range strings.Split(string(buf), "\n\n") { - skip, rest, ok := strings.Cut(gs, "goroutine ") - if skip != "" || !ok { - panic(fmt.Errorf("1 unparsable goroutine stack:\n%s", gs)) - } - ids, rest, ok := strings.Cut(rest, " [") - if !ok { - panic(fmt.Errorf("2 unparsable goroutine stack:\n%s", gs)) - } - id, err := strconv.Atoi(ids) - if err != nil { - panic(fmt.Errorf("3 unparsable goroutine stack:\n%s", gs)) - } - state, rest, ok := strings.Cut(rest, "]") - isSyscall := false - if strings.Contains(rest, "\nsyscall.") { - isSyscall = true - } - var parent int - _, rest, ok = strings.Cut(rest, "\ncreated by ") - if ok && strings.Contains(rest, " in goroutine ") { - _, rest, ok := strings.Cut(rest, " in goroutine ") - if !ok { - panic(fmt.Errorf("4 unparsable goroutine stack:\n%s", gs)) - } - parents, rest, ok := strings.Cut(rest, "\n") - if !ok { - panic(fmt.Errorf("5 unparsable goroutine stack:\n%s", gs)) - } - parent, err = strconv.Atoi(parents) - if err != nil { - panic(fmt.Errorf("6 unparsable goroutine stack:\n%s", gs)) - } - } - goroutines = append(goroutines, goroutine{ - id: id, - parent: parent, - state: state, - syscall: isSyscall, - }) - } - return goroutines -} - -// AdvanceTime advances the synthetic clock by d. -func (g *synctestGroup) AdvanceTime(d time.Duration) { - defer g.Wait() - g.mu.Lock() - defer g.mu.Unlock() - g.now = g.now.Add(d) - for tm := range g.timers { - if tm.when.After(g.now) { - continue - } - tm.run() - delete(g.timers, tm) - } -} - -// Now returns the current synthetic time. -func (g *synctestGroup) Now() time.Time { - g.mu.Lock() - defer g.mu.Unlock() - return g.now -} - -// TimeUntilEvent returns the amount of time until the next scheduled timer. -func (g *synctestGroup) TimeUntilEvent() (d time.Duration, scheduled bool) { - g.mu.Lock() - defer g.mu.Unlock() - for tm := range g.timers { - if dd := tm.when.Sub(g.now); !scheduled || dd < d { - d = dd - scheduled = true - } - } - return d, scheduled -} - -// Sleep is time.Sleep, but using synthetic time. -func (g *synctestGroup) Sleep(d time.Duration) { - tm := g.NewTimer(d) - <-tm.C() -} - -// NewTimer is time.NewTimer, but using synthetic time. -func (g *synctestGroup) NewTimer(d time.Duration) Timer { - return g.addTimer(d, &fakeTimer{ - ch: make(chan time.Time), - }) -} - -// AfterFunc is time.AfterFunc, but using synthetic time. -func (g *synctestGroup) AfterFunc(d time.Duration, f func()) Timer { - return g.addTimer(d, &fakeTimer{ - f: f, - }) -} - -// ContextWithTimeout is context.WithTimeout, but using synthetic time. -func (g *synctestGroup) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(ctx) - tm := g.AfterFunc(d, cancel) - return ctx, func() { - tm.Stop() - cancel() - } -} - -func (g *synctestGroup) addTimer(d time.Duration, tm *fakeTimer) *fakeTimer { - g.mu.Lock() - defer g.mu.Unlock() - tm.g = g - tm.when = g.now.Add(d) - if g.timers == nil { - g.timers = make(map[*fakeTimer]struct{}) - } - if tm.when.After(g.now) { - g.timers[tm] = struct{}{} - } else { - tm.run() - } - return tm -} - -type Timer = interface { - C() <-chan time.Time - Reset(d time.Duration) bool - Stop() bool -} - -type fakeTimer struct { - g *synctestGroup - when time.Time - ch chan time.Time - f func() -} - -func (tm *fakeTimer) run() { - if tm.ch != nil { - tm.ch <- tm.g.now - } else { - go func() { - tm.g.Join() - tm.f() - }() - } -} - -func (tm *fakeTimer) C() <-chan time.Time { return tm.ch } - -func (tm *fakeTimer) Reset(d time.Duration) bool { - tm.g.mu.Lock() - defer tm.g.mu.Unlock() - _, stopped := tm.g.timers[tm] - if d <= 0 { - delete(tm.g.timers, tm) - tm.run() - } else { - tm.when = tm.g.now.Add(d) - tm.g.timers[tm] = struct{}{} - } - return stopped -} - -func (tm *fakeTimer) Stop() bool { - tm.g.mu.Lock() - defer tm.g.mu.Unlock() - _, stopped := tm.g.timers[tm] - delete(tm.g.timers, tm) - return stopped -} - -// TestSynctestLogs verifies that t.Log works, -// in particular that the GOOS=js workaround in synctestGroup.Wait is working. -// (When GOOS=js, writing to stdout can hang indefinitely if some goroutine loops -// calling runtime.Gosched; see Wait for the workaround.) -func TestSynctestLogs(t *testing.T) { - g := newSynctest(time.Now()) - donec := make(chan struct{}) - go func() { - g.Join() - for range 100 { - t.Logf("logging a long line") - } - close(donec) - }() - g.Wait() - select { - case <-donec: - default: - panic("done") - } -} diff --git a/http2/synctest_go124_test.go b/http2/synctest_go124_test.go new file mode 100644 index 0000000000..59f66ac2da --- /dev/null +++ b/http2/synctest_go124_test.go @@ -0,0 +1,42 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.25 && goexperiment.synctest + +package http2 + +import ( + "slices" + "testing" + "testing/synctest" +) + +// synctestTest emulates the Go 1.25 synctest.Test function on Go 1.24. +func synctestTest(t *testing.T, f func(t testing.TB)) { + t.Helper() + synctest.Run(func() { + t.Helper() + ct := &cleanupT{T: t} + defer ct.done() + f(ct) + }) +} + +// cleanupT wraps a testing.T and adds its own Cleanup method. +// Used to execute cleanup functions within a synctest bubble. +type cleanupT struct { + *testing.T + cleanups []func() +} + +// Cleanup replaces T.Cleanup. +func (t *cleanupT) Cleanup(f func()) { + t.cleanups = append(t.cleanups, f) +} + +func (t *cleanupT) done() { + for _, f := range slices.Backward(t.cleanups) { + f() + } +} diff --git a/http2/synctest_go125_test.go b/http2/synctest_go125_test.go new file mode 100644 index 0000000000..a0c5696160 --- /dev/null +++ b/http2/synctest_go125_test.go @@ -0,0 +1,20 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.25 + +package http2 + +import ( + "testing" + "testing/synctest" +) + +func synctestTest(t *testing.T, f func(t testing.TB)) { + t.Helper() + synctest.Test(t, func(t *testing.T) { + t.Helper() + f(t) + }) +} diff --git a/http2/timer.go b/http2/timer.go deleted file mode 100644 index 0b1c17b812..0000000000 --- a/http2/timer.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2024 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -package http2 - -import "time" - -// A timer is a time.Timer, as an interface which can be replaced in tests. -type timer = interface { - C() <-chan time.Time - Reset(d time.Duration) bool - Stop() bool -} - -// timeTimer adapts a time.Timer to the timer interface. -type timeTimer struct { - *time.Timer -} - -func (t timeTimer) C() <-chan time.Time { return t.Timer.C } diff --git a/http2/transport.go b/http2/transport.go index f26356b9cd..35e3902519 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -193,50 +193,6 @@ type Transport struct { type transportTestHooks struct { newclientconn func(*ClientConn) - group synctestGroupInterface -} - -func (t *Transport) markNewGoroutine() { - if t != nil && t.transportTestHooks != nil { - t.transportTestHooks.group.Join() - } -} - -func (t *Transport) now() time.Time { - if t != nil && t.transportTestHooks != nil { - return t.transportTestHooks.group.Now() - } - return time.Now() -} - -func (t *Transport) timeSince(when time.Time) time.Duration { - if t != nil && t.transportTestHooks != nil { - return t.now().Sub(when) - } - return time.Since(when) -} - -// newTimer creates a new time.Timer, or a synthetic timer in tests. -func (t *Transport) newTimer(d time.Duration) timer { - if t.transportTestHooks != nil { - return t.transportTestHooks.group.NewTimer(d) - } - return timeTimer{time.NewTimer(d)} -} - -// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. -func (t *Transport) afterFunc(d time.Duration, f func()) timer { - if t.transportTestHooks != nil { - return t.transportTestHooks.group.AfterFunc(d, f) - } - return timeTimer{time.AfterFunc(d, f)} -} - -func (t *Transport) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { - if t.transportTestHooks != nil { - return t.transportTestHooks.group.ContextWithTimeout(ctx, d) - } - return context.WithTimeout(ctx, d) } func (t *Transport) maxHeaderListSize() uint32 { @@ -366,7 +322,7 @@ type ClientConn struct { readerErr error // set before readerDone is closed idleTimeout time.Duration // or 0 for never - idleTimer timer + idleTimer *time.Timer mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes @@ -534,14 +490,12 @@ func (cs *clientStream) closeReqBodyLocked() { cs.reqBodyClosed = make(chan struct{}) reqBodyClosed := cs.reqBodyClosed go func() { - cs.cc.t.markNewGoroutine() cs.reqBody.Close() close(reqBodyClosed) }() } type stickyErrWriter struct { - group synctestGroupInterface conn net.Conn timeout time.Duration err *error @@ -551,7 +505,7 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) { if *sew.err != nil { return 0, *sew.err } - n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p) + n, err = writeWithByteTimeout(sew.conn, sew.timeout, p) *sew.err = err return n, err } @@ -650,9 +604,9 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - tm := t.newTimer(d) + tm := time.NewTimer(d) select { - case <-tm.C(): + case <-tm.C: t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): @@ -699,6 +653,7 @@ var ( errClientConnUnusable = errors.New("http2: client conn not usable") errClientConnNotEstablished = errors.New("http2: client conn could not be established") errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + errClientConnForceClosed = errors.New("http2: client connection force closed via ClientConn.Close") ) // shouldRetryRequest is called by RoundTrip when a request fails to get @@ -838,14 +793,11 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro pingTimeout: conf.PingTimeout, pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), - lastActive: t.now(), + lastActive: time.Now(), } - var group synctestGroupInterface if t.transportTestHooks != nil { - t.markNewGoroutine() t.transportTestHooks.newclientconn(cc) c = cc.tconn - group = t.group } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -857,7 +809,6 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro // TODO: adjust this writer size to account for frame size + // MTU + crypto/tls record padding. cc.bw = bufio.NewWriter(stickyErrWriter{ - group: group, conn: c, timeout: conf.WriteByteTimeout, err: &cc.werr, @@ -906,7 +857,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro // Start the idle timer after the connection is fully initialized. if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d - cc.idleTimer = t.afterFunc(d, cc.onIdleTimeout) + cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) } go cc.readLoop() @@ -917,7 +868,7 @@ func (cc *ClientConn) healthCheck() { pingTimeout := cc.pingTimeout // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. - ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout) + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) defer cancel() cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) @@ -1120,7 +1071,7 @@ func (cc *ClientConn) tooIdleLocked() bool { // times are compared based on their wall time. We don't want // to reuse a connection that's been sitting idle during // VM/laptop suspend if monotonic time was also frozen. - return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && cc.t.timeSince(cc.lastIdle.Round(0)) > cc.idleTimeout + return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout } // onIdleTimeout is called from a time.AfterFunc goroutine. It will @@ -1186,7 +1137,6 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { done := make(chan struct{}) cancelled := false // guarded by cc.mu go func() { - cc.t.markNewGoroutine() cc.mu.Lock() defer cc.mu.Unlock() for { @@ -1257,8 +1207,7 @@ func (cc *ClientConn) closeForError(err error) { // // In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. func (cc *ClientConn) Close() error { - err := errors.New("http2: client connection force closed via ClientConn.Close") - cc.closeForError(err) + cc.closeForError(errClientConnForceClosed) return nil } @@ -1427,7 +1376,6 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) // // It sends the request and performs post-request cleanup (closing Request.Body, etc.). func (cs *clientStream) doRequest(req *http.Request, streamf func(*clientStream)) { - cs.cc.t.markNewGoroutine() err := cs.writeRequest(req, streamf) cs.cleanupWriteRequest(err) } @@ -1558,9 +1506,9 @@ func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStre var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} if d := cc.responseHeaderTimeout(); d != 0 { - timer := cc.t.newTimer(d) + timer := time.NewTimer(d) defer timer.Stop() - respHeaderTimer = timer.C() + respHeaderTimer = timer.C respHeaderRecv = cs.respHeaderRecv } // Wait until the peer half-closes its end of the stream, @@ -1753,7 +1701,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { // Return a fatal error which aborts the retry loop. return errClientConnNotEstablished } - cc.lastActive = cc.t.now() + cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { return errClientConnUnusable } @@ -2092,10 +2040,10 @@ func (cc *ClientConn) forgetStreamID(id uint32) { if len(cc.streams) != slen-1 { panic("forgetting unknown stream id") } - cc.lastActive = cc.t.now() + cc.lastActive = time.Now() if len(cc.streams) == 0 && cc.idleTimer != nil { cc.idleTimer.Reset(cc.idleTimeout) - cc.lastIdle = cc.t.now() + cc.lastIdle = time.Now() } // Wake up writeRequestBody via clientStream.awaitFlowControl and // wake up RoundTrip if there is a pending request. @@ -2121,7 +2069,6 @@ type clientConnReadLoop struct { // readLoop runs in its own goroutine and reads and dispatches frames. func (cc *ClientConn) readLoop() { - cc.t.markNewGoroutine() rl := &clientConnReadLoop{cc: cc} defer rl.cleanup() cc.readerErr = rl.run() @@ -2188,9 +2135,9 @@ func (rl *clientConnReadLoop) cleanup() { if cc.idleTimeout > 0 && unusedWaitTime > cc.idleTimeout { unusedWaitTime = cc.idleTimeout } - idleTime := cc.t.now().Sub(cc.lastActive) + idleTime := time.Now().Sub(cc.lastActive) if atomic.LoadUint32(&cc.atomicReused) == 0 && idleTime < unusedWaitTime && !cc.closedOnIdle { - cc.idleTimer = cc.t.afterFunc(unusedWaitTime-idleTime, func() { + cc.idleTimer = time.AfterFunc(unusedWaitTime-idleTime, func() { cc.t.connPool().MarkDead(cc) }) } else { @@ -2250,9 +2197,9 @@ func (rl *clientConnReadLoop) run() error { cc := rl.cc gotSettings := false readIdleTimeout := cc.readIdleTimeout - var t timer + var t *time.Timer if readIdleTimeout != 0 { - t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck) + t = time.AfterFunc(readIdleTimeout, cc.healthCheck) } for { f, err := cc.fr.ReadFrame() @@ -2998,7 +2945,6 @@ func (cc *ClientConn) Ping(ctx context.Context) error { var pingError error errc := make(chan struct{}) go func() { - cc.t.markNewGoroutine() cc.wmu.Lock() defer cc.wmu.Unlock() if pingError = cc.fr.WritePing(false, p); pingError != nil { @@ -3228,7 +3174,7 @@ func traceGotConn(req *http.Request, cc *ClientConn, reused bool) { cc.mu.Lock() ci.WasIdle = len(cc.streams) == 0 && reused if ci.WasIdle && !cc.lastActive.IsZero() { - ci.IdleTime = cc.t.timeSince(cc.lastActive) + ci.IdleTime = time.Since(cc.lastActive) } cc.mu.Unlock() diff --git a/http2/transport_test.go b/http2/transport_test.go index f94d9e400b..49aaf8c0d4 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build go1.25 || goexperiment.synctest + package http2 import ( @@ -26,13 +28,13 @@ import ( "net/url" "os" "reflect" - "runtime" "sort" "strconv" "strings" "sync" "sync/atomic" "testing" + "testing/synctest" "time" "golang.org/x/net/http2/hpack" @@ -121,7 +123,7 @@ func TestIdleConnTimeout(t *testing.T) { }, wantNewConn: false, }} { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t testing.TB) { tt := newTestTransport(t, func(tr *Transport) { tr.IdleConnTimeout = test.idleConnTimeout }) @@ -166,7 +168,7 @@ func TestIdleConnTimeout(t *testing.T) { tc.wantFrameType(FrameSettings) // ACK to our settings } - tt.advance(test.wait) + time.Sleep(test.wait) if got, want := tc.isClosed(), test.wantNewConn; got != want { t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want) } @@ -849,10 +851,18 @@ func newLocalListener(t *testing.T) net.Listener { return ln } -func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) } -func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) } +func TestTransportReqBodyAfterResponse_200(t *testing.T) { + synctestTest(t, func(t testing.TB) { + testTransportReqBodyAfterResponse(t, 200) + }) +} +func TestTransportReqBodyAfterResponse_403(t *testing.T) { + synctestTest(t, func(t testing.TB) { + testTransportReqBodyAfterResponse(t, 403) + }) +} -func testTransportReqBodyAfterResponse(t *testing.T, status int) { +func testTransportReqBodyAfterResponse(t testing.TB, status int) { const bodySize = 1 << 10 tc := newTestClientConn(t) @@ -1083,6 +1093,11 @@ func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) } func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) { + synctestTest(t, func(t testing.TB) { + testTransportResPatternBubble(t, expect100Continue, resHeader, withData, trailers) + }) +} +func testTransportResPatternBubble(t testing.TB, expect100Continue, resHeader headerType, withData bool, trailers headerType) { const reqBody = "some request body" const resBody = "some response body" @@ -1163,7 +1178,8 @@ func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerTy } // Issue 26189, Issue 17739: ignore unknown 1xx responses -func TestTransportUnknown1xx(t *testing.T) { +func TestTransportUnknown1xx(t *testing.T) { synctestTest(t, testTransportUnknown1xx) } +func testTransportUnknown1xx(t testing.TB) { var buf bytes.Buffer defer func() { got1xxFuncForTests = nil }() got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error { @@ -1213,6 +1229,9 @@ code=114 header=map[Foo-Bar:[114]] } func TestTransportReceiveUndeclaredTrailer(t *testing.T) { + synctestTest(t, testTransportReceiveUndeclaredTrailer) +} +func testTransportReceiveUndeclaredTrailer(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -1280,6 +1299,11 @@ func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) { } func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) { + synctestTest(t, func(t testing.TB) { + testInvalidTrailerBubble(t, mode, wantErr, trailers...) + }) +} +func testInvalidTrailerBubble(t testing.TB, mode headerType, wantErr error, trailers ...string) { tc := newTestClientConn(t) tc.greet() @@ -1334,7 +1358,7 @@ func headerListSize(h http.Header) (size uint32) { // space for an empty "Pad-Headers" key, then adds as many copies of // filler as possible. Any remaining bytes necessary to push the // header list size up to limit are added to h["Pad-Headers"]. -func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) { +func padHeaders(t testing.TB, h http.Header, limit uint64, filler string) { if limit > 0xffffffff { t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit) } @@ -1427,61 +1451,35 @@ func TestPadHeaders(t *testing.T) { } func TestTransportChecksRequestHeaderListSize(t *testing.T) { - ts := newTestServer(t, - func(w http.ResponseWriter, r *http.Request) { - // Consume body & force client to send - // trailers before writing response. - // io.ReadAll returns non-nil err for - // requests that attempt to send greater than - // maxHeaderListSize bytes of trailers, since - // those requests generate a stream reset. - io.ReadAll(r.Body) - r.Body.Close() - }, - func(ts *httptest.Server) { - ts.Config.MaxHeaderBytes = 16 << 10 - }, - optQuiet, - ) + synctestTest(t, testTransportChecksRequestHeaderListSize) +} +func testTransportChecksRequestHeaderListSize(t testing.TB) { + const peerSize = 16 << 10 - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() + tc := newTestClientConn(t) + tc.greet(Setting{SettingMaxHeaderListSize, peerSize}) checkRoundTrip := func(req *http.Request, wantErr error, desc string) { - // Make an arbitrary request to ensure we get the server's - // settings frame and initialize peerMaxHeaderListSize. - req0, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatalf("newRequest: NewRequest: %v", err) - } - res0, err := tr.RoundTrip(req0) - if err != nil { - t.Errorf("%v: Initial RoundTrip err = %v", desc, err) - } - res0.Body.Close() - - res, err := tr.RoundTrip(req) - if !errors.Is(err, wantErr) { - if res != nil { - res.Body.Close() - } - t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr) - return - } - if err == nil { - if res == nil { - t.Errorf("%v: response nil; want non-nil.", desc) - return - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK) + t.Helper() + rt := tc.roundTrip(req) + if wantErr != nil { + if err := rt.err(); !errors.Is(err, wantErr) { + t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr) } return } - if res != nil { - t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err) - } + + tc.wantFrameType(FrameHeaders) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + + rt.wantStatus(http.StatusOK) } headerListSizeForRequest := func(req *http.Request) (size uint64) { const addGzipHeader = true @@ -1501,56 +1499,15 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { newRequest := func() *http.Request { // Body must be non-nil to enable writing trailers. body := strings.NewReader("hello") - req, err := http.NewRequest("POST", ts.URL, body) + req, err := http.NewRequest("POST", "https://example.tld/", body) if err != nil { t.Fatalf("newRequest: NewRequest: %v", err) } return req } - var ( - scMu sync.Mutex - sc *serverConn - ) - testHookGetServerConn = func(v *serverConn) { - scMu.Lock() - defer scMu.Unlock() - if sc != nil { - panic("testHookGetServerConn called multiple times") - } - sc = v - } - defer func() { - testHookGetServerConn = nil - }() - - // Validate peerMaxHeaderListSize. - req := newRequest() - checkRoundTrip(req, nil, "Initial request") - addr := authorityAddr(req.URL.Scheme, req.URL.Host) - cc, err := tr.connPool().GetClientConn(req, addr) - if err != nil { - t.Fatalf("GetClientConn: %v", err) - } - cc.mu.Lock() - peerSize := cc.peerMaxHeaderListSize - cc.mu.Unlock() - scMu.Lock() - wantSize := uint64(sc.maxHeaderListSize()) - scMu.Unlock() - if peerSize != wantSize { - t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize) - } - - // Sanity check peerSize. (*serverConn) maxHeaderListSize adds - // 320 bytes of padding. - wantHeaderBytes := uint64(ts.Config.MaxHeaderBytes) + 320 - if peerSize != wantHeaderBytes { - t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes) - } - // Pad headers & trailers, but stay under peerSize. - req = newRequest() + req := newRequest() req.Header = make(http.Header) req.Trailer = make(http.Header) filler := strings.Repeat("*", 1024) @@ -1588,6 +1545,9 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { } func TestTransportChecksResponseHeaderListSize(t *testing.T) { + synctestTest(t, testTransportChecksResponseHeaderListSize) +} +func testTransportChecksResponseHeaderListSize(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -1633,7 +1593,8 @@ func TestTransportChecksResponseHeaderListSize(t *testing.T) { } } -func TestTransportCookieHeaderSplit(t *testing.T) { +func TestTransportCookieHeaderSplit(t *testing.T) { synctestTest(t, testTransportCookieHeaderSplit) } +func testTransportCookieHeaderSplit(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -1862,13 +1823,17 @@ func isTimeout(err error) bool { // Test that the http1 Transport.ResponseHeaderTimeout option and cancel is sent. func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) { - testTransportResponseHeaderTimeout(t, false) + synctestTest(t, func(t testing.TB) { + testTransportResponseHeaderTimeout(t, false) + }) } func TestTransportResponseHeaderTimeout_Body(t *testing.T) { - testTransportResponseHeaderTimeout(t, true) + synctestTest(t, func(t testing.TB) { + testTransportResponseHeaderTimeout(t, true) + }) } -func testTransportResponseHeaderTimeout(t *testing.T, body bool) { +func testTransportResponseHeaderTimeout(t testing.TB, body bool) { const bodySize = 4 << 20 tc := newTestClientConn(t, func(tr *Transport) { tr.t1 = &http.Transport{ @@ -1904,11 +1869,11 @@ func testTransportResponseHeaderTimeout(t *testing.T, body bool) { }) } - tc.advance(4 * time.Millisecond) + time.Sleep(4 * time.Millisecond) if rt.done() { t.Fatalf("RoundTrip is done after 4ms; want still waiting") } - tc.advance(1 * time.Millisecond) + time.Sleep(1 * time.Millisecond) if err := rt.err(); !isTimeout(err) { t.Fatalf("RoundTrip error: %v; want timeout error", err) @@ -2304,7 +2269,8 @@ func TestTransportNewTLSConfig(t *testing.T) { // The Google GFE responds to HEAD requests with a HEADERS frame // without END_STREAM, followed by a 0-length DATA frame with // END_STREAM. Make sure we don't get confused by that. (We did.) -func TestTransportReadHeadResponse(t *testing.T) { +func TestTransportReadHeadResponse(t *testing.T) { synctestTest(t, testTransportReadHeadResponse) } +func testTransportReadHeadResponse(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -2331,6 +2297,9 @@ func TestTransportReadHeadResponse(t *testing.T) { } func TestTransportReadHeadResponseWithBody(t *testing.T) { + synctestTest(t, testTransportReadHeadResponseWithBody) +} +func testTransportReadHeadResponseWithBody(t testing.TB) { // This test uses an invalid response format. // Discard logger output to not spam tests output. log.SetOutput(io.Discard) @@ -2371,101 +2340,102 @@ func (b neverEnding) Read(p []byte) (int, error) { return len(p), nil } -// golang.org/issue/15425: test that a handler closing the request -// body doesn't terminate the stream to the peer. (It just stops -// readability from the handler's side, and eventually the client -// runs out of flow control tokens) -func TestTransportHandlerBodyClose(t *testing.T) { - const bodySize = 10 << 20 - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - r.Body.Close() - io.Copy(w, io.LimitReader(neverEnding('A'), bodySize)) - }) - - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() +// #15425: Transport goroutine leak while the transport is still trying to +// write its body after the stream has completed. +func TestTransportStreamEndsWhileBodyIsBeingWritten(t *testing.T) { + synctestTest(t, testTransportStreamEndsWhileBodyIsBeingWritten) +} +func testTransportStreamEndsWhileBodyIsBeingWritten(t testing.TB) { + body := "this is the client request body" + const windowSize = 10 // less than len(body) - g0 := runtime.NumGoroutine() + tc := newTestClientConn(t) + tc.greet(Setting{SettingInitialWindowSize, windowSize}) - const numReq = 10 - for i := 0; i < numReq; i++ { - req, err := http.NewRequest("POST", ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) - if err != nil { - t.Fatal(err) - } - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatal(err) - } - n, err := io.Copy(io.Discard, res.Body) - res.Body.Close() - if n != bodySize || err != nil { - t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize) - } - } - tr.CloseIdleConnections() + // Client sends a request, and as much body as fits into the stream window. + req, _ := http.NewRequest("PUT", "https://dummy.tld/", strings.NewReader(body)) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + tc.wantData(wantData{ + streamID: rt.streamID(), + endStream: false, + size: windowSize, + }) - if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool { - gd := runtime.NumGoroutine() - g0 - return gd < numReq/2 - }) { - t.Errorf("appeared to leak goroutines") - } + // Server responds without permitting the rest of the body to be sent. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: true, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "413", + ), + }) + rt.wantStatus(413) } -// https://golang.org/issue/15930 -func TestTransportFlowControl(t *testing.T) { - const bufLen = 64 << 10 - var total int64 = 100 << 20 // 100MB - if testing.Short() { - total = 10 << 20 - } - - var wrote int64 // updated atomically - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - b := make([]byte, bufLen) - for wrote < total { - n, err := w.Write(b) - atomic.AddInt64(&wrote, int64(n)) - if err != nil { - t.Errorf("ResponseWriter.Write error: %v", err) - break - } - w.(http.Flusher).Flush() +func TestTransportFlowControl(t *testing.T) { synctestTest(t, testTransportFlowControl) } +func testTransportFlowControl(t testing.TB) { + const maxBuffer = 64 << 10 // 64KiB + tc := newTestClientConn(t, func(tr *http.Transport) { + tr.HTTP2 = &http.HTTP2Config{ + MaxReceiveBufferPerConnection: maxBuffer, + MaxReceiveBufferPerStream: maxBuffer, + MaxReadFrameSize: 16 << 20, // 16MiB } }) + tc.greet() - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal("NewRequest error:", err) - } - resp, err := tr.RoundTrip(req) - if err != nil { - t.Fatal("RoundTrip error:", err) - } - defer resp.Body.Close() + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) - var read int64 - b := make([]byte, bufLen) + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + rt.wantStatus(200) + + // Server fills up its transmit buffer. + // The client does not provide more flow control tokens, + // since the data hasn't been consumed by the user. + tc.writeData(rt.streamID(), false, make([]byte, maxBuffer)) + tc.wantIdle() + + // User reads data from the response body. + // The client sends more flow control tokens. + resp := rt.response() + if _, err := io.ReadFull(resp.Body, make([]byte, maxBuffer)); err != nil { + t.Fatalf("io.Body.Read: %v", err) + } + var connTokens, streamTokens uint32 for { - n, err := resp.Body.Read(b) - if err == io.EOF { + f := tc.readFrame() + if f == nil { break } - if err != nil { - t.Fatal("Read error:", err) + wu, ok := f.(*WindowUpdateFrame) + if !ok { + t.Fatalf("received unexpected frame %T (want WINDOW_UPDATE)", f) } - read += int64(n) - - const max = transportDefaultStreamFlow - if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max { - t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read) + switch wu.StreamID { + case 0: + connTokens += wu.Increment + case wu.StreamID: + streamTokens += wu.Increment + default: + t.Fatalf("received unexpected WINDOW_UPDATE for stream %v", wu.StreamID) } - - // Let the server get ahead of the client. - time.Sleep(1 * time.Millisecond) + } + if got, want := connTokens, uint32(maxBuffer); got != want { + t.Errorf("transport provided %v bytes of connection WINDOW_UPDATE, want %v", got, want) + } + if got, want := streamTokens, uint32(maxBuffer); got != want { + t.Errorf("transport provided %v bytes of stream WINDOW_UPDATE, want %v", got, want) } } @@ -2475,14 +2445,18 @@ func TestTransportFlowControl(t *testing.T) { // proceeds to close the TCP connection before the client gets its // response) func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) { - testTransportUsesGoAwayDebugError(t, false) + synctestTest(t, func(t testing.TB) { + testTransportUsesGoAwayDebugError(t, false) + }) } func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { - testTransportUsesGoAwayDebugError(t, true) + synctestTest(t, func(t testing.TB) { + testTransportUsesGoAwayDebugError(t, true) + }) } -func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { +func testTransportUsesGoAwayDebugError(t testing.TB, failMidBody bool) { tc := newTestClientConn(t) tc.greet() @@ -2532,7 +2506,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { } } -func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { +func testTransportReturnsUnusedFlowControl(t testing.TB, oneDataFrame bool) { tc := newTestClientConn(t) tc.greet() @@ -2573,7 +2547,7 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { t.Fatalf("body read = %v, %v; want 1, nil", n, err) } res.Body.Close() // leaving 4999 bytes unread - tc.sync() + synctest.Wait() sentAdditionalData := false tc.wantUnorderedFrames( @@ -2609,17 +2583,22 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) { // See golang.org/issue/16481 func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) { - testTransportReturnsUnusedFlowControl(t, true) + synctestTest(t, func(t testing.TB) { + testTransportReturnsUnusedFlowControl(t, true) + }) } // See golang.org/issue/20469 func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) { - testTransportReturnsUnusedFlowControl(t, false) + synctestTest(t, func(t testing.TB) { + testTransportReturnsUnusedFlowControl(t, false) + }) } // Issue 16612: adjust flow control on open streams when transport // receives SETTINGS with INITIAL_WINDOW_SIZE from server. -func TestTransportAdjustsFlowControl(t *testing.T) { +func TestTransportAdjustsFlowControl(t *testing.T) { synctestTest(t, testTransportAdjustsFlowControl) } +func testTransportAdjustsFlowControl(t testing.TB) { const bodySize = 1 << 20 tc := newTestClientConn(t) @@ -2676,6 +2655,9 @@ func TestTransportAdjustsFlowControl(t *testing.T) { // See golang.org/issue/16556 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { + synctestTest(t, testTransportReturnsDataPaddingFlowControl) +} +func testTransportReturnsDataPaddingFlowControl(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -2711,6 +2693,9 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { // golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a // StreamError as a result of the response HEADERS func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { + synctestTest(t, testTransportReturnsErrorOnBadResponseHeaders) +} +func testTransportReturnsErrorOnBadResponseHeaders(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -2762,6 +2747,9 @@ func (b byteAndEOFReader) Read(p []byte) (n int, err error) { // which returns (non-0, io.EOF) and also needs to set the ContentLength // explicitly. func TestTransportBodyDoubleEndStream(t *testing.T) { + synctestTest(t, testTransportBodyDoubleEndStream) +} +func testTransportBodyDoubleEndStream(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { // Nothing. }) @@ -2916,17 +2904,20 @@ func TestTransportRequestPathPseudo(t *testing.T) { // golang.org/issue/17071 -- don't sniff the first byte of the request body // before we've determined that the ClientConn is usable. func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) { + synctestTest(t, testRoundTripDoesntConsumeRequestBodyEarly) +} +func testRoundTripDoesntConsumeRequestBodyEarly(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + tc.closeWrite() + const body = "foo" req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body))) - cc := &ClientConn{ - closed: true, - reqHeaderMu: make(chan struct{}, 1), - t: &Transport{}, - } - _, err := cc.RoundTrip(req) - if err != errClientConnUnusable { - t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err) + rt := tc.roundTrip(req) + if err := rt.err(); err != errClientConnNotEstablished { + t.Fatalf("RoundTrip = %v; want errClientConnNotEstablished", err) } + slurp, err := io.ReadAll(req.Body) if err != nil { t.Errorf("ReadAll = %v", err) @@ -3031,7 +3022,8 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) { req.Header = http.Header{} } -func TestTransportCloseAfterLostPing(t *testing.T) { +func TestTransportCloseAfterLostPing(t *testing.T) { synctestTest(t, testTransportCloseAfterLostPing) } +func testTransportCloseAfterLostPing(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.PingTimeout = 1 * time.Second tr.ReadIdleTimeout = 1 * time.Second @@ -3042,10 +3034,10 @@ func TestTransportCloseAfterLostPing(t *testing.T) { rt := tc.roundTrip(req) tc.wantFrameType(FrameHeaders) - tc.advance(1 * time.Second) + time.Sleep(1 * time.Second) tc.wantFrameType(FramePing) - tc.advance(1 * time.Second) + time.Sleep(1 * time.Second) err := rt.err() if err == nil || !strings.Contains(err.Error(), "client connection lost") { t.Fatalf("expected to get error about \"connection lost\", got %v", err) @@ -3081,6 +3073,9 @@ func TestTransportPingWriteBlocks(t *testing.T) { } func TestTransportPingWhenReadingMultiplePings(t *testing.T) { + synctestTest(t, testTransportPingWhenReadingMultiplePings) +} +func testTransportPingWhenReadingMultiplePings(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.ReadIdleTimeout = 1000 * time.Millisecond }) @@ -3102,20 +3097,20 @@ func TestTransportPingWhenReadingMultiplePings(t *testing.T) { for i := 0; i < 5; i++ { // No ping yet... - tc.advance(999 * time.Millisecond) + time.Sleep(999 * time.Millisecond) if f := tc.readFrame(); f != nil { t.Fatalf("unexpected frame: %v", f) } // ...ping now. - tc.advance(1 * time.Millisecond) + time.Sleep(1 * time.Millisecond) f := readFrame[*PingFrame](t, tc) tc.writePing(true, f.Data) } // Cancel the request, Transport resets it and returns an error from body reads. cancel() - tc.sync() + synctest.Wait() tc.wantFrameType(FrameRSTStream) _, err := rt.readBody() @@ -3125,6 +3120,9 @@ func TestTransportPingWhenReadingMultiplePings(t *testing.T) { } func TestTransportPingWhenReadingPingDisabled(t *testing.T) { + synctestTest(t, testTransportPingWhenReadingPingDisabled) +} +func testTransportPingWhenReadingPingDisabled(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.ReadIdleTimeout = 0 // PINGs disabled }) @@ -3144,13 +3142,16 @@ func TestTransportPingWhenReadingPingDisabled(t *testing.T) { }) // No PING is sent, even after a long delay. - tc.advance(1 * time.Minute) + time.Sleep(1 * time.Minute) if f := tc.readFrame(); f != nil { t.Fatalf("unexpected frame: %v", f) } } func TestTransportRetryAfterGOAWAYNoRetry(t *testing.T) { + synctestTest(t, testTransportRetryAfterGOAWAYNoRetry) +} +func testTransportRetryAfterGOAWAYNoRetry(t testing.TB) { tt := newTestTransport(t) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) @@ -3175,6 +3176,9 @@ func TestTransportRetryAfterGOAWAYNoRetry(t *testing.T) { } func TestTransportRetryAfterGOAWAYRetry(t *testing.T) { + synctestTest(t, testTransportRetryAfterGOAWAYRetry) +} +func testTransportRetryAfterGOAWAYRetry(t testing.TB) { tt := newTestTransport(t) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) @@ -3219,6 +3223,9 @@ func TestTransportRetryAfterGOAWAYRetry(t *testing.T) { } func TestTransportRetryAfterGOAWAYSecondRequest(t *testing.T) { + synctestTest(t, testTransportRetryAfterGOAWAYSecondRequest) +} +func testTransportRetryAfterGOAWAYSecondRequest(t testing.TB) { tt := newTestTransport(t) // First request succeeds. @@ -3282,6 +3289,9 @@ func TestTransportRetryAfterGOAWAYSecondRequest(t *testing.T) { } func TestTransportRetryAfterRefusedStream(t *testing.T) { + synctestTest(t, testTransportRetryAfterRefusedStream) +} +func testTransportRetryAfterRefusedStream(t testing.TB) { tt := newTestTransport(t) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) @@ -3320,20 +3330,21 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { rt.wantStatus(204) } -func TestTransportRetryHasLimit(t *testing.T) { +func TestTransportRetryHasLimit(t *testing.T) { synctestTest(t, testTransportRetryHasLimit) } +func testTransportRetryHasLimit(t testing.TB) { tt := newTestTransport(t) req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) rt := tt.roundTrip(req) - // First attempt: Server sends a GOAWAY. tc := tt.getConn() + tc.netconn.SetReadDeadline(time.Time{}) tc.wantFrameType(FrameSettings) tc.wantFrameType(FrameWindowUpdate) - var totalDelay time.Duration count := 0 - for streamID := uint32(1); ; streamID += 2 { + start := time.Now() + for streamID := uint32(1); !rt.done(); streamID += 2 { count++ tc.wantHeaders(wantHeader{ streamID: streamID, @@ -3345,18 +3356,9 @@ func TestTransportRetryHasLimit(t *testing.T) { } tc.writeRSTStream(streamID, ErrCodeRefusedStream) - d, scheduled := tt.group.TimeUntilEvent() - if !scheduled { - if streamID == 1 { - continue - } - break - } - totalDelay += d - if totalDelay > 5*time.Minute { + if totalDelay := time.Since(start); totalDelay > 5*time.Minute { t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay) } - tt.advance(d) } if got, want := count, 5; got < count { t.Errorf("RoundTrip made %v attempts, want at least %v", got, want) @@ -3367,6 +3369,9 @@ func TestTransportRetryHasLimit(t *testing.T) { } func TestTransportResponseDataBeforeHeaders(t *testing.T) { + synctestTest(t, testTransportResponseDataBeforeHeaders) +} +func testTransportResponseDataBeforeHeaders(t testing.TB) { // Discard log output complaining about protocol error. log.SetOutput(io.Discard) t.Cleanup(func() { log.SetOutput(os.Stderr) }) // after other cleanup is done @@ -3408,7 +3413,7 @@ func TestTransportMaxFrameReadSize(t *testing.T) { maxReadFrameSize: 1024, want: minMaxFrameSize, }} { - t.Run(fmt.Sprint(test.maxReadFrameSize), func(t *testing.T) { + synctestSubtest(t, fmt.Sprint(test.maxReadFrameSize), func(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.MaxReadFrameSize = test.maxReadFrameSize }) @@ -3470,6 +3475,9 @@ func TestTransportRequestsLowServerLimit(t *testing.T) { // tests Transport.StrictMaxConcurrentStreams func TestTransportRequestsStallAtServerLimit(t *testing.T) { + synctestTest(t, testTransportRequestsStallAtServerLimit) +} +func testTransportRequestsStallAtServerLimit(t testing.TB) { const maxConcurrent = 2 tc := newTestClientConn(t, func(tr *Transport) { @@ -3517,7 +3525,7 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { // Cancel the maxConcurrent'th request. // The request should fail. close(cancelClientRequest) - tc.sync() + synctest.Wait() if err := rts[maxConcurrent].err(); err == nil { t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent) } @@ -3551,6 +3559,9 @@ func TestTransportRequestsStallAtServerLimit(t *testing.T) { } func TestTransportMaxDecoderHeaderTableSize(t *testing.T) { + synctestTest(t, testTransportMaxDecoderHeaderTableSize) +} +func testTransportMaxDecoderHeaderTableSize(t testing.TB) { var reqSize, resSize uint32 = 8192, 16384 tc := newTestClientConn(t, func(tr *Transport) { tr.MaxDecoderHeaderTableSize = reqSize @@ -3572,6 +3583,9 @@ func TestTransportMaxDecoderHeaderTableSize(t *testing.T) { } func TestTransportMaxEncoderHeaderTableSize(t *testing.T) { + synctestTest(t, testTransportMaxEncoderHeaderTableSize) +} +func testTransportMaxEncoderHeaderTableSize(t testing.TB) { var peerAdvertisedMaxHeaderTableSize uint32 = 16384 tc := newTestClientConn(t, func(tr *Transport) { tr.MaxEncoderHeaderTableSize = 8192 @@ -3610,59 +3624,52 @@ func TestAuthorityAddr(t *testing.T) { // Issue 20448: stop allocating for DATA frames' payload after // Response.Body.Close is called. func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) { - megabyteZero := make([]byte, 1<<20) + synctestTest(t, testTransportAllocationsAfterResponseBodyClose) +} +func testTransportAllocationsAfterResponseBodyClose(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() - writeErr := make(chan error, 1) + // Send request. + req, _ := http.NewRequest("PUT", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - w.(http.Flusher).Flush() - var sum int64 - for i := 0; i < 100; i++ { - n, err := w.Write(megabyteZero) - sum += int64(n) - if err != nil { - writeErr <- err - return - } - } - t.Logf("wrote all %d bytes", sum) - writeErr <- nil + // Receive response with some body. + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), }) + tc.writeData(rt.streamID(), false, make([]byte, 64)) + tc.wantIdle() - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() - c := &http.Client{Transport: tr} - res, err := c.Get(ts.URL) - if err != nil { - t.Fatal(err) - } + // Client reads a byte of the body, and then closes it. + respBody := rt.response().Body var buf [1]byte - if _, err := res.Body.Read(buf[:]); err != nil { + if _, err := respBody.Read(buf[:]); err != nil { t.Error(err) } - if err := res.Body.Close(); err != nil { + if err := respBody.Close(); err != nil { t.Error(err) } + tc.wantFrameType(FrameRSTStream) - trb, ok := res.Body.(transportResponseBody) - if !ok { - t.Fatalf("res.Body = %T; want transportResponseBody", res.Body) - } - if trb.cs.bufPipe.b != nil { - t.Errorf("response body pipe is still open") - } + // Server sends more of the body, which is ignored. + tc.writeData(rt.streamID(), false, make([]byte, 64)) - gotErr := <-writeErr - if gotErr == nil { - t.Errorf("Handler unexpectedly managed to write its entire response without getting an error") - } else if gotErr != errStreamClosed { - t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr) + if _, err := respBody.Read(buf[:]); err == nil { + t.Error("read from closed body unexpectedly succeeded") } } // Issue 18891: make sure Request.Body == NoBody means no DATA frame // is ever sent, even if empty. -func TestTransportNoBodyMeansNoDATA(t *testing.T) { +func TestTransportNoBodyMeansNoDATA(t *testing.T) { synctestTest(t, testTransportNoBodyMeansNoDATA) } +func testTransportNoBodyMeansNoDATA(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -3756,6 +3763,9 @@ func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) { // Verify transport doesn't crash when receiving bogus response lacking a :status header. // Issue 22880. func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) { + synctestTest(t, testTransportHandlesInvalidStatuslessResponse) +} +func testTransportHandlesInvalidStatuslessResponse(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -3842,200 +3852,128 @@ func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) { } } -func activeStreams(cc *ClientConn) int { - count := 0 - cc.mu.Lock() - defer cc.mu.Unlock() - for _, cs := range cc.streams { - select { - case <-cs.abort: - default: - count++ - } +// The client closes the connection just after the server got the client's HEADERS +// frame, but before the server sends its HEADERS response back. The expected +// result is an error on RoundTrip explaining the client closed the connection. +func TestClientConnCloseAtHeaders(t *testing.T) { synctestTest(t, testClientConnCloseAtHeaders) } +func testClientConnCloseAtHeaders(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + tc.cc.Close() + synctest.Wait() + if err := rt.err(); err != errClientConnForceClosed { + t.Fatalf("RoundTrip error = %v, want errClientConnForceClosed", err) } - return count } -type closeMode int +// The client closes the connection while reading the response. +// The expected behavior is a response body io read error on the client. +func TestClientConnCloseAtBody(t *testing.T) { synctestTest(t, testClientConnCloseAtBody) } +func testClientConnCloseAtBody(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() -const ( - closeAtHeaders closeMode = iota - closeAtBody - shutdown - shutdownCancel -) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) -// See golang.org/issue/17292 -func testClientConnClose(t *testing.T, closeMode closeMode) { - clientDone := make(chan struct{}) - defer close(clientDone) - handlerDone := make(chan struct{}) - closeDone := make(chan struct{}) - beforeHeader := func() {} - bodyWrite := func(w http.ResponseWriter) {} - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - defer close(handlerDone) - beforeHeader() - w.WriteHeader(http.StatusOK) - w.(http.Flusher).Flush() - bodyWrite(w) - select { - case <-w.(http.CloseNotifier).CloseNotify(): - // client closed connection before completion - if closeMode == shutdown || closeMode == shutdownCancel { - t.Error("expected request to complete") - } - case <-clientDone: - if closeMode == closeAtHeaders || closeMode == closeAtBody { - t.Error("expected connection closed by client") - } - } + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), }) - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() - ctx := context.Background() - cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false) - req, err := http.NewRequest("GET", ts.URL, nil) - if err != nil { - t.Fatal(err) - } - if closeMode == closeAtHeaders { - beforeHeader = func() { - if err := cc.Close(); err != nil { - t.Error(err) - } - close(closeDone) - } - } - var sendBody chan struct{} - if closeMode == closeAtBody { - sendBody = make(chan struct{}) - bodyWrite = func(w http.ResponseWriter) { - <-sendBody - b := make([]byte, 32) - w.Write(b) - w.(http.Flusher).Flush() - if err := cc.Close(); err != nil { - t.Errorf("unexpected ClientConn close error: %v", err) - } - close(closeDone) - w.Write(b) - w.(http.Flusher).Flush() - } - } - res, err := cc.RoundTrip(req) - if res != nil { - defer res.Body.Close() - } - if closeMode == closeAtHeaders { - got := fmt.Sprint(err) - want := "http2: client connection force closed via ClientConn.Close" - if got != want { - t.Fatalf("RoundTrip error = %v, want %v", got, want) - } - } else { - if err != nil { - t.Fatalf("RoundTrip: %v", err) - } - if got, want := activeStreams(cc), 1; got != want { - t.Errorf("got %d active streams, want %d", got, want) - } - } - switch closeMode { - case shutdownCancel: - if err = cc.Shutdown(canceledCtx); err != context.Canceled { - t.Errorf("got %v, want %v", err, context.Canceled) - } - if cc.closing == false { - t.Error("expected closing to be true") - } - if cc.CanTakeNewRequest() == true { - t.Error("CanTakeNewRequest to return false") - } - if v, want := len(cc.streams), 1; v != want { - t.Errorf("expected %d active streams, got %d", want, v) - } - clientDone <- struct{}{} - <-handlerDone - case shutdown: - wait := make(chan struct{}) - shutdownEnterWaitStateHook = func() { - close(wait) - shutdownEnterWaitStateHook = func() {} - } - defer func() { shutdownEnterWaitStateHook = func() {} }() - shutdown := make(chan struct{}, 1) - go func() { - if err = cc.Shutdown(context.Background()); err != nil { - t.Error(err) - } - close(shutdown) - }() - // Let the shutdown to enter wait state - <-wait - cc.mu.Lock() - if cc.closing == false { - t.Error("expected closing to be true") - } - cc.mu.Unlock() - if cc.CanTakeNewRequest() == true { - t.Error("CanTakeNewRequest to return false") - } - if got, want := activeStreams(cc), 1; got != want { - t.Errorf("got %d active streams, want %d", got, want) - } - // Let the active request finish - clientDone <- struct{}{} - // Wait for the shutdown to end - select { - case <-shutdown: - case <-time.After(2 * time.Second): - t.Fatal("expected server connection to close") - } - case closeAtHeaders, closeAtBody: - if closeMode == closeAtBody { - go close(sendBody) - if _, err := io.Copy(io.Discard, res.Body); err == nil { - t.Error("expected a Copy error, got nil") - } - } - <-closeDone - if got, want := activeStreams(cc), 0; got != want { - t.Errorf("got %d active streams, want %d", got, want) - } - // wait for server to get the connection close notice - select { - case <-handlerDone: - case <-time.After(2 * time.Second): - t.Fatal("expected server connection to close") - } - } -} + tc.writeData(rt.streamID(), false, make([]byte, 64)) + tc.cc.Close() + synctest.Wait() -// The client closes the connection just after the server got the client's HEADERS -// frame, but before the server sends its HEADERS response back. The expected -// result is an error on RoundTrip explaining the client closed the connection. -func TestClientConnCloseAtHeaders(t *testing.T) { - testClientConnClose(t, closeAtHeaders) -} - -// The client closes the connection between two server's response DATA frames. -// The expected behavior is a response body io read error on the client. -func TestClientConnCloseAtBody(t *testing.T) { - testClientConnClose(t, closeAtBody) + if _, err := io.Copy(io.Discard, rt.response().Body); err == nil { + t.Error("expected a Copy error, got nil") + } } // The client sends a GOAWAY frame before the server finished processing a request. // We expect the connection not to close until the request is completed. -func TestClientConnShutdown(t *testing.T) { - testClientConnClose(t, shutdown) +func TestClientConnShutdown(t *testing.T) { synctestTest(t, testClientConnShutdown) } +func testClientConnShutdown(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + go tc.cc.Shutdown(context.Background()) + synctest.Wait() + + tc.wantFrameType(FrameGoAway) + tc.wantIdle() // connection is not closed + body := []byte("body") + tc.writeHeaders(HeadersFrameParam{ + StreamID: rt.streamID(), + EndHeaders: true, + EndStream: false, + BlockFragment: tc.makeHeaderBlockFragment( + ":status", "200", + ), + }) + tc.writeData(rt.streamID(), true, body) + + rt.wantStatus(200) + rt.wantBody(body) + + // Now that the client has received the response, it closes the connection. + tc.wantClosed() } // The client sends a GOAWAY frame before the server finishes processing a request, // but cancels the passed context before the request is completed. The expected // behavior is the client closing the connection after the context is canceled. -func TestClientConnShutdownCancel(t *testing.T) { - testClientConnClose(t, shutdownCancel) +func TestClientConnShutdownCancel(t *testing.T) { synctestTest(t, testClientConnShutdownCancel) } +func testClientConnShutdownCancel(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + rt := tc.roundTrip(req) + tc.wantFrameType(FrameHeaders) + + ctx, cancel := context.WithCancel(t.Context()) + var shutdownErr error + go func() { + shutdownErr = tc.cc.Shutdown(ctx) + }() + synctest.Wait() + + tc.wantFrameType(FrameGoAway) + tc.wantIdle() // connection is not closed + + cancel() + synctest.Wait() + + if shutdownErr != context.Canceled { + t.Fatalf("ClientConn.Shutdown(ctx) did not return context.Canceled after cancelling context") + } + + // The documentation for this test states: + // The expected behavior is the client closing the connection + // after the context is canceled. + // + // This seems reasonable, but it isn't what we do. + // When ClientConn.Shutdown's context is canceled, Shutdown returns but + // the connection is not closed. + // + // TODO: Figure out the correct behavior. + if rt.done() { + t.Fatal("RoundTrip unexpectedly returned during shutdown") + } } // Issue 25009: use Request.GetBody if present, even if it seems like @@ -4117,6 +4055,11 @@ func (r *errReader) Read(p []byte) (int, error) { } func testTransportBodyReadError(t *testing.T, body []byte) { + synctestTest(t, func(t testing.TB) { + testTransportBodyReadErrorBubble(t, body) + }) +} +func testTransportBodyReadErrorBubble(t testing.TB, body []byte) { tc := newTestClientConn(t) tc.greet() @@ -4149,10 +4092,6 @@ readFrames: if err := rt.err(); err != bodyReadError { t.Fatalf("err = %v; want %v", err, bodyReadError) } - - if got := activeStreams(tc.cc); got != 0 { - t.Fatalf("active streams count: %v; want 0", got) - } } func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } @@ -4161,7 +4100,8 @@ func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyRea // Issue 32254: verify that the client sends END_STREAM flag eagerly with the last // (or in this test-case the only one) request body data frame, and does not send // extra zero-len data frames. -func TestTransportBodyEagerEndStream(t *testing.T) { +func TestTransportBodyEagerEndStream(t *testing.T) { synctestTest(t, testTransportBodyEagerEndStream) } +func testTransportBodyEagerEndStream(t testing.TB) { const reqBody = "some request body" const resBody = "some response body" @@ -4205,17 +4145,21 @@ func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) { []byte("123"), []byte("456"), }} - testTransportBodyLargerThanSpecifiedContentLength(t, body, 3) + synctestTest(t, func(t testing.TB) { + testTransportBodyLargerThanSpecifiedContentLength(t, body, 3) + }) } func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) { body := &chunkReader{[][]byte{ []byte("123"), }} - testTransportBodyLargerThanSpecifiedContentLength(t, body, 2) + synctestTest(t, func(t testing.TB) { + testTransportBodyLargerThanSpecifiedContentLength(t, body, 2) + }) } -func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) { +func testTransportBodyLargerThanSpecifiedContentLength(t testing.TB, body *chunkReader, contentLen int64) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { r.Body.Read(make([]byte, 6)) }) @@ -4299,35 +4243,28 @@ func TestTransportNewClientConnCloseOnWriteError(t *testing.T) { } func TestTransportRoundtripCloseOnWriteError(t *testing.T) { - req, err := http.NewRequest("GET", "https://dummy.tld/", nil) - if err != nil { - t.Fatal(err) - } - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {}) + synctestTest(t, testTransportRoundtripCloseOnWriteError) +} +func testTransportRoundtripCloseOnWriteError(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() - ctx := context.Background() - cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false) - if err != nil { - t.Fatal(err) - } + body := tc.newRequestBody() + body.writeBytes(1) + req, _ := http.NewRequest("GET", "https://dummy.tld/", body) + rt := tc.roundTrip(req) writeErr := errors.New("write error") - cc.wmu.Lock() - cc.werr = writeErr - cc.wmu.Unlock() + tc.closeWriteWithError(writeErr) - _, err = cc.RoundTrip(req) - if err != writeErr { - t.Fatalf("expected %v, got %v", writeErr, err) + body.writeBytes(1) + if err := rt.err(); err != writeErr { + t.Fatalf("RoundTrip error %v, want %v", err, writeErr) } - cc.mu.Lock() - closed := cc.closed - cc.mu.Unlock() - if !closed { - t.Fatal("expected closed") + rt2 := tc.roundTrip(req) + if err := rt2.err(); err != errClientConnUnusable { + t.Fatalf("RoundTrip error %v, want errClientConnUnusable", err) } } @@ -4818,6 +4755,9 @@ func TestTransportCloseRequestBody(t *testing.T) { } func TestTransportRetriesOnStreamProtocolError(t *testing.T) { + synctestTest(t, testTransportRetriesOnStreamProtocolError) +} +func testTransportRetriesOnStreamProtocolError(t testing.TB) { // This test verifies that // - receiving a protocol error on a connection does not interfere with // other requests in flight on that connection; @@ -4893,7 +4833,8 @@ func TestTransportRetriesOnStreamProtocolError(t *testing.T) { rt1.wantStatus(200) } -func TestClientConnReservations(t *testing.T) { +func TestClientConnReservations(t *testing.T) { synctestTest(t, testClientConnReservations) } +func testClientConnReservations(t testing.TB) { tc := newTestClientConn(t) tc.greet( Setting{ID: SettingMaxConcurrentStreams, Val: initialMaxConcurrentStreams}, @@ -4944,7 +4885,8 @@ func TestClientConnReservations(t *testing.T) { } } -func TestTransportTimeoutServerHangs(t *testing.T) { +func TestTransportTimeoutServerHangs(t *testing.T) { synctestTest(t, testTransportTimeoutServerHangs) } +func testTransportTimeoutServerHangs(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -4953,7 +4895,7 @@ func TestTransportTimeoutServerHangs(t *testing.T) { rt := tc.roundTrip(req) tc.wantFrameType(FrameHeaders) - tc.advance(5 * time.Second) + time.Sleep(5 * time.Second) if f := tc.readFrame(); f != nil { t.Fatalf("unexpected frame: %v", f) } @@ -4962,20 +4904,13 @@ func TestTransportTimeoutServerHangs(t *testing.T) { } cancel() - tc.sync() + synctest.Wait() if rt.err() != context.Canceled { t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err()) } } func TestTransportContentLengthWithoutBody(t *testing.T) { - contentLength := "" - ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Length", contentLength) - }) - tr := &Transport{TLSClientConfig: tlsConfigInsecure} - defer tr.CloseIdleConnections() - for _, test := range []struct { name string contentLength string @@ -4996,7 +4931,14 @@ func TestTransportContentLengthWithoutBody(t *testing.T) { wantContentLength: 0, }, } { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t testing.TB) { + contentLength := "" + ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", contentLength) + }) + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + contentLength = test.contentLength req, _ := http.NewRequest("GET", ts.URL, nil) @@ -5021,6 +4963,9 @@ func TestTransportContentLengthWithoutBody(t *testing.T) { } func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) { + synctestTest(t, testTransportCloseResponseBodyWhileRequestBodyHangs) +} +func testTransportCloseResponseBodyWhileRequestBodyHangs(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.(http.Flusher).Flush() @@ -5044,7 +4989,8 @@ func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) { pw.Close() } -func TestTransport300ResponseBody(t *testing.T) { +func TestTransport300ResponseBody(t *testing.T) { synctestTest(t, testTransport300ResponseBody) } +func testTransport300ResponseBody(t testing.TB) { reqc := make(chan struct{}) body := []byte("response body") ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { @@ -5120,7 +5066,8 @@ func (c *slowWriteConn) Write(b []byte) (n int, err error) { return c.Conn.Write(b) } -func TestTransportSlowWrites(t *testing.T) { +func TestTransportSlowWrites(t *testing.T) { synctestTest(t, testTransportSlowWrites) } +func testTransportSlowWrites(t testing.TB) { ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {}, ) @@ -5145,10 +5092,14 @@ func TestTransportSlowWrites(t *testing.T) { } func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) { - testTransportClosesConnAfterGoAway(t, 0) + synctestTest(t, func(t testing.TB) { + testTransportClosesConnAfterGoAway(t, 0) + }) } func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { - testTransportClosesConnAfterGoAway(t, 1) + synctestTest(t, func(t testing.TB) { + testTransportClosesConnAfterGoAway(t, 1) + }) } // testTransportClosesConnAfterGoAway verifies that the transport @@ -5157,7 +5108,7 @@ func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) { // lastStream is the last stream ID in the GOAWAY frame. // When 0, the transport (unsuccessfully) retries the request (stream 1); // when 1, the transport reads the response after receiving the GOAWAY. -func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) { +func testTransportClosesConnAfterGoAway(t testing.TB, lastStream uint32) { tc := newTestClientConn(t) tc.greet() @@ -5384,7 +5335,8 @@ func TestDialRaceResumesDial(t *testing.T) { } } -func TestTransportDataAfter1xxHeader(t *testing.T) { +func TestTransportDataAfter1xxHeader(t *testing.T) { synctestTest(t, testTransportDataAfter1xxHeader) } +func testTransportDataAfter1xxHeader(t testing.TB) { // Discard logger output to avoid spamming stderr. log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) @@ -5514,7 +5466,7 @@ func TestTransport1xxLimits(t *testing.T) { hcount: 20, limited: false, }} { - t.Run(test.name, func(t *testing.T) { + synctestSubtest(t, test.name, func(t testing.TB) { tc := newTestClientConn(t, test.opt) tc.greet() @@ -5549,7 +5501,8 @@ func TestTransport1xxLimits(t *testing.T) { } } -func TestTransportSendPingWithReset(t *testing.T) { +func TestTransportSendPingWithReset(t *testing.T) { synctestTest(t, testTransportSendPingWithReset) } +func testTransportSendPingWithReset(t testing.TB) { tc := newTestClientConn(t, func(tr *Transport) { tr.StrictMaxConcurrentStreams = true }) @@ -5609,6 +5562,9 @@ func TestTransportSendPingWithReset(t *testing.T) { // Issue #70505: gRPC gets upset if we send more than 2 pings per HEADERS/DATA frame // sent by the server. func TestTransportSendNoMoreThanOnePingWithReset(t *testing.T) { + synctestTest(t, testTransportSendNoMoreThanOnePingWithReset) +} +func testTransportSendNoMoreThanOnePingWithReset(t testing.TB) { tc := newTestClientConn(t) tc.greet() @@ -5674,6 +5630,9 @@ func TestTransportSendNoMoreThanOnePingWithReset(t *testing.T) { } func TestTransportConnBecomesUnresponsive(t *testing.T) { + synctestTest(t, testTransportConnBecomesUnresponsive) +} +func testTransportConnBecomesUnresponsive(t testing.TB) { // We send a number of requests in series to an unresponsive connection. // Each request is canceled or times out without a response. // Eventually, we open a new connection rather than trying to use the old one. @@ -5744,19 +5703,19 @@ func TestTransportConnBecomesUnresponsive(t *testing.T) { } // Test that the Transport can use a conn provided to it by a TLSNextProto hook. -func TestTransportTLSNextProtoConnOK(t *testing.T) { +func TestTransportTLSNextProtoConnOK(t *testing.T) { synctestTest(t, testTransportTLSNextProtoConnOK) } +func testTransportTLSNextProtoConnOK(t testing.TB) { t1 := &http.Transport{} t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe(tt.group) + cli, _ := synctestNetPipe() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { - tt.group.Join() t1.TLSNextProto["h2"]("dummy.tld", cliTLS) }() - tt.sync() + synctest.Wait() tc := tt.getConn() tc.greet() @@ -5787,18 +5746,20 @@ func TestTransportTLSNextProtoConnOK(t *testing.T) { // Test the case where a conn provided via a TLSNextProto hook immediately encounters an error. func TestTransportTLSNextProtoConnImmediateFailureUsed(t *testing.T) { + synctestTest(t, testTransportTLSNextProtoConnImmediateFailureUsed) +} +func testTransportTLSNextProtoConnImmediateFailureUsed(t testing.TB) { t1 := &http.Transport{} t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe(tt.group) + cli, _ := synctestNetPipe() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { - tt.group.Join() t1.TLSNextProto["h2"]("dummy.tld", cliTLS) }() - tt.sync() + synctest.Wait() tc := tt.getConn() // The connection encounters an error before we send a request that uses it. @@ -5825,6 +5786,9 @@ func TestTransportTLSNextProtoConnImmediateFailureUsed(t *testing.T) { // Test the case where a conn provided via a TLSNextProto hook is closed for idleness // before we use it. func TestTransportTLSNextProtoConnIdleTimoutBeforeUse(t *testing.T) { + synctestTest(t, testTransportTLSNextProtoConnIdleTimoutBeforeUse) +} +func testTransportTLSNextProtoConnIdleTimoutBeforeUse(t testing.TB) { t1 := &http.Transport{ IdleConnTimeout: 1 * time.Second, } @@ -5832,17 +5796,17 @@ func TestTransportTLSNextProtoConnIdleTimoutBeforeUse(t *testing.T) { tt := newTestTransport(t, t2) // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe(tt.group) + cli, _ := synctestNetPipe() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { - tt.group.Join() t1.TLSNextProto["h2"]("dummy.tld", cliTLS) }() - tt.sync() - tc := tt.getConn() + synctest.Wait() + _ = tt.getConn() // The connection encounters an error before we send a request that uses it. - tc.advance(2 * time.Second) + time.Sleep(2 * time.Second) + synctest.Wait() // Send a request on the Transport. // @@ -5857,18 +5821,20 @@ func TestTransportTLSNextProtoConnIdleTimoutBeforeUse(t *testing.T) { // Test the case where a conn provided via a TLSNextProto hook immediately encounters an error, // but no requests are sent which would use the bad connection. func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) { + synctestTest(t, testTransportTLSNextProtoConnImmediateFailureUnused) +} +func testTransportTLSNextProtoConnImmediateFailureUnused(t testing.TB) { t1 := &http.Transport{} t2, _ := ConfigureTransports(t1) tt := newTestTransport(t, t2) // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook. - cli, _ := synctestNetPipe(tt.group) + cli, _ := synctestNetPipe() cliTLS := tls.Client(cli, tlsConfigInsecure) go func() { - tt.group.Join() t1.TLSNextProto["h2"]("dummy.tld", cliTLS) }() - tt.sync() + synctest.Wait() tc := tt.getConn() // The connection encounters an error before we send a request that uses it. @@ -5876,7 +5842,7 @@ func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) { // Some time passes. // The dead connection is removed from the pool. - tc.advance(10 * time.Second) + time.Sleep(10 * time.Second) // Send a request on the Transport. // @@ -5959,6 +5925,9 @@ func TestExtendedConnectClientWithoutServerSupport(t *testing.T) { // Issue #70658: Make sure extended CONNECT requests don't get stuck if a // connection fails early in its lifetime. func TestExtendedConnectReadFrameError(t *testing.T) { + synctestTest(t, testExtendedConnectReadFrameError) +} +func testExtendedConnectReadFrameError(t testing.TB) { tc := newTestClientConn(t) tc.wantFrameType(FrameSettings) tc.wantFrameType(FrameWindowUpdate)