diff --git a/authhandler/authhandler.go b/authhandler/authhandler.go index 9bc6cd7bc..46d1396f1 100644 --- a/authhandler/authhandler.go +++ b/authhandler/authhandler.go @@ -34,7 +34,7 @@ type PKCEParams struct { // and returns an auth code and state upon approval. type AuthorizationHandler func(authCodeURL string) (code string, state string, err error) -// TokenSourceWithPKCE is an enhanced version of TokenSource with PKCE support. +// TokenSourceWithPKCE is an enhanced version of [oauth2.TokenSource] with PKCE support. // // The pkce parameter supports PKCE flow, which uses code challenge and code verifier // to prevent CSRF attacks. A unique code challenge and code verifier should be generated @@ -43,12 +43,12 @@ func TokenSourceWithPKCE(ctx context.Context, config *oauth2.Config, state strin return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state, pkce: pkce}) } -// TokenSource returns an oauth2.TokenSource that fetches access tokens +// TokenSource returns an [oauth2.TokenSource] that fetches access tokens // using 3-legged-OAuth flow. // -// The provided context.Context is used for oauth2 Exchange operation. +// The provided [context.Context] is used for oauth2 Exchange operation. // -// The provided oauth2.Config should be a full configuration containing AuthURL, +// The provided [oauth2.Config] should be a full configuration containing AuthURL, // TokenURL, and Scope. // // An environment-specific AuthorizationHandler is used to obtain user consent. diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 51121a3d5..e86346e8b 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -55,7 +55,7 @@ type Config struct { // Token uses client credentials to retrieve a token. // -// The provided context optionally controls which HTTP client is used. See the oauth2.HTTPClient variable. +// The provided context optionally controls which HTTP client is used. See the [oauth2.HTTPClient] variable. func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { return c.TokenSource(ctx).Token() } @@ -64,18 +64,18 @@ func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { // The token will auto-refresh as necessary. // // The provided context optionally controls which HTTP client -// is returned. See the oauth2.HTTPClient variable. +// is returned. See the [oauth2.HTTPClient] variable. // -// The returned Client and its Transport should not be modified. +// The returned [http.Client] and its Transport should not be modified. func (c *Config) Client(ctx context.Context) *http.Client { return oauth2.NewClient(ctx, c.TokenSource(ctx)) } -// TokenSource returns a TokenSource that returns t until t expires, +// TokenSource returns a [oauth2.TokenSource] that returns t until t expires, // automatically refreshing it as necessary using the provided context and the // client ID and client secret. // -// Most users will use Config.Client instead. +// Most users will use [Config.Client] instead. func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { source := &tokenSource{ ctx: ctx, diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go index 078e75ec7..e2a99eb14 100644 --- a/clientcredentials/clientcredentials_test.go +++ b/clientcredentials/clientcredentials_test.go @@ -7,7 +7,6 @@ package clientcredentials import ( "context" "io" - "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -36,9 +35,9 @@ func TestTokenSourceGrantTypeOverride(t *testing.T) { wantGrantType := "password" var gotGrantType string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { - t.Errorf("ioutil.ReadAll(r.Body) == %v, %v, want _, ", body, err) + t.Errorf("io.ReadAll(r.Body) == %v, %v, want _, ", body, err) } if err := r.Body.Close(); err != nil { t.Errorf("r.Body.Close() == %v, want ", err) @@ -81,7 +80,7 @@ func TestTokenRequest(t *testing.T) { if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { t.Errorf("Content-Type header = %q; want %q", got, want) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { r.Body.Close() } @@ -123,7 +122,7 @@ func TestTokenRefreshRequest(t *testing.T) { if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want { t.Errorf("Content-Type = %q; want %q", got, want) } - body, _ := ioutil.ReadAll(r.Body) + body, _ := io.ReadAll(r.Body) const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" if string(body) != want { t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want) diff --git a/deviceauth_test.go b/deviceauth_test.go index 3b9962005..0e61a2559 100644 --- a/deviceauth_test.go +++ b/deviceauth_test.go @@ -7,9 +7,6 @@ import ( "strings" "testing" "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" ) func TestDeviceAuthResponseMarshalJson(t *testing.T) { @@ -74,7 +71,16 @@ func TestDeviceAuthResponseUnmarshalJson(t *testing.T) { if err != nil { t.Fatal(err) } - if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) { + margin := time.Second + time.Since(begin) + timeDiff := got.Expiry.Sub(tc.want.Expiry) + if timeDiff < 0 { + timeDiff *= -1 + } + if timeDiff > margin { + t.Errorf("expiry time difference too large, got=%v, want=%v margin=%v", got.Expiry, tc.want.Expiry, margin) + } + got.Expiry, tc.want.Expiry = time.Time{}, time.Time{} + if got != tc.want { t.Errorf("want=%#v, got=%#v", tc.want, got) } }) diff --git a/endpoints/endpoints.go b/endpoints/endpoints.go index d6e575e1f..e862a3238 100644 --- a/endpoints/endpoints.go +++ b/endpoints/endpoints.go @@ -6,7 +6,7 @@ package endpoints import ( - "strings" + "net/url" "golang.org/x/oauth2" ) @@ -17,6 +17,30 @@ var Amazon = oauth2.Endpoint{ TokenURL: "https://api.amazon.com/auth/o2/token", } +// Apple is the endpoint for "Sign in with Apple". +// +// Documentation: https://developer.apple.com/documentation/signinwithapplerestapi +var Apple = oauth2.Endpoint{ + AuthURL: "https://appleid.apple.com/auth/authorize", + TokenURL: "https://appleid.apple.com/auth/token", +} + +// Asana is the endpoint for Asana. +// +// Documentation: https://developers.asana.com/docs/oauth +var Asana = oauth2.Endpoint{ + AuthURL: "https://app.asana.com/-/oauth_authorize", + TokenURL: "https://app.asana.com/-/oauth_token", +} + +// Badgr is the endpoint for Canvas Badges. +// +// Documentation: https://community.canvaslms.com/t5/Canvas-Badges-Credentials/Developers-Build-an-app-that-integrates-with-the-Canvas-Badges/ta-p/528727 +var Badgr = oauth2.Endpoint{ + AuthURL: "https://badgr.com/auth/oauth2/authorize", + TokenURL: "https://api.badgr.io/o/token", +} + // Battlenet is the endpoint for Battlenet. var Battlenet = oauth2.Endpoint{ AuthURL: "https://battle.net/oauth/authorize", @@ -35,16 +59,44 @@ var Cern = oauth2.Endpoint{ TokenURL: "https://oauth.web.cern.ch/OAuth/Token", } +// Coinbase is the endpoint for Coinbase. +// +// Documentation: https://docs.cdp.coinbase.com/coinbase-app/docs/coinbase-app-reference +var Coinbase = oauth2.Endpoint{ + AuthURL: "https://login.coinbase.com/oauth2/auth", + TokenURL: "https://login.coinbase.com/oauth2/token", +} + // Discord is the endpoint for Discord. +// +// Documentation: https://discord.com/developers/docs/topics/oauth2#shared-resources-oauth2-urls var Discord = oauth2.Endpoint{ AuthURL: "https://discord.com/oauth2/authorize", TokenURL: "https://discord.com/api/oauth2/token", } +// Dropbox is the endpoint for Dropbox. +// +// Documentation: https://developers.dropbox.com/oauth-guide +var Dropbox = oauth2.Endpoint{ + AuthURL: "https://www.dropbox.com/oauth2/authorize", + TokenURL: "https://api.dropboxapi.com/oauth2/token", +} + +// Endpoint is Ebay's OAuth 2.0 endpoint. +// +// Documentation: https://developer.ebay.com/api-docs/static/authorization_guide_landing.html +var Endpoint = oauth2.Endpoint{ + AuthURL: "https://auth.ebay.com/oauth2/authorize", + TokenURL: "https://api.ebay.com/identity/v1/oauth2/token", +} + // Facebook is the endpoint for Facebook. +// +// Documentation: https://developers.facebook.com/docs/facebook-login/guides/advanced/manual-flow var Facebook = oauth2.Endpoint{ - AuthURL: "https://www.facebook.com/v3.2/dialog/oauth", - TokenURL: "https://graph.facebook.com/v3.2/oauth/access_token", + AuthURL: "https://www.facebook.com/v22.0/dialog/oauth", + TokenURL: "https://graph.facebook.com/v22.0/oauth/access_token", } // Foursquare is the endpoint for Foursquare. @@ -104,6 +156,14 @@ var KaKao = oauth2.Endpoint{ TokenURL: "https://kauth.kakao.com/oauth/token", } +// Line is the endpoint for Line. +// +// Documentation: https://developers.line.biz/en/docs/line-login/integrate-line-login/ +var Line = oauth2.Endpoint{ + AuthURL: "https://access.line.me/oauth2/v2.1/authorize", + TokenURL: "https://api.line.me/oauth2/v2.1/token", +} + // LinkedIn is the endpoint for LinkedIn. var LinkedIn = oauth2.Endpoint{ AuthURL: "https://www.linkedin.com/oauth/v2/authorization", @@ -140,7 +200,17 @@ var Microsoft = oauth2.Endpoint{ TokenURL: "https://login.live.com/oauth20_token.srf", } +// Naver is the endpoint for Naver. +// +// Documentation: https://developers.naver.com/docs/login/devguide/devguide.md +var Naver = oauth2.Endpoint{ + AuthURL: "https://nid.naver.com/oauth2/authorize", + TokenURL: "https://nid.naver.com/oauth2/token", +} + // NokiaHealth is the endpoint for Nokia Health. +// +// Deprecated: Nokia Health is now Withings. var NokiaHealth = oauth2.Endpoint{ AuthURL: "https://account.health.nokia.com/oauth2_user/authorize2", TokenURL: "https://account.health.nokia.com/oauth2/token", @@ -152,6 +222,14 @@ var Odnoklassniki = oauth2.Endpoint{ TokenURL: "https://api.odnoklassniki.ru/oauth/token.do", } +// OpenStreetMap is the endpoint for OpenStreetMap.org. +// +// Documentation: https://wiki.openstreetmap.org/wiki/OAuth +var OpenStreetMap = oauth2.Endpoint{ + AuthURL: "https://www.openstreetmap.org/oauth2/authorize", + TokenURL: "https://www.openstreetmap.org/oauth2/token", +} + // Patreon is the endpoint for Patreon. var Patreon = oauth2.Endpoint{ AuthURL: "https://www.patreon.com/oauth2/authorize", @@ -170,10 +248,52 @@ var PayPalSandbox = oauth2.Endpoint{ TokenURL: "https://api.sandbox.paypal.com/v1/identity/openidconnect/tokenservice", } +// Pinterest is the endpoint for Pinterest. +// +// Documentation: https://developers.pinterest.com/docs/getting-started/set-up-authentication-and-authorization/ +var Pinterest = oauth2.Endpoint{ + AuthURL: "https://www.pinterest.com/oauth", + TokenURL: "https://api.pinterest.com/v5/oauth/token", +} + +// Pipedrive is the endpoint for Pipedrive. +// +// Documentation: https://developers.pipedrive.com/docs/api/v1/Oauth +var Pipedrive = oauth2.Endpoint{ + AuthURL: "https://oauth.pipedrive.com/oauth/authorize", + TokenURL: "https://oauth.pipedrive.com/oauth/token", +} + +// QQ is the endpoint for QQ. +// +// Documentation: https://wiki.connect.qq.com/%e5%bc%80%e5%8f%91%e6%94%bb%e7%95%a5_server-side +var QQ = oauth2.Endpoint{ + AuthURL: "https://graph.qq.com/oauth2.0/authorize", + TokenURL: "https://graph.qq.com/oauth2.0/token", +} + +// Rakuten is the endpoint for Rakuten. +// +// Documentation: https://webservice.rakuten.co.jp/documentation +var Rakuten = oauth2.Endpoint{ + AuthURL: "https://app.rakuten.co.jp/services/authorize", + TokenURL: "https://app.rakuten.co.jp/services/token", +} + // Slack is the endpoint for Slack. +// +// Documentation: https://api.slack.com/authentication/oauth-v2 var Slack = oauth2.Endpoint{ - AuthURL: "https://slack.com/oauth/authorize", - TokenURL: "https://slack.com/api/oauth.access", + AuthURL: "https://slack.com/oauth/v2/authorize", + TokenURL: "https://slack.com/api/oauth.v2.access", +} + +// Splitwise is the endpoint for Splitwise. +// +// Documentation: https://dev.splitwise.com/ +var Splitwise = oauth2.Endpoint{ + AuthURL: "https://www.splitwise.com/oauth/authorize", + TokenURL: "https://www.splitwise.com/oauth/token", } // Spotify is the endpoint for Spotify. @@ -212,6 +332,22 @@ var Vk = oauth2.Endpoint{ TokenURL: "https://oauth.vk.com/access_token", } +// Withings is the endpoint for Withings. +// +// Documentation: https://account.withings.com/oauth2_user/authorize2 +var Withings = oauth2.Endpoint{ + AuthURL: "https://account.withings.com/oauth2_user/authorize2", + TokenURL: "https://account.withings.com/oauth2/token", +} + +// X is the endpoint for X (Twitter). +// +// Documentation: https://docs.x.com/resources/fundamentals/authentication/oauth-2-0/user-access-token +var X = oauth2.Endpoint{ + AuthURL: "https://x.com/i/oauth2/authorize", + TokenURL: "https://api.x.com/2/oauth2/token", +} + // Yahoo is the endpoint for Yahoo. var Yahoo = oauth2.Endpoint{ AuthURL: "https://api.login.yahoo.com/oauth2/request_auth", @@ -230,6 +366,20 @@ var Zoom = oauth2.Endpoint{ TokenURL: "https://zoom.us/oauth/token", } +// Asgardeo returns a new oauth2.Endpoint for the given tenant. +// +// Documentation: https://wso2.com/asgardeo/docs/guides/authentication/oidc/discover-oidc-configs/ +func AsgardeoEndpoint(tenant string) oauth2.Endpoint { + u := url.URL{ + Scheme: "https", + Host: "api.asgardeo.io", + } + return oauth2.Endpoint{ + AuthURL: u.JoinPath("t", tenant, "/oauth2/authorize").String(), + TokenURL: u.JoinPath("t", tenant, "/oauth2/token").String(), + } +} + // AzureAD returns a new oauth2.Endpoint for the given tenant at Azure Active Directory. // If tenant is empty, it uses the tenant called `common`. // @@ -239,19 +389,29 @@ func AzureAD(tenant string) oauth2.Endpoint { if tenant == "" { tenant = "common" } + u := url.URL{ + Scheme: "https", + Host: "login.microsoftonline.com", + } return oauth2.Endpoint{ - AuthURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/authorize", - TokenURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/token", - DeviceAuthURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/devicecode", + AuthURL: u.JoinPath(tenant, "/oauth2/v2.0/authorize").String(), + TokenURL: u.JoinPath(tenant, "/oauth2/v2.0/token").String(), + DeviceAuthURL: u.JoinPath(tenant, "/oauth2/v2.0/devicecode").String(), } } -// HipChatServer returns a new oauth2.Endpoint for a HipChat Server instance -// running on the given domain or host. -func HipChatServer(host string) oauth2.Endpoint { +// AzureADB2CEndpoint returns a new oauth2.Endpoint for the given tenant and policy at Azure Active Directory B2C. +// policy is the Azure B2C User flow name Example: `B2C_1_SignUpSignIn`. +// +// Documentation: https://docs.microsoft.com/en-us/azure/active-directory-b2c/tokens-overview#endpoints +func AzureADB2CEndpoint(tenant string, policy string) oauth2.Endpoint { + u := url.URL{ + Scheme: "https", + Host: tenant + ".b2clogin.com", + } return oauth2.Endpoint{ - AuthURL: "https://" + host + "/users/authorize", - TokenURL: "https://" + host + "/v2/oauth/token", + AuthURL: u.JoinPath(tenant+".onmicrosoft.com", policy, "/oauth2/v2.0/authorize").String(), + TokenURL: u.JoinPath(tenant+".onmicrosoft.com", policy, "/oauth2/v2.0/token").String(), } } @@ -264,9 +424,42 @@ func HipChatServer(host string) oauth2.Endpoint { // https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-pools-assign-domain.html // https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-userpools-server-contract-reference.html func AWSCognito(domain string) oauth2.Endpoint { - domain = strings.TrimRight(domain, "/") + u, err := url.Parse(domain) + if err != nil || u.Scheme == "" || u.Host == "" { + panic("endpoints: invalid domain" + domain) + } + return oauth2.Endpoint{ + AuthURL: u.JoinPath("/oauth2/authorize").String(), + TokenURL: u.JoinPath("/oauth2/token").String(), + } +} + +// HipChatServer returns a new oauth2.Endpoint for a HipChat Server instance. +// host should be a hostname, without any scheme prefix. +// +// Documentation: https://developer.atlassian.com/server/hipchat/hipchat-rest-api-access-tokens/ +func HipChatServer(host string) oauth2.Endpoint { + u := url.URL{ + Scheme: "https", + Host: host, + } + return oauth2.Endpoint{ + AuthURL: u.JoinPath("/users/authorize").String(), + TokenURL: u.JoinPath("/v2/oauth/token").String(), + } +} + +// Shopify returns a new oauth2.Endpoint for the supplied shop domain name. +// host should be a hostname, without any scheme prefix. +// +// Documentation: https://shopify.dev/docs/apps/auth/oauth +func Shopify(host string) oauth2.Endpoint { + u := url.URL{ + Scheme: "https", + Host: host, + } return oauth2.Endpoint{ - AuthURL: domain + "/oauth2/authorize", - TokenURL: domain + "/oauth2/token", + AuthURL: u.JoinPath("/admin/oauth/authorize").String(), + TokenURL: u.JoinPath("/admin/oauth/access_token").String(), } } diff --git a/go.mod b/go.mod index da302fb45..48950e730 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,4 @@ module golang.org/x/oauth2 go 1.23.0 -require ( - cloud.google.com/go/compute/metadata v0.3.0 - github.com/google/go-cmp v0.5.9 -) +require cloud.google.com/go/compute/metadata v0.3.0 diff --git a/go.sum b/go.sum index 0c9052866..3eecfcce7 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,2 @@ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= diff --git a/google/downscope/downscoping.go b/google/downscope/downscoping.go index ebe8b0509..f704f7eef 100644 --- a/google/downscope/downscoping.go +++ b/google/downscope/downscoping.go @@ -39,7 +39,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/url" "strings" @@ -198,7 +198,7 @@ func (dts downscopingTokenSource) Token() (*oauth2.Token, error) { return nil, fmt.Errorf("unable to generate POST Request %v", err) } defer resp.Body.Close() - respBody, err := ioutil.ReadAll(resp.Body) + respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("downscope: unable to read response body: %v", err) } diff --git a/google/downscope/downscoping_test.go b/google/downscope/downscoping_test.go index ecdd98691..bd1684de7 100644 --- a/google/downscope/downscoping_test.go +++ b/google/downscope/downscoping_test.go @@ -6,7 +6,7 @@ package downscope import ( "context" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -27,7 +27,7 @@ func Test_DownscopedTokenSource(t *testing.T) { if r.URL.String() != "/" { t.Errorf("Unexpected request URL, %v is found", r.URL) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("Failed to read request body: %v", err) } diff --git a/google/example_test.go b/google/example_test.go index 3fc9cad3f..e19d7e6be 100644 --- a/google/example_test.go +++ b/google/example_test.go @@ -7,9 +7,9 @@ package google_test import ( "context" "fmt" - "io/ioutil" "log" "net/http" + "os" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -60,7 +60,7 @@ func ExampleJWTConfigFromJSON() { // To create a service account client, click "Create new Client ID", // select "Service Account", and click "Create Client ID". A JSON // key file will then be downloaded to your computer. - data, err := ioutil.ReadFile("/path/to/your-project-key.json") + data, err := os.ReadFile("/path/to/your-project-key.json") if err != nil { log.Fatal(err) } @@ -136,7 +136,7 @@ func ExampleComputeTokenSource() { func ExampleCredentialsFromJSON() { ctx := context.Background() - data, err := ioutil.ReadFile("/path/to/key-file.json") + data, err := os.ReadFile("/path/to/key-file.json") if err != nil { log.Fatal(err) } diff --git a/google/externalaccount/aws.go b/google/externalaccount/aws.go index 55d59999e..e1a735e01 100644 --- a/google/externalaccount/aws.go +++ b/google/externalaccount/aws.go @@ -14,7 +14,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/url" "os" @@ -170,7 +169,7 @@ func requestDataHash(req *http.Request) (string, error) { } defer requestBody.Close() - requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20)) + requestData, err = io.ReadAll(io.LimitReader(requestBody, 1<<20)) if err != nil { return "", err } @@ -419,7 +418,7 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { } defer resp.Body.Close() - respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return "", err } @@ -462,7 +461,7 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err } defer resp.Body.Close() - respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return "", err } @@ -531,7 +530,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, h } defer resp.Body.Close() - respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return result, err } @@ -564,7 +563,7 @@ func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (s } defer resp.Body.Close() - respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return "", err } diff --git a/google/externalaccount/basecredentials.go b/google/externalaccount/basecredentials.go index aa0bba2eb..6f7662170 100644 --- a/google/externalaccount/basecredentials.go +++ b/google/externalaccount/basecredentials.go @@ -486,11 +486,11 @@ func (ts tokenSource) Token() (*oauth2.Token, error) { ClientID: conf.ClientID, ClientSecret: conf.ClientSecret, } - var options map[string]interface{} + var options map[string]any // Do not pass workforce_pool_user_project when client authentication is used. // The client ID is sufficient for determining the user project. if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" { - options = map[string]interface{}{ + options = map[string]any{ "userProject": conf.WorkforcePoolUserProject, } } diff --git a/google/externalaccount/basecredentials_test.go b/google/externalaccount/basecredentials_test.go index d52f6a789..31a79fc84 100644 --- a/google/externalaccount/basecredentials_test.go +++ b/google/externalaccount/basecredentials_test.go @@ -8,7 +8,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -77,7 +77,7 @@ func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.T if got, want := headerMetrics, tets.metricsHeader; got != want { t.Errorf("got %v but want %v", got, want) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("Failed reading request body: %s.", err) } @@ -131,7 +131,7 @@ func createImpersonationServer(urlWanted, authWanted, bodyWanted, response strin if got, want := headerContentType, "application/json"; got != want { t.Errorf("got %v but want %v", got, want) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("Failed reading request body: %v.", err) } @@ -160,7 +160,7 @@ func createTargetServer(metricsHeaderWanted string, t *testing.T) *httptest.Serv if got, want := headerMetrics, metricsHeaderWanted; got != want { t.Errorf("got %v but want %v", got, want) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("Failed reading request body: %v.", err) } diff --git a/google/externalaccount/executablecredsource.go b/google/externalaccount/executablecredsource.go index dca5681a4..b173c61f0 100644 --- a/google/externalaccount/executablecredsource.go +++ b/google/externalaccount/executablecredsource.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "os" "os/exec" "regexp" @@ -258,7 +257,7 @@ func (cs executableCredentialSource) getTokenFromOutputFile() (token string, err } defer file.Close() - data, err := ioutil.ReadAll(io.LimitReader(file, 1<<20)) + data, err := io.ReadAll(io.LimitReader(file, 1<<20)) if err != nil || len(data) == 0 { // Cachefile exists, but no data found. Get new credential. return "", nil diff --git a/google/externalaccount/executablecredsource_test.go b/google/externalaccount/executablecredsource_test.go index 3ecc05f92..fd0c79fc5 100644 --- a/google/externalaccount/executablecredsource_test.go +++ b/google/externalaccount/executablecredsource_test.go @@ -8,13 +8,10 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" "os" - "sort" + "slices" "testing" "time" - - "github.com/google/go-cmp/cmp" ) type testEnvironment struct { @@ -254,14 +251,12 @@ func TestExecutableCredentialGetEnvironment(t *testing.T) { ecs.env = &tt.environment - // This Transformer sorts a []string. - sorter := cmp.Transformer("Sort", func(in []string) []string { - out := append([]string(nil), in...) // Copy input to avoid mutating it - sort.Strings(out) - return out - }) + got := ecs.executableEnvironment() + slices.Sort(got) + want := tt.expectedEnvironment + slices.Sort(want) - if got, want := ecs.executableEnvironment(), tt.expectedEnvironment; !cmp.Equal(got, want, sorter) { + if !slices.Equal(got, want) { t.Errorf("Incorrect environment received.\nReceived: %s\nExpected: %s", got, want) } }) @@ -614,7 +609,7 @@ func TestRetrieveExecutableSubjectTokenSuccesses(t *testing.T) { } func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) { - outputFile, err := ioutil.TempFile("testdata", "result.*.json") + outputFile, err := os.CreateTemp("testdata", "result.*.json") if err != nil { t.Fatalf("Tempfile failed: %v", err) } @@ -763,7 +758,7 @@ var cacheFailureTests = []struct { func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) { for _, tt := range cacheFailureTests { t.Run(tt.name, func(t *testing.T) { - outputFile, err := ioutil.TempFile("testdata", "result.*.json") + outputFile, err := os.CreateTemp("testdata", "result.*.json") if err != nil { t.Fatalf("Tempfile failed: %v", err) } @@ -866,7 +861,7 @@ var invalidCacheTests = []struct { func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) { for _, tt := range invalidCacheTests { t.Run(tt.name, func(t *testing.T) { - outputFile, err := ioutil.TempFile("testdata", "result.*.json") + outputFile, err := os.CreateTemp("testdata", "result.*.json") if err != nil { t.Fatalf("Tempfile failed: %v", err) } @@ -970,8 +965,7 @@ var cacheSuccessTests = []struct { func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) { for _, tt := range cacheSuccessTests { t.Run(tt.name, func(t *testing.T) { - - outputFile, err := ioutil.TempFile("testdata", "result.*.json") + outputFile, err := os.CreateTemp("testdata", "result.*.json") if err != nil { t.Fatalf("Tempfile failed: %v", err) } diff --git a/google/externalaccount/filecredsource.go b/google/externalaccount/filecredsource.go index 33766b972..46ebc1836 100644 --- a/google/externalaccount/filecredsource.go +++ b/google/externalaccount/filecredsource.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "os" ) @@ -29,14 +28,14 @@ func (cs fileCredentialSource) subjectToken() (string, error) { return "", fmt.Errorf("oauth2/google/externalaccount: failed to open credential file %q", cs.File) } defer tokenFile.Close() - tokenBytes, err := ioutil.ReadAll(io.LimitReader(tokenFile, 1<<20)) + tokenBytes, err := io.ReadAll(io.LimitReader(tokenFile, 1<<20)) if err != nil { return "", fmt.Errorf("oauth2/google/externalaccount: failed to read credential file: %v", err) } tokenBytes = bytes.TrimSpace(tokenBytes) switch cs.Format.Type { case "json": - jsonData := make(map[string]interface{}) + jsonData := make(map[string]any) err = json.Unmarshal(tokenBytes, &jsonData) if err != nil { return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err) diff --git a/google/externalaccount/header_test.go b/google/externalaccount/header_test.go index 39f279deb..bd59a0987 100644 --- a/google/externalaccount/header_test.go +++ b/google/externalaccount/header_test.go @@ -7,8 +7,6 @@ package externalaccount import ( "runtime" "testing" - - "github.com/google/go-cmp/cmp" ) func TestGoVersion(t *testing.T) { @@ -40,8 +38,8 @@ func TestGoVersion(t *testing.T) { } { version = tst.v got := goVersion() - if diff := cmp.Diff(got, tst.want); diff != "" { - t.Errorf("got(-),want(+):\n%s", diff) + if got != tst.want { + t.Errorf("go version = %q, want = %q", got, tst.want) } } version = runtime.Version diff --git a/google/externalaccount/urlcredsource.go b/google/externalaccount/urlcredsource.go index 71a7184e0..65bfd2046 100644 --- a/google/externalaccount/urlcredsource.go +++ b/google/externalaccount/urlcredsource.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "golang.org/x/oauth2" @@ -44,7 +43,7 @@ func (cs urlCredentialSource) subjectToken() (string, error) { } defer resp.Body.Close() - respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return "", fmt.Errorf("oauth2/google/externalaccount: invalid body in subject token URL query: %v", err) } @@ -54,7 +53,7 @@ func (cs urlCredentialSource) subjectToken() (string, error) { switch cs.Format.Type { case "json": - jsonData := make(map[string]interface{}) + jsonData := make(map[string]any) err = json.Unmarshal(respBody, &jsonData) if err != nil { return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err) diff --git a/google/google.go b/google/google.go index 7b82e7a08..e2eb9c927 100644 --- a/google/google.go +++ b/google/google.go @@ -285,27 +285,23 @@ func (cs computeSource) Token() (*oauth2.Token, error) { if err != nil { return nil, err } - var res struct { - AccessToken string `json:"access_token"` - ExpiresInSec int `json:"expires_in"` - TokenType string `json:"token_type"` - } + var res oauth2.Token err = json.NewDecoder(strings.NewReader(tokenJSON)).Decode(&res) if err != nil { return nil, fmt.Errorf("oauth2/google: invalid token JSON from metadata: %v", err) } - if res.ExpiresInSec == 0 || res.AccessToken == "" { + if res.ExpiresIn == 0 || res.AccessToken == "" { return nil, fmt.Errorf("oauth2/google: incomplete token received from metadata") } tok := &oauth2.Token{ AccessToken: res.AccessToken, TokenType: res.TokenType, - Expiry: time.Now().Add(time.Duration(res.ExpiresInSec) * time.Second), + Expiry: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second), } // NOTE(cbro): add hidden metadata about where the token is from. // This is needed for detection by client libraries to know that credentials come from the metadata server. // This may be removed in a future version of this library. - return tok.WithExtra(map[string]interface{}{ + return tok.WithExtra(map[string]any{ "oauth2.google.tokenSource": "compute-metadata", "oauth2.google.serviceAccount": acct, }), nil diff --git a/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go b/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go index 1bbbbac19..bcb9f5cd4 100644 --- a/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go +++ b/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go @@ -8,7 +8,7 @@ import ( "context" "encoding/json" "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" @@ -227,7 +227,7 @@ func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) { if got, want := headerContentType, trts.ContentType; got != want { t.Errorf("got %v but want %v", got, want) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("Failed reading request body: %s.", err) } diff --git a/google/internal/impersonate/impersonate.go b/google/internal/impersonate/impersonate.go index 6bc3af110..eaa8b5c71 100644 --- a/google/internal/impersonate/impersonate.go +++ b/google/internal/impersonate/impersonate.go @@ -10,7 +10,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "time" @@ -81,7 +80,7 @@ func (its ImpersonateTokenSource) Token() (*oauth2.Token, error) { return nil, fmt.Errorf("oauth2/google: unable to generate access token: %v", err) } defer resp.Body.Close() - body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return nil, fmt.Errorf("oauth2/google: unable to read body: %v", err) } diff --git a/google/internal/stsexchange/sts_exchange.go b/google/internal/stsexchange/sts_exchange.go index 1a0bebd15..edf700e21 100644 --- a/google/internal/stsexchange/sts_exchange.go +++ b/google/internal/stsexchange/sts_exchange.go @@ -9,7 +9,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/url" "strconv" @@ -28,7 +27,7 @@ func defaultHeader() http.Header { // The first 4 fields are all mandatory. headers can be used to pass additional // headers beyond the bare minimum required by the token exchange. options can // be used to pass additional JSON-structured options to the remote server. -func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) { +func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]any) (*Response, error) { data := url.Values{} data.Set("audience", request.Audience) data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") @@ -82,7 +81,7 @@ func makeRequest(ctx context.Context, endpoint string, data url.Values, authenti } defer resp.Body.Close() - body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return nil, err } diff --git a/google/internal/stsexchange/sts_exchange_test.go b/google/internal/stsexchange/sts_exchange_test.go index ff9a9ad08..1cac58269 100644 --- a/google/internal/stsexchange/sts_exchange_test.go +++ b/google/internal/stsexchange/sts_exchange_test.go @@ -7,7 +7,7 @@ package stsexchange import ( "context" "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "net/url" @@ -73,7 +73,7 @@ func TestExchangeToken(t *testing.T) { if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Failed reading request body: %v.", err) } @@ -132,7 +132,7 @@ var optsValues = [][]string{{"foo", "bar"}, {"cat", "pan"}} func TestExchangeToken_Opts(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("Failed reading request body: %v.", err) } @@ -146,7 +146,7 @@ func TestExchangeToken_Opts(t *testing.T) { } else if len(strOpts) < 1 { t.Errorf("\"options\" field has length 0.") } - var opts map[string]interface{} + var opts map[string]any err = json.Unmarshal([]byte(strOpts[0]), &opts) if err != nil { t.Fatalf("Couldn't parse received \"options\" field.") @@ -159,7 +159,7 @@ func TestExchangeToken_Opts(t *testing.T) { if !ok { t.Errorf("Couldn't find first option parameter.") } else { - tOpts1, ok := val.(map[string]interface{}) + tOpts1, ok := val.(map[string]any) if !ok { t.Errorf("Failed to assert the first option parameter as type testOpts.") } else { @@ -176,7 +176,7 @@ func TestExchangeToken_Opts(t *testing.T) { if !ok { t.Errorf("Couldn't find second option parameter.") } else { - tOpts2, ok := val2.(map[string]interface{}) + tOpts2, ok := val2.(map[string]any) if !ok { t.Errorf("Failed to assert the second option parameter as type testOpts.") } else { @@ -200,7 +200,7 @@ func TestExchangeToken_Opts(t *testing.T) { firstOption := testOpts{optsValues[0][0], optsValues[0][1]} secondOption := testOpts{optsValues[1][0], optsValues[1][1]} - inputOpts := make(map[string]interface{}) + inputOpts := make(map[string]any) inputOpts["one"] = firstOption inputOpts["two"] = secondOption ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts) @@ -220,7 +220,7 @@ func TestRefreshToken(t *testing.T) { if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Failed reading request body: %v.", err) } diff --git a/internal/doc.go b/internal/doc.go index 03265e888..8c7c475f2 100644 --- a/internal/doc.go +++ b/internal/doc.go @@ -2,5 +2,5 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package internal contains support packages for oauth2 package. +// Package internal contains support packages for [golang.org/x/oauth2]. package internal diff --git a/internal/oauth2.go b/internal/oauth2.go index 14989beaf..71ea6ad1f 100644 --- a/internal/oauth2.go +++ b/internal/oauth2.go @@ -13,7 +13,7 @@ import ( ) // ParseKey converts the binary contents of a private key file -// to an *rsa.PrivateKey. It detects whether the private key is in a +// to an [*rsa.PrivateKey]. It detects whether the private key is in a // PEM container or not. If so, it extracts the private key // from PEM container before conversion. It only supports PEM // containers with no passphrase. diff --git a/internal/token.go b/internal/token.go index e83ddeef0..8389f2462 100644 --- a/internal/token.go +++ b/internal/token.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "math" "mime" "net/http" @@ -26,9 +25,9 @@ import ( // the requests to access protected resources on the OAuth 2.0 // provider's backend. // -// This type is a mirror of oauth2.Token and exists to break +// This type is a mirror of [golang.org/x/oauth2.Token] and exists to break // an otherwise-circular dependency. Other internal packages -// should convert this Token into an oauth2.Token before use. +// should convert this Token into an [golang.org/x/oauth2.Token] before use. type Token struct { // AccessToken is the token that authorizes and authenticates // the requests. @@ -50,9 +49,16 @@ type Token struct { // mechanisms for that TokenSource will not be used. Expiry time.Time + // ExpiresIn is the OAuth2 wire format "expires_in" field, + // which specifies how many seconds later the token expires, + // relative to an unknown time base approximately around "now". + // It is the application's responsibility to populate + // `Expiry` from `ExpiresIn` when required. + ExpiresIn int64 `json:"expires_in,omitempty"` + // Raw optionally contains extra metadata from the server // when updating a token. - Raw interface{} + Raw any } // tokenJSON is the struct representing the HTTP response from OAuth2 @@ -99,14 +105,6 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error { return nil } -// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. -// -// Deprecated: this function no longer does anything. Caller code that -// wants to avoid potential extra HTTP requests made during -// auto-probing of the provider's auth style should set -// Endpoint.AuthStyle. -func RegisterBrokenAuthHeaderProvider(tokenURL string) {} - // AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. type AuthStyle int @@ -143,6 +141,11 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { return c } +type authStyleCacheKey struct { + url string + clientID string +} + // AuthStyleCache is the set of tokenURLs we've successfully used via // RetrieveToken and which style auth we ended up using. // It's called a cache, but it doesn't (yet?) shrink. It's expected that @@ -150,26 +153,26 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { // small. type AuthStyleCache struct { mu sync.Mutex - m map[string]AuthStyle // keyed by tokenURL + m map[authStyleCacheKey]AuthStyle } // lookupAuthStyle reports which auth style we last used with tokenURL // when calling RetrieveToken and whether we have ever done so. -func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { +func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) { c.mu.Lock() defer c.mu.Unlock() - style, ok = c.m[tokenURL] + style, ok = c.m[authStyleCacheKey{tokenURL, clientID}] return } // setAuthStyle adds an entry to authStyleCache, documented above. -func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) { +func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) { c.mu.Lock() defer c.mu.Unlock() if c.m == nil { - c.m = make(map[string]AuthStyle) + c.m = make(map[authStyleCacheKey]AuthStyle) } - c.m[tokenURL] = v + c.m[authStyleCacheKey{tokenURL, clientID}] = v } // newTokenRequest returns a new *http.Request to retrieve a new token @@ -210,9 +213,9 @@ func cloneURLValues(v url.Values) url.Values { } func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) { - needsAuthStyleProbe := authStyle == 0 + needsAuthStyleProbe := authStyle == AuthStyleUnknown if needsAuthStyleProbe { - if style, ok := styleCache.lookupAuthStyle(tokenURL); ok { + if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok { authStyle = style needsAuthStyleProbe = false } else { @@ -242,7 +245,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, token, err = doTokenRoundTrip(ctx, req) } if needsAuthStyleProbe && err == nil { - styleCache.setAuthStyle(tokenURL, authStyle) + styleCache.setAuthStyle(tokenURL, clientID, authStyle) } // Don't overwrite `RefreshToken` with an empty value // if this was a token refreshing request. @@ -257,7 +260,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { if err != nil { return nil, err } - body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) r.Body.Close() if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) @@ -312,7 +315,8 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) { TokenType: tj.TokenType, RefreshToken: tj.RefreshToken, Expiry: tj.expiry(), - Raw: make(map[string]interface{}), + ExpiresIn: int64(tj.ExpiresIn), + Raw: make(map[string]any), } json.Unmarshal(body, &token.Raw) // no error checks for optional fields } diff --git a/internal/token_test.go b/internal/token_test.go index c08862ae6..ef28c1162 100644 --- a/internal/token_test.go +++ b/internal/token_test.go @@ -75,3 +75,48 @@ func TestExpiresInUpperBound(t *testing.T) { t.Errorf("expiration time = %v; want %v", e, want) } } + +func TestAuthStyleCache(t *testing.T) { + var c LazyAuthStyleCache + + cases := []struct { + url string + clientID string + style AuthStyle + }{ + { + "https://host1.example.com/token", + "client_1", + AuthStyleInHeader, + }, { + "https://host2.example.com/token", + "client_2", + AuthStyleInParams, + }, { + "https://host1.example.com/token", + "client_3", + AuthStyleInParams, + }, + } + + for _, tt := range cases { + t.Run(tt.clientID, func(t *testing.T) { + cc := c.Get() + got, ok := cc.lookupAuthStyle(tt.url, tt.clientID) + if ok { + t.Fatalf("unexpected auth style found on first request: %v", got) + } + + cc.setAuthStyle(tt.url, tt.clientID, tt.style) + + got, ok = cc.lookupAuthStyle(tt.url, tt.clientID) + if !ok { + t.Fatalf("auth style not found in cache") + } + + if got != tt.style { + t.Fatalf("auth style mismatch, got=%v, want=%v", got, tt.style) + } + }) + } +} diff --git a/internal/transport.go b/internal/transport.go index b9db01ddf..afc0aeb27 100644 --- a/internal/transport.go +++ b/internal/transport.go @@ -9,8 +9,8 @@ import ( "net/http" ) -// HTTPClient is the context key to use with golang.org/x/net/context's -// WithValue function to associate an *http.Client value with a context. +// HTTPClient is the context key to use with [context.WithValue] +// to associate an [*http.Client] value with a context. var HTTPClient ContextKey // ContextKey is just an empty struct. It exists so HTTPClient can be diff --git a/jira/jira.go b/jira/jira.go index 814656e9e..0a28d1ef0 100644 --- a/jira/jira.go +++ b/jira/jira.go @@ -13,7 +13,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/url" "strings" @@ -114,7 +113,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } defer resp.Body.Close() - body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } @@ -123,11 +122,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) { } // tokenRes is the JSON response body. - var tokenRes struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int64 `json:"expires_in"` // relative seconds from now - } + var tokenRes oauth2.Token if err := json.Unmarshal(body, &tokenRes); err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } diff --git a/jws/jws.go b/jws/jws.go index 27ab06139..9bc484406 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -4,7 +4,7 @@ // Package jws provides a partial implementation // of JSON Web Signature encoding and decoding. -// It exists to support the golang.org/x/oauth2 package. +// It exists to support the [golang.org/x/oauth2] package. // // See RFC 7515. // @@ -48,7 +48,7 @@ type ClaimSet struct { // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 // This array is marshalled using custom code (see (c *ClaimSet) encode()). - PrivateClaims map[string]interface{} `json:"-"` + PrivateClaims map[string]any `json:"-"` } func (c *ClaimSet) encode() (string, error) { @@ -152,7 +152,7 @@ func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) { } // Encode encodes a signed JWS with provided header and claim set. -// This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key. +// This invokes [EncodeWithSigner] using [crypto/rsa.SignPKCS1v15] with the given RSA private key. func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { sg := func(data []byte) (sig []byte, err error) { h := sha256.New() diff --git a/jwt/jwt.go b/jwt/jwt.go index b2bf18298..38a92daca 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -13,7 +13,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "net/url" "strings" @@ -69,7 +68,7 @@ type Config struct { // PrivateClaims optionally specifies custom private claims in the JWT. // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 - PrivateClaims map[string]interface{} + PrivateClaims map[string]any // UseIDToken optionally specifies whether ID token should be used instead // of access token when the server returns both. @@ -136,7 +135,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } defer resp.Body.Close() - body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20)) + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) } @@ -148,10 +147,8 @@ func (js jwtSource) Token() (*oauth2.Token, error) { } // tokenRes is the JSON response body. var tokenRes struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - IDToken string `json:"id_token"` - ExpiresIn int64 `json:"expires_in"` // relative seconds from now + oauth2.Token + IDToken string `json:"id_token"` } if err := json.Unmarshal(body, &tokenRes); err != nil { return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) @@ -160,7 +157,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) { AccessToken: tokenRes.AccessToken, TokenType: tokenRes.TokenType, } - raw := make(map[string]interface{}) + raw := make(map[string]any) json.Unmarshal(body, &raw) // no error checks for optional fields token = token.WithExtra(raw) diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 9772dc520..c7619a10a 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -227,7 +227,7 @@ func TestJWTFetch_AssertionPayload(t *testing.T) { PrivateKey: dummyPrivateKey, PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", TokenURL: ts.URL, - PrivateClaims: map[string]interface{}{ + PrivateClaims: map[string]any{ "private0": "claim0", "private1": "claim1", }, @@ -273,11 +273,11 @@ func TestJWTFetch_AssertionPayload(t *testing.T) { t.Errorf("payload prn = %q; want %q", got, want) } if len(conf.PrivateClaims) > 0 { - var got interface{} + var got any if err := json.Unmarshal(gotjson, &got); err != nil { t.Errorf("failed to parse payload; err = %q", err) } - m := got.(map[string]interface{}) + m := got.(map[string]any) for v, k := range conf.PrivateClaims { if !reflect.DeepEqual(m[v], k) { t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k) diff --git a/oauth2.go b/oauth2.go index eacdd7fd9..de34feb84 100644 --- a/oauth2.go +++ b/oauth2.go @@ -22,9 +22,9 @@ import ( ) // NoContext is the default context you should supply if not using -// your own context.Context (see https://golang.org/x/net/context). +// your own [context.Context]. // -// Deprecated: Use context.Background() or context.TODO() instead. +// Deprecated: Use [context.Background] or [context.TODO] instead. var NoContext = context.TODO() // RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. @@ -37,8 +37,8 @@ func RegisterBrokenAuthHeaderProvider(tokenURL string) {} // Config describes a typical 3-legged OAuth2 flow, with both the // client application information and the server's endpoint URLs. -// For the client credentials 2-legged OAuth2 flow, see the clientcredentials -// package (https://golang.org/x/oauth2/clientcredentials). +// For the client credentials 2-legged OAuth2 flow, see the +// [golang.org/x/oauth2/clientcredentials] package. type Config struct { // ClientID is the application's ID. ClientID string @@ -46,7 +46,7 @@ type Config struct { // ClientSecret is the application's secret. ClientSecret string - // Endpoint contains the resource server's token endpoint + // Endpoint contains the authorization server's token endpoint // URLs. These are constants specific to each server and are // often available via site-specific packages, such as // google.Endpoint or github.Endpoint. @@ -135,7 +135,7 @@ type setParam struct{ k, v string } func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) } -// SetAuthURLParam builds an AuthCodeOption which passes key/value parameters +// SetAuthURLParam builds an [AuthCodeOption] which passes key/value parameters // to a provider's authorization endpoint. func SetAuthURLParam(key, value string) AuthCodeOption { return setParam{key, value} @@ -148,8 +148,8 @@ func SetAuthURLParam(key, value string) AuthCodeOption { // request and callback. The authorization server includes this value when // redirecting the user agent back to the client. // -// Opts may include AccessTypeOnline or AccessTypeOffline, as well -// as ApprovalForce. +// Opts may include [AccessTypeOnline] or [AccessTypeOffline], as well +// as [ApprovalForce]. // // To protect against CSRF attacks, opts should include a PKCE challenge // (S256ChallengeOption). Not all servers support PKCE. An alternative is to @@ -194,7 +194,7 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { // and when other authorization grant types are not available." // See https://tools.ietf.org/html/rfc6749#section-4.3 for more info. // -// The provided context optionally controls which HTTP client is used. See the HTTPClient variable. +// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable. func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { v := url.Values{ "grant_type": {"password"}, @@ -212,10 +212,10 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor // It is used after a resource provider redirects the user back // to the Redirect URI (the URL obtained from AuthCodeURL). // -// The provided context optionally controls which HTTP client is used. See the HTTPClient variable. +// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable. // -// The code will be in the *http.Request.FormValue("code"). Before -// calling Exchange, be sure to validate FormValue("state") if you are +// The code will be in the [http.Request.FormValue]("code"). Before +// calling Exchange, be sure to validate [http.Request.FormValue]("state") if you are // using it to protect against CSRF attacks. // // If using PKCE to protect against CSRF attacks, opts should include a @@ -242,10 +242,10 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client { return NewClient(ctx, c.TokenSource(ctx, t)) } -// TokenSource returns a TokenSource that returns t until t expires, +// TokenSource returns a [TokenSource] that returns t until t expires, // automatically refreshing it as necessary using the provided context. // -// Most users will use Config.Client instead. +// Most users will use [Config.Client] instead. func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { tkr := &tokenRefresher{ ctx: ctx, @@ -260,7 +260,7 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { } } -// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" +// tokenRefresher is a TokenSource that makes "grant_type=refresh_token" // HTTP requests to renew a token using a RefreshToken. type tokenRefresher struct { ctx context.Context // used to get HTTP requests @@ -305,8 +305,7 @@ type reuseTokenSource struct { } // Token returns the current token if it's still valid, else will -// refresh the current token (using r.Context for HTTP client -// information) and return the new one. +// refresh the current token and return the new one. func (s *reuseTokenSource) Token() (*Token, error) { s.mu.Lock() defer s.mu.Unlock() @@ -322,7 +321,7 @@ func (s *reuseTokenSource) Token() (*Token, error) { return t, nil } -// StaticTokenSource returns a TokenSource that always returns the same token. +// StaticTokenSource returns a [TokenSource] that always returns the same token. // Because the provided token t is never refreshed, StaticTokenSource is only // useful for tokens that never expire. func StaticTokenSource(t *Token) TokenSource { @@ -338,16 +337,16 @@ func (s staticTokenSource) Token() (*Token, error) { return s.t, nil } -// HTTPClient is the context key to use with golang.org/x/net/context's -// WithValue function to associate an *http.Client value with a context. +// HTTPClient is the context key to use with [context.WithValue] +// to associate a [*http.Client] value with a context. var HTTPClient internal.ContextKey -// NewClient creates an *http.Client from a Context and TokenSource. +// NewClient creates an [*http.Client] from a [context.Context] and [TokenSource]. // The returned client is not valid beyond the lifetime of the context. // -// Note that if a custom *http.Client is provided via the Context it +// Note that if a custom [*http.Client] is provided via the [context.Context] it // is used only for token acquisition and is not used to configure the -// *http.Client returned from NewClient. +// [*http.Client] returned from NewClient. // // As a special case, if src is nil, a non-OAuth2 client is returned // using the provided context. This exists to support related OAuth2 @@ -368,7 +367,7 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client { } } -// ReuseTokenSource returns a TokenSource which repeatedly returns the +// ReuseTokenSource returns a [TokenSource] which repeatedly returns the // same token as long as it's valid, starting with t. // When its cached token is invalid, a new token is obtained from src. // @@ -376,10 +375,10 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client { // (such as a file on disk) between runs of a program, rather than // obtaining new tokens unnecessarily. // -// The initial token t may be nil, in which case the TokenSource is +// The initial token t may be nil, in which case the [TokenSource] is // wrapped in a caching version if it isn't one already. This also // means it's always safe to wrap ReuseTokenSource around any other -// TokenSource without adverse effects. +// [TokenSource] without adverse effects. func ReuseTokenSource(t *Token, src TokenSource) TokenSource { // Don't wrap a reuseTokenSource in itself. That would work, // but cause an unnecessary number of mutex operations. @@ -397,8 +396,8 @@ func ReuseTokenSource(t *Token, src TokenSource) TokenSource { } } -// ReuseTokenSourceWithExpiry returns a TokenSource that acts in the same manner as the -// TokenSource returned by ReuseTokenSource, except the expiry buffer is +// ReuseTokenSourceWithExpiry returns a [TokenSource] that acts in the same manner as the +// [TokenSource] returned by [ReuseTokenSource], except the expiry buffer is // configurable. The expiration time of a token is calculated as // t.Expiry.Add(-earlyExpiry). func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource { diff --git a/oauth2_test.go b/oauth2_test.go index 37f0580d7..5db78f21e 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -104,7 +103,7 @@ func TestExchangeRequest(t *testing.T) { if headerContentType != "application/x-www-form-urlencoded" { t.Errorf("Unexpected Content-Type header %q", headerContentType) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Failed reading request body: %s.", err) } @@ -148,7 +147,7 @@ func TestExchangeRequest_CustomParam(t *testing.T) { if headerContentType != "application/x-www-form-urlencoded" { t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Failed reading request body: %s.", err) } @@ -194,7 +193,7 @@ func TestExchangeRequest_JSONResponse(t *testing.T) { if headerContentType != "application/x-www-form-urlencoded" { t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Failed reading request body: %s.", err) } @@ -301,7 +300,7 @@ func testExchangeRequest_JSONResponse_expiry(t *testing.T, exp string, want, nul conf := newConf(ts.URL) t1 := time.Now().Add(day) tok, err := conf.Exchange(context.Background(), "exchange-code") - t2 := t1.Add(day) + t2 := time.Now().Add(day) if got := (err == nil); got != want { if want { @@ -393,7 +392,7 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { if headerContentType != expected { t.Errorf("Content-Type header = %q; want %q", headerContentType, expected) } - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Failed reading request body: %s.", err) } @@ -435,7 +434,7 @@ func TestTokenRefreshRequest(t *testing.T) { if headerContentType != "application/x-www-form-urlencoded" { t.Errorf("Unexpected Content-Type header %q", headerContentType) } - body, _ := ioutil.ReadAll(r.Body) + body, _ := io.ReadAll(r.Body) if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { t.Errorf("Unexpected refresh token payload %q", body) } @@ -460,7 +459,7 @@ func TestFetchWithNoRefreshToken(t *testing.T) { if headerContentType != "application/x-www-form-urlencoded" { t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) } - body, _ := ioutil.ReadAll(r.Body) + body, _ := io.ReadAll(r.Body) if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) } diff --git a/pkce.go b/pkce.go index 6a95da975..cea8374d5 100644 --- a/pkce.go +++ b/pkce.go @@ -1,6 +1,7 @@ // Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. + package oauth2 import ( @@ -20,9 +21,9 @@ const ( // This follows recommendations in RFC 7636. // // A fresh verifier should be generated for each authorization. -// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL -// (or Config.DeviceAuth) and VerifierOption(verifier) to Config.Exchange -// (or Config.DeviceAccessToken). +// The resulting verifier should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth] +// with [S256ChallengeOption], and to [Config.Exchange] or [Config.DeviceAccessToken] +// with [VerifierOption]. func GenerateVerifier() string { // "RECOMMENDED that the output of a suitable random number generator be // used to create a 32-octet sequence. The octet sequence is then @@ -36,22 +37,22 @@ func GenerateVerifier() string { return base64.RawURLEncoding.EncodeToString(data) } -// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be -// passed to Config.Exchange or Config.DeviceAccessToken only. +// VerifierOption returns a PKCE code verifier [AuthCodeOption]. It should only be +// passed to [Config.Exchange] or [Config.DeviceAccessToken]. func VerifierOption(verifier string) AuthCodeOption { return setParam{k: codeVerifierKey, v: verifier} } // S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256. // -// Prefer to use S256ChallengeOption where possible. +// Prefer to use [S256ChallengeOption] where possible. func S256ChallengeFromVerifier(verifier string) string { sha := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(sha[:]) } // S256ChallengeOption derives a PKCE code challenge derived from verifier with -// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAuth +// method S256. It should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth] // only. func S256ChallengeOption(verifier string) AuthCodeOption { return challengeOption{ diff --git a/token.go b/token.go index 8c31136c4..239ec3296 100644 --- a/token.go +++ b/token.go @@ -44,7 +44,7 @@ type Token struct { // Expiry is the optional expiration time of the access token. // - // If zero, TokenSource implementations will reuse the same + // If zero, [TokenSource] implementations will reuse the same // token forever and RefreshToken or equivalent // mechanisms for that TokenSource will not be used. Expiry time.Time `json:"expiry,omitempty"` @@ -58,7 +58,7 @@ type Token struct { // raw optionally contains extra metadata from the server // when updating a token. - raw interface{} + raw any // expiryDelta is used to calculate when a token is considered // expired, by subtracting from Expiry. If zero, defaultExpiryDelta @@ -86,16 +86,16 @@ func (t *Token) Type() string { // SetAuthHeader sets the Authorization header to r using the access // token in t. // -// This method is unnecessary when using Transport or an HTTP Client +// This method is unnecessary when using [Transport] or an HTTP Client // returned by this package. func (t *Token) SetAuthHeader(r *http.Request) { r.Header.Set("Authorization", t.Type()+" "+t.AccessToken) } -// WithExtra returns a new Token that's a clone of t, but using the +// WithExtra returns a new [Token] that's a clone of t, but using the // provided raw extra map. This is only intended for use by packages // implementing derivative OAuth2 flows. -func (t *Token) WithExtra(extra interface{}) *Token { +func (t *Token) WithExtra(extra any) *Token { t2 := new(Token) *t2 = *t t2.raw = extra @@ -105,8 +105,8 @@ func (t *Token) WithExtra(extra interface{}) *Token { // Extra returns an extra field. // Extra fields are key-value pairs returned by the server as a // part of the token retrieval response. -func (t *Token) Extra(key string) interface{} { - if raw, ok := t.raw.(map[string]interface{}); ok { +func (t *Token) Extra(key string) any { + if raw, ok := t.raw.(map[string]any); ok { return raw[key] } @@ -163,6 +163,7 @@ func tokenFromInternal(t *internal.Token) *Token { TokenType: t.TokenType, RefreshToken: t.RefreshToken, Expiry: t.Expiry, + ExpiresIn: t.ExpiresIn, raw: t.Raw, } } diff --git a/token_test.go b/token_test.go index 0d8c7df86..5fa14fa8b 100644 --- a/token_test.go +++ b/token_test.go @@ -12,8 +12,8 @@ import ( func TestTokenExtra(t *testing.T) { type testCase struct { key string - val interface{} - want interface{} + val any + want any } const key = "extra-key" cases := []testCase{ @@ -23,7 +23,7 @@ func TestTokenExtra(t *testing.T) { {key: "other-key", val: "def", want: nil}, } for _, tc := range cases { - extra := make(map[string]interface{}) + extra := make(map[string]any) extra[tc.key] = tc.val tok := &Token{raw: extra} if got, want := tok.Extra(key), tc.want; got != want { diff --git a/transport.go b/transport.go index 90657915f..8bbebbac9 100644 --- a/transport.go +++ b/transport.go @@ -11,12 +11,12 @@ import ( "sync" ) -// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests, -// wrapping a base RoundTripper and adding an Authorization header -// with a token from the supplied Sources. +// Transport is an [http.RoundTripper] that makes OAuth 2.0 HTTP requests, +// wrapping a base [http.RoundTripper] and adding an Authorization header +// with a token from the supplied [TokenSource]. // // Transport is a low-level mechanism. Most code will use the -// higher-level Config.Client method instead. +// higher-level [Config.Client] method instead. type Transport struct { // Source supplies the token to add to outgoing requests' // Authorization headers. @@ -47,7 +47,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } - req2 := cloneRequest(req) // per RoundTripper contract + req2 := req.Clone(req.Context()) token.SetAuthHeader(req2) // req.Body is assumed to be closed by the base RoundTripper. @@ -73,17 +73,3 @@ func (t *Transport) base() http.RoundTripper { } return http.DefaultTransport } - -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map. -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header, len(r.Header)) - for k, s := range r.Header { - r2.Header[k] = append([]string(nil), s...) - } - return r2 -}