forked from smallstep/certificates
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrenew.go
191 lines (168 loc) · 5.27 KB
/
renew.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
package ca
import (
"context"
"crypto/tls"
"math/rand"
"sync"
"time"
"github.com/pkg/errors"
)
// RenewFunc defines the type of the functions used to get a new tls
// certificate.
type RenewFunc func() (*tls.Certificate, error)
var minCertDuration = time.Minute
// TLSRenewer automatically renews a tls certificate using a RenewFunc.
type TLSRenewer struct {
sync.RWMutex
RenewCertificate RenewFunc
cert *tls.Certificate
timer *time.Timer
renewBefore time.Duration
renewJitter time.Duration
certNotAfter time.Time
}
type tlsRenewerOptions func(r *TLSRenewer) error
// WithRenewBefore modifies a tlsRenewer by setting the renewBefore attribute.
func WithRenewBefore(b time.Duration) func(r *TLSRenewer) error {
return func(r *TLSRenewer) error {
r.renewBefore = b
return nil
}
}
// WithRenewJitter modifies a tlsRenewer by setting the renewJitter attribute.
func WithRenewJitter(j time.Duration) func(r *TLSRenewer) error {
return func(r *TLSRenewer) error {
r.renewJitter = j
return nil
}
}
// NewTLSRenewer creates a TLSRenewer for the given cert. It will use the given
// RenewFunc to get a new certificate when required.
func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOptions) (*TLSRenewer, error) {
r := &TLSRenewer{
RenewCertificate: fn,
cert: cert,
certNotAfter: cert.Leaf.NotAfter.Add(-1 * time.Minute),
}
for _, f := range opts {
if err := f(r); err != nil {
return nil, errors.Wrap(err, "error applying options")
}
}
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore)
if period < minCertDuration {
return nil, errors.Errorf("period must be greater than or equal to %s, but got %v.", minCertDuration, period)
}
// By default we will try to renew the cert before 2/3 of the validity
// period have expired.
if r.renewBefore == 0 {
r.renewBefore = period / 3
}
// By default we set the jitter to 1/20th of the validity period.
if r.renewJitter == 0 {
r.renewJitter = period / 20
}
return r, nil
}
// Run starts the certificate renewer for the given certificate.
func (r *TLSRenewer) Run() {
cert := r.getCertificate()
next := r.nextRenewDuration(cert.Leaf.NotAfter)
r.Lock()
r.timer = time.AfterFunc(next, r.renewCertificate)
r.Unlock()
}
// RunContext starts the certificate renewer for the given certificate.
func (r *TLSRenewer) RunContext(ctx context.Context) {
r.Run()
go func() {
<-ctx.Done()
r.Stop()
}()
}
// Stop prevents the renew timer from firing.
func (r *TLSRenewer) Stop() bool {
if r.timer != nil {
return r.timer.Stop()
}
return true
}
// GetCertificate returns the current server certificate.
//
// This method is set in the tls.Config GetCertificate property.
func (r *TLSRenewer) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return r.getCertificate(), nil
}
// GetCertificateForCA returns the current server certificate. It can only be
// used if the renew function creates the new certificate and do not uses a TLS
// request. It's intended to be use by the certificate authority server.
//
// This method is set in the tls.Config GetCertificate property.
func (r *TLSRenewer) GetCertificateForCA(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return r.getCertificateForCA(), nil
}
// GetClientCertificate returns the current client certificate.
//
// This method is set in the tls.Config GetClientCertificate property.
func (r *TLSRenewer) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return r.getCertificate(), nil
}
// getCertificate returns the certificate using a read-only lock.
//
// Known issue: It cannot renew an expired certificate because the /renew
// endpoint requires a valid client certificate. The certificate can expire
// if the timer does not fire e.g. when the CA is run from a laptop that
// enters sleep mode.
func (r *TLSRenewer) getCertificate() *tls.Certificate {
r.RLock()
cert := r.cert
r.RUnlock()
return cert
}
// getCertificateForCA returns the certificate using a read-only lock. It will
// automatically renew the certificate if it has expired.
func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
r.RLock()
// Force certificate renewal if the timer didn't run.
// This is an special case that can happen after a computer sleep.
if time.Now().After(r.certNotAfter) {
r.RUnlock()
r.renewCertificate()
r.RLock()
}
cert := r.cert
r.RUnlock()
return cert
}
// setCertificate updates the certificate using a read-write lock. It also
// updates certNotAfter with 1m of delta; this will force the renewal of the
// certificate if it is about to expire.
func (r *TLSRenewer) setCertificate(cert *tls.Certificate) {
r.Lock()
r.cert = cert
r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute)
r.Unlock()
}
func (r *TLSRenewer) renewCertificate() {
var next time.Duration
cert, err := r.RenewCertificate()
if err != nil {
next = r.renewJitter / 2
next += time.Duration(rand.Int63n(int64(next)))
} else {
r.setCertificate(cert)
next = r.nextRenewDuration(cert.Leaf.NotAfter)
}
r.Lock()
r.timer.Reset(next)
r.Unlock()
}
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {
d := time.Until(notAfter) - r.renewBefore
n := rand.Int63n(int64(r.renewJitter))
d -= time.Duration(n)
if d < 0 {
d = 0
}
return d
}