Skip to content

Commit d872f09

Browse files
committed
Use mTLS by default on SDK methods.
Add options to modify the tls.Config for different configurations. Fixes smallstep#7
1 parent bb03aad commit d872f09

File tree

9 files changed

+415
-417
lines changed

9 files changed

+415
-417
lines changed

ca/bootstrap.go

+10-57
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ func Bootstrap(token string) (*Client, error) {
4343
// Authority. By default the server will kick off a routine that will renew the
4444
// certificate after 2/3rd of the certificate's lifetime has expired.
4545
//
46+
// Without any extra option the server will be configured for mTLS, it will
47+
// require and verify clients certificates, but options can be used to drop this
48+
// requirement, the most common will be only verify the certs if given with
49+
// ca.VerifyClientCertIfGiven(), or add extra CAs with
50+
// ca.AddClientCA(*x509.Certificate).
51+
//
4652
// Usage:
4753
// // Default example with certificate rotation.
4854
// srv, err := ca.BootstrapServer(context.Background(), token, &http.Server{
@@ -61,60 +67,7 @@ func Bootstrap(token string) (*Client, error) {
6167
// return err
6268
// }
6369
// srv.ListenAndServeTLS("", "")
64-
func BootstrapServer(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
65-
if base.TLSConfig != nil {
66-
return nil, errors.New("server TLSConfig is already set")
67-
}
68-
69-
client, err := Bootstrap(token)
70-
if err != nil {
71-
return nil, err
72-
}
73-
74-
req, pk, err := CreateSignRequest(token)
75-
if err != nil {
76-
return nil, err
77-
}
78-
79-
sign, err := client.Sign(req)
80-
if err != nil {
81-
return nil, err
82-
}
83-
84-
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk)
85-
if err != nil {
86-
return nil, err
87-
}
88-
89-
base.TLSConfig = tlsConfig
90-
return base, nil
91-
}
92-
93-
// BootstrapServerWithMTLS is a helper function that using the given token
94-
// returns the given http.Server configured with a TLS certificate signed by the
95-
// Certificate Authority, this server will always require and verify a client
96-
// certificate. By default the server will kick off a routine that will renew
97-
// the certificate after 2/3rd of the certificate's lifetime has expired.
98-
//
99-
// Usage:
100-
// // Default example with certificate rotation.
101-
// srv, err := ca.BootstrapServerWithMTLS(context.Background(), token, &http.Server{
102-
// Addr: ":443",
103-
// Handler: handler,
104-
// })
105-
//
106-
// // Example canceling automatic certificate rotation.
107-
// ctx, cancel := context.WithCancel(context.Background())
108-
// defer cancel()
109-
// srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
110-
// Addr: ":443",
111-
// Handler: handler,
112-
// })
113-
// if err != nil {
114-
// return err
115-
// }
116-
// srv.ListenAndServeTLS("", "")
117-
func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
70+
func BootstrapServer(ctx context.Context, token string, base *http.Server, options ...TLSOption) (*http.Server, error) {
11871
if base.TLSConfig != nil {
11972
return nil, errors.New("server TLSConfig is already set")
12073
}
@@ -134,7 +87,7 @@ func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Serve
13487
return nil, err
13588
}
13689

137-
tlsConfig, err := client.GetServerMutualTLSConfig(ctx, sign, pk)
90+
tlsConfig, err := client.GetServerTLSConfig(ctx, sign, pk, options...)
13891
if err != nil {
13992
return nil, err
14093
}
@@ -161,7 +114,7 @@ func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Serve
161114
// return err
162115
// }
163116
// resp, err := client.Get("https://internal.smallstep.com")
164-
func BootstrapClient(ctx context.Context, token string) (*http.Client, error) {
117+
func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) {
165118
client, err := Bootstrap(token)
166119
if err != nil {
167120
return nil, err
@@ -177,7 +130,7 @@ func BootstrapClient(ctx context.Context, token string) (*http.Client, error) {
177130
return nil, err
178131
}
179132

180-
transport, err := client.Transport(ctx, sign, pk)
133+
transport, err := client.Transport(ctx, sign, pk, options...)
181134
if err != nil {
182135
return nil, err
183136
}

ca/bootstrap_test.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ func TestBootstrap(t *testing.T) {
124124
}
125125
}
126126

127-
func TestBootstrapServer(t *testing.T) {
127+
func TestBootstrapServerWithoutMTLS(t *testing.T) {
128128
srv := startCABootstrapServer()
129129
defer srv.Close()
130130
token := func() string {
@@ -146,7 +146,7 @@ func TestBootstrapServer(t *testing.T) {
146146
}
147147
for _, tt := range tests {
148148
t.Run(tt.name, func(t *testing.T) {
149-
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base)
149+
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base, VerifyClientCertIfGiven())
150150
if (err != nil) != tt.wantErr {
151151
t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
152152
return
@@ -192,24 +192,24 @@ func TestBootstrapServerWithMTLS(t *testing.T) {
192192
}
193193
for _, tt := range tests {
194194
t.Run(tt.name, func(t *testing.T) {
195-
got, err := BootstrapServerWithMTLS(tt.args.ctx, tt.args.token, tt.args.base)
195+
got, err := BootstrapServer(tt.args.ctx, tt.args.token, tt.args.base)
196196
if (err != nil) != tt.wantErr {
197-
t.Errorf("BootstrapServerWithMTLS() error = %v, wantErr %v", err, tt.wantErr)
197+
t.Errorf("BootstrapServer() error = %v, wantErr %v", err, tt.wantErr)
198198
return
199199
}
200200
if tt.wantErr {
201201
if got != nil {
202-
t.Errorf("BootstrapServerWithMTLS() = %v, want nil", got)
202+
t.Errorf("BootstrapServer() = %v, want nil", got)
203203
}
204204
} else {
205205
expected := &http.Server{
206206
TLSConfig: got.TLSConfig,
207207
}
208208
if !reflect.DeepEqual(got, expected) {
209-
t.Errorf("BootstrapServerWithMTLS() = %v, want %v", got, expected)
209+
t.Errorf("BootstrapServer() = %v, want %v", got, expected)
210210
}
211211
if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
212-
t.Errorf("BootstrapServerWithMTLS() invalid TLSConfig = %#v", got.TLSConfig)
212+
t.Errorf("BootstrapServer() invalid TLSConfig = %#v", got.TLSConfig)
213213
}
214214
}
215215
})

ca/tls.go

+28-36
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020
// GetClientTLSConfig returns a tls.Config for client use configured with the
2121
// sign certificate, and a new certificate pool with the sign root certificate.
2222
// The client certificate will automatically rotate before expiring.
23-
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
23+
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
2424
cert, err := TLSCertificate(sign, pk)
2525
if err != nil {
2626
return nil, err
@@ -36,10 +36,15 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
3636
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
3737
tlsConfig.PreferServerCipherSuites = true
3838
// Build RootCAs with given root certificate
39-
if pool := c.getCertPool(sign); pool != nil {
39+
if pool := getCertPool(sign); pool != nil {
4040
tlsConfig.RootCAs = pool
4141
}
4242

43+
// Apply options if given
44+
if err := setTLSOptions(tlsConfig, options); err != nil {
45+
return nil, err
46+
}
47+
4348
// Update renew function with transport
4449
tr, err := getDefaultTransport(tlsConfig)
4550
if err != nil {
@@ -56,7 +61,7 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
5661
// sign certificate, and a new certificate pool with the sign root certificate.
5762
// The returned tls.Config will only verify the client certificate if provided.
5863
// The server certificate will automatically rotate before expiring.
59-
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
64+
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
6065
cert, err := TLSCertificate(sign, pk)
6166
if err != nil {
6267
return nil, err
@@ -74,13 +79,18 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
7479
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
7580
tlsConfig.PreferServerCipherSuites = true
7681
// Build RootCAs with given root certificate
77-
if pool := c.getCertPool(sign); pool != nil {
82+
if pool := getCertPool(sign); pool != nil {
7883
tlsConfig.ClientCAs = pool
79-
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
84+
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
8085
// Add RootCAs for refresh client
8186
tlsConfig.RootCAs = pool
8287
}
8388

89+
// Apply options if given
90+
if err := setTLSOptions(tlsConfig, options); err != nil {
91+
return nil, err
92+
}
93+
8494
// Update renew function with transport
8595
tr, err := getDefaultTransport(tlsConfig)
8696
if err != nil {
@@ -93,44 +103,15 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
93103
return tlsConfig, nil
94104
}
95105

96-
// GetServerMutualTLSConfig returns a tls.Config for server use configured with
97-
// the sign certificate, and a new certificate pool with the sign root certificate.
98-
// The returned tls.Config will always require and verify a client certificate.
99-
// The server certificate will automatically rotate before expiring.
100-
func (c *Client) GetServerMutualTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
101-
tlsConfig, err := c.GetServerTLSConfig(ctx, sign, pk)
102-
if err != nil {
103-
return nil, err
104-
}
105-
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
106-
return tlsConfig, nil
107-
}
108-
109106
// Transport returns an http.Transport configured to use the client certificate from the sign response.
110-
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
111-
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)
107+
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
108+
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk, options...)
112109
if err != nil {
113110
return nil, err
114111
}
115112
return getDefaultTransport(tlsConfig)
116113
}
117114

118-
// getCertPool returns the transport x509.CertPool or the one from the sign
119-
// request.
120-
func (c *Client) getCertPool(sign *api.SignResponse) *x509.CertPool {
121-
// Return the transport certPool
122-
if c.certPool != nil {
123-
return c.certPool
124-
}
125-
// Return certificate used in sign request.
126-
if root, err := RootCertificate(sign); err == nil {
127-
pool := x509.NewCertPool()
128-
pool.AddCert(root)
129-
return pool
130-
}
131-
return nil
132-
}
133-
134115
// Certificate returns the server or client certificate from the sign response.
135116
func Certificate(sign *api.SignResponse) (*x509.Certificate, error) {
136117
if sign.ServerPEM.Certificate == nil {
@@ -189,6 +170,17 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
189170
return &cert, nil
190171
}
191172

173+
// getCertPool returns the transport x509.CertPool or the one from the sign
174+
// request.
175+
func getCertPool(sign *api.SignResponse) *x509.CertPool {
176+
if root, err := RootCertificate(sign); err == nil {
177+
pool := x509.NewCertPool()
178+
pool.AddCert(root)
179+
return pool
180+
}
181+
return nil
182+
}
183+
192184
func getDefaultTLSConfig(sign *api.SignResponse) *tls.Config {
193185
if sign.TLSOptions != nil {
194186
return sign.TLSOptions.TLSConfig()

ca/tls_options.go

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package ca
2+
3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
)
7+
8+
// TLSOption defines the type of a function that modifies a tls.Config.
9+
type TLSOption func(c *tls.Config) error
10+
11+
// setTLSOptions takes one or more option function and applies them in order to
12+
// a tls.Config.
13+
func setTLSOptions(c *tls.Config, options []TLSOption) error {
14+
for _, opt := range options {
15+
if err := opt(c); err != nil {
16+
return err
17+
}
18+
}
19+
return nil
20+
}
21+
22+
// RequireAndVerifyClientCert is a tls.Config option used on servers to enforce
23+
// a valid TLS client certificate. This is the default option for mTLS servers.
24+
func RequireAndVerifyClientCert() TLSOption {
25+
return func(c *tls.Config) error {
26+
c.ClientAuth = tls.RequireAndVerifyClientCert
27+
return nil
28+
}
29+
}
30+
31+
// VerifyClientCertIfGiven is a tls.Config option used on on servers to validate
32+
// a TLS client certificate if it is provided. It does not requires a certificate.
33+
func VerifyClientCertIfGiven() TLSOption {
34+
return func(c *tls.Config) error {
35+
c.ClientAuth = tls.VerifyClientCertIfGiven
36+
return nil
37+
}
38+
}
39+
40+
// AddRootCA adds to the tls.Config RootCAs the given certificate. RootCAs
41+
// defines the set of root certificate authorities that clients use when
42+
// verifying server certificates.
43+
func AddRootCA(cert *x509.Certificate) TLSOption {
44+
return func(c *tls.Config) error {
45+
if c.RootCAs == nil {
46+
c.RootCAs = x509.NewCertPool()
47+
}
48+
c.RootCAs.AddCert(cert)
49+
return nil
50+
}
51+
}
52+
53+
// AddClientCA adds to the tls.Config ClientCAs the given certificate. ClientCAs
54+
// defines the set of root certificate authorities that servers use if required
55+
// to verify a client certificate by the policy in ClientAuth.
56+
func AddClientCA(cert *x509.Certificate) TLSOption {
57+
return func(c *tls.Config) error {
58+
if c.ClientCAs == nil {
59+
c.ClientCAs = x509.NewCertPool()
60+
}
61+
c.ClientCAs.AddCert(cert)
62+
return nil
63+
}
64+
}

0 commit comments

Comments
 (0)