@@ -62,20 +62,50 @@ func TestAccept(t *testing.T) {
6262 t .Run ("badCompression" , func (t * testing.T ) {
6363 t .Parallel ()
6464
65- w := mockHijacker {
66- ResponseWriter : httptest .NewRecorder (),
65+ newRequest := func (extensions string ) * http.Request {
66+ r := httptest .NewRequest ("GET" , "/" , nil )
67+ r .Header .Set ("Connection" , "Upgrade" )
68+ r .Header .Set ("Upgrade" , "websocket" )
69+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
70+ r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
71+ r .Header .Set ("Sec-WebSocket-Extensions" , extensions )
72+ return r
73+ }
74+ errHijack := errors .New ("hijack error" )
75+ newResponseWriter := func () http.ResponseWriter {
76+ return mockHijacker {
77+ ResponseWriter : httptest .NewRecorder (),
78+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
79+ return nil , nil , errHijack
80+ },
81+ }
6782 }
68- r := httptest .NewRequest ("GET" , "/" , nil )
69- r .Header .Set ("Connection" , "Upgrade" )
70- r .Header .Set ("Upgrade" , "websocket" )
71- r .Header .Set ("Sec-WebSocket-Version" , "13" )
72- r .Header .Set ("Sec-WebSocket-Key" , "meow123" )
73- r .Header .Set ("Sec-WebSocket-Extensions" , "permessage-deflate; harharhar" )
7483
75- _ , err := Accept (w , r , & AcceptOptions {
76- CompressionMode : CompressionContextTakeover ,
84+ t .Run ("withoutFallback" , func (t * testing.T ) {
85+ t .Parallel ()
86+
87+ w := newResponseWriter ()
88+ r := newRequest ("permessage-deflate; harharhar" )
89+ _ , err := Accept (w , r , & AcceptOptions {
90+ CompressionMode : CompressionNoContextTakeover ,
91+ })
92+ assert .ErrorIs (t , errHijack , err )
93+ assert .Equal (t , "extension header" , w .Header ().Get ("Sec-WebSocket-Extensions" ), "" )
94+ })
95+ t .Run ("withFallback" , func (t * testing.T ) {
96+ t .Parallel ()
97+
98+ w := newResponseWriter ()
99+ r := newRequest ("permessage-deflate; harharhar, permessage-deflate" )
100+ _ , err := Accept (w , r , & AcceptOptions {
101+ CompressionMode : CompressionNoContextTakeover ,
102+ })
103+ assert .ErrorIs (t , errHijack , err )
104+ assert .Equal (t , "extension header" ,
105+ w .Header ().Get ("Sec-WebSocket-Extensions" ),
106+ CompressionNoContextTakeover .opts ().String (),
107+ )
77108 })
78- assert .Contains (t , err , `unsupported permessage-deflate parameter` )
79109 })
80110
81111 t .Run ("requireHttpHijacker" , func (t * testing.T ) {
@@ -344,79 +374,66 @@ func Test_authenticateOrigin(t *testing.T) {
344374 }
345375}
346376
347- func Test_acceptCompression (t * testing.T ) {
377+ func Test_selectDeflate (t * testing.T ) {
348378 t .Parallel ()
349379
350380 testCases := []struct {
351- name string
352- mode CompressionMode
353- reqSecWebSocketExtensions string
354- respSecWebSocketExtensions string
355- expCopts * compressionOptions
356- error bool
381+ name string
382+ mode CompressionMode
383+ header string
384+ expCopts * compressionOptions
385+ expOK bool
357386 }{
358387 {
359388 name : "disabled" ,
360389 mode : CompressionDisabled ,
361390 expCopts : nil ,
391+ expOK : false ,
362392 },
363393 {
364394 name : "noClientSupport" ,
365395 mode : CompressionNoContextTakeover ,
366396 expCopts : nil ,
397+ expOK : false ,
367398 },
368399 {
369- name : "permessage-deflate" ,
370- mode : CompressionNoContextTakeover ,
371- reqSecWebSocketExtensions : "permessage-deflate; client_max_window_bits" ,
372- respSecWebSocketExtensions : "permessage-deflate; client_no_context_takeover; server_no_context_takeover" ,
400+ name : "permessage-deflate" ,
401+ mode : CompressionNoContextTakeover ,
402+ header : "permessage-deflate; client_max_window_bits" ,
373403 expCopts : & compressionOptions {
374404 clientNoContextTakeover : true ,
375405 serverNoContextTakeover : true ,
376406 },
407+ expOK : true ,
408+ },
409+ {
410+ name : "permessage-deflate/unknown-parameter" ,
411+ mode : CompressionNoContextTakeover ,
412+ header : "permessage-deflate; meow" ,
413+ expOK : false ,
377414 },
378415 {
379- name : "permessage-deflate/error" ,
380- mode : CompressionNoContextTakeover ,
381- reqSecWebSocketExtensions : "permessage-deflate; meow" ,
382- error : true ,
416+ name : "permessage-deflate/unknown-parameter" ,
417+ mode : CompressionNoContextTakeover ,
418+ header : "permessage-deflate; meow, permessage-deflate; client_max_window_bits" ,
419+ expCopts : & compressionOptions {
420+ clientNoContextTakeover : true ,
421+ serverNoContextTakeover : true ,
422+ },
423+ expOK : true ,
383424 },
384- // {
385- // name: "x-webkit-deflate-frame",
386- // mode: CompressionNoContextTakeover,
387- // reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
388- // respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
389- // expCopts: &compressionOptions{
390- // clientNoContextTakeover: true,
391- // serverNoContextTakeover: true,
392- // },
393- // },
394- // {
395- // name: "x-webkit-deflate/error",
396- // mode: CompressionNoContextTakeover,
397- // reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
398- // error: true,
399- // },
400425 }
401426
402427 for _ , tc := range testCases {
403428 tc := tc
404429 t .Run (tc .name , func (t * testing.T ) {
405430 t .Parallel ()
406431
407- r := httptest .NewRequest (http .MethodGet , "/" , nil )
408- r .Header .Set ("Sec-WebSocket-Extensions" , tc .reqSecWebSocketExtensions )
409-
410- w := httptest .NewRecorder ()
411- copts , err := acceptCompression (r , w , tc .mode )
412- if tc .error {
413- assert .Error (t , err )
414- return
415- }
416-
417- assert .Success (t , err )
432+ h := http.Header {}
433+ h .Set ("Sec-WebSocket-Extensions" , tc .header )
434+ copts , ok := selectDeflate (websocketExtensions (h ), tc .mode )
435+ assert .Equal (t , "selected options" , tc .expOK , ok )
418436 assert .Equal (t , "compression options" , tc .expCopts , copts )
419- assert .Equal (t , "Sec-WebSocket-Extensions" , tc .respSecWebSocketExtensions , w .Header ().Get ("Sec-WebSocket-Extensions" ))
420437 })
421438 }
422439}
0 commit comments