Skip to content

Commit 25eba1a

Browse files
committed
WIP on the safely rotate of root and federated certificates.
Fixes smallstep#23
1 parent bacbf85 commit 25eba1a

File tree

4 files changed

+286
-74
lines changed

4 files changed

+286
-74
lines changed

ca/mutable_tls_config.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package ca
2+
3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
"sync"
7+
8+
"github.com/smallstep/certificates/api"
9+
)
10+
11+
// mutableTLSConfig allows to use a tls.Config with mutable cert pools.
12+
type mutableTLSConfig struct {
13+
sync.RWMutex
14+
config *tls.Config
15+
clientCerts []*x509.Certificate
16+
rootCerts []*x509.Certificate
17+
mutClientCerts []*x509.Certificate
18+
mutRootCerts []*x509.Certificate
19+
}
20+
21+
// newMutableTLSConfig creates a new mutableTLSConfig using the passed one as
22+
// the base one.
23+
func newMutableTLSConfig() *mutableTLSConfig {
24+
return &mutableTLSConfig{
25+
clientCerts: []*x509.Certificate{},
26+
rootCerts: []*x509.Certificate{},
27+
mutClientCerts: []*x509.Certificate{},
28+
mutRootCerts: []*x509.Certificate{},
29+
}
30+
}
31+
32+
// Init initializes the mutable tls.Config with the given tls.Config.
33+
func (c *mutableTLSConfig) Init(base *tls.Config) {
34+
c.Lock()
35+
c.config = base.Clone()
36+
c.Unlock()
37+
}
38+
39+
// TLSConfig returns the updated tls.Config it it has changed. It's is used in
40+
// the tls.Config GetConfigForClient.
41+
func (c *mutableTLSConfig) TLSConfig() (config *tls.Config) {
42+
c.RLock()
43+
config = c.config
44+
c.RUnlock()
45+
return
46+
}
47+
48+
// Reload reloads the tls.Config with the new CAs.
49+
func (c *mutableTLSConfig) Reload() {
50+
// Prepare new pools
51+
c.RLock()
52+
rootCAs := x509.NewCertPool()
53+
clientCAs := x509.NewCertPool()
54+
// Fixed certs
55+
for _, cert := range c.rootCerts {
56+
rootCAs.AddCert(cert)
57+
}
58+
for _, cert := range c.clientCerts {
59+
clientCAs.AddCert(cert)
60+
}
61+
// Mutable certs
62+
for _, cert := range c.mutRootCerts {
63+
rootCAs.AddCert(cert)
64+
}
65+
for _, cert := range c.mutClientCerts {
66+
clientCAs.AddCert(cert)
67+
}
68+
c.RUnlock()
69+
70+
// Set new pool
71+
c.Lock()
72+
c.config.RootCAs = rootCAs
73+
c.config.ClientCAs = clientCAs
74+
c.mutRootCerts = []*x509.Certificate{}
75+
c.mutClientCerts = []*x509.Certificate{}
76+
c.Unlock()
77+
}
78+
79+
// AddFixedClientCACert add an in-mutable cert to ClientCAs.
80+
func (c *mutableTLSConfig) AddInmutableClientCACert(cert *x509.Certificate) {
81+
c.Lock()
82+
c.clientCerts = append(c.clientCerts, cert)
83+
c.Unlock()
84+
}
85+
86+
// AddInmutableRootCACert add an in-mutable cert to RootCas.
87+
func (c *mutableTLSConfig) AddInmutableRootCACert(cert *x509.Certificate) {
88+
c.Lock()
89+
c.rootCerts = append(c.rootCerts, cert)
90+
c.Unlock()
91+
}
92+
93+
// AddClientCAs add mutable certs to ClientCAs.
94+
func (c *mutableTLSConfig) AddClientCAs(certs []api.Certificate) {
95+
c.Lock()
96+
for _, cert := range certs {
97+
c.mutClientCerts = append(c.mutClientCerts, cert.Certificate)
98+
}
99+
c.Unlock()
100+
}
101+
102+
// AddRootCAs add mutable certs to RootCAs.
103+
func (c *mutableTLSConfig) AddRootCAs(certs []api.Certificate) {
104+
c.Lock()
105+
for _, cert := range certs {
106+
c.mutRootCerts = append(c.mutRootCerts, cert.Certificate)
107+
}
108+
c.Unlock()
109+
}

ca/tls.go

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,21 @@ import (
2121
// sign certificate, and a new certificate pool with the sign root certificate.
2222
// The client certificate will automatically rotate before expiring.
2323
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*tls.Config, error) {
24-
cert, err := TLSCertificate(sign, pk)
24+
tlsConfig, _, err := c.getClientTLSConfig(ctx, sign, pk, options)
2525
if err != nil {
2626
return nil, err
2727
}
28+
return tlsConfig, nil
29+
}
30+
31+
func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options []TLSOption) (*tls.Config, *http.Transport, error) {
32+
cert, err := TLSCertificate(sign, pk)
33+
if err != nil {
34+
return nil, nil, err
35+
}
2836
renewer, err := NewTLSRenewer(cert, nil)
2937
if err != nil {
30-
return nil, err
38+
return nil, nil, err
3139
}
3240

3341
tlsConfig := getDefaultTLSConfig(sign)
@@ -43,22 +51,24 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
4351
// Apply options if given
4452
tlsCtx := newTLSOptionCtx(c, tlsConfig)
4553
if err := tlsCtx.apply(options); err != nil {
46-
return nil, err
54+
return nil, nil, err
4755
}
4856

4957
// Update renew function with transport
5058
tr, err := getDefaultTransport(tlsConfig)
5159
if err != nil {
52-
return nil, err
60+
return nil, nil, err
5361
}
62+
// Use mutable tls.Config on renew
63+
tr.DialTLS = c.buildDialTLS(tlsCtx)
5464
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
5565

5666
// Update client transport
5767
c.client.Transport = tr
5868

5969
// Start renewer
6070
renewer.RunContext(ctx)
61-
return tlsConfig, nil
71+
return tlsConfig, tr, nil
6272
}
6373

6474
// GetServerTLSConfig returns a tls.Config for server use configured with the
@@ -96,11 +106,18 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
96106
return nil, err
97107
}
98108

109+
// GetConfigForClient allows seamless root and federated roots rotation.
110+
// If the return of the callback is not-nil, it will use the returned
111+
// tls.Config instead of the default one.
112+
tlsConfig.GetConfigForClient = c.buildGetConfigForClient(tlsCtx)
113+
99114
// Update renew function with transport
100115
tr, err := getDefaultTransport(tlsConfig)
101116
if err != nil {
102117
return nil, err
103118
}
119+
// Use mutable tls.Config on renew
120+
tr.DialTLS = c.buildDialTLS(tlsCtx)
104121
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
105122

106123
// Update client transport
@@ -113,11 +130,34 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
113130

114131
// Transport returns an http.Transport configured to use the client certificate from the sign response.
115132
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey, options ...TLSOption) (*http.Transport, error) {
116-
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk, options...)
133+
_, tr, err := c.getClientTLSConfig(ctx, sign, pk, options)
117134
if err != nil {
118135
return nil, err
119136
}
120-
return getDefaultTransport(tlsConfig)
137+
return tr, nil
138+
}
139+
140+
// buildGetConfigForClient returns an implementation of GetConfigForClient
141+
// callback in tls.Config.
142+
//
143+
// If the implementation returns a nil tls.Config, the original Config will be
144+
// used, but if it's non-nil, the returned Config will be used to handle this
145+
// connection.
146+
func (c *Client) buildGetConfigForClient(ctx *TLSOptionCtx) func(*tls.ClientHelloInfo) (*tls.Config, error) {
147+
return func(*tls.ClientHelloInfo) (*tls.Config, error) {
148+
return ctx.mutableConfig.TLSConfig(), nil
149+
}
150+
}
151+
152+
// buildDialTLS returns an implementation of DialTLS callback in http.Transport.
153+
func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net.Conn, error) {
154+
return func(network, addr string) (net.Conn, error) {
155+
return tls.DialWithDialer(&net.Dialer{
156+
Timeout: 30 * time.Second,
157+
KeepAlive: 30 * time.Second,
158+
DualStack: true,
159+
}, network, addr, ctx.mutableConfig.TLSConfig())
160+
}
121161
}
122162

123163
// Certificate returns the server or client certificate from the sign response.

0 commit comments

Comments
 (0)