Skip to content

Commit 9c64dbd

Browse files
committed
Add helpers to add direct support for mTLS.
1 parent 272bbc5 commit 9c64dbd

File tree

6 files changed

+483
-14
lines changed

6 files changed

+483
-14
lines changed

ca/bootstrap.go

+53
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,59 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server) (*htt
9090
return base, nil
9191
}
9292

93+
// BootstrapServerWithMTLS is a helper function that using the given token
94+
// returns the given http.Server configured with a TLS certificate signed by the
95+
// Certificate Authority, this server will always require and verify a client
96+
// certificate. By default the server will kick off a routine that will renew
97+
// the certificate after 2/3rd of the certificate's lifetime has expired.
98+
//
99+
// Usage:
100+
// // Default example with certificate rotation.
101+
// srv, err := ca.BootstrapServerWithMTLS(context.Background(), token, &http.Server{
102+
// Addr: ":443",
103+
// Handler: handler,
104+
// })
105+
//
106+
// // Example canceling automatic certificate rotation.
107+
// ctx, cancel := context.WithCancel(context.Background())
108+
// defer cancel()
109+
// srv, err := ca.BootstrapServerWithMTLS(ctx, token, &http.Server{
110+
// Addr: ":443",
111+
// Handler: handler,
112+
// })
113+
// if err != nil {
114+
// return err
115+
// }
116+
// srv.ListenAndServeTLS("", "")
117+
func BootstrapServerWithMTLS(ctx context.Context, token string, base *http.Server) (*http.Server, error) {
118+
if base.TLSConfig != nil {
119+
return nil, errors.New("server TLSConfig is already set")
120+
}
121+
122+
client, err := Bootstrap(token)
123+
if err != nil {
124+
return nil, err
125+
}
126+
127+
req, pk, err := CreateSignRequest(token)
128+
if err != nil {
129+
return nil, err
130+
}
131+
132+
sign, err := client.Sign(req)
133+
if err != nil {
134+
return nil, err
135+
}
136+
137+
tlsConfig, err := client.GetServerMutualTLSConfig(ctx, sign, pk)
138+
if err != nil {
139+
return nil, err
140+
}
141+
142+
base.TLSConfig = tlsConfig
143+
return base, nil
144+
}
145+
93146
// BootstrapClient is a helper function that using the given bootstrap token
94147
// return an http.Client configured with a Transport prepared to do TLS
95148
// connections using the client certificate returned by the certificate

ca/bootstrap_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,52 @@ func TestBootstrapServer(t *testing.T) {
170170
}
171171
}
172172

173+
func TestBootstrapServerWithMTLS(t *testing.T) {
174+
srv := startCABootstrapServer()
175+
defer srv.Close()
176+
token := func() string {
177+
return generateBootstrapToken(srv.URL, "subject", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
178+
}
179+
type args struct {
180+
ctx context.Context
181+
token string
182+
base *http.Server
183+
}
184+
tests := []struct {
185+
name string
186+
args args
187+
wantErr bool
188+
}{
189+
{"ok", args{context.Background(), token(), &http.Server{}}, false},
190+
{"fail", args{context.Background(), "bad-token", &http.Server{}}, true},
191+
{"fail with TLSConfig", args{context.Background(), token(), &http.Server{TLSConfig: &tls.Config{}}}, true},
192+
}
193+
for _, tt := range tests {
194+
t.Run(tt.name, func(t *testing.T) {
195+
got, err := BootstrapServerWithMTLS(tt.args.ctx, tt.args.token, tt.args.base)
196+
if (err != nil) != tt.wantErr {
197+
t.Errorf("BootstrapServerWithMTLS() error = %v, wantErr %v", err, tt.wantErr)
198+
return
199+
}
200+
if tt.wantErr {
201+
if got != nil {
202+
t.Errorf("BootstrapServerWithMTLS() = %v, want nil", got)
203+
}
204+
} else {
205+
expected := &http.Server{
206+
TLSConfig: got.TLSConfig,
207+
}
208+
if !reflect.DeepEqual(got, expected) {
209+
t.Errorf("BootstrapServerWithMTLS() = %v, want %v", got, expected)
210+
}
211+
if got.TLSConfig == nil || got.TLSConfig.ClientCAs == nil || got.TLSConfig.RootCAs == nil || got.TLSConfig.GetCertificate == nil || got.TLSConfig.GetClientCertificate == nil {
212+
t.Errorf("BootstrapServerWithMTLS() invalid TLSConfig = %#v", got.TLSConfig)
213+
}
214+
}
215+
})
216+
}
217+
}
218+
173219
func TestBootstrapClient(t *testing.T) {
174220
srv := startCABootstrapServer()
175221
defer srv.Close()

ca/tls.go

+18-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919

2020
// GetClientTLSConfig returns a tls.Config for client use configured with the
2121
// sign certificate, and a new certificate pool with the sign root certificate.
22-
// The certificate will automatically rotate before expiring.
22+
// The client certificate will automatically rotate before expiring.
2323
func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
2424
cert, err := TLSCertificate(sign, pk)
2525
if err != nil {
@@ -32,16 +32,14 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
3232

3333
tlsConfig := getDefaultTLSConfig(sign)
3434
// Note that with GetClientCertificate tlsConfig.Certificates is not used.
35+
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
3536
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
3637
tlsConfig.PreferServerCipherSuites = true
3738
// Build RootCAs with given root certificate
3839
if pool := c.getCertPool(sign); pool != nil {
3940
tlsConfig.RootCAs = pool
4041
}
4142

42-
// Parse Certificates and build NameToCertificate
43-
tlsConfig.BuildNameToCertificate()
44-
4543
// Update renew function with transport
4644
tr, err := getDefaultTransport(tlsConfig)
4745
if err != nil {
@@ -56,7 +54,8 @@ func (c *Client) GetClientTLSConfig(ctx context.Context, sign *api.SignResponse,
5654

5755
// GetServerTLSConfig returns a tls.Config for server use configured with the
5856
// sign certificate, and a new certificate pool with the sign root certificate.
59-
// The certificate will automatically rotate before expiring.
57+
// The returned tls.Config will only verify the client certificate if provided.
58+
// The server certificate will automatically rotate before expiring.
6059
func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
6160
cert, err := TLSCertificate(sign, pk)
6261
if err != nil {
@@ -70,6 +69,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
7069
tlsConfig := getDefaultTLSConfig(sign)
7170
// Note that GetCertificate will only be called if the client supplies SNI
7271
// information or if tlsConfig.Certificates is empty.
72+
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
7373
tlsConfig.GetCertificate = renewer.GetCertificate
7474
tlsConfig.GetClientCertificate = renewer.GetClientCertificate
7575
tlsConfig.PreferServerCipherSuites = true
@@ -93,6 +93,19 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
9393
return tlsConfig, nil
9494
}
9595

96+
// GetServerMutualTLSConfig returns a tls.Config for server use configured with
97+
// the sign certificate, and a new certificate pool with the sign root certificate.
98+
// The returned tls.Config will always require and verify a client certificate.
99+
// The server certificate will automatically rotate before expiring.
100+
func (c *Client) GetServerMutualTLSConfig(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Config, error) {
101+
tlsConfig, err := c.GetServerTLSConfig(ctx, sign, pk)
102+
if err != nil {
103+
return nil, err
104+
}
105+
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
106+
return tlsConfig, nil
107+
}
108+
96109
// Transport returns an http.Transport configured to use the client certificate from the sign response.
97110
func (c *Client) Transport(ctx context.Context, sign *api.SignResponse, pk crypto.PrivateKey) (*http.Transport, error) {
98111
tlsConfig, err := c.GetClientTLSConfig(ctx, sign, pk)

0 commit comments

Comments
 (0)