Skip to content

Commit 8c8db0d

Browse files
committed
Modify errs.BadRequestErr() to always return an error to the client.
1 parent 8ce807a commit 8c8db0d

14 files changed

+65
-46
lines changed

api/api.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
318318
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
319319
cursor, limit, err := ParseCursor(r)
320320
if err != nil {
321-
WriteError(w, errs.BadRequestErr(err))
321+
WriteError(w, err)
322322
return
323323
}
324324

@@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
435435
if v := q.Get("limit"); len(v) > 0 {
436436
limit, err = strconv.Atoi(v)
437437
if err != nil {
438-
return "", 0, errors.Wrapf(err, "error converting %s to integer", v)
438+
return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v)
439439
}
440440
}
441441
return

api/api_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,9 @@ func TestSignRequest_Validate(t *testing.T) {
403403
fields fields
404404
err error
405405
}{
406-
{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing csr.")},
406+
{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")},
407407
{"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")},
408-
{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing ott.")},
408+
{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")},
409409
}
410410
for _, tt := range tests {
411411
t.Run(tt.name, func(t *testing.T) {
@@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
10871087
t.Fatal(err)
10881088
}
10891089

1090-
expectedError400 := errs.BadRequestErr(errors.New("force"))
1090+
expectedError400 := errs.BadRequest("limit 'abc' is not an integer")
10911091
expectedError400Bytes, err := json.Marshal(expectedError400)
10921092
assert.FatalError(t, err)
10931093
expectedError500 := errs.InternalServer("force")

api/revoke_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ func TestRevokeRequestValidate(t *testing.T) {
2828
tests := map[string]test{
2929
"error/missing serial": {
3030
rr: &RevokeRequest{},
31-
err: &errs.Error{Err: errors.New("The request could not be completed: missing serial."), Status: http.StatusBadRequest},
31+
err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest},
3232
},
3333
"error/bad reasonCode": {
3434
rr: &RevokeRequest{
3535
Serial: "sn",
3636
ReasonCode: 15,
3737
Passive: true,
3838
},
39-
err: &errs.Error{Err: errors.New("The request could not be completed: reasonCode out of bounds."), Status: http.StatusBadRequest},
39+
err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest},
4040
},
4141
"error/non-passive not implemented": {
4242
rr: &RevokeRequest{

api/sign.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error {
2626
return errs.BadRequest("missing csr")
2727
}
2828
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
29-
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
29+
return errs.BadRequestErr(err, "invalid csr")
3030
}
3131
if s.OTT == "" {
3232
return errs.BadRequest("missing ott")
@@ -50,7 +50,7 @@ type SignResponse struct {
5050
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
5151
var body SignRequest
5252
if err := ReadJSON(r.Body, &body); err != nil {
53-
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
53+
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
5454
return
5555
}
5656

api/ssh.go

+12-12
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ type SSHSignRequest struct {
4949
func (s *SSHSignRequest) Validate() error {
5050
switch {
5151
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
52-
return errors.Errorf("unknown certType %s", s.CertType)
52+
return errs.BadRequest("invalid certType '%s'", s.CertType)
5353
case len(s.PublicKey) == 0:
54-
return errors.New("missing or empty publicKey")
54+
return errs.BadRequest("missing or empty publicKey")
5555
case s.OTT == "":
56-
return errors.New("missing or empty ott")
56+
return errs.BadRequest("missing or empty ott")
5757
default:
5858
// Validate identity signature if provided
5959
if s.IdentityCSR.CertificateRequest != nil {
6060
if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil {
61-
return errors.Wrap(err, "invalid identityCSR")
61+
return errs.BadRequestErr(err, "invalid identityCSR")
6262
}
6363
}
6464
return nil
@@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error {
185185
case provisioner.SSHUserCert, provisioner.SSHHostCert:
186186
return nil
187187
default:
188-
return errors.Errorf("unsupported type %s", r.Type)
188+
return errs.BadRequest("invalid type '%s'", r.Type)
189189
}
190190
}
191191

@@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct {
208208
func (r *SSHCheckPrincipalRequest) Validate() error {
209209
switch {
210210
case r.Type != provisioner.SSHHostCert:
211-
return errors.Errorf("unsupported type %s", r.Type)
211+
return errs.BadRequest("unsupported type '%s'", r.Type)
212212
case r.Principal == "":
213-
return errors.New("missing or empty principal")
213+
return errs.BadRequest("missing or empty principal")
214214
default:
215215
return nil
216216
}
@@ -232,7 +232,7 @@ type SSHBastionRequest struct {
232232
// Validate checks the values of the SSHBastionRequest.
233233
func (r *SSHBastionRequest) Validate() error {
234234
if r.Hostname == "" {
235-
return errors.New("missing or empty hostname")
235+
return errs.BadRequest("missing or empty hostname")
236236
}
237237
return nil
238238
}
@@ -256,7 +256,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
256256

257257
logOtt(w, body.OTT)
258258
if err := body.Validate(); err != nil {
259-
WriteError(w, errs.BadRequestErr(err))
259+
WriteError(w, err)
260260
return
261261
}
262262

@@ -398,7 +398,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
398398
return
399399
}
400400
if err := body.Validate(); err != nil {
401-
WriteError(w, errs.BadRequestErr(err))
401+
WriteError(w, err)
402402
return
403403
}
404404

@@ -430,7 +430,7 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
430430
return
431431
}
432432
if err := body.Validate(); err != nil {
433-
WriteError(w, errs.BadRequestErr(err))
433+
WriteError(w, err)
434434
return
435435
}
436436

@@ -469,7 +469,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
469469
return
470470
}
471471
if err := body.Validate(); err != nil {
472-
WriteError(w, errs.BadRequestErr(err))
472+
WriteError(w, err)
473473
return
474474
}
475475

api/sshRekey.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"net/http"
55
"time"
66

7-
"github.com/pkg/errors"
87
"github.com/smallstep/certificates/authority/provisioner"
98
"github.com/smallstep/certificates/errs"
109
"golang.org/x/crypto/ssh"
@@ -20,9 +19,9 @@ type SSHRekeyRequest struct {
2019
func (s *SSHRekeyRequest) Validate() error {
2120
switch {
2221
case s.OTT == "":
23-
return errors.New("missing or empty ott")
22+
return errs.BadRequest("missing or empty ott")
2423
case len(s.PublicKey) == 0:
25-
return errors.New("missing or empty public key")
24+
return errs.BadRequest("missing or empty public key")
2625
default:
2726
return nil
2827
}
@@ -46,7 +45,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
4645

4746
logOtt(w, body.OTT)
4847
if err := body.Validate(); err != nil {
49-
WriteError(w, errs.BadRequestErr(err))
48+
WriteError(w, err)
5049
return
5150
}
5251

api/sshRenew.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type SSHRenewRequest struct {
1919
func (s *SSHRenewRequest) Validate() error {
2020
switch {
2121
case s.OTT == "":
22-
return errors.New("missing or empty ott")
22+
return errs.BadRequest("missing or empty ott")
2323
default:
2424
return nil
2525
}
@@ -43,7 +43,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
4343

4444
logOtt(w, body.OTT)
4545
if err := body.Validate(); err != nil {
46-
WriteError(w, errs.BadRequestErr(err))
46+
WriteError(w, err)
4747
return
4848
}
4949

authority/provisioner/sshpop_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
258258
p: p,
259259
token: tok,
260260
code: http.StatusBadRequest,
261-
err: errors.New("The request could not be completed: sshpop token subject must be equivalent to sshpop certificate serial number."),
261+
err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"),
262262
}
263263
},
264264
"ok": func(t *testing.T) test {
@@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
337337
p: p,
338338
token: tok,
339339
code: http.StatusBadRequest,
340-
err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."),
340+
err: errors.New("sshpop certificate must be a host ssh certificate"),
341341
}
342342
},
343343
"ok": func(t *testing.T) test {
@@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
419419
p: p,
420420
token: tok,
421421
code: http.StatusBadRequest,
422-
err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."),
422+
err: errors.New("sshpop certificate must be a host ssh certificate"),
423423
}
424424
},
425425
"ok": func(t *testing.T) test {

authority/ssh.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
9494

9595
// Check for required variables.
9696
if err := t.ValidateRequiredData(data); err != nil {
97-
return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set <key=value>` flag", err))
97+
return nil, errs.BadRequestErr(err, "%v, please use `--set <key=value>` flag", err)
9898
}
9999

100100
o, err := t.Output(mergedData)

authority/ssh_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
912912
cert: &ssh.Certificate{},
913913
key: pub,
914914
signOpts: []provisioner.SignOption{},
915-
err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."),
915+
err: errors.New("cannot rekey a certificate without validity period"),
916916
code: http.StatusBadRequest,
917917
}
918918
},
@@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
923923
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
924924
key: pub,
925925
signOpts: []provisioner.SignOption{},
926-
err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."),
926+
err: errors.New("cannot rekey a certificate without validity period"),
927927
code: http.StatusBadRequest,
928928
}
929929
},
@@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
956956
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
957957
key: pub,
958958
signOpts: []provisioner.SignOption{},
959-
err: errors.New("The request could not be completed: unexpected certificate type '0'."),
959+
err: errors.New("unexpected certificate type '0'"),
960960
code: http.StatusBadRequest,
961961
}
962962
},

authority/tls_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) {
11871187
Reason: reason,
11881188
OTT: raw,
11891189
},
1190-
err: errors.New("The request could not be completed: certificate with serial number 'sn' is already revoked"),
1190+
err: errors.New("certificate with serial number 'sn' is already revoked"),
11911191
code: http.StatusBadRequest,
11921192
checkErrDetails: func(err *errs.Error) {
11931193
assert.Equals(t, err.Details["token"], raw)

ca/ca_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func TestCASign(t *testing.T) {
115115
ca: ca,
116116
body: "invalid json",
117117
status: http.StatusBadRequest,
118-
errMsg: errs.BadRequestDefaultMsg,
118+
errMsg: errs.BadRequestPrefix,
119119
}
120120
},
121121
"fail invalid-csr-sig": func(t *testing.T) *signTest {
@@ -153,7 +153,7 @@ ZEp7knvU2psWRw==
153153
ca: ca,
154154
body: string(body),
155155
status: http.StatusBadRequest,
156-
errMsg: errs.BadRequestDefaultMsg,
156+
errMsg: errs.BadRequestPrefix,
157157
}
158158
},
159159
"fail unauthorized-ott": func(t *testing.T) *signTest {

ca/client.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -1108,8 +1108,7 @@ retry:
11081108
retried = true
11091109
goto retry
11101110
}
1111-
1112-
return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body))
1111+
return nil, readError(resp.Body)
11131112
}
11141113
var check api.SSHCheckPrincipalResponse
11151114
if err := readJSON(resp.Body, &check); err != nil {

errs/error.go

+28-7
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type Option func(e *Error) error
2525
// message only if it is empty.
2626
func withDefaultMessage(format string, args ...interface{}) Option {
2727
return func(e *Error) error {
28-
if len(e.Msg) > 0 {
28+
if e.Msg != "" {
2929
return e
3030
}
3131
e.Msg = fmt.Sprintf(format, args...)
@@ -164,7 +164,8 @@ type Messenger interface {
164164
func StatusCodeError(code int, e error, opts ...Option) error {
165165
switch code {
166166
case http.StatusBadRequest:
167-
return BadRequestErr(e, opts...)
167+
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
168+
return NewErr(http.StatusBadRequest, e, opts...)
168169
case http.StatusUnauthorized:
169170
return UnauthorizedErr(e, opts...)
170171
case http.StatusForbidden:
@@ -200,6 +201,15 @@ var (
200201
BadRequestPrefix = "The request could not be completed: "
201202
)
202203

204+
func formatMessage(status int, msg string) string {
205+
switch status {
206+
case http.StatusBadRequest:
207+
return BadRequestPrefix + msg + "."
208+
default:
209+
return msg
210+
}
211+
}
212+
203213
// splitOptionArgs splits the variadic length args into string formatting args
204214
// and Option(s) to apply to an Error.
205215
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
@@ -229,11 +239,24 @@ func New(status int, format string, args ...interface{}) error {
229239
msg := fmt.Sprintf(format, args...)
230240
return &Error{
231241
Status: status,
232-
Msg: msg,
242+
Msg: formatMessage(status, msg),
233243
Err: errors.New(msg),
234244
}
235245
}
236246

247+
// NewError creates a new http error with the given error and message.
248+
func NewError(status int, err error, format string, args ...interface{}) error {
249+
msg := fmt.Sprintf(format, args...)
250+
if _, ok := err.(StackTracer); !ok {
251+
err = errors.Wrap(err, msg)
252+
}
253+
return &Error{
254+
Status: status,
255+
Msg: formatMessage(status, msg),
256+
Err: err,
257+
}
258+
}
259+
237260
// NewErr returns a new Error. If the given error implements the StatusCoder
238261
// interface we will ignore the given status.
239262
func NewErr(status int, err error, opts ...Option) error {
@@ -308,14 +331,12 @@ func NotImplementedErr(err error, opts ...Option) error {
308331

309332
// BadRequest creates a 400 error with the given format and arguments.
310333
func BadRequest(format string, args ...interface{}) error {
311-
format = BadRequestPrefix + format + "."
312334
return New(http.StatusBadRequest, format, args...)
313335
}
314336

315337
// BadRequestErr returns an 400 error with the given error.
316-
func BadRequestErr(err error, opts ...Option) error {
317-
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
318-
return NewErr(http.StatusBadRequest, err, opts...)
338+
func BadRequestErr(err error, format string, args ...interface{}) error {
339+
return NewError(http.StatusBadRequest, err, format, args...)
319340
}
320341

321342
// Unauthorized creates a 401 error with the given format and arguments.

0 commit comments

Comments
 (0)