diff --git a/README.md b/README.md index e967cd8a..1f1ca46d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # websocket -[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) +[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket) websocket is a minimal and idiomatic WebSocket library for Go. @@ -15,16 +15,16 @@ go get nhooyr.io/websocket - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) -- Thorough unit tests with [90% coverage](https://coveralls.io/github/nhooyr/websocket) -- [Minimal dependencies](https://godoc.org/nhooyr.io/websocket?imports) -- JSON and protobuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages +- Thorough tests with [90% coverage](https://coveralls.io/github/nhooyr/websocket) +- [Minimal dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) +- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes -- [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) -- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper -- [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API +- [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) +- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper +- [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression -- Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) +- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) ## Roadmap @@ -32,9 +32,10 @@ go get nhooyr.io/websocket ## Examples -For a production quality example that demonstrates the complete API, see the [echo example](https://godoc.org/nhooyr.io/websocket#example-package--Echo). +For a production quality example that demonstrates the complete API, see the +[echo example](./examples/echo). -For a full stack example, see [./chat-example](./chat-example). +For a full stack example, see the [chat example](./examples/chat). ### Server @@ -88,14 +89,14 @@ c.Close(websocket.StatusNormalClosure, "") Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): - Mature and widely used -- [Prepared writes](https://godoc.org/github.com/gorilla/websocket#PreparedMessage) -- Configurable [buffer sizes](https://godoc.org/github.com/gorilla/websocket#hdr-Buffers) +- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) +- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) Advantages of nhooyr.io/websocket: - Minimal and idiomatic API - - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. -- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper + - Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. +- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Full [context.Context](https://blog.golang.org/context) support - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) @@ -103,24 +104,24 @@ Advantages of nhooyr.io/websocket: - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Concurrent writes - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) -- Idiomatic [ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API +- Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - Gorilla requires registering a pong callback before sending a Ping - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) -- Transparent message buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages +- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - - We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) -- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) + - We use a vendored [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) +- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) #### golang.org/x/net/websocket -[golang.org/x/net/websocket](https://godoc.org/golang.org/x/net/websocket) is deprecated. +[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). -The [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper will ease in transitioning +The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning to nhooyr.io/websocket. #### gobwas/ws diff --git a/accept.go b/accept.go index 479138fc..47e20b52 100644 --- a/accept.go +++ b/accept.go @@ -9,10 +9,11 @@ import ( "errors" "fmt" "io" + "log" "net/http" "net/textproto" "net/url" - "strconv" + "path/filepath" "strings" "nhooyr.io/websocket/internal/errd" @@ -25,18 +26,27 @@ type AcceptOptions struct { // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string - // InsecureSkipVerify disables Accept's origin verification behaviour. By default, - // the connection will only be accepted if the request origin is equal to the request - // host. + // InsecureSkipVerify is used to disable Accept's origin verification behaviour. // - // This is only required if you want javascript served from a different domain - // to access your WebSocket server. + // Deprecated: Use OriginPatterns with a match all pattern of * instead to control + // origin authorization yourself. + InsecureSkipVerify bool + + // OriginPatterns lists the host patterns for authorized origins. + // The request host is always authorized. + // Use this to enable cross origin WebSockets. + // + // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. + // In such a case, example.com is the origin and chat.example.com is the request host. + // One would set this field to []string{"example.com"} to authorize example.com to connect. // - // See https://stackoverflow.com/a/37837709/4283659 + // Each pattern is matched case insensitively against the request origin host + // with filepath.Match. + // See https://golang.org/pkg/path/filepath/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. - InsecureSkipVerify bool + OriginPatterns []string // CompressionMode controls the compression mode. // Defaults to CompressionNoContextTakeover. @@ -77,8 +87,12 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } if !opts.InsecureSkipVerify { - err = authenticateOrigin(r) + err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { + if errors.Is(err, filepath.ErrBadPattern) { + log.Printf("websocket: %v", err) + err = errors.New(http.StatusText(http.StatusForbidden)) + } http.Error(w, err.Error(), http.StatusForbidden) return nil, err } @@ -165,18 +179,35 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ return 0, nil } -func authenticateOrigin(r *http.Request) error { +func authenticateOrigin(r *http.Request, originHosts []string) error { origin := r.Header.Get("Origin") - if origin != "" { - u, err := url.Parse(origin) + if origin == "" { + return nil + } + + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + + if strings.EqualFold(r.Host, u.Host) { + return nil + } + + for _, hostPattern := range originHosts { + matched, err := match(hostPattern, u.Host) if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + if matched { + return nil } } - return nil + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) +} + +func match(pattern, s string) (bool, error) { + return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { @@ -235,16 +266,6 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi return copts, nil } -// parseExtensionParameter parses the value in the extension parameter p. -func parseExtensionParameter(p string) (int, bool) { - ps := strings.Split(p, "=") - if len(ps) == 1 { - return 0, false - } - i, e := strconv.Atoi(strings.Trim(ps[1], `"`)) - return i, e == nil -} - func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() // The peer must explicitly request it. diff --git a/accept_js.go b/accept_js.go index 724b35b5..daad4b79 100644 --- a/accept_js.go +++ b/accept_js.go @@ -9,6 +9,7 @@ import ( type AcceptOptions struct { Subprotocols []string InsecureSkipVerify bool + OriginPatterns []string CompressionMode CompressionMode CompressionThreshold int } diff --git a/accept_test.go b/accept_test.go index 49667799..40a7b40c 100644 --- a/accept_test.go +++ b/accept_test.go @@ -244,10 +244,11 @@ func Test_authenticateOrigin(t *testing.T) { t.Parallel() testCases := []struct { - name string - origin string - host string - success bool + name string + origin string + host string + originPatterns []string + success bool }{ { name: "none", @@ -278,6 +279,26 @@ func Test_authenticateOrigin(t *testing.T) { host: "example.com", success: true, }, + { + name: "originPatterns", + origin: "https://two.examplE.com", + host: "example.com", + originPatterns: []string{ + "*.example.com", + "bar.com", + }, + success: true, + }, + { + name: "originPatternsUnauthorized", + origin: "https://two.examplE.com", + host: "example.com", + originPatterns: []string{ + "exam3.com", + "bar.com", + }, + success: false, + }, } for _, tc := range testCases { @@ -288,7 +309,7 @@ func Test_authenticateOrigin(t *testing.T) { r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r.Header.Set("Origin", tc.origin) - err := authenticateOrigin(r) + err := authenticateOrigin(r, tc.originPatterns) if tc.success { assert.Success(t, err) } else { diff --git a/chat-example/chat.go b/chat-example/chat.go deleted file mode 100644 index e6e355d0..00000000 --- a/chat-example/chat.go +++ /dev/null @@ -1,128 +0,0 @@ -package main - -import ( - "context" - "errors" - "io" - "io/ioutil" - "log" - "net/http" - "sync" - "time" - - "nhooyr.io/websocket" -) - -// chatServer enables broadcasting to a set of subscribers. -type chatServer struct { - subscribersMu sync.RWMutex - subscribers map[chan<- []byte]struct{} -} - -// subscribeHandler accepts the WebSocket connection and then subscribes -// it to all future messages. -func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, nil) - if err != nil { - log.Print(err) - return - } - defer c.Close(websocket.StatusInternalError, "") - - err = cs.subscribe(r.Context(), c) - if errors.Is(err, context.Canceled) { - return - } - if websocket.CloseStatus(err) == websocket.StatusNormalClosure || - websocket.CloseStatus(err) == websocket.StatusGoingAway { - return - } - if err != nil { - log.Print(err) - } -} - -// publishHandler reads the request body with a limit of 8192 bytes and then publishes -// the received message. -func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - return - } - body := io.LimitReader(r.Body, 8192) - msg, err := ioutil.ReadAll(body) - if err != nil { - http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) - return - } - - cs.publish(msg) -} - -// subscribe subscribes the given WebSocket to all broadcast messages. -// It creates a msgs chan with a buffer of 16 to give some room to slower -// connections and then registers it. It then listens for all messages -// and writes them to the WebSocket. If the context is cancelled or -// an error occurs, it returns and deletes the subscription. -// -// It uses CloseRead to keep reading from the connection to process control -// messages and cancel the context if the connection drops. -func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - - msgs := make(chan []byte, 16) - cs.addSubscriber(msgs) - defer cs.deleteSubscriber(msgs) - - for { - select { - case msg := <-msgs: - err := writeTimeout(ctx, time.Second*5, c, msg) - if err != nil { - return err - } - case <-ctx.Done(): - return ctx.Err() - } - } -} - -// publish publishes the msg to all subscribers. -// It never blocks and so messages to slow subscribers -// are dropped. -func (cs *chatServer) publish(msg []byte) { - cs.subscribersMu.RLock() - defer cs.subscribersMu.RUnlock() - - for c := range cs.subscribers { - select { - case c <- msg: - default: - } - } -} - -// addSubscriber registers a subscriber with a channel -// on which to send messages. -func (cs *chatServer) addSubscriber(msgs chan<- []byte) { - cs.subscribersMu.Lock() - if cs.subscribers == nil { - cs.subscribers = make(map[chan<- []byte]struct{}) - } - cs.subscribers[msgs] = struct{}{} - cs.subscribersMu.Unlock() -} - -// deleteSubscriber deletes the subscriber with the given msgs channel. -func (cs *chatServer) deleteSubscriber(msgs chan []byte) { - cs.subscribersMu.Lock() - delete(cs.subscribers, msgs) - cs.subscribersMu.Unlock() -} - -func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error { - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - return c.Write(ctx, websocket.MessageText, msg) -} diff --git a/chat-example/go.mod b/chat-example/go.mod deleted file mode 100644 index 34fa5a69..00000000 --- a/chat-example/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module nhooyr.io/websocket/example-chat - -go 1.13 - -require nhooyr.io/websocket v1.8.2 diff --git a/ci/fmt.mk b/ci/fmt.mk index 3512d02f..1ed2920f 100644 --- a/ci/fmt.mk +++ b/ci/fmt.mk @@ -13,7 +13,7 @@ goimports: gen goimports -w "-local=$$(go list -m)" . prettier: - prettier --write --print-width=120 --no-semi --trailing-comma=all --loglevel=warn $$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") + prettier --write --print-width=120 --no-semi --trailing-comma=all --loglevel=warn --arrow-parens=avoid $$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") gen: stringer -type=opcode,MessageType,StatusCode -output=stringer.go diff --git a/ci/test.mk b/ci/test.mk index c62a25b6..b2f92b7c 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -11,6 +11,7 @@ coveralls: gotest goveralls -coverprofile=ci/out/coverage.prof gotest: - go test -timeout=30m -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... + go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... sed -i '/stringer\.go/d' ci/out/coverage.prof sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof + sed -i '/example/d' ci/out/coverage.prof diff --git a/close_notjs.go b/close_notjs.go index 25372995..4251311d 100644 --- a/close_notjs.go +++ b/close_notjs.go @@ -34,15 +34,17 @@ func (c *Conn) Close(code StatusCode, reason string) error { func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - err = c.writeClose(code, reason) - if err != nil && CloseStatus(err) == -1 && err != errAlreadyWroteClose { - return err + writeErr := c.writeClose(code, reason) + closeHandshakeErr := c.waitCloseHandshake() + + if writeErr != nil { + return writeErr } - err = c.waitCloseHandshake() - if CloseStatus(err) == -1 { - return err + if CloseStatus(closeHandshakeErr) == -1 { + return closeHandshakeErr } + return nil } @@ -50,10 +52,10 @@ var errAlreadyWroteClose = errors.New("already wrote close") func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() - closing := c.wroteClose + wroteClose := c.wroteClose c.wroteClose = true c.closeMu.Unlock() - if closing { + if wroteClose { return errAlreadyWroteClose } @@ -62,28 +64,34 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { Reason: reason, } - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) - var p []byte - var err error + var marshalErr error if ce.Code != StatusNoStatusRcvd { - p, err = ce.bytes() - if err != nil { - log.Printf("websocket: %v", err) + p, marshalErr = ce.bytes() + if marshalErr != nil { + log.Printf("websocket: %v", marshalErr) } } - werr := c.writeControl(context.Background(), opClose, p) - if err != nil { - return err + writeErr := c.writeControl(context.Background(), opClose, p) + if CloseStatus(writeErr) != -1 { + // Not a real error if it's due to a close frame being received. + writeErr = nil + } + + // We do this after in case there was an error writing the close frame. + c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + + if marshalErr != nil { + return marshalErr } - return werr + return writeErr } func (c *Conn) waitCloseHandshake() error { defer c.close(nil) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := c.readMu.lock(ctx) diff --git a/conn_test.go b/conn_test.go index 64e6736f..451d093a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -13,7 +13,6 @@ import ( "os" "os/exec" "strings" - "sync" "testing" "time" @@ -268,14 +267,14 @@ func TestConn(t *testing.T) { func TestWasm(t *testing.T) { t.Parallel() - var wg sync.WaitGroup - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wg.Add(1) - defer wg.Done() + if os.Getenv("CI") != "" { + t.Skip("skipping on CI") + } + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, + Subprotocols: []string{"echo"}, + OriginPatterns: []string{"*"}, }) if err != nil { t.Errorf("echo server failed: %v", err) @@ -291,14 +290,13 @@ func TestWasm(t *testing.T) { return } })) - defer wg.Wait() defer s.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wstest.URL(s))) + cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() if err != nil { @@ -333,8 +331,10 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs tt = &connTest{t: t, ctx: ctx} tt.appendDone(cancel) - c1, c2, err := wstest.Pipe(dialOpts, acceptOpts) - assert.Success(tt.t, err) + c1, c2 = wstest.Pipe(dialOpts, acceptOpts) + if xrand.Bool() { + c1, c2 = c2, c1 + } tt.appendDone(func() { c2.Close(websocket.StatusInternalError, "") c1.Close(websocket.StatusInternalError, "") diff --git a/dial.go b/dial.go index 50a0ecce..2b25e351 100644 --- a/dial.go +++ b/dial.go @@ -58,6 +58,8 @@ type DialOptions struct { // This function requires at least Go 1.12 as it uses a new feature // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 +// +// URLs with http/https schemes will work and are interpreted as ws/wss. func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { return dial(ctx, u, opts, nil) } @@ -145,6 +147,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts u.Scheme = "http" case "wss": u.Scheme = "https" + case "http", "https": default: return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } diff --git a/example_test.go b/example_test.go index 075107b0..632c4d6e 100644 --- a/example_test.go +++ b/example_test.go @@ -1,5 +1,3 @@ -// +build !js - package websocket_test import ( @@ -12,9 +10,10 @@ import ( "nhooyr.io/websocket/wsjson" ) -// This example accepts a WebSocket connection, reads a single JSON -// message from the client and then closes the connection. func ExampleAccept() { + // This handler accepts a WebSocket connection, reads a single JSON + // message from the client and then closes the connection. + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { @@ -40,9 +39,10 @@ func ExampleAccept() { log.Fatal(err) } -// This example dials a server, writes a single JSON message and then -// closes the connection. func ExampleDial() { + // Dials a server, writes a single JSON message and then + // closes the connection. + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() @@ -60,9 +60,10 @@ func ExampleDial() { c.Close(websocket.StatusNormalClosure, "") } -// This example dials a server and then expects to be disconnected with status code -// websocket.StatusNormalClosure. func ExampleCloseStatus() { + // Dials a server and then expects to be disconnected with status code + // websocket.StatusNormalClosure. + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() @@ -78,9 +79,9 @@ func ExampleCloseStatus() { } } -// This example shows how to correctly handle a WebSocket connection -// on which you will only write and do not expect to read data messages. func Example_writeOnly() { + // This handler demonstrates how to correctly handle a write only WebSocket connection. + // i.e you only expect to write messages and do not expect to read any messages. fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, nil) if err != nil { @@ -115,3 +116,83 @@ func Example_writeOnly() { err := http.ListenAndServe("localhost:8080", fn) log.Fatal(err) } + +func Example_crossOrigin() { + // This handler demonstrates how to safely accept cross origin WebSockets + // from the origin example.com. + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: []string{"example.com"}, + }) + if err != nil { + log.Println(err) + return + } + c.Close(websocket.StatusNormalClosure, "cross origin WebSocket accepted") + }) + + err := http.ListenAndServe("localhost:8080", fn) + log.Fatal(err) +} + +// This example demonstrates how to create a WebSocket server +// that gracefully exits when sent a signal. +// +// It starts a WebSocket server that keeps every connection open +// for 10 seconds. +// If you CTRL+C while a connection is open, it will wait at most 30s +// for all connections to terminate before shutting down. +// func ExampleGrace() { +// fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// c, err := websocket.Accept(w, r, nil) +// if err != nil { +// log.Println(err) +// return +// } +// defer c.Close(websocket.StatusInternalError, "the sky is falling") +// +// ctx := c.CloseRead(r.Context()) +// select { +// case <-ctx.Done(): +// case <-time.After(time.Second * 10): +// } +// +// c.Close(websocket.StatusNormalClosure, "") +// }) +// +// var g websocket.Grace +// s := &http.Server{ +// Handler: g.Handler(fn), +// ReadTimeout: time.Second * 15, +// WriteTimeout: time.Second * 15, +// } +// +// errc := make(chan error, 1) +// go func() { +// errc <- s.ListenAndServe() +// }() +// +// sigs := make(chan os.Signal, 1) +// signal.Notify(sigs, os.Interrupt) +// select { +// case err := <-errc: +// log.Printf("failed to listen and serve: %v", err) +// case sig := <-sigs: +// log.Printf("terminating: %v", sig) +// } +// +// ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) +// defer cancel() +// s.Shutdown(ctx) +// g.Shutdown(ctx) +// } + +// This example demonstrates full stack chat with an automated test. +func Example_fullStackChat() { + // https://github.com/nhooyr/websocket/tree/master/examples/chat +} + +// This example demonstrates a echo server. +func Example_echo() { + // https://github.com/nhooyr/websocket/tree/master/examples/echo +} diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..3cb437ae --- /dev/null +++ b/examples/README.md @@ -0,0 +1,4 @@ +# Examples + +This directory contains more involved examples unsuitable +for display with godoc. diff --git a/chat-example/README.md b/examples/chat/README.md similarity index 74% rename from chat-example/README.md rename to examples/chat/README.md index ef06275d..a4c99a93 100644 --- a/chat-example/README.md +++ b/examples/chat/README.md @@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by `index.html` and then `index.js`. + +There are two automated tests for the server included in `chat_test.go`. The first is a simple one +client echo test. It publishes a single message and ensures it's received. + +The second is a complex concurrency test where 10 clients send 128 unique messages +of max 128 bytes concurrently. The test ensures all messages are seen by every client. diff --git a/examples/chat/chat.go b/examples/chat/chat.go new file mode 100644 index 00000000..532e50f5 --- /dev/null +++ b/examples/chat/chat.go @@ -0,0 +1,182 @@ +package main + +import ( + "context" + "errors" + "io/ioutil" + "log" + "net/http" + "sync" + "time" + + "golang.org/x/time/rate" + + "nhooyr.io/websocket" +) + +// chatServer enables broadcasting to a set of subscribers. +type chatServer struct { + // subscriberMessageBuffer controls the max number + // of messages that can be queued for a subscriber + // before it is kicked. + // + // Defaults to 16. + subscriberMessageBuffer int + + // publishLimiter controls the rate limit applied to the publish endpoint. + // + // Defaults to one publish every 100ms with a burst of 8. + publishLimiter *rate.Limiter + + // logf controls where logs are sent. + // Defaults to log.Printf. + logf func(f string, v ...interface{}) + + // serveMux routes the various endpoints to the appropriate handler. + serveMux http.ServeMux + + subscribersMu sync.Mutex + subscribers map[*subscriber]struct{} +} + +// newChatServer constructs a chatServer with the defaults. +func newChatServer() *chatServer { + cs := &chatServer{ + subscriberMessageBuffer: 16, + logf: log.Printf, + subscribers: make(map[*subscriber]struct{}), + publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8), + } + cs.serveMux.Handle("/", http.FileServer(http.Dir("."))) + cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler) + cs.serveMux.HandleFunc("/publish", cs.publishHandler) + + return cs +} + +// subscriber represents a subscriber. +// Messages are sent on the msgs channel and if the client +// cannot keep up with the messages, closeSlow is called. +type subscriber struct { + msgs chan []byte + closeSlow func() +} + +func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + cs.serveMux.ServeHTTP(w, r) +} + +// subscribeHandler accepts the WebSocket connection and then subscribes +// it to all future messages. +func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + if err != nil { + cs.logf("%v", err) + return + } + defer c.Close(websocket.StatusInternalError, "") + + err = cs.subscribe(r.Context(), c) + if errors.Is(err, context.Canceled) { + return + } + if websocket.CloseStatus(err) == websocket.StatusNormalClosure || + websocket.CloseStatus(err) == websocket.StatusGoingAway { + return + } + if err != nil { + cs.logf("%v", err) + return + } +} + +// publishHandler reads the request body with a limit of 8192 bytes and then publishes +// the received message. +func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + body := http.MaxBytesReader(w, r.Body, 8192) + msg, err := ioutil.ReadAll(body) + if err != nil { + http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) + return + } + + cs.publish(msg) + + w.WriteHeader(http.StatusAccepted) +} + +// subscribe subscribes the given WebSocket to all broadcast messages. +// It creates a subscriber with a buffered msgs chan to give some room to slower +// connections and then registers the subscriber. It then listens for all messages +// and writes them to the WebSocket. If the context is cancelled or +// an error occurs, it returns and deletes the subscription. +// +// It uses CloseRead to keep reading from the connection to process control +// messages and cancel the context if the connection drops. +func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error { + ctx = c.CloseRead(ctx) + + s := &subscriber{ + msgs: make(chan []byte, cs.subscriberMessageBuffer), + closeSlow: func() { + c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages") + }, + } + cs.addSubscriber(s) + defer cs.deleteSubscriber(s) + + for { + select { + case msg := <-s.msgs: + err := writeTimeout(ctx, time.Second*5, c, msg) + if err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// publish publishes the msg to all subscribers. +// It never blocks and so messages to slow subscribers +// are dropped. +func (cs *chatServer) publish(msg []byte) { + cs.subscribersMu.Lock() + defer cs.subscribersMu.Unlock() + + cs.publishLimiter.Wait(context.Background()) + + for s := range cs.subscribers { + select { + case s.msgs <- msg: + default: + go s.closeSlow() + } + } +} + +// addSubscriber registers a subscriber. +func (cs *chatServer) addSubscriber(s *subscriber) { + cs.subscribersMu.Lock() + cs.subscribers[s] = struct{}{} + cs.subscribersMu.Unlock() +} + +// deleteSubscriber deletes the given subscriber. +func (cs *chatServer) deleteSubscriber(s *subscriber) { + cs.subscribersMu.Lock() + delete(cs.subscribers, s) + cs.subscribersMu.Unlock() +} + +func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + return c.Write(ctx, websocket.MessageText, msg) +} diff --git a/examples/chat/chat_test.go b/examples/chat/chat_test.go new file mode 100644 index 00000000..eae18580 --- /dev/null +++ b/examples/chat/chat_test.go @@ -0,0 +1,279 @@ +// +build !js + +package main + +import ( + "context" + "crypto/rand" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "golang.org/x/time/rate" + + "nhooyr.io/websocket" +) + +func Test_chatServer(t *testing.T) { + t.Parallel() + + // This is a simple echo test with a single client. + // The client sends a message and ensures it receives + // it on its WebSocket. + t.Run("simple", func(t *testing.T) { + t.Parallel() + + url, closeFn := setupTest(t) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + cl, err := newClient(ctx, url) + assertSuccess(t, err) + defer cl.Close() + + expMsg := randString(512) + err = cl.publish(ctx, expMsg) + assertSuccess(t, err) + + msg, err := cl.nextMessage() + assertSuccess(t, err) + + if expMsg != msg { + t.Fatalf("expected %v but got %v", expMsg, msg) + } + }) + + // This test is a complex concurrency test. + // 10 clients are started that send 128 different + // messages of max 128 bytes concurrently. + // + // The test verifies that every message is seen by ever client + // and no errors occur anywhere. + t.Run("concurrency", func(t *testing.T) { + t.Parallel() + + const nmessages = 128 + const maxMessageSize = 128 + const nclients = 16 + + url, closeFn := setupTest(t) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var clients []*client + var clientMsgs []map[string]struct{} + for i := 0; i < nclients; i++ { + cl, err := newClient(ctx, url) + assertSuccess(t, err) + defer cl.Close() + + clients = append(clients, cl) + clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize)) + } + + allMessages := make(map[string]struct{}) + for _, msgs := range clientMsgs { + for m := range msgs { + allMessages[m] = struct{}{} + } + } + + var wg sync.WaitGroup + for i, cl := range clients { + i := i + cl := cl + + wg.Add(1) + go func() { + defer wg.Done() + err := cl.publishMsgs(ctx, clientMsgs[i]) + if err != nil { + t.Errorf("client %d failed to publish all messages: %v", i, err) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := testAllMessagesReceived(cl, nclients*nmessages, allMessages) + if err != nil { + t.Errorf("client %d failed to receive all messages: %v", i, err) + } + }() + } + + wg.Wait() + }) +} + +// setupTest sets up chatServer that can be used +// via the returned url. +// +// Defer closeFn to ensure everything is cleaned up at +// the end of the test. +// +// chatServer logs will be logged via t.Logf. +func setupTest(t *testing.T) (url string, closeFn func()) { + cs := newChatServer() + cs.logf = t.Logf + + // To ensure tests run quickly under even -race. + cs.subscriberMessageBuffer = 4096 + cs.publishLimiter.SetLimit(rate.Inf) + + s := httptest.NewServer(cs) + return s.URL, func() { + s.Close() + } +} + +// testAllMessagesReceived ensures that after n reads, all msgs in msgs +// have been read. +func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error { + msgs = cloneMessages(msgs) + + for i := 0; i < n; i++ { + msg, err := cl.nextMessage() + if err != nil { + return err + } + delete(msgs, msg) + } + + if len(msgs) != 0 { + return fmt.Errorf("did not receive all expected messages: %q", msgs) + } + return nil +} + +func cloneMessages(msgs map[string]struct{}) map[string]struct{} { + msgs2 := make(map[string]struct{}, len(msgs)) + for m := range msgs { + msgs2[m] = struct{}{} + } + return msgs2 +} + +func randMessages(n, maxMessageLength int) map[string]struct{} { + msgs := make(map[string]struct{}) + for i := 0; i < n; i++ { + m := randString(randInt(maxMessageLength)) + if _, ok := msgs[m]; ok { + i-- + continue + } + msgs[m] = struct{}{} + } + return msgs +} + +func assertSuccess(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +type client struct { + url string + c *websocket.Conn +} + +func newClient(ctx context.Context, url string) (*client, error) { + c, _, err := websocket.Dial(ctx, url+"/subscribe", nil) + if err != nil { + return nil, err + } + + cl := &client{ + url: url, + c: c, + } + + return cl, nil +} + +func (cl *client) publish(ctx context.Context, msg string) (err error) { + defer func() { + if err != nil { + cl.c.Close(websocket.StatusInternalError, "publish failed") + } + }() + + req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusAccepted { + return fmt.Errorf("publish request failed: %v", resp.StatusCode) + } + return nil +} + +func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error { + for m := range msgs { + err := cl.publish(ctx, m) + if err != nil { + return err + } + } + return nil +} + +func (cl *client) nextMessage() (string, error) { + typ, b, err := cl.c.Read(context.Background()) + if err != nil { + return "", err + } + + if typ != websocket.MessageText { + cl.c.Close(websocket.StatusUnsupportedData, "expected text message") + return "", fmt.Errorf("expected text message but got %v", typ) + } + return string(b), nil +} + +func (cl *client) Close() error { + return cl.c.Close(websocket.StatusNormalClosure, "") +} + +// randString generates a random string with length n. +func randString(n int) string { + b := make([]byte, n) + _, err := rand.Reader.Read(b) + if err != nil { + panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) + } + + s := strings.ToValidUTF8(string(b), "_") + s = strings.ReplaceAll(s, "\x00", "_") + if len(s) > n { + return s[:n] + } + if len(s) < n { + // Pad with = + extra := n - len(s) + return s + strings.Repeat("=", extra) + } + return s +} + +// randInt returns a randomly generated integer between [0, max). +func randInt(max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + if err != nil { + panic(fmt.Sprintf("failed to get random int: %v", err)) + } + return int(x.Int64()) +} diff --git a/chat-example/go.sum b/examples/chat/go.sum similarity index 57% rename from chat-example/go.sum rename to examples/chat/go.sum index 0755fca5..e4bbd62d 100644 --- a/chat-example/go.sum +++ b/examples/chat/go.sum @@ -1,12 +1,18 @@ +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -nhooyr.io/websocket v1.8.2 h1:LwdzfyyOZKtVFoXay6A39Acu03KmidSZ3YUUvPa13PA= -nhooyr.io/websocket v1.8.2/go.mod h1:LiqdCg1Cu7TPWxEvPjPa0TGYxCsy4pHNTN9gGluwBpQ= diff --git a/chat-example/index.css b/examples/chat/index.css similarity index 98% rename from chat-example/index.css rename to examples/chat/index.css index 29804662..73a8e0f3 100644 --- a/chat-example/index.css +++ b/examples/chat/index.css @@ -5,7 +5,7 @@ body { #root { padding: 40px 20px; - max-width: 480px; + max-width: 600px; margin: auto; height: 100vh; diff --git a/chat-example/index.html b/examples/chat/index.html similarity index 100% rename from chat-example/index.html rename to examples/chat/index.html diff --git a/chat-example/index.js b/examples/chat/index.js similarity index 69% rename from chat-example/index.js rename to examples/chat/index.js index 8fb3dfb8..5868e7ca 100644 --- a/chat-example/index.js +++ b/examples/chat/index.js @@ -7,8 +7,11 @@ const conn = new WebSocket(`ws://${location.host}/subscribe`) conn.addEventListener("close", ev => { - console.info("websocket disconnected, reconnecting in 1000ms", ev) - setTimeout(dial, 1000) + appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true) + if (ev.code !== 1001) { + appendLog("Reconnecting in 1s", true) + setTimeout(dial, 1000) + } }) conn.addEventListener("open", ev => { console.info("websocket connected") @@ -34,17 +37,21 @@ const messageInput = document.getElementById("message-input") // appendLog appends the passed text to messageLog. - function appendLog(text) { + function appendLog(text, error) { const p = document.createElement("p") // Adding a timestamp to each message makes the log easier to read. p.innerText = `${new Date().toLocaleTimeString()}: ${text}` + if (error) { + p.style.color = "red" + p.style.fontStyle = "bold" + } messageLog.append(p) return p } appendLog("Submit a message to get started!") // onsubmit publishes the message from the user when the form is submitted. - publishForm.onsubmit = ev => { + publishForm.onsubmit = async ev => { ev.preventDefault() const msg = messageInput.value @@ -54,9 +61,16 @@ messageInput.value = "" expectingMessage = true - fetch("/publish", { - method: "POST", - body: msg, - }) + try { + const resp = await fetch("/publish", { + method: "POST", + body: msg, + }) + if (resp.status !== 202) { + throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`) + } + } catch (err) { + appendLog(`Publish failed: ${err.message}`, true) + } } })() diff --git a/chat-example/main.go b/examples/chat/main.go similarity index 55% rename from chat-example/main.go rename to examples/chat/main.go index 2a520924..7f3cf6f3 100644 --- a/chat-example/main.go +++ b/examples/chat/main.go @@ -1,11 +1,13 @@ package main import ( + "context" "errors" "log" "net" "net/http" "os" + "os/signal" "time" ) @@ -31,17 +33,28 @@ func run() error { } log.Printf("listening on http://%v", l.Addr()) - var ws chatServer - - m := http.NewServeMux() - m.Handle("/", http.FileServer(http.Dir("."))) - m.HandleFunc("/subscribe", ws.subscribeHandler) - m.HandleFunc("/publish", ws.publishHandler) - - s := http.Server{ - Handler: m, + cs := newChatServer() + s := &http.Server{ + Handler: cs, ReadTimeout: time.Second * 10, WriteTimeout: time.Second * 10, } - return s.Serve(l) + errc := make(chan error, 1) + go func() { + errc <- s.Serve(l) + }() + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, os.Interrupt) + select { + case err := <-errc: + log.Printf("failed to serve: %v", err) + case sig := <-sigs: + log.Printf("terminating: %v", sig) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + return s.Shutdown(ctx) } diff --git a/example_echo_test.go b/examples/echo/main.go similarity index 98% rename from example_echo_test.go rename to examples/echo/main.go index cd195d2e..f1771752 100644 --- a/example_echo_test.go +++ b/examples/echo/main.go @@ -1,6 +1,4 @@ -// +build !js - -package websocket_test +package main import ( "context" @@ -21,7 +19,7 @@ import ( // This example starts a WebSocket echo server, // dials the server and then sends 5 different messages // and prints out the server's responses. -func Example_echo() { +func main() { // First we listen on port 0 which means the OS will // assign us a random free port. This is the listener // the server will serve on and the client will connect to. @@ -56,6 +54,7 @@ func Example_echo() { if err != nil { log.Fatalf("client failed: %v", err) } + // Output: // received: map[i:0] // received: map[i:1] diff --git a/go.sum b/go.sum index dac1ed3a..736df430 100644 --- a/go.sum +++ b/go.sum @@ -4,16 +4,12 @@ github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls= github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= -github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index 0a2899ee..1534f316 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -5,26 +5,19 @@ package wstest import ( "bufio" "context" - "fmt" "net" "net/http" "net/http/httptest" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/errd" - "nhooyr.io/websocket/internal/test/xrand" ) // Pipe is used to create an in memory connection // between two websockets analogous to net.Pipe. -func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (_ *websocket.Conn, _ *websocket.Conn, err error) { - defer errd.Wrap(&err, "failed to create ws pipe") - - var serverConn *websocket.Conn - var acceptErr error +func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (clientConn, serverConn *websocket.Conn) { tt := fakeTransport{ h: func(w http.ResponseWriter, r *http.Request) { - serverConn, acceptErr = websocket.Accept(w, r, acceptOpts) + serverConn, _ = websocket.Accept(w, r, acceptOpts) }, } @@ -36,19 +29,8 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) Transport: tt, } - clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) - if err != nil { - return nil, nil, fmt.Errorf("failed to dial with fake transport: %w", err) - } - - if serverConn == nil { - return nil, nil, fmt.Errorf("failed to get server conn from fake transport: %w", acceptErr) - } - - if xrand.Bool() { - return serverConn, clientConn, nil - } - return clientConn, serverConn, nil + clientConn, _, _ = websocket.Dial(context.Background(), "ws://example.com", dialOpts) + return clientConn, serverConn } type fakeTransport struct { diff --git a/internal/test/wstest/url.go b/internal/test/wstest/url.go deleted file mode 100644 index a11c61b4..00000000 --- a/internal/test/wstest/url.go +++ /dev/null @@ -1,11 +0,0 @@ -package wstest - -import ( - "net/http/httptest" - "strings" -) - -// URL returns the ws url for s. -func URL(s *httptest.Server) string { - return strings.Replace(s.URL, "http", "ws", 1) -} diff --git a/read.go b/read.go index 381cea3d..afd08cc7 100644 --- a/read.go +++ b/read.go @@ -52,7 +52,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close -// frames are responded to. +// frames are responded to. This means c.Ping and c.Close will still work as expected. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) go func() { diff --git a/write.go b/write.go index baa5e6e2..60a4fba0 100644 --- a/write.go +++ b/write.go @@ -246,7 +246,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if err != nil { return 0, err } - defer c.writeFrameMu.unlock() + defer func() { + // We leave it locked when writing the close frame to avoid + // any other goroutine writing any other frame. + if opcode != opClose { + c.writeFrameMu.unlock() + } + }() select { case <-c.closed: @@ -256,8 +262,14 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco defer func() { if err != nil { - err = fmt.Errorf("failed to write frame: %w", err) + select { + case <-c.closed: + err = c.closeErr + case <-ctx.Done(): + err = ctx.Err() + } c.close(err) + err = fmt.Errorf("failed to write frame: %w", err) } }() diff --git a/ws_js.go b/ws_js.go index 2b560ce8..b87e32cd 100644 --- a/ws_js.go +++ b/ws_js.go @@ -9,6 +9,7 @@ import ( "net/http" "reflect" "runtime" + "strings" "sync" "syscall/js" @@ -255,6 +256,9 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp opts = &DialOptions{} } + url = strings.Replace(url, "http://", "ws://", 1) + url = strings.Replace(url, "https://", "wss://", 1) + ws, err := wsjs.New(url, opts.Subprotocols) if err != nil { return nil, nil, err