From e8f5bc8900a1c929f71a1c892451e19f82f9f6ee Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 21 Feb 2020 01:41:40 -0500 Subject: [PATCH 01/20] Add Example_crossOrigin Closes #194 --- example_test.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/example_test.go b/example_test.go index 075107b0..666914d2 100644 --- a/example_test.go +++ b/example_test.go @@ -6,6 +6,7 @@ import ( "context" "log" "net/http" + "net/url" "time" "nhooyr.io/websocket" @@ -115,3 +116,30 @@ func Example_writeOnly() { err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } + +// This example demonstrates how to safely accept cross origin WebSockets +// from the origin example.com. +func Example_crossOrigin() { + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin != "" { + u, err := url.Parse(origin) + if err != nil || u.Host != "example.com" { + http.Error(w, "bad origin header", http.StatusForbidden) + return + } + } + + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + InsecureSkipVerify: true, + }) + if err != nil { + log.Println(err) + return + } + c.Close(websocket.StatusNormalClosure, "cross origin WebSocket accepted") + }) + + err := http.ListenAndServe("localhost:8080", fn) + log.Fatal(err) +} From 500b9d734508a2c8ab09457ac9895b895bc86470 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Tue, 25 Feb 2020 22:20:19 -0500 Subject: [PATCH 02/20] Add OriginPatterns to AcceptOptions Closes #194 --- accept.go | 73 +++++++++++++++++++++++++++++++------------------ accept_test.go | 31 +++++++++++++++++---- example_test.go | 12 +------- 3 files changed, 74 insertions(+), 42 deletions(-) diff --git a/accept.go b/accept.go index 479138fc..47e20b52 100644 --- a/accept.go +++ b/accept.go @@ -9,10 +9,11 @@ import ( "errors" "fmt" "io" + "log" "net/http" "net/textproto" "net/url" - "strconv" + "path/filepath" "strings" "nhooyr.io/websocket/internal/errd" @@ -25,18 +26,27 @@ type AcceptOptions struct { // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string - // InsecureSkipVerify disables Accept's origin verification behaviour. By default, - // the connection will only be accepted if the request origin is equal to the request - // host. + // InsecureSkipVerify is used to disable Accept's origin verification behaviour. // - // This is only required if you want javascript served from a different domain - // to access your WebSocket server. + // Deprecated: Use OriginPatterns with a match all pattern of * instead to control + // origin authorization yourself. + InsecureSkipVerify bool + + // OriginPatterns lists the host patterns for authorized origins. + // The request host is always authorized. + // Use this to enable cross origin WebSockets. + // + // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. + // In such a case, example.com is the origin and chat.example.com is the request host. + // One would set this field to []string{"example.com"} to authorize example.com to connect. // - // See https://stackoverflow.com/a/37837709/4283659 + // Each pattern is matched case insensitively against the request origin host + // with filepath.Match. + // See https://golang.org/pkg/path/filepath/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. - InsecureSkipVerify bool + OriginPatterns []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. @@ -77,8 +87,12 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } if !opts.InsecureSkipVerify { - err = authenticateOrigin(r) + err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { + if errors.Is(err, filepath.ErrBadPattern) { + log.Printf("websocket: %v", err) + err = errors.New(http.StatusText(http.StatusForbidden)) + } http.Error(w, err.Error(), http.StatusForbidden) return nil, err } @@ -165,18 +179,35 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return 0, nil } -func authenticateOrigin(r *http.Request) error { +func authenticateOrigin(r *http.Request, originHosts []string) error { origin := r.Header.Get("Origin") - if origin != "" { - u, err := url.Parse(origin) + if origin == "" { + return nil + } + + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + + if strings.EqualFold(r.Host, u.Host) { + return nil + } + + for _, hostPattern := range originHosts { + matched, err := match(hostPattern, u.Host) if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + if matched { + return nil } } - return nil + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) +} + +func match(pattern, s string) (bool, error) { + return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { @@ -235,16 +266,6 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi return copts, nil } -// parseExtensionParameter parses the value in the extension parameter p. -func parseExtensionParameter(p string) (int, bool) { - ps := strings.Split(p, "=") - if len(ps) == 1 { - return 0, false - } - i, e := strconv.Atoi(strings.Trim(ps[1], `"`)) - return i, e == nil -} - func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() // The peer must explicitly request it. diff --git a/accept_test.go b/accept_test.go index 49667799..40a7b40c 100644 --- a/accept_test.go +++ b/accept_test.go @@ -244,10 +244,11 @@ func Test_authenticateOrigin(t *testing.T) { t.Parallel() testCases := []struct { - name string - origin string - host string - success bool + name string + origin string + host string + originPatterns []string + success bool }{ { name: "none", @@ -278,6 +279,26 @@ func Test_authenticateOrigin(t *testing.T) { host: "example.com", success: true, }, + { + name: "originPatterns", + origin: "https://two.examplE.com", + host: "example.com", + originPatterns: []string{ + "*.example.com", + "bar.com", + }, + success: true, + }, + { + name: "originPatternsUnauthorized", + origin: "https://two.examplE.com", + host: "example.com", + originPatterns: []string{ + "exam3.com", + "bar.com", + }, + success: false, + }, } for _, tc := range testCases { @@ -288,7 +309,7 @@ func Test_authenticateOrigin(t *testing.T) { r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r.Header.Set("Origin", tc.origin) - err := authenticateOrigin(r) + err := authenticateOrigin(r, tc.originPatterns) if tc.success { assert.Success(t, err) } else { diff --git a/example_test.go b/example_test.go index 666914d2..c56e53f3 100644 --- a/example_test.go +++ b/example_test.go @@ -6,7 +6,6 @@ import ( "context" "log" "net/http" - "net/url" "time" "nhooyr.io/websocket" @@ -121,17 +120,8 @@ func Example_writeOnly() { // from the origin example.com. func Example_crossOrigin() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") - if origin != "" { - u, err := url.Parse(origin) - if err != nil || u.Host != "example.com" { - http.Error(w, "bad origin header", http.StatusForbidden) - return - } - } - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - InsecureSkipVerify: true, + OriginPatterns: []string{"example.com"}, }) if err != nil { log.Println(err) From 97345d8042c26a1e56f8e3cc5b494230677e4ab0 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 26 Feb 2020 00:29:13 -0500 Subject: [PATCH 03/20] Simplify wstest.Pipe --- conn_test.go | 6 ++++-- internal/test/wstest/pipe.go | 26 ++++---------------------- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/conn_test.go b/conn_test.go index 535afe24..28da3c07 100644 --- a/conn_test.go +++ b/conn_test.go @@ -337,8 +337,10 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs tt = &connTest{t: t, ctx: ctx} tt.appendDone(cancel) - c1, c2, err := wstest.Pipe(dialOpts, acceptOpts) - assert.Success(tt.t, err) + c1, c2 = wstest.Pipe(dialOpts, acceptOpts) + if xrand.Bool() { + c1, c2 = c2, c1 + } tt.appendDone(func() { c2.Close(websocket.StatusInternalError, "") c1.Close(websocket.StatusInternalError, "") diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index 0a2899ee..1534f316 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -5,26 +5,19 @@ package wstest import ( "bufio" "context" - "fmt" "net" "net/http" "net/http/httptest" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/errd" - "nhooyr.io/websocket/internal/test/xrand" ) // Pipe is used to create an in memory connection // between two websockets analogous to net.Pipe. -func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (_ *websocket.Conn, _ *websocket.Conn, err error) { - defer errd.Wrap(&err, "failed to create ws pipe") - - var serverConn *websocket.Conn - var acceptErr error +func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (clientConn, serverConn *websocket.Conn) { tt := fakeTransport{ h: func(w http.ResponseWriter, r *http.Request) { - serverConn, acceptErr = websocket.Accept(w, r, acceptOpts) + serverConn, _ = websocket.Accept(w, r, acceptOpts) }, } @@ -36,19 +29,8 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) Transport: tt, } - clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) - if err != nil { - return nil, nil, fmt.Errorf("failed to dial with fake transport: %w", err) - } - - if serverConn == nil { - return nil, nil, fmt.Errorf("failed to get server conn from fake transport: %w", acceptErr) - } - - if xrand.Bool() { - return serverConn, clientConn, nil - } - return clientConn, serverConn, nil + clientConn, _, _ = websocket.Dial(context.Background(), "ws://example.com", dialOpts) + return clientConn, serverConn } type fakeTransport struct { From deb14cfd901a252aa2ac0395db8d5c553d271248 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 27 Feb 2020 16:09:06 -0500 Subject: [PATCH 04/20] Make sure to release lock when acquiring and connection is closed. Closes #205 --- conn_notjs.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/conn_notjs.go b/conn_notjs.go index 2ec5f5bf..bb2eb22f 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -245,9 +245,11 @@ func (m *mu) lock(ctx context.Context) error { case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected - // the receive on closed. + // over the receive on closed. select { case <-m.c.closed: + // Make sure to release. + m.unlock() return m.c.closeErr default: } From 97172f3339a9bef16fa82fde84b4b0c7a1357e56 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Tue, 25 Feb 2020 23:59:57 -0500 Subject: [PATCH 05/20] Add Grace to gracefully close WebSocket connections Closes #199 --- accept.go | 20 ++++++- conn_notjs.go | 5 ++ conn_test.go | 12 ++--- example_echo_test.go | 6 ++- example_test.go | 46 ++++++++++++++++ grace.go | 123 +++++++++++++++++++++++++++++++++++++++++++ ws_js.go | 2 + 7 files changed, 202 insertions(+), 12 deletions(-) create mode 100644 grace.go diff --git a/accept.go b/accept.go index 47e20b52..52a93459 100644 --- a/accept.go +++ b/accept.go @@ -75,6 +75,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") + g := graceFromRequest(r) + if g != nil && g.isClosing() { + err := errors.New("server closing") + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return nil, err + } + if opts == nil { opts = &AcceptOptions{} } @@ -134,7 +141,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - return newConn(connConfig{ + c := newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, @@ -143,7 +150,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con br: brw.Reader, bw: brw.Writer, - }), nil + }) + + if g != nil { + err = g.addConn(c) + if err != nil { + return nil, err + } + } + + return c, nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { diff --git a/conn_notjs.go b/conn_notjs.go index bb2eb22f..f604898e 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -33,6 +33,7 @@ type Conn struct { flateThreshold int br *bufio.Reader bw *bufio.Writer + g *Grace readTimeout chan context.Context writeTimeout chan context.Context @@ -138,6 +139,10 @@ func (c *Conn) close(err error) { // closeErr. c.rwc.Close() + if c.g != nil { + c.g.delConn(c) + } + go func() { c.msgWriterState.close() diff --git a/conn_test.go b/conn_test.go index 28da3c07..af4fa4c0 100644 --- a/conn_test.go +++ b/conn_test.go @@ -13,7 +13,6 @@ import ( "os" "os/exec" "strings" - "sync" "testing" "time" @@ -272,11 +271,9 @@ func TestWasm(t *testing.T) { t.Skip("skipping on CI") } - var wg sync.WaitGroup - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wg.Add(1) - defer wg.Done() - + var g websocket.Grace + defer g.Close() + s := httptest.NewServer(g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, @@ -294,8 +291,7 @@ func TestWasm(t *testing.T) { t.Errorf("echo server failed: %v", err) return } - })) - defer wg.Wait() + }))) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) diff --git a/example_echo_test.go b/example_echo_test.go index cd195d2e..0c0b84ea 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -31,13 +31,15 @@ func Example_echo() { } defer l.Close() + var g websocket.Grace + defer g.Close() s := &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Handler: g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r) if err != nil { log.Printf("echo server: %v", err) } - }), + })), ReadTimeout: time.Second * 15, WriteTimeout: time.Second * 15, } diff --git a/example_test.go b/example_test.go index c56e53f3..ce049bc3 100644 --- a/example_test.go +++ b/example_test.go @@ -6,6 +6,8 @@ import ( "context" "log" "net/http" + "os" + "os/signal" "time" "nhooyr.io/websocket" @@ -133,3 +135,47 @@ func Example_crossOrigin() { err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } + +// This example demonstrates how to create a WebSocket server +// that gracefully exits when sent a signal. +// +// It starts a WebSocket server that keeps every connection open +// for 10 seconds. +// If you CTRL+C while a connection is open, it will wait at most 30s +// for all connections to terminate before shutting down. +func ExampleGrace() { + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + if err != nil { + log.Println(err) + return + } + defer c.Close(websocket.StatusInternalError, "the sky is falling") + + ctx := c.CloseRead(r.Context()) + select { + case <-ctx.Done(): + case <-time.After(time.Second * 10): + } + + c.Close(websocket.StatusNormalClosure, "") + }) + + var g websocket.Grace + s := &http.Server{ + Handler: g.Handler(fn), + ReadTimeout: time.Second * 15, + WriteTimeout: time.Second * 15, + } + go s.ListenAndServe() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + sig := <-sigs + log.Printf("recieved %v, shutting down", sig) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + s.Shutdown(ctx) + g.Shutdown(ctx) +} diff --git a/grace.go b/grace.go new file mode 100644 index 00000000..8dadc43d --- /dev/null +++ b/grace.go @@ -0,0 +1,123 @@ +package websocket + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + "time" +) + +// Grace enables graceful shutdown of accepted WebSocket connections. +// +// Use Handler to wrap WebSocket handlers to record accepted connections +// and then use Close or Shutdown to gracefully close these connections. +// +// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods. +type Grace struct { + mu sync.Mutex + closing bool + conns map[*Conn]struct{} +} + +// Handler returns a handler that wraps around h to record +// all WebSocket connections accepted. +// +// Use Close or Shutdown to gracefully close recorded connections. +func (g *Grace) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), gracefulContextKey{}, g) + r = r.WithContext(ctx) + h.ServeHTTP(w, r) + }) +} + +func (g *Grace) isClosing() bool { + g.mu.Lock() + defer g.mu.Unlock() + return g.closing +} + +func graceFromRequest(r *http.Request) *Grace { + g, _ := r.Context().Value(gracefulContextKey{}).(*Grace) + return g +} + +func (g *Grace) addConn(c *Conn) error { + g.mu.Lock() + defer g.mu.Unlock() + if g.closing { + c.Close(StatusGoingAway, "server shutting down") + return errors.New("server shutting down") + } + if g.conns == nil { + g.conns = make(map[*Conn]struct{}) + } + g.conns[c] = struct{}{} + c.g = g + return nil +} + +func (g *Grace) delConn(c *Conn) { + g.mu.Lock() + defer g.mu.Unlock() + delete(g.conns, c) +} + +type gracefulContextKey struct{} + +// Close prevents the acceptance of new connections with +// http.StatusServiceUnavailable and closes all accepted +// connections with StatusGoingAway. +func (g *Grace) Close() error { + g.mu.Lock() + g.closing = true + var wg sync.WaitGroup + for c := range g.conns { + wg.Add(1) + go func(c *Conn) { + defer wg.Done() + c.Close(StatusGoingAway, "server shutting down") + }(c) + + delete(g.conns, c) + } + g.mu.Unlock() + + wg.Wait() + + return nil +} + +// Shutdown prevents the acceptance of new connections and waits until +// all connections close. If the context is cancelled before that, it +// calls Close to close all connections immediately. +func (g *Grace) Shutdown(ctx context.Context) error { + defer g.Close() + + g.mu.Lock() + g.closing = true + g.mu.Unlock() + + // Same poll period used by net/http. + t := time.NewTicker(500 * time.Millisecond) + defer t.Stop() + for { + if g.zeroConns() { + return nil + } + + select { + case <-t.C: + case <-ctx.Done(): + return fmt.Errorf("failed to shutdown WebSockets: %w", ctx.Err()) + } + } +} + +func (g *Grace) zeroConns() bool { + g.mu.Lock() + defer g.mu.Unlock() + return len(g.conns) == 0 +} diff --git a/ws_js.go b/ws_js.go index 2b560ce8..a8c8b771 100644 --- a/ws_js.go +++ b/ws_js.go @@ -38,6 +38,8 @@ type Conn struct { readSignal chan struct{} readBufMu sync.Mutex readBuf []wsjs.MessageEvent + + g *Grace } func (c *Conn) close(err error, wasClean bool) { From e335b09210e47739545fe30c69f3a0f56ede98a0 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 26 Feb 2020 14:47:40 -0500 Subject: [PATCH 06/20] Use grace in chat example --- accept.go | 4 ++-- chat-example/go.mod | 4 +++- chat-example/go.sum | 10 ++++++++-- chat-example/index.css | 2 +- chat-example/index.js | 13 ++++++++++--- chat-example/main.go | 29 +++++++++++++++++++++++++++-- example_test.go | 14 +++++++++++--- grace.go | 20 ++++++++++++-------- 8 files changed, 74 insertions(+), 22 deletions(-) diff --git a/accept.go b/accept.go index 52a93459..dd96c9bd 100644 --- a/accept.go +++ b/accept.go @@ -76,8 +76,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con defer errd.Wrap(&err, "failed to accept WebSocket connection") g := graceFromRequest(r) - if g != nil && g.isClosing() { - err := errors.New("server closing") + if g != nil && g.isShuttingdown() { + err := errors.New("server shutting down") http.Error(w, err.Error(), http.StatusServiceUnavailable) return nil, err } diff --git a/chat-example/go.mod b/chat-example/go.mod index 34fa5a69..c47a5a2f 100644 --- a/chat-example/go.mod +++ b/chat-example/go.mod @@ -2,4 +2,6 @@ module nhooyr.io/websocket/example-chat go 1.13 -require nhooyr.io/websocket v1.8.2 +require nhooyr.io/websocket v0.0.0 + +replace nhooyr.io/websocket => ../ diff --git a/chat-example/go.sum b/chat-example/go.sum index 0755fca5..e4bbd62d 100644 --- a/chat-example/go.sum +++ b/chat-example/go.sum @@ -1,12 +1,18 @@ +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -nhooyr.io/websocket v1.8.2 h1:LwdzfyyOZKtVFoXay6A39Acu03KmidSZ3YUUvPa13PA= -nhooyr.io/websocket v1.8.2/go.mod h1:LiqdCg1Cu7TPWxEvPjPa0TGYxCsy4pHNTN9gGluwBpQ= diff --git a/chat-example/index.css b/chat-example/index.css index 29804662..73a8e0f3 100644 --- a/chat-example/index.css +++ b/chat-example/index.css @@ -5,7 +5,7 @@ body { #root { padding: 40px 20px; - max-width: 480px; + max-width: 600px; margin: auto; height: 100vh; diff --git a/chat-example/index.js b/chat-example/index.js index 8fb3dfb8..a42c2d30 100644 --- a/chat-example/index.js +++ b/chat-example/index.js @@ -7,8 +7,11 @@ const conn = new WebSocket(`ws://${location.host}/subscribe`) conn.addEventListener("close", ev => { - console.info("websocket disconnected, reconnecting in 1000ms", ev) - setTimeout(dial, 1000) + appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true) + if (ev.code !== 1001) { + appendLog("Reconnecting in 1s", true) + setTimeout(dial, 1000) + } }) conn.addEventListener("open", ev => { console.info("websocket connected") @@ -34,10 +37,14 @@ const messageInput = document.getElementById("message-input") // appendLog appends the passed text to messageLog. - function appendLog(text) { + function appendLog(text, error) { const p = document.createElement("p") // Adding a timestamp to each message makes the log easier to read. p.innerText = `${new Date().toLocaleTimeString()}: ${text}` + if (error) { + p.style.color = "red" + p.style.fontStyle = "bold" + } messageLog.append(p) return p } diff --git a/chat-example/main.go b/chat-example/main.go index 2a520924..f985d382 100644 --- a/chat-example/main.go +++ b/chat-example/main.go @@ -1,12 +1,16 @@ package main import ( + "context" "errors" "log" "net" "net/http" "os" + "os/signal" "time" + + "nhooyr.io/websocket" ) func main() { @@ -38,10 +42,31 @@ func run() error { m.HandleFunc("/subscribe", ws.subscribeHandler) m.HandleFunc("/publish", ws.publishHandler) + var g websocket.Grace s := http.Server{ - Handler: m, + Handler: g.Handler(m), ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } - return s.Serve(l) + errc := make(chan error, 1) + go func() { + errc <- s.Serve(l) + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + select { + case err := <-errc: + log.Printf("failed to serve: %v", err) + case sig := <-sigs: + log.Printf("terminating: %v", sig) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + s.Shutdown(ctx) + g.Shutdown(ctx) + + return nil } diff --git a/example_test.go b/example_test.go index ce049bc3..462de376 100644 --- a/example_test.go +++ b/example_test.go @@ -167,12 +167,20 @@ func ExampleGrace() { ReadTimeout: time.Second * 15, WriteTimeout: time.Second * 15, } - go s.ListenAndServe() + + errc := make(chan error, 1) + go func() { + errc <- s.ListenAndServe() + }() sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt) - sig := <-sigs - log.Printf("recieved %v, shutting down", sig) + select { + case err := <-errc: + log.Printf("failed to listen and serve: %v", err) + case sig := <-sigs: + log.Printf("terminating: %v", sig) + } ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() diff --git a/grace.go b/grace.go index 8dadc43d..c53cd40b 100644 --- a/grace.go +++ b/grace.go @@ -15,10 +15,13 @@ import ( // and then use Close or Shutdown to gracefully close these connections. // // Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods. +// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket +// connections. type Grace struct { - mu sync.Mutex - closing bool - conns map[*Conn]struct{} + mu sync.Mutex + closed bool + shuttingDown bool + conns map[*Conn]struct{} } // Handler returns a handler that wraps around h to record @@ -33,10 +36,10 @@ func (g *Grace) Handler(h http.Handler) http.Handler { }) } -func (g *Grace) isClosing() bool { +func (g *Grace) isShuttingdown() bool { g.mu.Lock() defer g.mu.Unlock() - return g.closing + return g.shuttingDown } func graceFromRequest(r *http.Request) *Grace { @@ -47,7 +50,7 @@ func graceFromRequest(r *http.Request) *Grace { func (g *Grace) addConn(c *Conn) error { g.mu.Lock() defer g.mu.Unlock() - if g.closing { + if g.closed { c.Close(StatusGoingAway, "server shutting down") return errors.New("server shutting down") } @@ -72,7 +75,8 @@ type gracefulContextKey struct{} // connections with StatusGoingAway. func (g *Grace) Close() error { g.mu.Lock() - g.closing = true + g.shuttingDown = true + g.closed = true var wg sync.WaitGroup for c := range g.conns { wg.Add(1) @@ -97,7 +101,7 @@ func (g *Grace) Shutdown(ctx context.Context) error { defer g.Close() g.mu.Lock() - g.closing = true + g.shuttingDown = true g.mu.Unlock() // Same poll period used by net/http. From 190981dcf7f6af74049e8c6eab9dd500b0a9a47f Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 26 Feb 2020 15:31:29 -0500 Subject: [PATCH 07/20] Add automated test to chat example --- chat-example/chat.go | 61 ++++++++++++----- chat-example/chat_test.go | 137 ++++++++++++++++++++++++++++++++++++++ chat-example/go.mod | 7 -- chat-example/main.go | 10 +-- 4 files changed, 183 insertions(+), 32 deletions(-) create mode 100644 chat-example/chat_test.go delete mode 100644 chat-example/go.mod diff --git a/chat-example/chat.go b/chat-example/chat.go index e6e355d0..9b264195 100644 --- a/chat-example/chat.go +++ b/chat-example/chat.go @@ -15,8 +15,28 @@ import ( // chatServer enables broadcasting to a set of subscribers. type chatServer struct { + registerOnce sync.Once + m http.ServeMux + subscribersMu sync.RWMutex - subscribers map[chan<- []byte]struct{} + subscribers map[*subscriber]struct{} +} + +// subscriber represents a subscriber. +// Messages are sent on the msgs channel and if the client +// cannot keep up with the messages, closeSlow is called. +type subscriber struct { + msgs chan []byte + closeSlow func() +} + +func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + cs.registerOnce.Do(func() { + cs.m.Handle("/", http.FileServer(http.Dir("."))) + cs.m.HandleFunc("/subscribe", cs.subscribeHandler) + cs.m.HandleFunc("/publish", cs.publishHandler) + }) + cs.m.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes @@ -57,11 +77,13 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { } cs.publish(msg) + + w.WriteHeader(http.StatusAccepted) } // subscribe subscribes the given WebSocket to all broadcast messages. -// It creates a msgs chan with a buffer of 16 to give some room to slower -// connections and then registers it. It then listens for all messages +// It creates a subscriber with a buffered msgs chan to give some room to slower +// connections and then registers the subscriber. It then listens for all messages // and writes them to the WebSocket. If the context is cancelled or // an error occurs, it returns and deletes the subscription. // @@ -70,13 +92,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) - msgs := make(chan []byte, 16) - cs.addSubscriber(msgs) - defer cs.deleteSubscriber(msgs) + s := &subscriber{ + msgs: make(chan []byte, 16), + closeSlow: func() { + c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + }, + } + cs.addSubscriber(s) + defer cs.deleteSubscriber(s) for { select { - case msg := <-msgs: + case msg := <-s.msgs: err := writeTimeout(ctx, time.Second*5, c, msg) if err != nil { return err @@ -94,29 +121,29 @@ func (cs *chatServer) publish(msg []byte) { cs.subscribersMu.RLock() defer cs.subscribersMu.RUnlock() - for c := range cs.subscribers { + for s := range cs.subscribers { select { - case c <- msg: + case s.msgs <- msg: default: + go s.closeSlow() } } } -// addSubscriber registers a subscriber with a channel -// on which to send messages. -func (cs *chatServer) addSubscriber(msgs chan<- []byte) { +// addSubscriber registers a subscriber. +func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() if cs.subscribers == nil { - cs.subscribers = make(map[chan<- []byte]struct{}) + cs.subscribers = make(map[*subscriber]struct{}) } - cs.subscribers[msgs] = struct{}{} + cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } -// deleteSubscriber deletes the subscriber with the given msgs channel. -func (cs *chatServer) deleteSubscriber(msgs chan []byte) { +// deleteSubscriber deletes the given subscriber. +func (cs *chatServer) deleteSubscriber(s *subscriber) { cs.subscribersMu.Lock() - delete(cs.subscribers, msgs) + delete(cs.subscribers, s) cs.subscribersMu.Unlock() } diff --git a/chat-example/chat_test.go b/chat-example/chat_test.go new file mode 100644 index 00000000..d1772381 --- /dev/null +++ b/chat-example/chat_test.go @@ -0,0 +1,137 @@ +// +build !js + +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "nhooyr.io/websocket" +) + +func TestGrace(t *testing.T) { + t.Parallel() + + var cs chatServer + var g websocket.Grace + s := httptest.NewServer(g.Handler(&cs)) + defer s.Close() + defer g.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + cl1, err := newClient(ctx, s.URL) + assertSuccess(t, err) + defer cl1.Close() + + cl2, err := newClient(ctx, s.URL) + assertSuccess(t, err) + defer cl2.Close() + + err = cl1.publish(ctx, "hello") + assertSuccess(t, err) + + assertReceivedMessage(ctx, cl1, "hello") + assertReceivedMessage(ctx, cl2, "hello") +} + +type client struct { + msgs chan string + url string + c *websocket.Conn +} + +func newClient(ctx context.Context, url string) (*client, error) { + wsURL := strings.ReplaceAll(url, "http://", "ws://") + c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil) + if err != nil { + return nil, err + } + + cl := &client{ + msgs: make(chan string, 16), + url: url, + c: c, + } + go cl.readLoop() + + return cl, nil +} + +func (cl *client) readLoop() { + defer cl.c.Close(websocket.StatusInternalError, "") + defer close(cl.msgs) + + for { + typ, b, err := cl.c.Read(context.Background()) + if err != nil { + return + } + + if typ != websocket.MessageText { + cl.c.Close(websocket.StatusUnsupportedData, "expected text message") + return + } + + select { + case cl.msgs <- string(b): + default: + cl.c.Close(websocket.StatusInternalError, "messages coming in too fast to handle") + return + } + } +} + +func (cl *client) receive(ctx context.Context) (string, error) { + select { + case msg, ok := <-cl.msgs: + if !ok { + return "", errors.New("client closed") + } + return msg, nil + case <-ctx.Done(): + return "", ctx.Err() + } +} + +func (cl *client) publish(ctx context.Context, msg string) error { + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("publish request failed: %v", resp.StatusCode) + } + return nil +} + +func (cl *client) Close() error { + return cl.c.Close(websocket.StatusNormalClosure, "") +} + +func assertSuccess(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func assertReceivedMessage(ctx context.Context, cl *client, msg string) error { + msg, err := cl.receive(ctx) + if err != nil { + return err + } + if msg != "hello" { + return fmt.Errorf("expected hello but got %q", msg) + } + return nil +} diff --git a/chat-example/go.mod b/chat-example/go.mod deleted file mode 100644 index c47a5a2f..00000000 --- a/chat-example/go.mod +++ /dev/null @@ -1,7 +0,0 @@ -module nhooyr.io/websocket/example-chat - -go 1.13 - -require nhooyr.io/websocket v0.0.0 - -replace nhooyr.io/websocket => ../ diff --git a/chat-example/main.go b/chat-example/main.go index f985d382..a265f60c 100644 --- a/chat-example/main.go +++ b/chat-example/main.go @@ -35,16 +35,10 @@ func run() error { } log.Printf("listening on http://%v", l.Addr()) - var ws chatServer - - m := http.NewServeMux() - m.Handle("/", http.FileServer(http.Dir("."))) - m.HandleFunc("/subscribe", ws.subscribeHandler) - m.HandleFunc("/publish", ws.publishHandler) - + var cs chatServer var g websocket.Grace s := http.Server{ - Handler: g.Handler(m), + Handler: g.Handler(&cs), ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } From da3aa8cfcc08909ea3cd41153637e5c1697bac59 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 26 Feb 2020 20:39:59 -0500 Subject: [PATCH 08/20] Improve chat example test --- chat-example/README.md | 6 + chat-example/chat.go | 67 ++++++--- chat-example/chat_test.go | 285 ++++++++++++++++++++++++++++---------- chat-example/index.js | 17 ++- chat-example/main.go | 4 +- ci/test.mk | 3 +- 6 files changed, 284 insertions(+), 98 deletions(-) diff --git a/chat-example/README.md b/chat-example/README.md index ef06275d..a4c99a93 100644 --- a/chat-example/README.md +++ b/chat-example/README.md @@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by `index.html` and then `index.js`. + +There are two automated tests for the server included in `chat_test.go`. The first is a simple one +client echo test. It publishes a single message and ensures it's received. + +The second is a complex concurrency test where 10 clients send 128 unique messages +of max 128 bytes concurrently. The test ensures all messages are seen by every client. diff --git a/chat-example/chat.go b/chat-example/chat.go index 9b264195..532e50f5 100644 --- a/chat-example/chat.go +++ b/chat-example/chat.go @@ -3,25 +3,57 @@ package main import ( "context" "errors" - "io" "io/ioutil" "log" "net/http" "sync" "time" + "golang.org/x/time/rate" + "nhooyr.io/websocket" ) // chatServer enables broadcasting to a set of subscribers. type chatServer struct { - registerOnce sync.Once - m http.ServeMux - - subscribersMu sync.RWMutex + // subscriberMessageBuffer controls the max number + // of messages that can be queued for a subscriber + // before it is kicked. + // + // Defaults to 16. + subscriberMessageBuffer int + + // publishLimiter controls the rate limit applied to the publish endpoint. + // + // Defaults to one publish every 100ms with a burst of 8. + publishLimiter *rate.Limiter + + // logf controls where logs are sent. + // Defaults to log.Printf. + logf func(f string, v ...interface{}) + + // serveMux routes the various endpoints to the appropriate handler. + serveMux http.ServeMux + + subscribersMu sync.Mutex subscribers map[*subscriber]struct{} } +// newChatServer constructs a chatServer with the defaults. +func newChatServer() *chatServer { + cs := &chatServer{ + subscriberMessageBuffer: 16, + logf: log.Printf, + subscribers: make(map[*subscriber]struct{}), + publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), + } + cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) + cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) + cs.serveMux.HandleFunc("/publish", cs.publishHandler) + + return cs +} + // subscriber represents a subscriber. // Messages are sent on the msgs channel and if the client // cannot keep up with the messages, closeSlow is called. @@ -31,12 +63,7 @@ type subscriber struct { } func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - cs.registerOnce.Do(func() { - cs.m.Handle("/", http.FileServer(http.Dir("."))) - cs.m.HandleFunc("/subscribe", cs.subscribeHandler) - cs.m.HandleFunc("/publish", cs.publishHandler) - }) - cs.m.ServeHTTP(w, r) + cs.serveMux.ServeHTTP(w, r) } // subscribeHandler accepts the WebSocket connection and then subscribes @@ -44,7 +71,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { - log.Print(err) + cs.logf("%v", err) return } defer c.Close(websocket.StatusInternalError, "") @@ -58,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { return } if err != nil { - log.Print(err) + cs.logf("%v", err) + return } } @@ -69,7 +97,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } - body := io.LimitReader(r.Body, 8192) + body := http.MaxBytesReader(w, r.Body, 8192) msg, err := ioutil.ReadAll(body) if err != nil { http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) @@ -93,7 +121,7 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { ctx = c.CloseRead(ctx) s := &subscriber{ - msgs: make(chan []byte, 16), + msgs: make(chan []byte, cs.subscriberMessageBuffer), closeSlow: func() { c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") }, @@ -118,8 +146,10 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { // It never blocks and so messages to slow subscribers // are dropped. func (cs *chatServer) publish(msg []byte) { - cs.subscribersMu.RLock() - defer cs.subscribersMu.RUnlock() + cs.subscribersMu.Lock() + defer cs.subscribersMu.Unlock() + + cs.publishLimiter.Wait(context.Background()) for s := range cs.subscribers { select { @@ -133,9 +163,6 @@ func (cs *chatServer) publish(msg []byte) { // addSubscriber registers a subscriber. func (cs *chatServer) addSubscriber(s *subscriber) { cs.subscribersMu.Lock() - if cs.subscribers == nil { - cs.subscribers = make(map[*subscriber]struct{}) - } cs.subscribers[s] = struct{}{} cs.subscribersMu.Unlock() } diff --git a/chat-example/chat_test.go b/chat-example/chat_test.go index d1772381..491499cc 100644 --- a/chat-example/chat_test.go +++ b/chat-example/chat_test.go @@ -4,104 +4,214 @@ package main import ( "context" - "errors" + "crypto/rand" "fmt" + "math/big" "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" + "golang.org/x/time/rate" + "nhooyr.io/websocket" ) -func TestGrace(t *testing.T) { +func Test_chatServer(t *testing.T) { t.Parallel() - var cs chatServer + // This is a simple echo test with a single client. + // The client sends a message and ensures it receives + // it on its WebSocket. + t.Run("simple", func(t *testing.T) { + t.Parallel() + + url, closeFn := setupTest(t) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + cl, err := newClient(ctx, url) + assertSuccess(t, err) + defer cl.Close() + + expMsg := randString(512) + err = cl.publish(ctx, expMsg) + assertSuccess(t, err) + + msg, err := cl.nextMessage() + assertSuccess(t, err) + + if expMsg != msg { + t.Fatalf("expected %v but got %v", expMsg, msg) + } + }) + + // This test is a complex concurrency test. + // 10 clients are started that send 128 different + // messages of max 128 bytes concurrently. + // + // The test verifies that every message is seen by ever client + // and no errors occur anywhere. + t.Run("concurrency", func(t *testing.T) { + t.Parallel() + + const nmessages = 128 + const maxMessageSize = 128 + const nclients = 10 + + url, closeFn := setupTest(t) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var clients []*client + var clientMsgs []map[string]struct{} + for i := 0; i < nclients; i++ { + cl, err := newClient(ctx, url) + assertSuccess(t, err) + defer cl.Close() + + clients = append(clients, cl) + clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize)) + } + + allMessages := make(map[string]struct{}) + for _, msgs := range clientMsgs { + for m := range msgs { + allMessages[m] = struct{}{} + } + } + + var wg sync.WaitGroup + for i, cl := range clients { + i := i + cl := cl + + wg.Add(1) + go func() { + defer wg.Done() + err := cl.publishMsgs(ctx, clientMsgs[i]) + if err != nil { + t.Errorf("client %d failed to publish all messages: %v", i, err) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := testAllMessagesReceived(cl, nclients*nmessages, allMessages) + if err != nil { + t.Errorf("client %d failed to receive all messages: %v", i, err) + } + }() + } + + wg.Wait() + }) +} + +// setupTest sets up chatServer that can be used +// via the returned url. +// +// Defer closeFn to ensure everything is cleaned up at +// the end of the test. +// +// chatServer logs will be logged via t.Logf. +func setupTest(t *testing.T) (url string, closeFn func()) { + cs := newChatServer() + cs.logf = t.Logf + + // To ensure tests run quickly under even -race. + cs.subscriberMessageBuffer = 4096 + cs.publishLimiter.SetLimit(rate.Inf) + var g websocket.Grace - s := httptest.NewServer(g.Handler(&cs)) - defer s.Close() - defer g.Close() + s := httptest.NewServer(g.Handler(cs)) + return s.URL, func() { + s.Close() + g.Close() + } +} - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() +// testAllMessagesReceived ensures that after n reads, all msgs in msgs +// have been read. +func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error { + msgs = cloneMessages(msgs) - cl1, err := newClient(ctx, s.URL) - assertSuccess(t, err) - defer cl1.Close() + for i := 0; i < n; i++ { + msg, err := cl.nextMessage() + if err != nil { + return err + } + delete(msgs, msg) + } - cl2, err := newClient(ctx, s.URL) - assertSuccess(t, err) - defer cl2.Close() + if len(msgs) != 0 { + return fmt.Errorf("did not receive all expected messages: %q", msgs) + } + return nil +} - err = cl1.publish(ctx, "hello") - assertSuccess(t, err) +func cloneMessages(msgs map[string]struct{}) map[string]struct{} { + msgs2 := make(map[string]struct{}, len(msgs)) + for m := range msgs { + msgs2[m] = struct{}{} + } + return msgs2 +} - assertReceivedMessage(ctx, cl1, "hello") - assertReceivedMessage(ctx, cl2, "hello") +func randMessages(n, maxMessageLength int) map[string]struct{} { + msgs := make(map[string]struct{}) + for i := 0; i < n; i++ { + m := randString(randInt(maxMessageLength)) + if _, ok := msgs[m]; ok { + i-- + continue + } + msgs[m] = struct{}{} + } + return msgs +} + +func assertSuccess(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } } type client struct { - msgs chan string - url string - c *websocket.Conn + url string + c *websocket.Conn } func newClient(ctx context.Context, url string) (*client, error) { - wsURL := strings.ReplaceAll(url, "http://", "ws://") + wsURL := strings.Replace(url, "http://", "ws://", 1) c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil) if err != nil { return nil, err } cl := &client{ - msgs: make(chan string, 16), - url: url, - c: c, + url: url, + c: c, } - go cl.readLoop() return cl, nil } -func (cl *client) readLoop() { - defer cl.c.Close(websocket.StatusInternalError, "") - defer close(cl.msgs) - - for { - typ, b, err := cl.c.Read(context.Background()) +func (cl *client) publish(ctx context.Context, msg string) (err error) { + defer func() { if err != nil { - return + cl.c.Close(websocket.StatusInternalError, "publish failed") } + }() - if typ != websocket.MessageText { - cl.c.Close(websocket.StatusUnsupportedData, "expected text message") - return - } - - select { - case cl.msgs <- string(b): - default: - cl.c.Close(websocket.StatusInternalError, "messages coming in too fast to handle") - return - } - } -} - -func (cl *client) receive(ctx context.Context) (string, error) { - select { - case msg, ok := <-cl.msgs: - if !ok { - return "", errors.New("client closed") - } - return msg, nil - case <-ctx.Done(): - return "", ctx.Err() - } -} - -func (cl *client) publish(ctx context.Context, msg string) error { req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) resp, err := http.DefaultClient.Do(req) if err != nil { @@ -114,24 +224,59 @@ func (cl *client) publish(ctx context.Context, msg string) error { return nil } +func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error { + for m := range msgs { + err := cl.publish(ctx, m) + if err != nil { + return err + } + } + return nil +} + +func (cl *client) nextMessage() (string, error) { + typ, b, err := cl.c.Read(context.Background()) + if err != nil { + return "", err + } + + if typ != websocket.MessageText { + cl.c.Close(websocket.StatusUnsupportedData, "expected text message") + return "", fmt.Errorf("expected text message but got %v", typ) + } + return string(b), nil +} + func (cl *client) Close() error { return cl.c.Close(websocket.StatusNormalClosure, "") } -func assertSuccess(t *testing.T, err error) { - t.Helper() +// randString generates a random string with length n. +func randString(n int) string { + b := make([]byte, n) + _, err := rand.Reader.Read(b) if err != nil { - t.Fatal(err) + panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) + } + + s := strings.ToValidUTF8(string(b), "_") + s = strings.ReplaceAll(s, "\x00", "_") + if len(s) > n { + return s[:n] } + if len(s) < n { + // Pad with = + extra := n - len(s) + return s + strings.Repeat("=", extra) + } + return s } -func assertReceivedMessage(ctx context.Context, cl *client, msg string) error { - msg, err := cl.receive(ctx) +// randInt returns a randomly generated integer between [0, max). +func randInt(max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) if err != nil { - return err + panic(fmt.Sprintf("failed to get random int: %v", err)) } - if msg != "hello" { - return fmt.Errorf("expected hello but got %q", msg) - } - return nil + return int(x.Int64()) } diff --git a/chat-example/index.js b/chat-example/index.js index a42c2d30..5868e7ca 100644 --- a/chat-example/index.js +++ b/chat-example/index.js @@ -51,7 +51,7 @@ appendLog("Submit a message to get started!") // onsubmit publishes the message from the user when the form is submitted. - publishForm.onsubmit = ev => { + publishForm.onsubmit = async ev => { ev.preventDefault() const msg = messageInput.value @@ -61,9 +61,16 @@ messageInput.value = "" expectingMessage = true - fetch("/publish", { - method: "POST", - body: msg, - }) + try { + const resp = await fetch("/publish", { + method: "POST", + body: msg, + }) + if (resp.status !== 202) { + throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`) + } + } catch (err) { + appendLog(`Publish failed: ${err.message}`, true) + } } })() diff --git a/chat-example/main.go b/chat-example/main.go index a265f60c..1b6f3266 100644 --- a/chat-example/main.go +++ b/chat-example/main.go @@ -35,10 +35,10 @@ func run() error { } log.Printf("listening on http://%v", l.Addr()) - var cs chatServer + cs := newChatServer() var g websocket.Grace s := http.Server{ - Handler: g.Handler(&cs), + Handler: g.Handler(cs), ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } diff --git a/ci/test.mk b/ci/test.mk index c62a25b6..291d6beb 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -11,6 +11,7 @@ coveralls: gotest goveralls -coverprofile=ci/out/coverage.prof gotest: - go test -timeout=30m -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... + go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... sed -i '/stringer\.go/d' ci/out/coverage.prof sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof + sed -i '/chat-example/d' ci/out/coverage.prof From 07343c2a717904b31466177f5dfeefb5ef5ab687 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 27 Feb 2020 21:03:35 -0500 Subject: [PATCH 09/20] Allow passing http:// and https:// URLS to Dial --- README.md | 8 ++++---- accept_js.go | 1 + chat-example/chat_test.go | 5 ++--- close_notjs.go | 17 +++++++++-------- conn_test.go | 2 +- dial.go | 3 +++ example_test.go | 7 +++++-- internal/test/wstest/url.go | 11 ----------- ws_js.go | 4 ++++ 9 files changed, 29 insertions(+), 29 deletions(-) delete mode 100644 internal/test/wstest/url.go diff --git a/README.md b/README.md index e967cd8a..2d71ce0b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # websocket -[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) +[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket) websocket is a minimal and idiomatic WebSocket library for Go. @@ -16,8 +16,8 @@ go get nhooyr.io/websocket - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - Thorough unit tests with [90% coverage](https://coveralls.io/github/nhooyr/websocket) -- [Minimal dependencies](https://godoc.org/nhooyr.io/websocket?imports) -- JSON and protobuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages +- [Minimal dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) +- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson?tab=doc) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb?tab=doc) subpackages - Zero alloc reads and writes - Concurrent writes - [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) @@ -98,7 +98,7 @@ Advantages of nhooyr.io/websocket: - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Full [context.Context](https://blog.golang.org/context) support -- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) +- Dials use [net/http.Client](https://golang.org/pkg/net/http/#Client) - Will enable easy HTTP/2 support in the future - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Concurrent writes diff --git a/accept_js.go b/accept_js.go index 724b35b5..daad4b79 100644 --- a/accept_js.go +++ b/accept_js.go @@ -9,6 +9,7 @@ import ( type AcceptOptions struct { Subprotocols []string InsecureSkipVerify bool + OriginPatterns []string CompressionMode CompressionMode CompressionThreshold int } diff --git a/chat-example/chat_test.go b/chat-example/chat_test.go index 491499cc..2cbc995e 100644 --- a/chat-example/chat_test.go +++ b/chat-example/chat_test.go @@ -61,7 +61,7 @@ func Test_chatServer(t *testing.T) { const nmessages = 128 const maxMessageSize = 128 - const nclients = 10 + const nclients = 16 url, closeFn := setupTest(t) defer closeFn() @@ -191,8 +191,7 @@ type client struct { } func newClient(ctx context.Context, url string) (*client, error) { - wsURL := strings.Replace(url, "http://", "ws://", 1) - c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil) + c, _, err := websocket.Dial(ctx, url+"/subscribe", nil) if err != nil { return nil, err } diff --git a/close_notjs.go b/close_notjs.go index 25372995..4f1cebcb 100644 --- a/close_notjs.go +++ b/close_notjs.go @@ -34,14 +34,15 @@ func (c *Conn) Close(code StatusCode, reason string) error { func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - err = c.writeClose(code, reason) - if err != nil && CloseStatus(err) == -1 && err != errAlreadyWroteClose { - return err + writeErr := c.writeClose(code, reason) + closeHandshakeErr := c.waitCloseHandshake() + + if writeErr != nil { + return writeErr } - err = c.waitCloseHandshake() - if CloseStatus(err) == -1 { - return err + if CloseStatus(closeHandshakeErr) == -1 { + return closeHandshakeErr } return nil } @@ -50,10 +51,10 @@ var errAlreadyWroteClose = errors.New("already wrote close") func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() - closing := c.wroteClose + wroteClose := c.wroteClose c.wroteClose = true c.closeMu.Unlock() - if closing { + if wroteClose { return errAlreadyWroteClose } diff --git a/conn_test.go b/conn_test.go index af4fa4c0..7514540d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -298,7 +298,7 @@ func TestWasm(t *testing.T) { defer cancel() cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wstest.URL(s))) + cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() if err != nil { diff --git a/dial.go b/dial.go index 50a0ecce..9ab680eb 100644 --- a/dial.go +++ b/dial.go @@ -58,6 +58,8 @@ type DialOptions struct { // This function requires at least Go 1.12 as it uses a new feature // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 +// +// URLs with http/https schemes will work and translated into ws/wss. func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { return dial(ctx, u, opts, nil) } @@ -145,6 +147,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts u.Scheme = "http" case "wss": u.Scheme = "https" + case "http", "https": default: return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } diff --git a/example_test.go b/example_test.go index 462de376..39de0b80 100644 --- a/example_test.go +++ b/example_test.go @@ -1,5 +1,3 @@ -// +build !js - package websocket_test import ( @@ -187,3 +185,8 @@ func ExampleGrace() { s.Shutdown(ctx) g.Shutdown(ctx) } + +// This example demonstrates full stack chat with an automated test. +func Example_fullStackChat() { + // https://github.com/nhooyr/websocket/tree/master/chat-example +} diff --git a/internal/test/wstest/url.go b/internal/test/wstest/url.go deleted file mode 100644 index a11c61b4..00000000 --- a/internal/test/wstest/url.go +++ /dev/null @@ -1,11 +0,0 @@ -package wstest - -import ( - "net/http/httptest" - "strings" -) - -// URL returns the ws url for s. -func URL(s *httptest.Server) string { - return strings.Replace(s.URL, "http", "ws", 1) -} diff --git a/ws_js.go b/ws_js.go index a8c8b771..69019e61 100644 --- a/ws_js.go +++ b/ws_js.go @@ -9,6 +9,7 @@ import ( "net/http" "reflect" "runtime" + "strings" "sync" "syscall/js" @@ -257,6 +258,9 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp opts = &DialOptions{} } + url = strings.Replace(url, "http://", "ws://", 1) + url = strings.Replace(url, "https://", "wss://", 1) + ws, err := wsjs.New(url, opts.Subprotocols) if err != nil { return nil, nil, err From 008b61622b1cb1774a35191d35b102139b02f321 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 22 Mar 2020 12:33:24 -0400 Subject: [PATCH 10/20] Remove Grace partially --- accept.go | 14 ------ conn_notjs.go | 5 --- example_echo_test.go | 5 +++ grace.go | 105 ++++++++++++++++++++----------------------- 4 files changed, 54 insertions(+), 75 deletions(-) diff --git a/accept.go b/accept.go index dd96c9bd..a583f232 100644 --- a/accept.go +++ b/accept.go @@ -75,13 +75,6 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { defer errd.Wrap(&err, "failed to accept WebSocket connection") - g := graceFromRequest(r) - if g != nil && g.isShuttingdown() { - err := errors.New("server shutting down") - http.Error(w, err.Error(), http.StatusServiceUnavailable) - return nil, err - } - if opts == nil { opts = &AcceptOptions{} } @@ -152,13 +145,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con bw: brw.Writer, }) - if g != nil { - err = g.addConn(c) - if err != nil { - return nil, err - } - } - return c, nil } diff --git a/conn_notjs.go b/conn_notjs.go index f604898e..bb2eb22f 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -33,7 +33,6 @@ type Conn struct { flateThreshold int br *bufio.Reader bw *bufio.Writer - g *Grace readTimeout chan context.Context writeTimeout chan context.Context @@ -139,10 +138,6 @@ func (c *Conn) close(err error) { // closeErr. c.rwc.Close() - if c.g != nil { - c.g.delConn(c) - } - go func() { c.msgWriterState.close() diff --git a/example_echo_test.go b/example_echo_test.go index 0c0b84ea..fb212c1f 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -18,6 +18,11 @@ import ( "nhooyr.io/websocket/wsjson" ) +// TODO IMPROVE CANCELLATION AND SHUTDOWN +// TODO on context cancel send websocket going away and fix the read timeout error to be dependant on context deadline reached. +// TODO this way you cancel your context and the right message automatically gets sent. Furthrmore, then u can just use a simple waitgroup to wait for connections. +// TODO grace is wrong as it doesn't wait for the individual goroutines. + // This example starts a WebSocket echo server, // dials the server and then sends 5 different messages // and prints out the server's responses. diff --git a/grace.go b/grace.go index c53cd40b..a0ec8969 100644 --- a/grace.go +++ b/grace.go @@ -2,7 +2,6 @@ package websocket import ( "context" - "errors" "fmt" "net/http" "sync" @@ -17,79 +16,75 @@ import ( // Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods. // It's required as net/http's Shutdown and Close methods do not keep track of WebSocket // connections. +// +// Make sure to Close or Shutdown the *http.Server first as you don't want to accept +// any new connections while the existing websockets are being shut down. type Grace struct { - mu sync.Mutex - closed bool - shuttingDown bool - conns map[*Conn]struct{} + handlersMu sync.Mutex + closing bool + handlers map[context.Context]context.CancelFunc } // Handler returns a handler that wraps around h to record // all WebSocket connections accepted. // // Use Close or Shutdown to gracefully close recorded connections. +// Make sure to Close or Shutdown the *http.Server first. func (g *Grace) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), gracefulContextKey{}, g) + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + r = r.WithContext(ctx) + + ok := g.add(w, ctx, cancel) + if !ok { + return + } + defer g.del(ctx) + h.ServeHTTP(w, r) }) } -func (g *Grace) isShuttingdown() bool { - g.mu.Lock() - defer g.mu.Unlock() - return g.shuttingDown -} - -func graceFromRequest(r *http.Request) *Grace { - g, _ := r.Context().Value(gracefulContextKey{}).(*Grace) - return g -} +func (g *Grace) add(w http.ResponseWriter, ctx context.Context, cancel context.CancelFunc) bool { + g.handlersMu.Lock() + defer g.handlersMu.Unlock() -func (g *Grace) addConn(c *Conn) error { - g.mu.Lock() - defer g.mu.Unlock() - if g.closed { - c.Close(StatusGoingAway, "server shutting down") - return errors.New("server shutting down") + if g.closing { + http.Error(w, "shutting down", http.StatusServiceUnavailable) + return false } - if g.conns == nil { - g.conns = make(map[*Conn]struct{}) + + if g.handlers == nil { + g.handlers = make(map[context.Context]context.CancelFunc) } - g.conns[c] = struct{}{} - c.g = g - return nil -} + g.handlers[ctx] = cancel -func (g *Grace) delConn(c *Conn) { - g.mu.Lock() - defer g.mu.Unlock() - delete(g.conns, c) + return true } -type gracefulContextKey struct{} +func (g *Grace) del(ctx context.Context) { + g.handlersMu.Lock() + defer g.handlersMu.Unlock() + + delete(g.handlers, ctx) +} // Close prevents the acceptance of new connections with // http.StatusServiceUnavailable and closes all accepted // connections with StatusGoingAway. +// +// Make sure to Close or Shutdown the *http.Server first. func (g *Grace) Close() error { - g.mu.Lock() - g.shuttingDown = true - g.closed = true - var wg sync.WaitGroup - for c := range g.conns { - wg.Add(1) - go func(c *Conn) { - defer wg.Done() - c.Close(StatusGoingAway, "server shutting down") - }(c) - - delete(g.conns, c) + g.handlersMu.Lock() + for _, cancel := range g.handlers { + cancel() } - g.mu.Unlock() + g.handlersMu.Unlock() - wg.Wait() + // Wait for all goroutines to exit. + g.Shutdown(context.Background()) return nil } @@ -97,18 +92,16 @@ func (g *Grace) Close() error { // Shutdown prevents the acceptance of new connections and waits until // all connections close. If the context is cancelled before that, it // calls Close to close all connections immediately. +// +// Make sure to Close or Shutdown the *http.Server first. func (g *Grace) Shutdown(ctx context.Context) error { defer g.Close() - g.mu.Lock() - g.shuttingDown = true - g.mu.Unlock() - // Same poll period used by net/http. t := time.NewTicker(500 * time.Millisecond) defer t.Stop() for { - if g.zeroConns() { + if g.zeroHandlers() { return nil } @@ -120,8 +113,8 @@ func (g *Grace) Shutdown(ctx context.Context) error { } } -func (g *Grace) zeroConns() bool { - g.mu.Lock() - defer g.mu.Unlock() - return len(g.conns) == 0 +func (g *Grace) zeroHandlers() bool { + g.handlersMu.Lock() + defer g.handlersMu.Unlock() + return len(g.handlers) == 0 } From b307b475131604cad9f43f1f30cc757725eca80e Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 21:22:14 -0400 Subject: [PATCH 11/20] Improve docs and fix examples Closes #207 --- README.md | 17 +-- ci/test.mk | 2 +- example_test.go | 120 +++++++++--------- examples/README.md | 4 + {chat-example => examples/chat}/README.md | 0 {chat-example => examples/chat}/chat.go | 0 {chat-example => examples/chat}/chat_test.go | 0 {chat-example => examples/chat}/go.sum | 0 {chat-example => examples/chat}/index.css | 0 {chat-example => examples/chat}/index.html | 0 {chat-example => examples/chat}/index.js | 0 {chat-example => examples/chat}/main.go | 0 example_echo_test.go => examples/echo/echo.go | 4 +- 13 files changed, 78 insertions(+), 69 deletions(-) create mode 100644 examples/README.md rename {chat-example => examples/chat}/README.md (100%) rename {chat-example => examples/chat}/chat.go (100%) rename {chat-example => examples/chat}/chat_test.go (100%) rename {chat-example => examples/chat}/go.sum (100%) rename {chat-example => examples/chat}/index.css (100%) rename {chat-example => examples/chat}/index.html (100%) rename {chat-example => examples/chat}/index.js (100%) rename {chat-example => examples/chat}/main.go (100%) rename example_echo_test.go => examples/echo/echo.go (99%) diff --git a/README.md b/README.md index 2d71ce0b..3debf2f8 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,9 @@ go get nhooyr.io/websocket - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) -- Thorough unit tests with [90% coverage](https://coveralls.io/github/nhooyr/websocket) -- [Minimal dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) -- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson?tab=doc) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb?tab=doc) subpackages +- Thorough tests with [90% coverage](https://coveralls.io/github/nhooyr/websocket) +- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) +- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes - [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) @@ -32,9 +32,10 @@ go get nhooyr.io/websocket ## Examples -For a production quality example that demonstrates the complete API, see the [echo example](https://godoc.org/nhooyr.io/websocket#example-package--Echo). +For a production quality example that demonstrates the complete API, see the +[echo example](./examples/echo). -For a full stack example, see [./chat-example](./chat-example). +For a full stack example, see the [chat example](./examples/chat). ### Server @@ -98,7 +99,7 @@ Advantages of nhooyr.io/websocket: - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Full [context.Context](https://blog.golang.org/context) support -- Dials use [net/http.Client](https://golang.org/pkg/net/http/#Client) +- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) - Will enable easy HTTP/2 support in the future - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Concurrent writes @@ -111,7 +112,7 @@ Advantages of nhooyr.io/websocket: - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - - We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) + - We use a vendored [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) - [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) @@ -120,7 +121,7 @@ Advantages of nhooyr.io/websocket: [golang.org/x/net/websocket](https://godoc.org/golang.org/x/net/websocket) is deprecated. See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). -The [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper will ease in transitioning +The [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) can help in transitioning to nhooyr.io/websocket. #### gobwas/ws diff --git a/ci/test.mk b/ci/test.mk index 291d6beb..b2f92b7c 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -14,4 +14,4 @@ gotest: go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... sed -i '/stringer\.go/d' ci/out/coverage.prof sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof - sed -i '/chat-example/d' ci/out/coverage.prof + sed -i '/example/d' ci/out/coverage.prof diff --git a/example_test.go b/example_test.go index 39de0b80..632c4d6e 100644 --- a/example_test.go +++ b/example_test.go @@ -4,17 +4,16 @@ import ( "context" "log" "net/http" - "os" - "os/signal" "time" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" ) -// This example accepts a WebSocket connection, reads a single JSON -// message from the client and then closes the connection. func ExampleAccept() { + // This handler accepts a WebSocket connection, reads a single JSON + // message from the client and then closes the connection. + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { @@ -40,9 +39,10 @@ func ExampleAccept() { log.Fatal(err) } -// This example dials a server, writes a single JSON message and then -// closes the connection. func ExampleDial() { + // Dials a server, writes a single JSON message and then + // closes the connection. + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() @@ -60,9 +60,10 @@ func ExampleDial() { c.Close(websocket.StatusNormalClosure, "") } -// This example dials a server and then expects to be disconnected with status code -// websocket.StatusNormalClosure. func ExampleCloseStatus() { + // Dials a server and then expects to be disconnected with status code + // websocket.StatusNormalClosure. + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() @@ -78,9 +79,9 @@ func ExampleCloseStatus() { } } -// This example shows how to correctly handle a WebSocket connection -// on which you will only write and do not expect to read data messages. func Example_writeOnly() { + // This handler demonstrates how to correctly handle a write only WebSocket connection. + // i.e you only expect to write messages and do not expect to read any messages. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { @@ -116,9 +117,9 @@ func Example_writeOnly() { log.Fatal(err) } -// This example demonstrates how to safely accept cross origin WebSockets -// from the origin example.com. func Example_crossOrigin() { + // This handler demonstrates how to safely accept cross origin WebSockets + // from the origin example.com. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ OriginPatterns: []string{"example.com"}, @@ -141,52 +142,57 @@ func Example_crossOrigin() { // for 10 seconds. // If you CTRL+C while a connection is open, it will wait at most 30s // for all connections to terminate before shutting down. -func ExampleGrace() { - fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, nil) - if err != nil { - log.Println(err) - return - } - defer c.Close(websocket.StatusInternalError, "the sky is falling") - - ctx := c.CloseRead(r.Context()) - select { - case <-ctx.Done(): - case <-time.After(time.Second * 10): - } - - c.Close(websocket.StatusNormalClosure, "") - }) - - var g websocket.Grace - s := &http.Server{ - Handler: g.Handler(fn), - ReadTimeout: time.Second * 15, - WriteTimeout: time.Second * 15, - } - - errc := make(chan error, 1) - go func() { - errc <- s.ListenAndServe() - }() - - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, os.Interrupt) - select { - case err := <-errc: - log.Printf("failed to listen and serve: %v", err) - case sig := <-sigs: - log.Printf("terminating: %v", sig) - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - s.Shutdown(ctx) - g.Shutdown(ctx) -} +// func ExampleGrace() { +// fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// c, err := websocket.Accept(w, r, nil) +// if err != nil { +// log.Println(err) +// return +// } +// defer c.Close(websocket.StatusInternalError, "the sky is falling") +// +// ctx := c.CloseRead(r.Context()) +// select { +// case <-ctx.Done(): +// case <-time.After(time.Second * 10): +// } +// +// c.Close(websocket.StatusNormalClosure, "") +// }) +// +// var g websocket.Grace +// s := &http.Server{ +// Handler: g.Handler(fn), +// ReadTimeout: time.Second * 15, +// WriteTimeout: time.Second * 15, +// } +// +// errc := make(chan error, 1) +// go func() { +// errc <- s.ListenAndServe() +// }() +// +// sigs := make(chan os.Signal, 1) +// signal.Notify(sigs, os.Interrupt) +// select { +// case err := <-errc: +// log.Printf("failed to listen and serve: %v", err) +// case sig := <-sigs: +// log.Printf("terminating: %v", sig) +// } +// +// ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) +// defer cancel() +// s.Shutdown(ctx) +// g.Shutdown(ctx) +// } // This example demonstrates full stack chat with an automated test. func Example_fullStackChat() { - // https://github.com/nhooyr/websocket/tree/master/chat-example + // https://github.com/nhooyr/websocket/tree/master/examples/chat +} + +// This example demonstrates a echo server. +func Example_echo() { + // https://github.com/nhooyr/websocket/tree/master/examples/echo } diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..3cb437ae --- /dev/null +++ b/examples/README.md @@ -0,0 +1,4 @@ +# Examples + +This directory contains more involved examples unsuitable +for display with godoc. diff --git a/chat-example/README.md b/examples/chat/README.md similarity index 100% rename from chat-example/README.md rename to examples/chat/README.md diff --git a/chat-example/chat.go b/examples/chat/chat.go similarity index 100% rename from chat-example/chat.go rename to examples/chat/chat.go diff --git a/chat-example/chat_test.go b/examples/chat/chat_test.go similarity index 100% rename from chat-example/chat_test.go rename to examples/chat/chat_test.go diff --git a/chat-example/go.sum b/examples/chat/go.sum similarity index 100% rename from chat-example/go.sum rename to examples/chat/go.sum diff --git a/chat-example/index.css b/examples/chat/index.css similarity index 100% rename from chat-example/index.css rename to examples/chat/index.css diff --git a/chat-example/index.html b/examples/chat/index.html similarity index 100% rename from chat-example/index.html rename to examples/chat/index.html diff --git a/chat-example/index.js b/examples/chat/index.js similarity index 100% rename from chat-example/index.js rename to examples/chat/index.js diff --git a/chat-example/main.go b/examples/chat/main.go similarity index 100% rename from chat-example/main.go rename to examples/chat/main.go diff --git a/example_echo_test.go b/examples/echo/echo.go similarity index 99% rename from example_echo_test.go rename to examples/echo/echo.go index fb212c1f..0f31235d 100644 --- a/example_echo_test.go +++ b/examples/echo/echo.go @@ -1,6 +1,4 @@ -// +build !js - -package websocket_test +package main import ( "context" From f7ef6b827324774cae310d49aa5f2482cd953a4c Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 21:22:18 -0400 Subject: [PATCH 12/20] Remove grace for now Updates #209 --- conn_test.go | 11 ++- examples/chat/chat_test.go | 5 +- examples/chat/main.go | 11 +-- examples/echo/{echo.go => main.go} | 10 +-- grace.go | 120 ----------------------------- ws_js.go | 2 - 6 files changed, 15 insertions(+), 144 deletions(-) rename examples/echo/{echo.go => main.go} (96%) delete mode 100644 grace.go diff --git a/conn_test.go b/conn_test.go index 7514540d..68dc837d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -271,12 +271,11 @@ func TestWasm(t *testing.T) { t.Skip("skipping on CI") } - var g websocket.Grace - defer g.Close() - s := httptest.NewServer(g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // TODO grace + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, + Subprotocols: []string{"echo"}, + OriginPatterns: []string{"*"}, }) if err != nil { t.Errorf("echo server failed: %v", err) @@ -291,7 +290,7 @@ func TestWasm(t *testing.T) { t.Errorf("echo server failed: %v", err) return } - }))) + })) defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) diff --git a/examples/chat/chat_test.go b/examples/chat/chat_test.go index 2cbc995e..79523d2a 100644 --- a/examples/chat/chat_test.go +++ b/examples/chat/chat_test.go @@ -130,11 +130,10 @@ func setupTest(t *testing.T) (url string, closeFn func()) { cs.subscriberMessageBuffer = 4096 cs.publishLimiter.SetLimit(rate.Inf) - var g websocket.Grace - s := httptest.NewServer(g.Handler(cs)) + // TODO grace + s := httptest.NewServer(cs) return s.URL, func() { s.Close() - g.Close() } } diff --git a/examples/chat/main.go b/examples/chat/main.go index 1b6f3266..cc2d01e8 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -9,8 +9,6 @@ import ( "os" "os/signal" "time" - - "nhooyr.io/websocket" ) func main() { @@ -36,9 +34,9 @@ func run() error { log.Printf("listening on http://%v", l.Addr()) cs := newChatServer() - var g websocket.Grace + // TODO grace s := http.Server{ - Handler: g.Handler(cs), + Handler: cs, ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } @@ -59,8 +57,5 @@ func run() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - s.Shutdown(ctx) - g.Shutdown(ctx) - - return nil + return s.Shutdown(ctx) } diff --git a/examples/echo/echo.go b/examples/echo/main.go similarity index 96% rename from examples/echo/echo.go rename to examples/echo/main.go index 0f31235d..db2d06c9 100644 --- a/examples/echo/echo.go +++ b/examples/echo/main.go @@ -24,7 +24,7 @@ import ( // This example starts a WebSocket echo server, // dials the server and then sends 5 different messages // and prints out the server's responses. -func Example_echo() { +func main() { // First we listen on port 0 which means the OS will // assign us a random free port. This is the listener // the server will serve on and the client will connect to. @@ -34,15 +34,14 @@ func Example_echo() { } defer l.Close() - var g websocket.Grace - defer g.Close() + // TODO grace s := &http.Server{ - Handler: g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r) if err != nil { log.Printf("echo server: %v", err) } - })), + }), ReadTimeout: time.Second * 15, WriteTimeout: time.Second * 15, } @@ -61,6 +60,7 @@ func Example_echo() { if err != nil { log.Fatalf("client failed: %v", err) } + // Output: // received: map[i:0] // received: map[i:1] diff --git a/grace.go b/grace.go deleted file mode 100644 index a0ec8969..00000000 --- a/grace.go +++ /dev/null @@ -1,120 +0,0 @@ -package websocket - -import ( - "context" - "fmt" - "net/http" - "sync" - "time" -) - -// Grace enables graceful shutdown of accepted WebSocket connections. -// -// Use Handler to wrap WebSocket handlers to record accepted connections -// and then use Close or Shutdown to gracefully close these connections. -// -// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods. -// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket -// connections. -// -// Make sure to Close or Shutdown the *http.Server first as you don't want to accept -// any new connections while the existing websockets are being shut down. -type Grace struct { - handlersMu sync.Mutex - closing bool - handlers map[context.Context]context.CancelFunc -} - -// Handler returns a handler that wraps around h to record -// all WebSocket connections accepted. -// -// Use Close or Shutdown to gracefully close recorded connections. -// Make sure to Close or Shutdown the *http.Server first. -func (g *Grace) Handler(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, cancel := context.WithCancel(r.Context()) - defer cancel() - - r = r.WithContext(ctx) - - ok := g.add(w, ctx, cancel) - if !ok { - return - } - defer g.del(ctx) - - h.ServeHTTP(w, r) - }) -} - -func (g *Grace) add(w http.ResponseWriter, ctx context.Context, cancel context.CancelFunc) bool { - g.handlersMu.Lock() - defer g.handlersMu.Unlock() - - if g.closing { - http.Error(w, "shutting down", http.StatusServiceUnavailable) - return false - } - - if g.handlers == nil { - g.handlers = make(map[context.Context]context.CancelFunc) - } - g.handlers[ctx] = cancel - - return true -} - -func (g *Grace) del(ctx context.Context) { - g.handlersMu.Lock() - defer g.handlersMu.Unlock() - - delete(g.handlers, ctx) -} - -// Close prevents the acceptance of new connections with -// http.StatusServiceUnavailable and closes all accepted -// connections with StatusGoingAway. -// -// Make sure to Close or Shutdown the *http.Server first. -func (g *Grace) Close() error { - g.handlersMu.Lock() - for _, cancel := range g.handlers { - cancel() - } - g.handlersMu.Unlock() - - // Wait for all goroutines to exit. - g.Shutdown(context.Background()) - - return nil -} - -// Shutdown prevents the acceptance of new connections and waits until -// all connections close. If the context is cancelled before that, it -// calls Close to close all connections immediately. -// -// Make sure to Close or Shutdown the *http.Server first. -func (g *Grace) Shutdown(ctx context.Context) error { - defer g.Close() - - // Same poll period used by net/http. - t := time.NewTicker(500 * time.Millisecond) - defer t.Stop() - for { - if g.zeroHandlers() { - return nil - } - - select { - case <-t.C: - case <-ctx.Done(): - return fmt.Errorf("failed to shutdown WebSockets: %w", ctx.Err()) - } - } -} - -func (g *Grace) zeroHandlers() bool { - g.handlersMu.Lock() - defer g.handlersMu.Unlock() - return len(g.handlers) == 0 -} diff --git a/ws_js.go b/ws_js.go index 69019e61..b87e32cd 100644 --- a/ws_js.go +++ b/ws_js.go @@ -39,8 +39,6 @@ type Conn struct { readSignal chan struct{} readBufMu sync.Mutex readBuf []wsjs.MessageEvent - - g *Grace } func (c *Conn) close(err error, wasClean bool) { From 98779ee0af50df3c774208c733ae18572ef409b6 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 21:22:21 -0400 Subject: [PATCH 13/20] Fix outdated close handshake docs Closes #212 --- close_notjs.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/close_notjs.go b/close_notjs.go index 4f1cebcb..fd9262b5 100644 --- a/close_notjs.go +++ b/close_notjs.go @@ -84,7 +84,7 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { func (c *Conn) waitCloseHandshake() error { defer c.close(nil) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := c.readMu.lock(ctx) From 5db7b716a02b179de735256d5502a4ac360a7942 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 21:22:23 -0400 Subject: [PATCH 14/20] Clarify CloseRead docs Closes #208 --- read.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/read.go b/read.go index 381cea3d..afd08cc7 100644 --- a/read.go +++ b/read.go @@ -52,7 +52,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close -// frames are responded to. +// frames are responded to. This means c.Ping and c.Close will still work as expected. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) go func() { From d0fa6bf84a14583ed29ded6f11d1e8665363a8b3 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 21:36:00 -0400 Subject: [PATCH 15/20] Update prettier invocation for v2.0.0 --- ci/fmt.mk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/fmt.mk b/ci/fmt.mk index 3512d02f..1ed2920f 100644 --- a/ci/fmt.mk +++ b/ci/fmt.mk @@ -13,7 +13,7 @@ goimports: gen goimports -w "-local=$$(go list -m)" . prettier: - prettier --write --print-width=120 --no-semi --trailing-comma=all --loglevel=warn $$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") + prettier --write --print-width=120 --no-semi --trailing-comma=all --loglevel=warn --arrow-parens=avoid $$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") gen: stringer -type=opcode,MessageType,StatusCode -output=stringer.go From ba35516b80c8f8474024174e0899b557310b8f14 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 21:45:23 -0400 Subject: [PATCH 16/20] Doc fixes --- README.md | 26 +++++++++++++------------- close_notjs.go | 1 + dial.go | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 3debf2f8..60967789 100644 --- a/README.md +++ b/README.md @@ -20,11 +20,11 @@ go get nhooyr.io/websocket - JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes -- [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) -- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper -- [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API +- [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) +- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper +- [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression -- Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) +- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) ## Roadmap @@ -89,14 +89,14 @@ c.Close(websocket.StatusNormalClosure, "") Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): - Mature and widely used -- [Prepared writes](https://godoc.org/github.com/gorilla/websocket#PreparedMessage) -- Configurable [buffer sizes](https://godoc.org/github.com/gorilla/websocket#hdr-Buffers) +- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) +- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) Advantages of nhooyr.io/websocket: - Minimal and idiomatic API - - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. -- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper + - Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. +- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Full [context.Context](https://blog.golang.org/context) support - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) @@ -104,24 +104,24 @@ Advantages of nhooyr.io/websocket: - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Concurrent writes - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) -- Idiomatic [ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API +- Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - Gorilla requires registering a pong callback before sending a Ping - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) -- Transparent message buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages +- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - We use a vendored [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) -- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) +- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) #### golang.org/x/net/websocket -[golang.org/x/net/websocket](https://godoc.org/golang.org/x/net/websocket) is deprecated. +[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). -The [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) can help in transitioning +The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning to nhooyr.io/websocket. #### gobwas/ws diff --git a/close_notjs.go b/close_notjs.go index fd9262b5..c30ac87a 100644 --- a/close_notjs.go +++ b/close_notjs.go @@ -63,6 +63,7 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { Reason: reason, } + // TODO one problem with this is that if the connection is actually closed in the meantime, the error returned below will be this one lol. c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) var p []byte diff --git a/dial.go b/dial.go index 9ab680eb..2b25e351 100644 --- a/dial.go +++ b/dial.go @@ -59,7 +59,7 @@ type DialOptions struct { // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 // -// URLs with http/https schemes will work and translated into ws/wss. +// URLs with http/https schemes will work and are interpreted as ws/wss. func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { return dial(ctx, u, opts, nil) } From c4d4650128fa3c5a2e2da7c41b7749d62e53fae7 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 22:27:02 -0400 Subject: [PATCH 17/20] Fix bad close handshake logic This doesn't affect real world applications due to buffering but the testss would occasionally fail on CI due to the code not handling an immediate close after writing the close frame while resetting the write timeout. --- close_notjs.go | 28 +++++++++++++++++----------- go.sum | 4 ---- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/close_notjs.go b/close_notjs.go index c30ac87a..4251311d 100644 --- a/close_notjs.go +++ b/close_notjs.go @@ -44,6 +44,7 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { if CloseStatus(closeHandshakeErr) == -1 { return closeHandshakeErr } + return nil } @@ -63,23 +64,28 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { Reason: reason, } - // TODO one problem with this is that if the connection is actually closed in the meantime, the error returned below will be this one lol. - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) - var p []byte - var err error + var marshalErr error if ce.Code != StatusNoStatusRcvd { - p, err = ce.bytes() - if err != nil { - log.Printf("websocket: %v", err) + p, marshalErr = ce.bytes() + if marshalErr != nil { + log.Printf("websocket: %v", marshalErr) } } - werr := c.writeControl(context.Background(), opClose, p) - if err != nil { - return err + writeErr := c.writeControl(context.Background(), opClose, p) + if CloseStatus(writeErr) != -1 { + // Not a real error if it's due to a close frame being received. + writeErr = nil + } + + // We do this after in case there was an error writing the close frame. + c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + + if marshalErr != nil { + return marshalErr } - return werr + return writeErr } func (c *Conn) waitCloseHandshake() error { diff --git a/go.sum b/go.sum index dac1ed3a..736df430 100644 --- a/go.sum +++ b/go.sum @@ -4,16 +4,12 @@ github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls= github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= -github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= From d34c89a325afbef0a8c4d23cece6fb4e4cc2b1a5 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 22:31:35 -0400 Subject: [PATCH 18/20] Prevent all writes after close Closes #213 --- write.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/write.go b/write.go index baa5e6e2..2d439de7 100644 --- a/write.go +++ b/write.go @@ -246,7 +246,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if err != nil { return 0, err } - defer c.writeFrameMu.unlock() + defer func() { + // We leave it locked when writing the close frame to avoid + // any other goroutine writing any other frame. + if opcode != opClose { + c.writeFrameMu.unlock() + } + }() select { case <-c.closed: From 2dc66c3f143f34f669248262380fd7c38eba107e Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 22:34:47 -0400 Subject: [PATCH 19/20] Check whether the connection is closed before returning a write IO error Closes #215 --- write.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/write.go b/write.go index 2d439de7..60a4fba0 100644 --- a/write.go +++ b/write.go @@ -262,8 +262,14 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco defer func() { if err != nil { - err = fmt.Errorf("failed to write frame: %w", err) + select { + case <-c.closed: + err = c.closeErr + case <-ctx.Done(): + err = ctx.Err() + } c.close(err) + err = fmt.Errorf("failed to write frame: %w", err) } }() From 1d80cf339293725be66c457a5caa0a136ec743c5 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 13 Apr 2020 22:45:59 -0400 Subject: [PATCH 20/20] Final doc fixes --- README.md | 2 +- accept.go | 6 ++---- conn_test.go | 1 - examples/chat/chat_test.go | 1 - examples/chat/main.go | 3 +-- examples/echo/main.go | 6 ------ 6 files changed, 4 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 60967789..1f1ca46d 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ go get nhooyr.io/websocket - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - Thorough tests with [90% coverage](https://coveralls.io/github/nhooyr/websocket) -- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) +- [Minimal dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) - JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes diff --git a/accept.go b/accept.go index a583f232..47e20b52 100644 --- a/accept.go +++ b/accept.go @@ -134,7 +134,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - c := newConn(connConfig{ + return newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), rwc: netConn, client: false, @@ -143,9 +143,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con br: brw.Reader, bw: brw.Writer, - }) - - return c, nil + }), nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { diff --git a/conn_test.go b/conn_test.go index 68dc837d..451d093a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -271,7 +271,6 @@ func TestWasm(t *testing.T) { t.Skip("skipping on CI") } - // TODO grace s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, diff --git a/examples/chat/chat_test.go b/examples/chat/chat_test.go index 79523d2a..eae18580 100644 --- a/examples/chat/chat_test.go +++ b/examples/chat/chat_test.go @@ -130,7 +130,6 @@ func setupTest(t *testing.T) (url string, closeFn func()) { cs.subscriberMessageBuffer = 4096 cs.publishLimiter.SetLimit(rate.Inf) - // TODO grace s := httptest.NewServer(cs) return s.URL, func() { s.Close() diff --git a/examples/chat/main.go b/examples/chat/main.go index cc2d01e8..7f3cf6f3 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -34,8 +34,7 @@ func run() error { log.Printf("listening on http://%v", l.Addr()) cs := newChatServer() - // TODO grace - s := http.Server{ + s := &http.Server{ Handler: cs, ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, diff --git a/examples/echo/main.go b/examples/echo/main.go index db2d06c9..f1771752 100644 --- a/examples/echo/main.go +++ b/examples/echo/main.go @@ -16,11 +16,6 @@ import ( "nhooyr.io/websocket/wsjson" ) -// TODO IMPROVE CANCELLATION AND SHUTDOWN -// TODO on context cancel send websocket going away and fix the read timeout error to be dependant on context deadline reached. -// TODO this way you cancel your context and the right message automatically gets sent. Furthrmore, then u can just use a simple waitgroup to wait for connections. -// TODO grace is wrong as it doesn't wait for the individual goroutines. - // This example starts a WebSocket echo server, // dials the server and then sends 5 different messages // and prints out the server's responses. @@ -34,7 +29,6 @@ func main() { } defer l.Close() - // TODO grace s := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r)