@@ -50,15 +50,15 @@ func TestOAuth2(t *testing.T) {
5050 t .Parallel ()
5151 req := httptest .NewRequest ("GET" , "/" , nil )
5252 res := httptest .NewRecorder ()
53- httpmw .ExtractOAuth2 (nil , nil , nil )(nil ).ServeHTTP (res , req )
53+ httpmw .ExtractOAuth2 (nil , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
5454 require .Equal (t , http .StatusBadRequest , res .Result ().StatusCode )
5555 })
5656 t .Run ("RedirectWithoutCode" , func (t * testing.T ) {
5757 t .Parallel ()
5858 req := httptest .NewRequest ("GET" , "/?redirect=" + url .QueryEscape ("/dashboard" ), nil )
5959 res := httptest .NewRecorder ()
6060 tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
61- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
61+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
6262 location := res .Header ().Get ("Location" )
6363 if ! assert .NotEmpty (t , location ) {
6464 return
@@ -82,7 +82,7 @@ func TestOAuth2(t *testing.T) {
8282 req := httptest .NewRequest ("GET" , "/?redirect=" + url .QueryEscape (uri .String ()), nil )
8383 res := httptest .NewRecorder ()
8484 tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
85- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
85+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
8686 location := res .Header ().Get ("Location" )
8787 if ! assert .NotEmpty (t , location ) {
8888 return
@@ -97,15 +97,15 @@ func TestOAuth2(t *testing.T) {
9797 req := httptest .NewRequest ("GET" , "/?code=something" , nil )
9898 res := httptest .NewRecorder ()
9999 tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
100- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
100+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
101101 require .Equal (t , http .StatusBadRequest , res .Result ().StatusCode )
102102 })
103103 t .Run ("NoStateCookie" , func (t * testing.T ) {
104104 t .Parallel ()
105105 req := httptest .NewRequest ("GET" , "/?code=something&state=test" , nil )
106106 res := httptest .NewRecorder ()
107107 tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
108- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
108+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
109109 require .Equal (t , http .StatusUnauthorized , res .Result ().StatusCode )
110110 })
111111 t .Run ("MismatchedState" , func (t * testing.T ) {
@@ -117,7 +117,7 @@ func TestOAuth2(t *testing.T) {
117117 })
118118 res := httptest .NewRecorder ()
119119 tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
120- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
120+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(nil ).ServeHTTP (res , req )
121121 require .Equal (t , http .StatusUnauthorized , res .Result ().StatusCode )
122122 })
123123 t .Run ("ExchangeCodeAndState" , func (t * testing.T ) {
@@ -133,7 +133,7 @@ func TestOAuth2(t *testing.T) {
133133 })
134134 res := httptest .NewRecorder ()
135135 tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
136- httpmw .ExtractOAuth2 (tp , nil , nil )(http .HandlerFunc (func (_ http.ResponseWriter , r * http.Request ) {
136+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, nil )(http .HandlerFunc (func (_ http.ResponseWriter , r * http.Request ) {
137137 state := httpmw .OAuth2 (r )
138138 require .Equal (t , "/dashboard" , state .Redirect )
139139 })).ServeHTTP (res , req )
@@ -144,7 +144,7 @@ func TestOAuth2(t *testing.T) {
144144 res := httptest .NewRecorder ()
145145 tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline , oauth2 .SetAuthURLParam ("foo" , "bar" ))
146146 authOpts := map [string ]string {"foo" : "bar" }
147- httpmw .ExtractOAuth2 (tp , nil , authOpts )(nil ).ServeHTTP (res , req )
147+ httpmw .ExtractOAuth2 (tp , nil , codersdk. HTTPCookieConfig {}, authOpts )(nil ).ServeHTTP (res , req )
148148 location := res .Header ().Get ("Location" )
149149 // Ideally we would also assert that the location contains the query params
150150 // we set in the auth URL but this would essentially be testing the oauth2 package.
@@ -157,12 +157,17 @@ func TestOAuth2(t *testing.T) {
157157 req := httptest .NewRequest ("GET" , "/?oidc_merge_state=" + customState + "&redirect=" + url .QueryEscape ("/dashboard" ), nil )
158158 res := httptest .NewRecorder ()
159159 tp := newTestOAuth2Provider (t , oauth2 .AccessTypeOffline )
160- httpmw .ExtractOAuth2 (tp , nil , nil )(nil ).ServeHTTP (res , req )
160+ httpmw .ExtractOAuth2 (tp , nil , codersdk.HTTPCookieConfig {
161+ Secure : true ,
162+ SameSite : "none" ,
163+ }, nil )(nil ).ServeHTTP (res , req )
161164
162165 found := false
163166 for _ , cookie := range res .Result ().Cookies () {
164167 if cookie .Name == codersdk .OAuth2StateCookie {
165168 require .Equal (t , cookie .Value , customState , "expected state" )
169+ require .Equal (t , true , cookie .Secure , "cookie set to secure" )
170+ require .Equal (t , http .SameSiteNoneMode , cookie .SameSite , "same-site = none" )
166171 found = true
167172 }
168173 }
0 commit comments