Skip to content

Commit 92e95e4

Browse files
authoredAug 13, 2024
Merge pull request smallstep#1940 from smallstep/mariano/self-trust
Allow to use private IdPs with the OIDC provisioner
2 parents ffae6a5 + a01a2fb commit 92e95e4

14 files changed

+276
-44
lines changed
 

‎authority/authority.go

+10
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ type Authority struct {
4949
templates *templates.Templates
5050
linkedCAToken string
5151
webhookClient *http.Client
52+
httpClient *http.Client
5253

5354
// X509 CA
5455
password []byte
@@ -491,6 +492,15 @@ func (a *Authority) init() error {
491492
a.certificates.Store(hex.EncodeToString(sum[:]), crt)
492493
}
493494

495+
// Initialize HTTPClient with all root certs
496+
clientRoots := make([]*x509.Certificate, 0, len(a.rootX509Certs)+len(a.federatedX509Certs))
497+
clientRoots = append(clientRoots, a.rootX509Certs...)
498+
clientRoots = append(clientRoots, a.federatedX509Certs...)
499+
a.httpClient, err = newHTTPClient(clientRoots...)
500+
if err != nil {
501+
return err
502+
}
503+
494504
// Decrypt and load SSH keys
495505
var tmplVars templates.Step
496506
if a.config.SSH != nil {

‎authority/http_client.go

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package authority
2+
3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
"fmt"
7+
"net/http"
8+
)
9+
10+
// newHTTPClient returns an HTTP client that trusts the system cert pool and the
11+
// given roots.
12+
func newHTTPClient(roots ...*x509.Certificate) (*http.Client, error) {
13+
pool, err := x509.SystemCertPool()
14+
if err != nil {
15+
return nil, fmt.Errorf("error initializing http client: %w", err)
16+
}
17+
for _, crt := range roots {
18+
pool.AddCert(crt)
19+
}
20+
21+
tr, ok := http.DefaultTransport.(*http.Transport)
22+
if !ok {
23+
return nil, fmt.Errorf("error initializing http client: type is not *http.Transport")
24+
}
25+
tr = tr.Clone()
26+
tr.TLSClientConfig = &tls.Config{
27+
MinVersion: tls.VersionTLS12,
28+
RootCAs: pool,
29+
}
30+
31+
return &http.Client{
32+
Transport: tr,
33+
}, nil
34+
}

‎authority/http_client_test.go

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package authority
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"crypto/x509"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"testing"
12+
"time"
13+
14+
"github.com/smallstep/certificates/authority/provisioner"
15+
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/require"
17+
"go.step.sm/crypto/jose"
18+
"go.step.sm/crypto/keyutil"
19+
"go.step.sm/crypto/x509util"
20+
)
21+
22+
func mustCertificate(t *testing.T, a *Authority, csr *x509.CertificateRequest) []*x509.Certificate {
23+
t.Helper()
24+
25+
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
26+
27+
now := time.Now()
28+
signOpts := provisioner.SignOptions{
29+
NotBefore: provisioner.NewTimeDuration(now),
30+
NotAfter: provisioner.NewTimeDuration(now.Add(5 * time.Minute)),
31+
Backdate: 1 * time.Minute,
32+
}
33+
34+
sans := []string{}
35+
sans = append(sans, csr.DNSNames...)
36+
sans = append(sans, csr.EmailAddresses...)
37+
for _, s := range csr.IPAddresses {
38+
sans = append(sans, s.String())
39+
}
40+
for _, s := range csr.URIs {
41+
sans = append(sans, s.String())
42+
}
43+
44+
key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
45+
require.NoError(t, err)
46+
47+
token, err := generateToken(csr.Subject.CommonName, "step-cli", testAudiences.Sign[0], sans, now, key)
48+
require.NoError(t, err)
49+
50+
extraOpts, err := a.Authorize(ctx, token)
51+
require.NoError(t, err)
52+
53+
chain, err := a.SignWithContext(ctx, csr, signOpts, extraOpts...)
54+
require.NoError(t, err)
55+
56+
return chain
57+
}
58+
59+
func Test_newHTTPClient(t *testing.T) {
60+
signer, err := keyutil.GenerateDefaultSigner()
61+
require.NoError(t, err)
62+
63+
csr, err := x509util.CreateCertificateRequest("test", []string{"localhost", "127.0.0.1", "[::1]"}, signer)
64+
require.NoError(t, err)
65+
66+
auth := testAuthority(t)
67+
chain := mustCertificate(t, auth, csr)
68+
69+
t.Run("SystemCertPool", func(t *testing.T) {
70+
resp, err := auth.httpClient.Get("https://smallstep.com")
71+
require.NoError(t, err)
72+
assert.Equal(t, http.StatusOK, resp.StatusCode)
73+
b, err := io.ReadAll(resp.Body)
74+
assert.NoError(t, err)
75+
assert.NotEmpty(t, b)
76+
assert.NoError(t, resp.Body.Close())
77+
})
78+
79+
t.Run("LocalCertPool", func(t *testing.T) {
80+
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81+
fmt.Fprint(w, "ok")
82+
}))
83+
srv.TLS = &tls.Config{
84+
Certificates: []tls.Certificate{
85+
{Certificate: [][]byte{chain[0].Raw, chain[1].Raw}, PrivateKey: signer, Leaf: chain[0]},
86+
},
87+
}
88+
srv.StartTLS()
89+
defer srv.Close()
90+
91+
resp, err := auth.httpClient.Get(srv.URL)
92+
require.NoError(t, err)
93+
assert.Equal(t, http.StatusOK, resp.StatusCode)
94+
b, err := io.ReadAll(resp.Body)
95+
assert.NoError(t, err)
96+
assert.Equal(t, []byte("ok"), b)
97+
assert.NoError(t, resp.Body.Close())
98+
99+
t.Run("DefaultClient", func(t *testing.T) {
100+
client := &http.Client{}
101+
_, err := client.Get(srv.URL)
102+
assert.Error(t, err)
103+
})
104+
})
105+
}

‎authority/provisioner/azure.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,14 @@ func (p *Azure) Init(config Config) (err error) {
251251
p.assertConfig()
252252

253253
// Decode and validate openid-configuration endpoint
254-
if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
254+
if err = getAndDecode(http.DefaultClient, p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
255255
return
256256
}
257257
if err := p.oidcConfig.Validate(); err != nil {
258258
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
259259
}
260260
// Get JWK key set
261-
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
261+
if p.keyStore, err = newKeyStore(http.DefaultClient, p.oidcConfig.JWKSetURI); err != nil {
262262
return
263263
}
264264

‎authority/provisioner/controller.go

+11
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type Controller struct {
2626
policy *policyEngine
2727
webhookClient *http.Client
2828
webhooks []*Webhook
29+
httpClient *http.Client
2930
}
3031

3132
// NewController initializes a new provisioner controller.
@@ -48,9 +49,19 @@ func NewController(p Interface, claims *Claims, config Config, options *Options)
4849
policy: policy,
4950
webhookClient: config.WebhookClient,
5051
webhooks: options.GetWebhooks(),
52+
httpClient: config.HTTPClient,
5153
}, nil
5254
}
5355

56+
// GetHTTPClient returns the configured HTTP client or the default one if none
57+
// is configured.
58+
func (c *Controller) GetHTTPClient() *http.Client {
59+
if c.httpClient != nil {
60+
return c.httpClient
61+
}
62+
return &http.Client{}
63+
}
64+
5465
// GetIdentity returns the identity for a given email.
5566
func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) {
5667
if c.IdentityFunc != nil {

‎authority/provisioner/controller_test.go

+34-8
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ import (
99
"testing"
1010
"time"
1111

12+
"github.com/smallstep/certificates/authority/policy"
13+
"github.com/smallstep/certificates/webhook"
14+
"github.com/stretchr/testify/assert"
1215
"go.step.sm/crypto/pemutil"
1316
"go.step.sm/crypto/x509util"
1417
"go.step.sm/linkedca"
1518
"golang.org/x/crypto/ssh"
16-
17-
"github.com/smallstep/certificates/authority/policy"
18-
"github.com/smallstep/certificates/webhook"
1919
)
2020

2121
var trueValue = true
@@ -79,12 +79,14 @@ func TestNewController(t *testing.T) {
7979
wantErr bool
8080
}{
8181
{"ok", args{&JWK{}, nil, Config{
82-
Claims: globalProvisionerClaims,
83-
Audiences: testAudiences,
82+
Claims: globalProvisionerClaims,
83+
Audiences: testAudiences,
84+
HTTPClient: &http.Client{},
8485
}, nil}, &Controller{
85-
Interface: &JWK{},
86-
Audiences: &testAudiences,
87-
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
86+
Interface: &JWK{},
87+
Audiences: &testAudiences,
88+
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
89+
httpClient: &http.Client{},
8890
}, false},
8991
{"ok with claims", args{&JWK{}, &Claims{
9092
DisableRenewal: &defaultDisableRenewal,
@@ -145,6 +147,30 @@ func TestNewController(t *testing.T) {
145147
}
146148
}
147149

150+
func TestController_GetHTTPClient(t *testing.T) {
151+
srv := generateTLSJWKServer(2)
152+
defer srv.Close()
153+
type fields struct {
154+
httpClient *http.Client
155+
}
156+
tests := []struct {
157+
name string
158+
fields fields
159+
want *http.Client
160+
}{
161+
{"ok custom", fields{srv.Client()}, srv.Client()},
162+
{"ok default", fields{http.DefaultClient}, http.DefaultClient},
163+
}
164+
for _, tt := range tests {
165+
t.Run(tt.name, func(t *testing.T) {
166+
c := &Controller{
167+
httpClient: tt.fields.httpClient,
168+
}
169+
assert.Equal(t, tt.want, c.GetHTTPClient())
170+
})
171+
}
172+
}
173+
148174
func TestController_GetIdentity(t *testing.T) {
149175
ctx := context.Background()
150176
type fields struct {

‎authority/provisioner/gcp.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ func (p *GCP) Init(config Config) (err error) {
228228
p.assertConfig()
229229

230230
// Initialize key store
231-
if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil {
231+
if p.keyStore, err = newKeyStore(http.DefaultClient, p.config.CertsURL); err != nil {
232232
return
233233
}
234234

‎authority/provisioner/keystore.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,21 @@ var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)
2222

2323
type keyStore struct {
2424
sync.RWMutex
25+
client *http.Client
2526
uri string
2627
keySet jose.JSONWebKeySet
2728
timer *time.Timer
2829
expiry time.Time
2930
jitter time.Duration
3031
}
3132

32-
func newKeyStore(uri string) (*keyStore, error) {
33-
keys, age, err := getKeysFromJWKsURI(uri)
33+
func newKeyStore(client *http.Client, uri string) (*keyStore, error) {
34+
keys, age, err := getKeysFromJWKsURI(client, uri)
3435
if err != nil {
3536
return nil, err
3637
}
3738
ks := &keyStore{
39+
client: client,
3840
uri: uri,
3941
keySet: keys,
4042
expiry: getExpirationTime(age),
@@ -64,7 +66,7 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
6466

6567
func (ks *keyStore) reload() {
6668
var next time.Duration
67-
keys, age, err := getKeysFromJWKsURI(ks.uri)
69+
keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri)
6870
if err != nil {
6971
next = ks.nextReloadDuration(ks.jitter / 2)
7072
} else {
@@ -90,9 +92,9 @@ func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
9092
return abs(age)
9193
}
9294

93-
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
95+
func getKeysFromJWKsURI(client *http.Client, uri string) (jose.JSONWebKeySet, time.Duration, error) {
9496
var keys jose.JSONWebKeySet
95-
resp, err := http.Get(uri) //nolint:gosec // openid-configuration jwks_uri
97+
resp, err := client.Get(uri)
9698
if err != nil {
9799
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
98100
}

0 commit comments

Comments
 (0)