Skip to content

Commit 259e959

Browse files
committed
Add support for the provisioner controller
The claimer, audiences and custom callback methods are now managed by the provisioner controller in an uniform way.
1 parent 3c2ff33 commit 259e959

26 files changed

+450
-475
lines changed

Diff for: authority/authorize.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,14 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
276276
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
277277
serial := cert.SerialNumber.String()
278278
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
279+
279280
isRevoked, err := a.IsRevoked(serial)
280281
if err != nil {
281282
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
282283
}
283284
if isRevoked {
284285
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
285286
}
286-
287287
p, ok := a.provisioners.LoadByCertificate(cert)
288288
if !ok {
289289
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)

Diff for: authority/authorize_test.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ func TestAuthority_Authorize(t *testing.T) {
753753

754754
func TestAuthority_authorizeRenew(t *testing.T) {
755755
fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt")
756+
fooCrt.NotAfter = time.Now().Add(time.Hour)
756757
assert.FatalError(t, err)
757758

758759
renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt")
@@ -822,7 +823,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
822823
return &authorizeTest{
823824
auth: a,
824825
cert: renewDisabledCrt,
825-
err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"),
826+
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"),
826827
code: http.StatusUnauthorized,
827828
}
828829
},
@@ -909,6 +910,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner.
909910
}
910911

911912
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
913+
now := time.Now()
912914
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
913915
if err != nil {
914916
return nil, nil, err
@@ -917,6 +919,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
917919
if err != nil {
918920
return nil, nil, err
919921
}
922+
if cert.ValidAfter == 0 {
923+
cert.ValidAfter = uint64(now.Unix())
924+
}
925+
if cert.ValidBefore == 0 {
926+
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
927+
}
920928
if err := cert.SignCert(rand.Reader, signer); err != nil {
921929
return nil, nil, err
922930
}

Diff for: authority/provisioner/acme.go

+7-15
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"time"
77

88
"github.com/pkg/errors"
9-
"github.com/smallstep/certificates/errs"
109
)
1110

1211
// ACME is the acme provisioner type, an entity that can authorize the ACME
@@ -24,7 +23,7 @@ type ACME struct {
2423
RequireEAB bool `json:"requireEAB,omitempty"`
2524
Claims *Claims `json:"claims,omitempty"`
2625
Options *Options `json:"options,omitempty"`
27-
claimer *Claimer
26+
ctl *Controller
2827
}
2928

3029
// GetID returns the provisioner unique identifier.
@@ -69,7 +68,7 @@ func (p *ACME) GetOptions() *Options {
6968
// DefaultTLSCertDuration returns the default TLS cert duration enforced by
7069
// the provisioner.
7170
func (p *ACME) DefaultTLSCertDuration() time.Duration {
72-
return p.claimer.DefaultTLSCertDuration()
71+
return p.ctl.Claimer.DefaultTLSCertDuration()
7372
}
7473

7574
// Init initializes and validates the fields of a JWK type.
@@ -81,12 +80,8 @@ func (p *ACME) Init(config Config) (err error) {
8180
return errors.New("provisioner name cannot be empty")
8281
}
8382

84-
// Update claims with global ones
85-
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
86-
return err
87-
}
88-
89-
return err
83+
p.ctl, err = NewController(p, p.Claims, config)
84+
return
9085
}
9186

9287
// AuthorizeSign does not do any validation, because all validation is handled
@@ -97,10 +92,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
9792
// modifiers / withOptions
9893
newProvisionerExtensionOption(TypeACME, p.Name, ""),
9994
newForceCNOption(p.ForceCN),
100-
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
95+
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
10196
// validators
10297
defaultPublicKeyValidator{},
103-
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
98+
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
10499
}, nil
105100
}
106101

@@ -118,8 +113,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
118113
// revocation status. Just confirms that the provisioner that created the
119114
// certificate was configured to allow renewals.
120115
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
121-
if p.claimer.IsDisableRenewal() {
122-
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName())
123-
}
124-
return nil
116+
return p.ctl.AuthorizeRenew(ctx, cert)
125117
}

Diff for: authority/provisioner/acme_test.go

+16-9
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ func TestACME_Init(t *testing.T) {
9191
}
9292

9393
func TestACME_AuthorizeRenew(t *testing.T) {
94+
now := time.Now()
9495
type test struct {
9596
p *ACME
9697
cert *x509.Certificate
@@ -104,21 +105,27 @@ func TestACME_AuthorizeRenew(t *testing.T) {
104105
// disable renewal
105106
disable := true
106107
p.Claims = &Claims{DisableRenewal: &disable}
107-
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
108+
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
108109
assert.FatalError(t, err)
109110
return test{
110-
p: p,
111-
cert: &x509.Certificate{},
111+
p: p,
112+
cert: &x509.Certificate{
113+
NotBefore: now,
114+
NotAfter: now.Add(time.Hour),
115+
},
112116
code: http.StatusUnauthorized,
113-
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()),
117+
err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
114118
}
115119
},
116120
"ok": func(t *testing.T) test {
117121
p, err := generateACME()
118122
assert.FatalError(t, err)
119123
return test{
120-
p: p,
121-
cert: &x509.Certificate{},
124+
p: p,
125+
cert: &x509.Certificate{
126+
NotBefore: now,
127+
NotAfter: now.Add(time.Hour),
128+
},
122129
}
123130
},
124131
}
@@ -179,11 +186,11 @@ func TestACME_AuthorizeSign(t *testing.T) {
179186
case *forceCNOption:
180187
assert.Equals(t, v.ForceCN, tc.p.ForceCN)
181188
case profileDefaultDuration:
182-
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
189+
assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
183190
case defaultPublicKeyValidator:
184191
case *validityValidator:
185-
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
186-
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
192+
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
193+
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
187194
default:
188195
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
189196
}

Diff for: authority/provisioner/aws.go

+12-18
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,8 @@ type AWS struct {
264264
IIDRoots string `json:"iidRoots,omitempty"`
265265
Claims *Claims `json:"claims,omitempty"`
266266
Options *Options `json:"options,omitempty"`
267-
claimer *Claimer
268267
config *awsConfig
269-
audiences Audiences
268+
ctl *Controller
270269
}
271270

272271
// GetID returns the provisioner unique identifier.
@@ -400,15 +399,11 @@ func (p *AWS) Init(config Config) (err error) {
400399
case p.InstanceAge.Value() < 0:
401400
return errors.New("provisioner instanceAge cannot be negative")
402401
}
403-
// Update claims with global ones
404-
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
405-
return err
406-
}
402+
407403
// Add default config
408404
if p.config, err = newAWSConfig(p.IIDRoots); err != nil {
409405
return err
410406
}
411-
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
412407

413408
// validate IMDS versions
414409
if len(p.IMDSVersions) == 0 {
@@ -425,7 +420,9 @@ func (p *AWS) Init(config Config) (err error) {
425420
}
426421
}
427422

428-
return nil
423+
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
424+
p.ctl, err = NewController(p, p.Claims, config)
425+
return
429426
}
430427

431428
// AuthorizeSign validates the given token and returns the sign options that
@@ -473,11 +470,11 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
473470
templateOptions,
474471
// modifiers / withOptions
475472
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
476-
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
473+
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
477474
// validators
478475
defaultPublicKeyValidator{},
479476
commonNameValidator(payload.Claims.Subject),
480-
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
477+
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
481478
), nil
482479
}
483480

@@ -486,10 +483,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
486483
// revocation status. Just confirms that the provisioner that created the
487484
// certificate was configured to allow renewals.
488485
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
489-
if p.claimer.IsDisableRenewal() {
490-
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName())
491-
}
492-
return nil
486+
return p.ctl.AuthorizeRenew(ctx, cert)
493487
}
494488

495489
// assertConfig initializes the config if it has not been initialized
@@ -664,7 +658,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
664658
}
665659

666660
// validate audiences with the defaults
667-
if !matchesAudience(payload.Audience, p.audiences.Sign) {
661+
if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) {
668662
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
669663
}
670664

@@ -704,7 +698,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
704698

705699
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
706700
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
707-
if !p.claimer.IsSSHCAEnabled() {
701+
if !p.ctl.Claimer.IsSSHCAEnabled() {
708702
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
709703
}
710704
claims, err := p.authorizeToken(token)
@@ -752,11 +746,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
752746
// Validate user SignSSHOptions.
753747
sshCertOptionsValidator(defaults),
754748
// Set the validity bounds if not set.
755-
&sshDefaultDuration{p.claimer},
749+
&sshDefaultDuration{p.ctl.Claimer},
756750
// Validate public key
757751
&sshDefaultPublicKeyValidator{},
758752
// Validate the validity period.
759-
&sshCertValidityValidator{p.claimer},
753+
&sshCertValidityValidator{p.ctl.Claimer},
760754
// Require all the fields in the SSH certificate
761755
&sshCertDefaultValidator{},
762756
), nil

Diff for: authority/provisioner/aws_test.go

+15-8
Original file line numberDiff line numberDiff line change
@@ -682,13 +682,13 @@ func TestAWS_AuthorizeSign(t *testing.T) {
682682
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
683683
assert.Len(t, 2, v.KeyValuePairs)
684684
case profileDefaultDuration:
685-
assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration())
685+
assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration())
686686
case commonNameValidator:
687687
assert.Equals(t, string(v), tt.args.cn)
688688
case defaultPublicKeyValidator:
689689
case *validityValidator:
690-
assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration())
691-
assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration())
690+
assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration())
691+
assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration())
692692
case ipAddressesValidator:
693693
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
694694
case emailAddressesValidator:
@@ -726,7 +726,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
726726
// disable sshCA
727727
disable := false
728728
p3.Claims = &Claims{EnableSSHCA: &disable}
729-
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
729+
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
730730
assert.FatalError(t, err)
731731

732732
t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
@@ -747,7 +747,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
747747
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
748748
assert.FatalError(t, err)
749749

750-
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
750+
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
751751
expectedHostOptions := &SignSSHOptions{
752752
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
753753
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@@ -824,6 +824,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
824824
}
825825

826826
func TestAWS_AuthorizeRenew(t *testing.T) {
827+
now := time.Now()
827828
p1, err := generateAWS()
828829
assert.FatalError(t, err)
829830
p2, err := generateAWS()
@@ -832,7 +833,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
832833
// disable renewal
833834
disable := true
834835
p2.Claims = &Claims{DisableRenewal: &disable}
835-
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
836+
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
836837
assert.FatalError(t, err)
837838

838839
type args struct {
@@ -845,8 +846,14 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
845846
code int
846847
wantErr bool
847848
}{
848-
{"ok", p1, args{nil}, http.StatusOK, false},
849-
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
849+
{"ok", p1, args{&x509.Certificate{
850+
NotBefore: now,
851+
NotAfter: now.Add(time.Hour),
852+
}}, http.StatusOK, false},
853+
{"fail/renew-disabled", p2, args{&x509.Certificate{
854+
NotBefore: now,
855+
NotAfter: now.Add(time.Hour),
856+
}}, http.StatusUnauthorized, true},
850857
}
851858
for _, tt := range tests {
852859
t.Run(tt.name, func(t *testing.T) {

0 commit comments

Comments
 (0)