Skip to content

Commit 300c19f

Browse files
committed
Add a custom enforcer that can be used to modify a cert.
1 parent 09a9b3e commit 300c19f

File tree

4 files changed

+139
-35
lines changed

4 files changed

+139
-35
lines changed

authority/authority.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ type Authority struct {
5050
rootX509CertPool *x509.CertPool
5151
federatedX509Certs []*x509.Certificate
5252
certificates *sync.Map
53+
x509Enforcers []provisioner.CertificateEnforcer
5354

5455
// SCEP CA
5556
scepService *scep.Service

authority/options.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,15 @@ func WithLinkedCAToken(token string) Option {
241241
}
242242
}
243243

244+
// WithX509Enforcers is an option that allows to define custom certificate
245+
// modifiers that will be processed just before the signing of the certificate.
246+
func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option {
247+
return func(a *Authority) error {
248+
a.x509Enforcers = ces
249+
return nil
250+
}
251+
}
252+
244253
func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
245254
var block *pem.Block
246255
var certs []*x509.Certificate

authority/tls.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,17 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
180180
}
181181
}
182182

183+
// Process injected modifiers after validation
184+
for _, m := range a.x509Enforcers {
185+
if err := m.Enforce(leaf); err != nil {
186+
return nil, errs.ApplyOptions(
187+
errs.ForbiddenErr(err, "error creating certificate"),
188+
opts...,
189+
)
190+
}
191+
}
192+
193+
// Sign certificate
183194
lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate))
184195
resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{
185196
Template: leaf,

authority/tls_test.go

Lines changed: 118 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,17 @@ type basicConstraints struct {
205205
MaxPathLen int `asn1:"optional,default:-1"`
206206
}
207207

208+
type testEnforcer struct {
209+
enforcer func(*x509.Certificate) error
210+
}
211+
212+
func (e *testEnforcer) Enforce(cert *x509.Certificate) error {
213+
if e.enforcer != nil {
214+
return e.enforcer(cert)
215+
}
216+
return nil
217+
}
218+
208219
func TestAuthority_Sign(t *testing.T) {
209220
pub, priv, err := keyutil.GenerateDefaultKeyPair()
210221
assert.FatalError(t, err)
@@ -238,14 +249,15 @@ func TestAuthority_Sign(t *testing.T) {
238249
assert.FatalError(t, err)
239250

240251
type signTest struct {
241-
auth *Authority
242-
csr *x509.CertificateRequest
243-
signOpts provisioner.SignOptions
244-
extraOpts []provisioner.SignOption
245-
notBefore time.Time
246-
notAfter time.Time
247-
err error
248-
code int
252+
auth *Authority
253+
csr *x509.CertificateRequest
254+
signOpts provisioner.SignOptions
255+
extraOpts []provisioner.SignOption
256+
notBefore time.Time
257+
notAfter time.Time
258+
extensionsCount int
259+
err error
260+
code int
249261
}
250262
tests := map[string]func(*testing.T) *signTest{
251263
"fail invalid signature": func(t *testing.T) *signTest {
@@ -454,22 +466,66 @@ ZYtQ9Ot36qc=
454466
code: http.StatusInternalServerError,
455467
}
456468
},
457-
"ok": func(t *testing.T) *signTest {
469+
"fail with provisioner enforcer": func(t *testing.T) *signTest {
458470
csr := getCSR(t, priv)
459-
_a := testAuthority(t)
460-
_a.db = &db.MockAuthDB{
471+
aa := testAuthority(t)
472+
aa.db = &db.MockAuthDB{
461473
MStoreCertificate: func(crt *x509.Certificate) error {
462474
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
463475
return nil
464476
},
465477
}
478+
466479
return &signTest{
467-
auth: a,
480+
auth: aa,
481+
csr: csr,
482+
extraOpts: append(extraOpts, &testEnforcer{
483+
enforcer: func(crt *x509.Certificate) error { return fmt.Errorf("an error") },
484+
}),
485+
signOpts: signOpts,
486+
err: errors.New("error creating certificate"),
487+
code: http.StatusForbidden,
488+
}
489+
},
490+
"fail with custom enforcer": func(t *testing.T) *signTest {
491+
csr := getCSR(t, priv)
492+
aa := testAuthority(t, WithX509Enforcers(&testEnforcer{
493+
enforcer: func(cert *x509.Certificate) error {
494+
return fmt.Errorf("an error")
495+
},
496+
}))
497+
aa.db = &db.MockAuthDB{
498+
MStoreCertificate: func(crt *x509.Certificate) error {
499+
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
500+
return nil
501+
},
502+
}
503+
return &signTest{
504+
auth: aa,
468505
csr: csr,
469506
extraOpts: extraOpts,
470507
signOpts: signOpts,
471-
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
472-
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
508+
err: errors.New("error creating certificate"),
509+
code: http.StatusForbidden,
510+
}
511+
},
512+
"ok": func(t *testing.T) *signTest {
513+
csr := getCSR(t, priv)
514+
_a := testAuthority(t)
515+
_a.db = &db.MockAuthDB{
516+
MStoreCertificate: func(crt *x509.Certificate) error {
517+
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
518+
return nil
519+
},
520+
}
521+
return &signTest{
522+
auth: a,
523+
csr: csr,
524+
extraOpts: extraOpts,
525+
signOpts: signOpts,
526+
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
527+
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
528+
extensionsCount: 6,
473529
}
474530
},
475531
"ok with enforced modifier": func(t *testing.T) *signTest {
@@ -497,12 +553,13 @@ ZYtQ9Ot36qc=
497553
},
498554
}
499555
return &signTest{
500-
auth: a,
501-
csr: csr,
502-
extraOpts: enforcedExtraOptions,
503-
signOpts: signOpts,
504-
notBefore: now.Truncate(time.Second),
505-
notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second),
556+
auth: a,
557+
csr: csr,
558+
extraOpts: enforcedExtraOptions,
559+
signOpts: signOpts,
560+
notBefore: now.Truncate(time.Second),
561+
notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second),
562+
extensionsCount: 6,
506563
}
507564
},
508565
"ok with custom template": func(t *testing.T) *signTest {
@@ -530,12 +587,13 @@ ZYtQ9Ot36qc=
530587
},
531588
}
532589
return &signTest{
533-
auth: testAuthority,
534-
csr: csr,
535-
extraOpts: testExtraOpts,
536-
signOpts: signOpts,
537-
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
538-
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
590+
auth: testAuthority,
591+
csr: csr,
592+
extraOpts: testExtraOpts,
593+
signOpts: signOpts,
594+
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
595+
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
596+
extensionsCount: 6,
539597
}
540598
},
541599
"ok/csr with no template critical SAN extension": func(t *testing.T) *signTest {
@@ -558,12 +616,39 @@ ZYtQ9Ot36qc=
558616
},
559617
}
560618
return &signTest{
561-
auth: _a,
562-
csr: csr,
563-
extraOpts: enforcedExtraOptions,
564-
signOpts: provisioner.SignOptions{},
565-
notBefore: now.Truncate(time.Second),
566-
notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second),
619+
auth: _a,
620+
csr: csr,
621+
extraOpts: enforcedExtraOptions,
622+
signOpts: provisioner.SignOptions{},
623+
notBefore: now.Truncate(time.Second),
624+
notAfter: now.Add(365 * 24 * time.Hour).Truncate(time.Second),
625+
extensionsCount: 5,
626+
}
627+
},
628+
"ok with custom enforcer": func(t *testing.T) *signTest {
629+
csr := getCSR(t, priv)
630+
aa := testAuthority(t, WithX509Enforcers(&testEnforcer{
631+
enforcer: func(cert *x509.Certificate) error {
632+
cert.CRLDistributionPoints = []string{"http://ca.example.org/leaf.crl"}
633+
return nil
634+
},
635+
}))
636+
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
637+
aa.db = &db.MockAuthDB{
638+
MStoreCertificate: func(crt *x509.Certificate) error {
639+
assert.Equals(t, crt.Subject.CommonName, "smallstep test")
640+
assert.Equals(t, crt.CRLDistributionPoints, []string{"http://ca.example.org/leaf.crl"})
641+
return nil
642+
},
643+
}
644+
return &signTest{
645+
auth: aa,
646+
csr: csr,
647+
extraOpts: extraOpts,
648+
signOpts: signOpts,
649+
notBefore: signOpts.NotBefore.Time().Truncate(time.Second),
650+
notAfter: signOpts.NotAfter.Time().Truncate(time.Second),
651+
extensionsCount: 7,
567652
}
568653
},
569654
}
@@ -645,16 +730,14 @@ ZYtQ9Ot36qc=
645730
// Empty CSR subject test does not use any provisioner extensions.
646731
// So provisioner ID ext will be missing.
647732
found = 1
648-
assert.Len(t, 5, leaf.Extensions)
649-
} else {
650-
assert.Len(t, 6, leaf.Extensions)
651733
}
652734
}
653735
}
654736
assert.Equals(t, found, 1)
655737
realIntermediate, err := x509.ParseCertificate(issuer.Raw)
656738
assert.FatalError(t, err)
657739
assert.Equals(t, intermediate, realIntermediate)
740+
assert.Len(t, tc.extensionsCount, leaf.Extensions)
658741
}
659742
}
660743
})

0 commit comments

Comments
 (0)