Skip to content

Commit 6b6d61d

Browse files
authored
Merge pull request smallstep#53 from smallstep/claims-omitempty
Proper omitempty on claims
2 parents 095ab89 + 68ff077 commit 6b6d61d

File tree

8 files changed

+106
-54
lines changed

8 files changed

+106
-54
lines changed

authority/config.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,22 @@ type AuthConfig struct {
6060

6161
// Validate validates the authority configuration.
6262
func (c *AuthConfig) Validate(audiences []string) error {
63-
var err error
6463
if c == nil {
6564
return errors.New("authority cannot be undefined")
6665
}
6766
if len(c.Provisioners) == 0 {
6867
return errors.New("authority.provisioners cannot be empty")
6968
}
7069

71-
if c.Claims, err = c.Claims.Init(&globalProvisionerClaims); err != nil {
70+
// Merge global and configuration claims
71+
claimer, err := provisioner.NewClaimer(c.Claims, globalProvisionerClaims)
72+
if err != nil {
7273
return err
7374
}
7475

7576
// Initialize provisioners
7677
config := provisioner.Config{
77-
Claims: *c.Claims,
78+
Claims: claimer.Claims(),
7879
Audiences: audiences,
7980
}
8081
for _, p := range c.Provisioners {

authority/config_test.go

+11
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,17 @@ func TestAuthConfigValidate(t *testing.T) {
272272
err: errors.New("provisioner type cannot be empty"),
273273
}
274274
},
275+
"fail-invalid-claims": func(t *testing.T) AuthConfigValidateTest {
276+
return AuthConfigValidateTest{
277+
ac: &AuthConfig{
278+
Provisioners: p,
279+
Claims: &provisioner.Claims{
280+
MinTLSDur: &provisioner.Duration{-1},
281+
},
282+
},
283+
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
284+
}
285+
},
275286
"ok-empty-asn1dn-template": func(t *testing.T) AuthConfigValidateTest {
276287
return AuthConfigValidateTest{
277288
ac: &AuthConfig{

authority/provisioner/claims.go

+47-33
Original file line numberDiff line numberDiff line change
@@ -8,76 +8,90 @@ import (
88

99
// Claims so that individual provisioners can override global claims.
1010
type Claims struct {
11-
globalClaims *Claims
1211
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
1312
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
1413
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
1514
DisableRenewal *bool `json:"disableRenewal,omitempty"`
1615
}
1716

18-
// Init initializes and validates the individual provisioner claims.
19-
func (pc *Claims) Init(global *Claims) (*Claims, error) {
20-
if pc == nil {
21-
pc = &Claims{}
17+
// Claimer is the type that controls claims. It provides an interface around the
18+
// current claim and the global one.
19+
type Claimer struct {
20+
global Claims
21+
claims *Claims
22+
}
23+
24+
// NewClaimer initializes a new claimer with the given claims.
25+
func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
26+
c := &Claimer{global: global, claims: claims}
27+
return c, c.Validate()
28+
}
29+
30+
// Claims returns the merge of the inner and global claims.
31+
func (c *Claimer) Claims() Claims {
32+
disableRenewal := c.IsDisableRenewal()
33+
return Claims{
34+
MinTLSDur: &Duration{c.MinTLSCertDuration()},
35+
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
36+
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
37+
DisableRenewal: &disableRenewal,
2238
}
23-
pc.globalClaims = global
24-
return pc, pc.Validate()
2539
}
2640

2741
// DefaultTLSCertDuration returns the default TLS cert duration for the
2842
// provisioner. If the default is not set within the provisioner, then the global
2943
// default from the authority configuration will be used.
30-
func (pc *Claims) DefaultTLSCertDuration() time.Duration {
31-
if pc.DefaultTLSDur == nil || pc.DefaultTLSDur.Duration == 0 {
32-
return pc.globalClaims.DefaultTLSCertDuration()
44+
func (c *Claimer) DefaultTLSCertDuration() time.Duration {
45+
if c.claims == nil || c.claims.DefaultTLSDur == nil {
46+
return c.global.DefaultTLSDur.Duration
3347
}
34-
return pc.DefaultTLSDur.Duration
48+
return c.claims.DefaultTLSDur.Duration
3549
}
3650

3751
// MinTLSCertDuration returns the minimum TLS cert duration for the provisioner.
3852
// If the minimum is not set within the provisioner, then the global
3953
// minimum from the authority configuration will be used.
40-
func (pc *Claims) MinTLSCertDuration() time.Duration {
41-
if pc.MinTLSDur == nil || pc.MinTLSDur.Duration == 0 {
42-
return pc.globalClaims.MinTLSCertDuration()
54+
func (c *Claimer) MinTLSCertDuration() time.Duration {
55+
if c.claims == nil || c.claims.MinTLSDur == nil {
56+
return c.global.MinTLSDur.Duration
4357
}
44-
return pc.MinTLSDur.Duration
58+
return c.claims.MinTLSDur.Duration
4559
}
4660

4761
// MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner.
4862
// If the maximum is not set within the provisioner, then the global
4963
// maximum from the authority configuration will be used.
50-
func (pc *Claims) MaxTLSCertDuration() time.Duration {
51-
if pc.MaxTLSDur == nil || pc.MaxTLSDur.Duration == 0 {
52-
return pc.globalClaims.MaxTLSCertDuration()
64+
func (c *Claimer) MaxTLSCertDuration() time.Duration {
65+
if c.claims == nil || c.claims.MaxTLSDur == nil {
66+
return c.global.MaxTLSDur.Duration
5367
}
54-
return pc.MaxTLSDur.Duration
68+
return c.claims.MaxTLSDur.Duration
5569
}
5670

5771
// IsDisableRenewal returns if the renewal flow is disabled for the
5872
// provisioner. If the property is not set within the provisioner, then the
5973
// global value from the authority configuration will be used.
60-
func (pc *Claims) IsDisableRenewal() bool {
61-
if pc.DisableRenewal == nil {
62-
return pc.globalClaims.IsDisableRenewal()
74+
func (c *Claimer) IsDisableRenewal() bool {
75+
if c.claims == nil || c.claims.DisableRenewal == nil {
76+
return *c.global.DisableRenewal
6377
}
64-
return *pc.DisableRenewal
78+
return *c.claims.DisableRenewal
6579
}
6680

6781
// Validate validates and modifies the Claims with default values.
68-
func (pc *Claims) Validate() error {
82+
func (c *Claimer) Validate() error {
6983
var (
70-
min = pc.MinTLSCertDuration()
71-
max = pc.MaxTLSCertDuration()
72-
def = pc.DefaultTLSCertDuration()
84+
min = c.MinTLSCertDuration()
85+
max = c.MaxTLSCertDuration()
86+
def = c.DefaultTLSCertDuration()
7387
)
7488
switch {
75-
case min == 0:
76-
return errors.Errorf("claims: MinTLSCertDuration cannot be empty")
77-
case max == 0:
78-
return errors.Errorf("claims: MaxTLSCertDuration cannot be empty")
79-
case def == 0:
80-
return errors.Errorf("claims: DefaultTLSCertDuration cannot be empty")
89+
case min <= 0:
90+
return errors.Errorf("claims: MinTLSCertDuration must be greater than 0")
91+
case max <= 0:
92+
return errors.Errorf("claims: MaxTLSCertDuration must be greater than 0")
93+
case def <= 0:
94+
return errors.Errorf("claims: DefaultTLSCertDuration must be greater than 0")
8195
case max < min:
8296
return errors.Errorf("claims: MaxCertDuration cannot be less "+
8397
"than MinCertDuration: MaxCertDuration - %v, MinCertDuration - %v", max, min)

authority/provisioner/jwk.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type JWK struct {
2323
Key *jose.JSONWebKey `json:"key"`
2424
EncryptedKey string `json:"encryptedKey,omitempty"`
2525
Claims *Claims `json:"claims,omitempty"`
26+
claimer *Claimer
2627
audiences []string
2728
}
2829

@@ -57,7 +58,12 @@ func (p *JWK) Init(config Config) (err error) {
5758
case p.Key == nil:
5859
return errors.New("provisioner key cannot be empty")
5960
}
60-
p.Claims, err = p.Claims.Init(&config.Claims)
61+
62+
// Update claims with global ones
63+
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
64+
return err
65+
}
66+
6167
p.audiences = config.Audiences
6268
return err
6369
}
@@ -104,15 +110,15 @@ func (p *JWK) Authorize(token string) ([]SignOption, error) {
104110
commonNameValidator(claims.Subject),
105111
dnsNamesValidator(dnsNames),
106112
ipAddressesValidator(ips),
107-
profileDefaultDuration(p.Claims.DefaultTLSCertDuration()),
113+
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
108114
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
109-
newValidityValidator(p.Claims.MinTLSCertDuration(), p.Claims.MaxTLSCertDuration()),
115+
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
110116
}, nil
111117
}
112118

113119
// AuthorizeRenewal returns an error if the renewal is disabled.
114120
func (p *JWK) AuthorizeRenewal(cert *x509.Certificate) error {
115-
if p.Claims.IsDisableRenewal() {
121+
if p.claimer.IsDisableRenewal() {
116122
return errors.Errorf("renew is disabled for provisioner %s", p.GetID())
117123
}
118124
return nil

authority/provisioner/jwk_test.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ func TestJWK_Init(t *testing.T) {
7878
err: errors.New("provisioner key cannot be empty"),
7979
}
8080
},
81+
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
82+
return ProvisionerValidateTest{
83+
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
84+
err: errors.New("claims: DefaultTLSCertDuration must be greater than 0"),
85+
}
86+
},
8187
"ok": func(t *testing.T) ProvisionerValidateTest {
8288
return ProvisionerValidateTest{
8389
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences},
@@ -201,10 +207,9 @@ func TestJWK_AuthorizeRenewal(t *testing.T) {
201207

202208
// disable renewal
203209
disable := true
204-
p2.Claims = &Claims{
205-
globalClaims: &globalProvisionerClaims,
206-
DisableRenewal: &disable,
207-
}
210+
p2.Claims = &Claims{DisableRenewal: &disable}
211+
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
212+
assert.FatalError(t, err)
208213

209214
type args struct {
210215
cert *x509.Certificate

authority/provisioner/oidc.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ type OIDC struct {
5555
Claims *Claims `json:"claims,omitempty"`
5656
configuration openIDConfiguration
5757
keyStore *keyStore
58+
claimer *Claimer
5859
}
5960

6061
// IsAdmin returns true if the given email is in the Admins whitelist, false
@@ -111,9 +112,10 @@ func (o *OIDC) Init(config Config) (err error) {
111112
}
112113

113114
// Update claims with global ones
114-
if o.Claims, err = o.Claims.Init(&config.Claims); err != nil {
115+
if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
115116
return err
116117
}
118+
117119
// Decode and validate openid-configuration endpoint
118120
if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil {
119121
return err
@@ -202,23 +204,23 @@ func (o *OIDC) Authorize(token string) ([]SignOption, error) {
202204
// Admins should be able to authorize any SAN
203205
if o.IsAdmin(claims.Email) {
204206
return []SignOption{
205-
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
207+
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
206208
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
207-
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
209+
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
208210
}, nil
209211
}
210212

211213
return []SignOption{
212214
emailOnlyIdentity(claims.Email),
213-
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
215+
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
214216
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
215-
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
217+
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
216218
}, nil
217219
}
218220

219221
// AuthorizeRenewal returns an error if the renewal is disabled.
220222
func (o *OIDC) AuthorizeRenewal(cert *x509.Certificate) error {
221-
if o.Claims.IsDisableRenewal() {
223+
if o.claimer.IsDisableRenewal() {
222224
return errors.Errorf("renew is disabled for provisioner %s", o.GetID())
223225
}
224226
return nil

authority/provisioner/oidc_test.go

+7-4
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ func TestOIDC_Init(t *testing.T) {
6464
config := Config{
6565
Claims: globalProvisionerClaims,
6666
}
67+
badClaims := &Claims{
68+
DefaultTLSDur: &Duration{0},
69+
}
6770

6871
type fields struct {
6972
Type string
@@ -93,6 +96,7 @@ func TestOIDC_Init(t *testing.T) {
9396
{"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
9497
{"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil}, args{config}, true},
9598
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil}, args{config}, true},
99+
{"bad-claims", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", badClaims, nil, nil}, args{config}, true},
96100
}
97101
for _, tt := range tests {
98102
t.Run(tt.name, func(t *testing.T) {
@@ -241,10 +245,9 @@ func TestOIDC_AuthorizeRenewal(t *testing.T) {
241245

242246
// disable renewal
243247
disable := true
244-
p2.Claims = &Claims{
245-
globalClaims: &globalProvisionerClaims,
246-
DisableRenewal: &disable,
247-
}
248+
p2.Claims = &Claims{DisableRenewal: &disable}
249+
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
250+
assert.FatalError(t, err)
248251

249252
type args struct {
250253
cert *x509.Certificate

authority/provisioner/utils_test.go

+10
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,18 @@ func generateJWK() (*JWK, error) {
109109
if err != nil {
110110
return nil, err
111111
}
112+
claimer, err := NewClaimer(nil, globalProvisionerClaims)
113+
if err != nil {
114+
return nil, err
115+
}
112116
return &JWK{
113117
Name: name,
114118
Type: "JWK",
115119
Key: &public,
116120
EncryptedKey: encrypted,
117121
Claims: &globalProvisionerClaims,
118122
audiences: testAudiences,
123+
claimer: claimer,
119124
}, nil
120125
}
121126

@@ -136,6 +141,10 @@ func generateOIDC() (*OIDC, error) {
136141
if err != nil {
137142
return nil, err
138143
}
144+
claimer, err := NewClaimer(nil, globalProvisionerClaims)
145+
if err != nil {
146+
return nil, err
147+
}
139148
return &OIDC{
140149
Name: name,
141150
Type: "OIDC",
@@ -150,6 +159,7 @@ func generateOIDC() (*OIDC, error) {
150159
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
151160
expiry: time.Now().Add(24 * time.Hour),
152161
},
162+
claimer: claimer,
153163
}, nil
154164
}
155165

0 commit comments

Comments
 (0)