diff --git a/accept.go b/accept.go
index 47e20b52..dd96c9bd 100644
--- a/accept.go
+++ b/accept.go
@@ -75,6 +75,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
defer errd.Wrap(&err, "failed to accept WebSocket connection")
+ g := graceFromRequest(r)
+ if g != nil && g.isShuttingdown() {
+ err := errors.New("server shutting down")
+ http.Error(w, err.Error(), http.StatusServiceUnavailable)
+ return nil, err
+ }
+
if opts == nil {
opts = &AcceptOptions{}
}
@@ -134,7 +141,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
- return newConn(connConfig{
+ c := newConn(connConfig{
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
rwc: netConn,
client: false,
@@ -143,7 +150,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
br: brw.Reader,
bw: brw.Writer,
- }), nil
+ })
+
+ if g != nil {
+ err = g.addConn(c)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return c, nil
}
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
diff --git a/chat-example/README.md b/chat-example/README.md
index ef06275d..a4c99a93 100644
--- a/chat-example/README.md
+++ b/chat-example/README.md
@@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin
The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by
`index.html` and then `index.js`.
+
+There are two automated tests for the server included in `chat_test.go`. The first is a simple one
+client echo test. It publishes a single message and ensures it's received.
+
+The second is a complex concurrency test where 10 clients send 128 unique messages
+of max 128 bytes concurrently. The test ensures all messages are seen by every client.
diff --git a/chat-example/chat.go b/chat-example/chat.go
index e6e355d0..532e50f5 100644
--- a/chat-example/chat.go
+++ b/chat-example/chat.go
@@ -3,20 +3,67 @@ package main
import (
"context"
"errors"
- "io"
"io/ioutil"
"log"
"net/http"
"sync"
"time"
+ "golang.org/x/time/rate"
+
"nhooyr.io/websocket"
)
// chatServer enables broadcasting to a set of subscribers.
type chatServer struct {
- subscribersMu sync.RWMutex
- subscribers map[chan<- []byte]struct{}
+ // subscriberMessageBuffer controls the max number
+ // of messages that can be queued for a subscriber
+ // before it is kicked.
+ //
+ // Defaults to 16.
+ subscriberMessageBuffer int
+
+ // publishLimiter controls the rate limit applied to the publish endpoint.
+ //
+ // Defaults to one publish every 100ms with a burst of 8.
+ publishLimiter *rate.Limiter
+
+ // logf controls where logs are sent.
+ // Defaults to log.Printf.
+ logf func(f string, v ...interface{})
+
+ // serveMux routes the various endpoints to the appropriate handler.
+ serveMux http.ServeMux
+
+ subscribersMu sync.Mutex
+ subscribers map[*subscriber]struct{}
+}
+
+// newChatServer constructs a chatServer with the defaults.
+func newChatServer() *chatServer {
+ cs := &chatServer{
+ subscriberMessageBuffer: 16,
+ logf: log.Printf,
+ subscribers: make(map[*subscriber]struct{}),
+ publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
+ }
+ cs.serveMux.Handle("/", http.FileServer(http.Dir(".")))
+ cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler)
+ cs.serveMux.HandleFunc("/publish", cs.publishHandler)
+
+ return cs
+}
+
+// subscriber represents a subscriber.
+// Messages are sent on the msgs channel and if the client
+// cannot keep up with the messages, closeSlow is called.
+type subscriber struct {
+ msgs chan []byte
+ closeSlow func()
+}
+
+func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ cs.serveMux.ServeHTTP(w, r)
}
// subscribeHandler accepts the WebSocket connection and then subscribes
@@ -24,7 +71,7 @@ type chatServer struct {
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
- log.Print(err)
+ cs.logf("%v", err)
return
}
defer c.Close(websocket.StatusInternalError, "")
@@ -38,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
return
}
if err != nil {
- log.Print(err)
+ cs.logf("%v", err)
+ return
}
}
@@ -49,7 +97,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
- body := io.LimitReader(r.Body, 8192)
+ body := http.MaxBytesReader(w, r.Body, 8192)
msg, err := ioutil.ReadAll(body)
if err != nil {
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
@@ -57,11 +105,13 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
}
cs.publish(msg)
+
+ w.WriteHeader(http.StatusAccepted)
}
// subscribe subscribes the given WebSocket to all broadcast messages.
-// It creates a msgs chan with a buffer of 16 to give some room to slower
-// connections and then registers it. It then listens for all messages
+// It creates a subscriber with a buffered msgs chan to give some room to slower
+// connections and then registers the subscriber. It then listens for all messages
// and writes them to the WebSocket. If the context is cancelled or
// an error occurs, it returns and deletes the subscription.
//
@@ -70,13 +120,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
ctx = c.CloseRead(ctx)
- msgs := make(chan []byte, 16)
- cs.addSubscriber(msgs)
- defer cs.deleteSubscriber(msgs)
+ s := &subscriber{
+ msgs: make(chan []byte, cs.subscriberMessageBuffer),
+ closeSlow: func() {
+ c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
+ },
+ }
+ cs.addSubscriber(s)
+ defer cs.deleteSubscriber(s)
for {
select {
- case msg := <-msgs:
+ case msg := <-s.msgs:
err := writeTimeout(ctx, time.Second*5, c, msg)
if err != nil {
return err
@@ -91,32 +146,31 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
// It never blocks and so messages to slow subscribers
// are dropped.
func (cs *chatServer) publish(msg []byte) {
- cs.subscribersMu.RLock()
- defer cs.subscribersMu.RUnlock()
+ cs.subscribersMu.Lock()
+ defer cs.subscribersMu.Unlock()
+
+ cs.publishLimiter.Wait(context.Background())
- for c := range cs.subscribers {
+ for s := range cs.subscribers {
select {
- case c <- msg:
+ case s.msgs <- msg:
default:
+ go s.closeSlow()
}
}
}
-// addSubscriber registers a subscriber with a channel
-// on which to send messages.
-func (cs *chatServer) addSubscriber(msgs chan<- []byte) {
+// addSubscriber registers a subscriber.
+func (cs *chatServer) addSubscriber(s *subscriber) {
cs.subscribersMu.Lock()
- if cs.subscribers == nil {
- cs.subscribers = make(map[chan<- []byte]struct{})
- }
- cs.subscribers[msgs] = struct{}{}
+ cs.subscribers[s] = struct{}{}
cs.subscribersMu.Unlock()
}
-// deleteSubscriber deletes the subscriber with the given msgs channel.
-func (cs *chatServer) deleteSubscriber(msgs chan []byte) {
+// deleteSubscriber deletes the given subscriber.
+func (cs *chatServer) deleteSubscriber(s *subscriber) {
cs.subscribersMu.Lock()
- delete(cs.subscribers, msgs)
+ delete(cs.subscribers, s)
cs.subscribersMu.Unlock()
}
diff --git a/chat-example/chat_test.go b/chat-example/chat_test.go
new file mode 100644
index 00000000..491499cc
--- /dev/null
+++ b/chat-example/chat_test.go
@@ -0,0 +1,282 @@
+// +build !js
+
+package main
+
+import (
+ "context"
+ "crypto/rand"
+ "fmt"
+ "math/big"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "golang.org/x/time/rate"
+
+ "nhooyr.io/websocket"
+)
+
+func Test_chatServer(t *testing.T) {
+ t.Parallel()
+
+ // This is a simple echo test with a single client.
+ // The client sends a message and ensures it receives
+ // it on its WebSocket.
+ t.Run("simple", func(t *testing.T) {
+ t.Parallel()
+
+ url, closeFn := setupTest(t)
+ defer closeFn()
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+ defer cancel()
+
+ cl, err := newClient(ctx, url)
+ assertSuccess(t, err)
+ defer cl.Close()
+
+ expMsg := randString(512)
+ err = cl.publish(ctx, expMsg)
+ assertSuccess(t, err)
+
+ msg, err := cl.nextMessage()
+ assertSuccess(t, err)
+
+ if expMsg != msg {
+ t.Fatalf("expected %v but got %v", expMsg, msg)
+ }
+ })
+
+ // This test is a complex concurrency test.
+ // 10 clients are started that send 128 different
+ // messages of max 128 bytes concurrently.
+ //
+ // The test verifies that every message is seen by ever client
+ // and no errors occur anywhere.
+ t.Run("concurrency", func(t *testing.T) {
+ t.Parallel()
+
+ const nmessages = 128
+ const maxMessageSize = 128
+ const nclients = 10
+
+ url, closeFn := setupTest(t)
+ defer closeFn()
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+ defer cancel()
+
+ var clients []*client
+ var clientMsgs []map[string]struct{}
+ for i := 0; i < nclients; i++ {
+ cl, err := newClient(ctx, url)
+ assertSuccess(t, err)
+ defer cl.Close()
+
+ clients = append(clients, cl)
+ clientMsgs = append(clientMsgs, randMessages(nmessages, maxMessageSize))
+ }
+
+ allMessages := make(map[string]struct{})
+ for _, msgs := range clientMsgs {
+ for m := range msgs {
+ allMessages[m] = struct{}{}
+ }
+ }
+
+ var wg sync.WaitGroup
+ for i, cl := range clients {
+ i := i
+ cl := cl
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ err := cl.publishMsgs(ctx, clientMsgs[i])
+ if err != nil {
+ t.Errorf("client %d failed to publish all messages: %v", i, err)
+ }
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ err := testAllMessagesReceived(cl, nclients*nmessages, allMessages)
+ if err != nil {
+ t.Errorf("client %d failed to receive all messages: %v", i, err)
+ }
+ }()
+ }
+
+ wg.Wait()
+ })
+}
+
+// setupTest sets up chatServer that can be used
+// via the returned url.
+//
+// Defer closeFn to ensure everything is cleaned up at
+// the end of the test.
+//
+// chatServer logs will be logged via t.Logf.
+func setupTest(t *testing.T) (url string, closeFn func()) {
+ cs := newChatServer()
+ cs.logf = t.Logf
+
+ // To ensure tests run quickly under even -race.
+ cs.subscriberMessageBuffer = 4096
+ cs.publishLimiter.SetLimit(rate.Inf)
+
+ var g websocket.Grace
+ s := httptest.NewServer(g.Handler(cs))
+ return s.URL, func() {
+ s.Close()
+ g.Close()
+ }
+}
+
+// testAllMessagesReceived ensures that after n reads, all msgs in msgs
+// have been read.
+func testAllMessagesReceived(cl *client, n int, msgs map[string]struct{}) error {
+ msgs = cloneMessages(msgs)
+
+ for i := 0; i < n; i++ {
+ msg, err := cl.nextMessage()
+ if err != nil {
+ return err
+ }
+ delete(msgs, msg)
+ }
+
+ if len(msgs) != 0 {
+ return fmt.Errorf("did not receive all expected messages: %q", msgs)
+ }
+ return nil
+}
+
+func cloneMessages(msgs map[string]struct{}) map[string]struct{} {
+ msgs2 := make(map[string]struct{}, len(msgs))
+ for m := range msgs {
+ msgs2[m] = struct{}{}
+ }
+ return msgs2
+}
+
+func randMessages(n, maxMessageLength int) map[string]struct{} {
+ msgs := make(map[string]struct{})
+ for i := 0; i < n; i++ {
+ m := randString(randInt(maxMessageLength))
+ if _, ok := msgs[m]; ok {
+ i--
+ continue
+ }
+ msgs[m] = struct{}{}
+ }
+ return msgs
+}
+
+func assertSuccess(t *testing.T, err error) {
+ t.Helper()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+type client struct {
+ url string
+ c *websocket.Conn
+}
+
+func newClient(ctx context.Context, url string) (*client, error) {
+ wsURL := strings.Replace(url, "http://", "ws://", 1)
+ c, _, err := websocket.Dial(ctx, wsURL+"/subscribe", nil)
+ if err != nil {
+ return nil, err
+ }
+
+ cl := &client{
+ url: url,
+ c: c,
+ }
+
+ return cl, nil
+}
+
+func (cl *client) publish(ctx context.Context, msg string) (err error) {
+ defer func() {
+ if err != nil {
+ cl.c.Close(websocket.StatusInternalError, "publish failed")
+ }
+ }()
+
+ req, _ := http.NewRequestWithContext(ctx, http.MethodPost, cl.url+"/publish", strings.NewReader(msg))
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusAccepted {
+ return fmt.Errorf("publish request failed: %v", resp.StatusCode)
+ }
+ return nil
+}
+
+func (cl *client) publishMsgs(ctx context.Context, msgs map[string]struct{}) error {
+ for m := range msgs {
+ err := cl.publish(ctx, m)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (cl *client) nextMessage() (string, error) {
+ typ, b, err := cl.c.Read(context.Background())
+ if err != nil {
+ return "", err
+ }
+
+ if typ != websocket.MessageText {
+ cl.c.Close(websocket.StatusUnsupportedData, "expected text message")
+ return "", fmt.Errorf("expected text message but got %v", typ)
+ }
+ return string(b), nil
+}
+
+func (cl *client) Close() error {
+ return cl.c.Close(websocket.StatusNormalClosure, "")
+}
+
+// randString generates a random string with length n.
+func randString(n int) string {
+ b := make([]byte, n)
+ _, err := rand.Reader.Read(b)
+ if err != nil {
+ panic(fmt.Sprintf("failed to generate rand bytes: %v", err))
+ }
+
+ s := strings.ToValidUTF8(string(b), "_")
+ s = strings.ReplaceAll(s, "\x00", "_")
+ if len(s) > n {
+ return s[:n]
+ }
+ if len(s) < n {
+ // Pad with =
+ extra := n - len(s)
+ return s + strings.Repeat("=", extra)
+ }
+ return s
+}
+
+// randInt returns a randomly generated integer between [0, max).
+func randInt(max int) int {
+ x, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
+ if err != nil {
+ panic(fmt.Sprintf("failed to get random int: %v", err))
+ }
+ return int(x.Int64())
+}
diff --git a/chat-example/go.mod b/chat-example/go.mod
deleted file mode 100644
index 34fa5a69..00000000
--- a/chat-example/go.mod
+++ /dev/null
@@ -1,5 +0,0 @@
-module nhooyr.io/websocket/example-chat
-
-go 1.13
-
-require nhooyr.io/websocket v1.8.2
diff --git a/chat-example/go.sum b/chat-example/go.sum
index 0755fca5..e4bbd62d 100644
--- a/chat-example/go.sum
+++ b/chat-example/go.sum
@@ -1,12 +1,18 @@
+github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
+github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8=
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
+github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
+github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
+github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y=
github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-nhooyr.io/websocket v1.8.2 h1:LwdzfyyOZKtVFoXay6A39Acu03KmidSZ3YUUvPa13PA=
-nhooyr.io/websocket v1.8.2/go.mod h1:LiqdCg1Cu7TPWxEvPjPa0TGYxCsy4pHNTN9gGluwBpQ=
diff --git a/chat-example/index.css b/chat-example/index.css
index 29804662..73a8e0f3 100644
--- a/chat-example/index.css
+++ b/chat-example/index.css
@@ -5,7 +5,7 @@ body {
#root {
padding: 40px 20px;
- max-width: 480px;
+ max-width: 600px;
margin: auto;
height: 100vh;
diff --git a/chat-example/index.js b/chat-example/index.js
index 8fb3dfb8..5868e7ca 100644
--- a/chat-example/index.js
+++ b/chat-example/index.js
@@ -7,8 +7,11 @@
const conn = new WebSocket(`ws://${location.host}/subscribe`)
conn.addEventListener("close", ev => {
- console.info("websocket disconnected, reconnecting in 1000ms", ev)
- setTimeout(dial, 1000)
+ appendLog(`WebSocket Disconnected code: ${ev.code}, reason: ${ev.reason}`, true)
+ if (ev.code !== 1001) {
+ appendLog("Reconnecting in 1s", true)
+ setTimeout(dial, 1000)
+ }
})
conn.addEventListener("open", ev => {
console.info("websocket connected")
@@ -34,17 +37,21 @@
const messageInput = document.getElementById("message-input")
// appendLog appends the passed text to messageLog.
- function appendLog(text) {
+ function appendLog(text, error) {
const p = document.createElement("p")
// Adding a timestamp to each message makes the log easier to read.
p.innerText = `${new Date().toLocaleTimeString()}: ${text}`
+ if (error) {
+ p.style.color = "red"
+ p.style.fontStyle = "bold"
+ }
messageLog.append(p)
return p
}
appendLog("Submit a message to get started!")
// onsubmit publishes the message from the user when the form is submitted.
- publishForm.onsubmit = ev => {
+ publishForm.onsubmit = async ev => {
ev.preventDefault()
const msg = messageInput.value
@@ -54,9 +61,16 @@
messageInput.value = ""
expectingMessage = true
- fetch("/publish", {
- method: "POST",
- body: msg,
- })
+ try {
+ const resp = await fetch("/publish", {
+ method: "POST",
+ body: msg,
+ })
+ if (resp.status !== 202) {
+ throw new Error(`Unexpected HTTP Status ${resp.status} ${resp.statusText}`)
+ }
+ } catch (err) {
+ appendLog(`Publish failed: ${err.message}`, true)
+ }
}
})()
diff --git a/chat-example/main.go b/chat-example/main.go
index 2a520924..1b6f3266 100644
--- a/chat-example/main.go
+++ b/chat-example/main.go
@@ -1,12 +1,16 @@
package main
import (
+ "context"
"errors"
"log"
"net"
"net/http"
"os"
+ "os/signal"
"time"
+
+ "nhooyr.io/websocket"
)
func main() {
@@ -31,17 +35,32 @@ func run() error {
}
log.Printf("listening on http://%v", l.Addr())
- var ws chatServer
-
- m := http.NewServeMux()
- m.Handle("/", http.FileServer(http.Dir(".")))
- m.HandleFunc("/subscribe", ws.subscribeHandler)
- m.HandleFunc("/publish", ws.publishHandler)
-
+ cs := newChatServer()
+ var g websocket.Grace
s := http.Server{
- Handler: m,
+ Handler: g.Handler(cs),
ReadTimeout: time.Second * 10,
WriteTimeout: time.Second * 10,
}
- return s.Serve(l)
+ errc := make(chan error, 1)
+ go func() {
+ errc <- s.Serve(l)
+ }()
+
+ sigs := make(chan os.Signal, 1)
+ signal.Notify(sigs, os.Interrupt)
+ select {
+ case err := <-errc:
+ log.Printf("failed to serve: %v", err)
+ case sig := <-sigs:
+ log.Printf("terminating: %v", sig)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
+ defer cancel()
+
+ s.Shutdown(ctx)
+ g.Shutdown(ctx)
+
+ return nil
}
diff --git a/ci/test.mk b/ci/test.mk
index c62a25b6..291d6beb 100644
--- a/ci/test.mk
+++ b/ci/test.mk
@@ -11,6 +11,7 @@ coveralls: gotest
goveralls -coverprofile=ci/out/coverage.prof
gotest:
- go test -timeout=30m -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./...
+ go test -timeout=30m -covermode=atomic -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./...
sed -i '/stringer\.go/d' ci/out/coverage.prof
sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
+ sed -i '/chat-example/d' ci/out/coverage.prof
diff --git a/conn_notjs.go b/conn_notjs.go
index bb2eb22f..f604898e 100644
--- a/conn_notjs.go
+++ b/conn_notjs.go
@@ -33,6 +33,7 @@ type Conn struct {
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
+ g *Grace
readTimeout chan context.Context
writeTimeout chan context.Context
@@ -138,6 +139,10 @@ func (c *Conn) close(err error) {
// closeErr.
c.rwc.Close()
+ if c.g != nil {
+ c.g.delConn(c)
+ }
+
go func() {
c.msgWriterState.close()
diff --git a/conn_test.go b/conn_test.go
index 28da3c07..af4fa4c0 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -13,7 +13,6 @@ import (
"os"
"os/exec"
"strings"
- "sync"
"testing"
"time"
@@ -272,11 +271,9 @@ func TestWasm(t *testing.T) {
t.Skip("skipping on CI")
}
- var wg sync.WaitGroup
- s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- wg.Add(1)
- defer wg.Done()
-
+ var g websocket.Grace
+ defer g.Close()
+ s := httptest.NewServer(g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{"echo"},
InsecureSkipVerify: true,
@@ -294,8 +291,7 @@ func TestWasm(t *testing.T) {
t.Errorf("echo server failed: %v", err)
return
}
- }))
- defer wg.Wait()
+ })))
defer s.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
diff --git a/example_echo_test.go b/example_echo_test.go
index cd195d2e..0c0b84ea 100644
--- a/example_echo_test.go
+++ b/example_echo_test.go
@@ -31,13 +31,15 @@ func Example_echo() {
}
defer l.Close()
+ var g websocket.Grace
+ defer g.Close()
s := &http.Server{
- Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ Handler: g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r)
if err != nil {
log.Printf("echo server: %v", err)
}
- }),
+ })),
ReadTimeout: time.Second * 15,
WriteTimeout: time.Second * 15,
}
diff --git a/example_test.go b/example_test.go
index c56e53f3..462de376 100644
--- a/example_test.go
+++ b/example_test.go
@@ -6,6 +6,8 @@ import (
"context"
"log"
"net/http"
+ "os"
+ "os/signal"
"time"
"nhooyr.io/websocket"
@@ -133,3 +135,55 @@ func Example_crossOrigin() {
err := http.ListenAndServe("localhost:8080", fn)
log.Fatal(err)
}
+
+// This example demonstrates how to create a WebSocket server
+// that gracefully exits when sent a signal.
+//
+// It starts a WebSocket server that keeps every connection open
+// for 10 seconds.
+// If you CTRL+C while a connection is open, it will wait at most 30s
+// for all connections to terminate before shutting down.
+func ExampleGrace() {
+ fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ c, err := websocket.Accept(w, r, nil)
+ if err != nil {
+ log.Println(err)
+ return
+ }
+ defer c.Close(websocket.StatusInternalError, "the sky is falling")
+
+ ctx := c.CloseRead(r.Context())
+ select {
+ case <-ctx.Done():
+ case <-time.After(time.Second * 10):
+ }
+
+ c.Close(websocket.StatusNormalClosure, "")
+ })
+
+ var g websocket.Grace
+ s := &http.Server{
+ Handler: g.Handler(fn),
+ ReadTimeout: time.Second * 15,
+ WriteTimeout: time.Second * 15,
+ }
+
+ errc := make(chan error, 1)
+ go func() {
+ errc <- s.ListenAndServe()
+ }()
+
+ sigs := make(chan os.Signal, 1)
+ signal.Notify(sigs, os.Interrupt)
+ select {
+ case err := <-errc:
+ log.Printf("failed to listen and serve: %v", err)
+ case sig := <-sigs:
+ log.Printf("terminating: %v", sig)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
+ defer cancel()
+ s.Shutdown(ctx)
+ g.Shutdown(ctx)
+}
diff --git a/grace.go b/grace.go
new file mode 100644
index 00000000..c53cd40b
--- /dev/null
+++ b/grace.go
@@ -0,0 +1,127 @@
+package websocket
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "sync"
+ "time"
+)
+
+// Grace enables graceful shutdown of accepted WebSocket connections.
+//
+// Use Handler to wrap WebSocket handlers to record accepted connections
+// and then use Close or Shutdown to gracefully close these connections.
+//
+// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods.
+// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket
+// connections.
+type Grace struct {
+ mu sync.Mutex
+ closed bool
+ shuttingDown bool
+ conns map[*Conn]struct{}
+}
+
+// Handler returns a handler that wraps around h to record
+// all WebSocket connections accepted.
+//
+// Use Close or Shutdown to gracefully close recorded connections.
+func (g *Grace) Handler(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ctx := context.WithValue(r.Context(), gracefulContextKey{}, g)
+ r = r.WithContext(ctx)
+ h.ServeHTTP(w, r)
+ })
+}
+
+func (g *Grace) isShuttingdown() bool {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ return g.shuttingDown
+}
+
+func graceFromRequest(r *http.Request) *Grace {
+ g, _ := r.Context().Value(gracefulContextKey{}).(*Grace)
+ return g
+}
+
+func (g *Grace) addConn(c *Conn) error {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ if g.closed {
+ c.Close(StatusGoingAway, "server shutting down")
+ return errors.New("server shutting down")
+ }
+ if g.conns == nil {
+ g.conns = make(map[*Conn]struct{})
+ }
+ g.conns[c] = struct{}{}
+ c.g = g
+ return nil
+}
+
+func (g *Grace) delConn(c *Conn) {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ delete(g.conns, c)
+}
+
+type gracefulContextKey struct{}
+
+// Close prevents the acceptance of new connections with
+// http.StatusServiceUnavailable and closes all accepted
+// connections with StatusGoingAway.
+func (g *Grace) Close() error {
+ g.mu.Lock()
+ g.shuttingDown = true
+ g.closed = true
+ var wg sync.WaitGroup
+ for c := range g.conns {
+ wg.Add(1)
+ go func(c *Conn) {
+ defer wg.Done()
+ c.Close(StatusGoingAway, "server shutting down")
+ }(c)
+
+ delete(g.conns, c)
+ }
+ g.mu.Unlock()
+
+ wg.Wait()
+
+ return nil
+}
+
+// Shutdown prevents the acceptance of new connections and waits until
+// all connections close. If the context is cancelled before that, it
+// calls Close to close all connections immediately.
+func (g *Grace) Shutdown(ctx context.Context) error {
+ defer g.Close()
+
+ g.mu.Lock()
+ g.shuttingDown = true
+ g.mu.Unlock()
+
+ // Same poll period used by net/http.
+ t := time.NewTicker(500 * time.Millisecond)
+ defer t.Stop()
+ for {
+ if g.zeroConns() {
+ return nil
+ }
+
+ select {
+ case <-t.C:
+ case <-ctx.Done():
+ return fmt.Errorf("failed to shutdown WebSockets: %w", ctx.Err())
+ }
+ }
+}
+
+func (g *Grace) zeroConns() bool {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ return len(g.conns) == 0
+}
diff --git a/ws_js.go b/ws_js.go
index 2b560ce8..a8c8b771 100644
--- a/ws_js.go
+++ b/ws_js.go
@@ -38,6 +38,8 @@ type Conn struct {
readSignal chan struct{}
readBufMu sync.Mutex
readBuf []wsjs.MessageEvent
+
+ g *Grace
}
func (c *Conn) close(err error, wasClean bool) {