Skip to content

Commit 541dbe5

Browse files
committed
http2: add Server.WriteByteTimeout
Transports support a WriteByteTimeout option which sets the maximum amount of time we can go without being able to write any bytes to a connection. Add an equivalent option to Server for consistency. Fixes golang/go#61777 Change-Id: Iaa8a69dfc403906eb224829320f901e5a6a5c429 Reviewed-on: https://go-review.googlesource.com/c/net/+/601496 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Carlos Amedee <carlos@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
1 parent 3c333c0 commit 541dbe5

File tree

5 files changed

+98
-28
lines changed

5 files changed

+98
-28
lines changed

http2/connframes_test.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package http2
66

77
import (
88
"bytes"
9-
"context"
109
"io"
1110
"net/http"
1211
"os"
@@ -295,7 +294,7 @@ func (tf *testConnFramer) wantClosed() {
295294
if err == nil {
296295
tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
297296
}
298-
if err == context.DeadlineExceeded {
297+
if err == os.ErrDeadlineExceeded {
299298
tf.t.Fatalf("connection is not closed; want it to be")
300299
}
301300
}
@@ -306,7 +305,7 @@ func (tf *testConnFramer) wantIdle() {
306305
if err == nil {
307306
tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
308307
}
309-
if err != context.DeadlineExceeded {
308+
if err != os.ErrDeadlineExceeded {
310309
tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
311310
}
312311
}

http2/http2.go

+46-7
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ import (
1919
"bufio"
2020
"context"
2121
"crypto/tls"
22+
"errors"
2223
"fmt"
23-
"io"
24+
"net"
2425
"net/http"
2526
"os"
2627
"sort"
@@ -237,13 +238,19 @@ func (cw closeWaiter) Wait() {
237238
// Its buffered writer is lazily allocated as needed, to minimize
238239
// idle memory usage with many connections.
239240
type bufferedWriter struct {
240-
_ incomparable
241-
w io.Writer // immutable
242-
bw *bufio.Writer // non-nil when data is buffered
241+
_ incomparable
242+
group synctestGroupInterface // immutable
243+
conn net.Conn // immutable
244+
bw *bufio.Writer // non-nil when data is buffered
245+
byteTimeout time.Duration // immutable, WriteByteTimeout
243246
}
244247

245-
func newBufferedWriter(w io.Writer) *bufferedWriter {
246-
return &bufferedWriter{w: w}
248+
func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter {
249+
return &bufferedWriter{
250+
group: group,
251+
conn: conn,
252+
byteTimeout: timeout,
253+
}
247254
}
248255

249256
// bufWriterPoolBufferSize is the size of bufio.Writer's
@@ -270,7 +277,7 @@ func (w *bufferedWriter) Available() int {
270277
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
271278
if w.bw == nil {
272279
bw := bufWriterPool.Get().(*bufio.Writer)
273-
bw.Reset(w.w)
280+
bw.Reset((*bufferedWriterTimeoutWriter)(w))
274281
w.bw = bw
275282
}
276283
return w.bw.Write(p)
@@ -288,6 +295,38 @@ func (w *bufferedWriter) Flush() error {
288295
return err
289296
}
290297

298+
type bufferedWriterTimeoutWriter bufferedWriter
299+
300+
func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
301+
return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p)
302+
}
303+
304+
// writeWithByteTimeout writes to conn.
305+
// If more than timeout passes without any bytes being written to the connection,
306+
// the write fails.
307+
func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
308+
if timeout <= 0 {
309+
return conn.Write(p)
310+
}
311+
for {
312+
var now time.Time
313+
if group == nil {
314+
now = time.Now()
315+
} else {
316+
now = group.Now()
317+
}
318+
conn.SetWriteDeadline(now.Add(timeout))
319+
nn, err := conn.Write(p[n:])
320+
n += nn
321+
if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
322+
// Either we finished the write, made no progress, or hit the deadline.
323+
// Whichever it is, we're done now.
324+
conn.SetWriteDeadline(time.Time{})
325+
return n, err
326+
}
327+
}
328+
}
329+
291330
func mustUint31(v int32) uint32 {
292331
if v < 0 || v > 2147483647 {
293332
panic("out of range")

http2/server.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ type Server struct {
127127
// If zero or negative, there is no timeout.
128128
IdleTimeout time.Duration
129129

130+
// WriteByteTimeout is the timeout after which a connection will be
131+
// closed if no data can be written to it. The timeout begins when data is
132+
// available to write, and is extended whenever any bytes are written.
133+
// If zero or negative, there is no timeout.
134+
WriteByteTimeout time.Duration
135+
130136
// MaxUploadBufferPerConnection is the size of the initial flow
131137
// control window for each connections. The HTTP/2 spec does not
132138
// allow this to be smaller than 65535 or larger than 2^32-1.
@@ -446,7 +452,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
446452
conn: c,
447453
baseCtx: baseCtx,
448454
remoteAddrStr: c.RemoteAddr().String(),
449-
bw: newBufferedWriter(c),
455+
bw: newBufferedWriter(s.group, c, s.WriteByteTimeout),
450456
handler: opts.handler(),
451457
streams: make(map[uint32]*stream),
452458
readFrameCh: make(chan readFrameResult),
@@ -1320,6 +1326,10 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
13201326
sc.writingFrame = false
13211327
sc.writingFrameAsync = false
13221328

1329+
if res.err != nil {
1330+
sc.conn.Close()
1331+
}
1332+
13231333
wr := res.wr
13241334

13251335
if writeEndsStream(wr.write) {

http2/server_test.go

+32
Original file line numberDiff line numberDiff line change
@@ -4674,3 +4674,35 @@ func TestServerSetReadWriteDeadlineRace(t *testing.T) {
46744674
}
46754675
resp.Body.Close()
46764676
}
4677+
4678+
func TestServerWriteByteTimeout(t *testing.T) {
4679+
const timeout = 1 * time.Second
4680+
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4681+
w.Write(make([]byte, 100))
4682+
}, func(s *Server) {
4683+
s.WriteByteTimeout = timeout
4684+
})
4685+
st.greet()
4686+
4687+
st.cc.(*synctestNetConn).SetReadBufferSize(1) // write one byte at a time
4688+
st.writeHeaders(HeadersFrameParam{
4689+
StreamID: 1,
4690+
BlockFragment: st.encodeHeader(),
4691+
EndStream: true,
4692+
EndHeaders: true,
4693+
})
4694+
4695+
// Read a few bytes, staying just under WriteByteTimeout.
4696+
for i := 0; i < 10; i++ {
4697+
st.advance(timeout - 1)
4698+
if n, err := st.cc.Read(make([]byte, 1)); n != 1 || err != nil {
4699+
t.Fatalf("read %v: %v, %v; want 1, nil", i, n, err)
4700+
}
4701+
}
4702+
4703+
// Wait for WriteByteTimeout.
4704+
// The connection should close.
4705+
st.advance(1 * time.Second) // timeout after writing one byte
4706+
st.advance(1 * time.Second) // timeout after failing to write any more bytes
4707+
st.wantClosed()
4708+
}

http2/transport.go

+7-17
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import (
2525
"net/http"
2626
"net/http/httptrace"
2727
"net/textproto"
28-
"os"
2928
"sort"
3029
"strconv"
3130
"strings"
@@ -499,6 +498,7 @@ func (cs *clientStream) closeReqBodyLocked() {
499498
}
500499

501500
type stickyErrWriter struct {
501+
group synctestGroupInterface
502502
conn net.Conn
503503
timeout time.Duration
504504
err *error
@@ -508,22 +508,9 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
508508
if *sew.err != nil {
509509
return 0, *sew.err
510510
}
511-
for {
512-
if sew.timeout != 0 {
513-
sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
514-
}
515-
nn, err := sew.conn.Write(p[n:])
516-
n += nn
517-
if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
518-
// Keep extending the deadline so long as we're making progress.
519-
continue
520-
}
521-
if sew.timeout != 0 {
522-
sew.conn.SetWriteDeadline(time.Time{})
523-
}
524-
*sew.err = err
525-
return n, err
526-
}
511+
n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p)
512+
*sew.err = err
513+
return n, err
527514
}
528515

529516
// noCachedConnError is the concrete type of ErrNoCachedConn, which
@@ -792,10 +779,12 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
792779
pings: make(map[[8]byte]chan struct{}),
793780
reqHeaderMu: make(chan struct{}, 1),
794781
}
782+
var group synctestGroupInterface
795783
if t.transportTestHooks != nil {
796784
t.markNewGoroutine()
797785
t.transportTestHooks.newclientconn(cc)
798786
c = cc.tconn
787+
group = t.group
799788
}
800789
if VerboseLogs {
801790
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -807,6 +796,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
807796
// TODO: adjust this writer size to account for frame size +
808797
// MTU + crypto/tls record padding.
809798
cc.bw = bufio.NewWriter(stickyErrWriter{
799+
group: group,
810800
conn: c,
811801
timeout: t.WriteByteTimeout,
812802
err: &cc.werr,

0 commit comments

Comments
 (0)