88 "context"
99 "crypto/rand"
1010 "encoding/base64"
11- "errors"
1211 "fmt"
1312 "io"
1413 "io/ioutil"
@@ -47,18 +46,27 @@ type DialOptions struct {
4746 CompressionThreshold int
4847}
4948
50- func (opts * DialOptions ) cloneWithDefaults () * DialOptions {
49+ func (opts * DialOptions ) cloneWithDefaults (ctx context.Context ) (context.Context , context.CancelFunc , * DialOptions ) {
50+ var cancel context.CancelFunc
51+
5152 var o DialOptions
5253 if opts != nil {
5354 o = * opts
5455 }
5556 if o .HTTPClient == nil {
5657 o .HTTPClient = http .DefaultClient
58+ } else if opts .HTTPClient .Timeout > 0 {
59+ ctx , cancel = context .WithTimeout (ctx , opts .HTTPClient .Timeout )
60+
61+ newClient := * opts .HTTPClient
62+ newClient .Timeout = 0
63+ opts .HTTPClient = & newClient
5764 }
5865 if o .HTTPHeader == nil {
5966 o .HTTPHeader = http.Header {}
6067 }
61- return & o
68+
69+ return ctx , cancel , & o
6270}
6371
6472// Dial performs a WebSocket handshake on url.
@@ -81,7 +89,11 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
8189func dial (ctx context.Context , urls string , opts * DialOptions , rand io.Reader ) (_ * Conn , _ * http.Response , err error ) {
8290 defer errd .Wrap (& err , "failed to WebSocket dial" )
8391
84- opts = opts .cloneWithDefaults ()
92+ var cancel context.CancelFunc
93+ ctx , cancel , opts = opts .cloneWithDefaults (ctx )
94+ if cancel != nil {
95+ defer cancel ()
96+ }
8597
8698 secWebSocketKey , err := secWebSocketKey (rand )
8799 if err != nil {
@@ -137,10 +149,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
137149}
138150
139151func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , copts * compressionOptions , secWebSocketKey string ) (* http.Response , error ) {
140- if opts .HTTPClient .Timeout > 0 {
141- return nil , errors .New ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
142- }
143-
144152 u , err := url .Parse (urls )
145153 if err != nil {
146154 return nil , fmt .Errorf ("failed to parse url: %w" , err )
@@ -193,11 +201,11 @@ func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSo
193201 return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
194202 }
195203
196- if ! headerContainsToken (resp .Header , "Connection" , "Upgrade" ) {
204+ if ! headerContainsTokenIgnoreCase (resp .Header , "Connection" , "Upgrade" ) {
197205 return nil , fmt .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
198206 }
199207
200- if ! headerContainsToken (resp .Header , "Upgrade" , "WebSocket" ) {
208+ if ! headerContainsTokenIgnoreCase (resp .Header , "Upgrade" , "WebSocket" ) {
201209 return nil , fmt .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
202210 }
203211
0 commit comments