Skip to content

Commit dccbdf3

Browse files
committed
Introduce generalized statusCoder errors and loads of ssh unit tests.
* StatusCoder api errors that have friendly user messages. * Unit tests for SSH sign/renew/rekey/revoke across all provisioners.
1 parent 549291c commit dccbdf3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+5211
-2120
lines changed

api/api.go

-108
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"crypto/dsa"
66
"crypto/ecdsa"
77
"crypto/rsa"
8-
"crypto/tls"
98
"crypto/x509"
109
"encoding/asn1"
1110
"encoding/base64"
@@ -209,14 +208,6 @@ type RootResponse struct {
209208
RootPEM Certificate `json:"ca"`
210209
}
211210

212-
// SignRequest is the request body for a certificate signature request.
213-
type SignRequest struct {
214-
CsrPEM CertificateRequest `json:"csr"`
215-
OTT string `json:"ott"`
216-
NotAfter TimeDuration `json:"notAfter"`
217-
NotBefore TimeDuration `json:"notBefore"`
218-
}
219-
220211
// ProvisionersResponse is the response object that returns the list of
221212
// provisioners.
222213
type ProvisionersResponse struct {
@@ -230,31 +221,6 @@ type ProvisionerKeyResponse struct {
230221
Key string `json:"key"`
231222
}
232223

233-
// Validate checks the fields of the SignRequest and returns nil if they are ok
234-
// or an error if something is wrong.
235-
func (s *SignRequest) Validate() error {
236-
if s.CsrPEM.CertificateRequest == nil {
237-
return errs.BadRequest(errors.New("missing csr"))
238-
}
239-
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
240-
return errs.BadRequest(errors.Wrap(err, "invalid csr"))
241-
}
242-
if s.OTT == "" {
243-
return errs.BadRequest(errors.New("missing ott"))
244-
}
245-
246-
return nil
247-
}
248-
249-
// SignResponse is the response object of the certificate signature request.
250-
type SignResponse struct {
251-
ServerPEM Certificate `json:"crt"`
252-
CaPEM Certificate `json:"ca"`
253-
CertChainPEM []Certificate `json:"certChain"`
254-
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
255-
TLS *tls.ConnectionState `json:"-"`
256-
}
257-
258224
// RootsResponse is the response object of the roots request.
259225
type RootsResponse struct {
260226
Certificates []Certificate `json:"crts"`
@@ -344,80 +310,6 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
344310
return certChainPEM
345311
}
346312

347-
// Sign is an HTTP handler that reads a certificate request and an
348-
// one-time-token (ott) from the body and creates a new certificate with the
349-
// information in the certificate request.
350-
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
351-
var body SignRequest
352-
if err := ReadJSON(r.Body, &body); err != nil {
353-
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
354-
return
355-
}
356-
357-
logOtt(w, body.OTT)
358-
if err := body.Validate(); err != nil {
359-
WriteError(w, err)
360-
return
361-
}
362-
363-
opts := provisioner.Options{
364-
NotBefore: body.NotBefore,
365-
NotAfter: body.NotAfter,
366-
}
367-
368-
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
369-
if err != nil {
370-
WriteError(w, errs.Unauthorized(err))
371-
return
372-
}
373-
374-
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
375-
if err != nil {
376-
WriteError(w, errs.Forbidden(err))
377-
return
378-
}
379-
certChainPEM := certChainToPEM(certChain)
380-
var caPEM Certificate
381-
if len(certChainPEM) > 0 {
382-
caPEM = certChainPEM[1]
383-
}
384-
logCertificate(w, certChain[0])
385-
JSONStatus(w, &SignResponse{
386-
ServerPEM: certChainPEM[0],
387-
CaPEM: caPEM,
388-
CertChainPEM: certChainPEM,
389-
TLSOptions: h.Authority.GetTLSOptions(),
390-
}, http.StatusCreated)
391-
}
392-
393-
// Renew uses the information of certificate in the TLS connection to create a
394-
// new one.
395-
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
396-
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
397-
WriteError(w, errs.BadRequest(errors.New("missing peer certificate")))
398-
return
399-
}
400-
401-
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
402-
if err != nil {
403-
WriteError(w, errs.Forbidden(err))
404-
return
405-
}
406-
certChainPEM := certChainToPEM(certChain)
407-
var caPEM Certificate
408-
if len(certChainPEM) > 0 {
409-
caPEM = certChainPEM[1]
410-
}
411-
412-
logCertificate(w, certChain[0])
413-
JSONStatus(w, &SignResponse{
414-
ServerPEM: certChainPEM[0],
415-
CaPEM: caPEM,
416-
CertChainPEM: certChainPEM,
417-
TLSOptions: h.Authority.GetTLSOptions(),
418-
}, http.StatusCreated)
419-
}
420-
421313
// Provisioners returns the list of provisioners configured in the authority.
422314
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
423315
cursor, limit, err := parseCursor(r)

api/api_test.go

+19-12
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/smallstep/assert"
2929
"github.com/smallstep/certificates/authority"
3030
"github.com/smallstep/certificates/authority/provisioner"
31+
"github.com/smallstep/certificates/errs"
3132
"github.com/smallstep/certificates/logging"
3233
"github.com/smallstep/certificates/sshutil"
3334
"github.com/smallstep/certificates/templates"
@@ -914,7 +915,7 @@ func Test_caHandler_Renew(t *testing.T) {
914915
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
915916
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
916917
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
917-
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
918+
{"renew error", cs, nil, nil, errs.Forbidden(fmt.Errorf("an error")), http.StatusForbidden},
918919
}
919920

920921
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
@@ -934,13 +935,13 @@ func Test_caHandler_Renew(t *testing.T) {
934935
res := w.Result()
935936

936937
if res.StatusCode != tt.statusCode {
937-
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
938+
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
938939
}
939940

940941
body, err := ioutil.ReadAll(res.Body)
941942
res.Body.Close()
942943
if err != nil {
943-
t.Errorf("caHandler.Root unexpected error = %v", err)
944+
t.Errorf("caHandler.Renew unexpected error = %v", err)
944945
}
945946
if tt.statusCode < http.StatusBadRequest {
946947
if !bytes.Equal(bytes.TrimSpace(body), expected) {
@@ -1009,8 +1010,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
10091010
t.Fatal(err)
10101011
}
10111012

1012-
expectedError400 := []byte(`{"status":400,"message":"Bad Request"}`)
1013-
expectedError500 := []byte(`{"status":500,"message":"Internal Server Error"}`)
1013+
expectedError400 := errs.BadRequest(errors.New("force"))
1014+
expectedError400Bytes, err := json.Marshal(expectedError400)
1015+
assert.FatalError(t, err)
1016+
expectedError500 := errs.InternalServerError(errors.New("force"))
1017+
expectedError500Bytes, err := json.Marshal(expectedError500)
1018+
assert.FatalError(t, err)
10141019
for _, tt := range tests {
10151020
t.Run(tt.name, func(t *testing.T) {
10161021
h := &caHandler{
@@ -1035,12 +1040,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
10351040
} else {
10361041
switch tt.statusCode {
10371042
case 400:
1038-
if !bytes.Equal(bytes.TrimSpace(body), expectedError400) {
1039-
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400)
1043+
if !bytes.Equal(bytes.TrimSpace(body), expectedError400Bytes) {
1044+
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError400Bytes)
10401045
}
10411046
case 500:
1042-
if !bytes.Equal(bytes.TrimSpace(body), expectedError500) {
1043-
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500)
1047+
if !bytes.Equal(bytes.TrimSpace(body), expectedError500Bytes) {
1048+
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError500Bytes)
10441049
}
10451050
default:
10461051
t.Errorf("caHandler.Provisioner unexpected status code = %d", tt.statusCode)
@@ -1077,7 +1082,9 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
10771082
}
10781083

10791084
expected := []byte(`{"key":"` + privKey + `"}`)
1080-
expectedError := []byte(`{"status":404,"message":"Not Found"}`)
1085+
expectedError404 := errs.NotFound(errors.New("force"))
1086+
expectedError404Bytes, err := json.Marshal(expectedError404)
1087+
assert.FatalError(t, err)
10811088

10821089
for _, tt := range tests {
10831090
t.Run(tt.name, func(t *testing.T) {
@@ -1101,8 +1108,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
11011108
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expected)
11021109
}
11031110
} else {
1104-
if !bytes.Equal(bytes.TrimSpace(body), expectedError) {
1105-
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError)
1111+
if !bytes.Equal(bytes.TrimSpace(body), expectedError404Bytes) {
1112+
t.Errorf("caHandler.Provisioners Body = %s, wants %s", body, expectedError404Bytes)
11061113
}
11071114
}
11081115
})

api/renew.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package api
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/pkg/errors"
7+
"github.com/smallstep/certificates/errs"
8+
)
9+
10+
// Renew uses the information of certificate in the TLS connection to create a
11+
// new one.
12+
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
13+
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
14+
WriteError(w, errs.BadRequest(errors.New("missing peer certificate")))
15+
return
16+
}
17+
18+
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
19+
if err != nil {
20+
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
21+
return
22+
}
23+
certChainPEM := certChainToPEM(certChain)
24+
var caPEM Certificate
25+
if len(certChainPEM) > 0 {
26+
caPEM = certChainPEM[1]
27+
}
28+
29+
logCertificate(w, certChain[0])
30+
JSONStatus(w, &SignResponse{
31+
ServerPEM: certChainPEM[0],
32+
CaPEM: caPEM,
33+
CertChainPEM: certChainPEM,
34+
TLSOptions: h.Authority.GetTLSOptions(),
35+
}, http.StatusCreated)
36+
}

api/sign.go

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package api
2+
3+
import (
4+
"crypto/tls"
5+
"net/http"
6+
7+
"github.com/pkg/errors"
8+
"github.com/smallstep/certificates/authority/provisioner"
9+
"github.com/smallstep/certificates/errs"
10+
"github.com/smallstep/cli/crypto/tlsutil"
11+
)
12+
13+
// SignRequest is the request body for a certificate signature request.
14+
type SignRequest struct {
15+
CsrPEM CertificateRequest `json:"csr"`
16+
OTT string `json:"ott"`
17+
NotAfter TimeDuration `json:"notAfter"`
18+
NotBefore TimeDuration `json:"notBefore"`
19+
}
20+
21+
// Validate checks the fields of the SignRequest and returns nil if they are ok
22+
// or an error if something is wrong.
23+
func (s *SignRequest) Validate() error {
24+
if s.CsrPEM.CertificateRequest == nil {
25+
return errs.BadRequest(errors.New("missing csr"))
26+
}
27+
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
28+
return errs.BadRequest(errors.Wrap(err, "invalid csr"))
29+
}
30+
if s.OTT == "" {
31+
return errs.BadRequest(errors.New("missing ott"))
32+
}
33+
34+
return nil
35+
}
36+
37+
// SignResponse is the response object of the certificate signature request.
38+
type SignResponse struct {
39+
ServerPEM Certificate `json:"crt"`
40+
CaPEM Certificate `json:"ca"`
41+
CertChainPEM []Certificate `json:"certChain"`
42+
TLSOptions *tlsutil.TLSOptions `json:"tlsOptions,omitempty"`
43+
TLS *tls.ConnectionState `json:"-"`
44+
}
45+
46+
// Sign is an HTTP handler that reads a certificate request and an
47+
// one-time-token (ott) from the body and creates a new certificate with the
48+
// information in the certificate request.
49+
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
50+
var body SignRequest
51+
if err := ReadJSON(r.Body, &body); err != nil {
52+
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
53+
return
54+
}
55+
56+
logOtt(w, body.OTT)
57+
if err := body.Validate(); err != nil {
58+
WriteError(w, err)
59+
return
60+
}
61+
62+
opts := provisioner.Options{
63+
NotBefore: body.NotBefore,
64+
NotAfter: body.NotAfter,
65+
}
66+
67+
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
68+
if err != nil {
69+
WriteError(w, errs.Unauthorized(err))
70+
return
71+
}
72+
73+
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
74+
if err != nil {
75+
WriteError(w, errs.Forbidden(err))
76+
return
77+
}
78+
certChainPEM := certChainToPEM(certChain)
79+
var caPEM Certificate
80+
if len(certChainPEM) > 0 {
81+
caPEM = certChainPEM[1]
82+
}
83+
logCertificate(w, certChain[0])
84+
JSONStatus(w, &SignResponse{
85+
ServerPEM: certChainPEM[0],
86+
CaPEM: caPEM,
87+
CertChainPEM: certChainPEM,
88+
TLSOptions: h.Authority.GetTLSOptions(),
89+
}, http.StatusCreated)
90+
}

api/ssh.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
282282
ValidAfter: body.ValidAfter,
283283
}
284284

285-
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod)
285+
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
286286
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
287287
if err != nil {
288288
WriteError(w, errs.Unauthorized(err))

api/sshRekey.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
5656
return
5757
}
5858

59-
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RekeySSHMethod)
59+
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod)
6060
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
6161
if err != nil {
6262
WriteError(w, errs.Unauthorized(err))
6363
return
6464
}
65-
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT)
65+
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
6666
if err != nil {
6767
WriteError(w, errs.InternalServerError(err))
6868
}

api/sshRenew.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
4646
return
4747
}
4848

49-
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RenewSSHMethod)
49+
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod)
5050
_, err := h.Authority.Authorize(ctx, body.OTT)
5151
if err != nil {
5252
WriteError(w, errs.Unauthorized(err))
5353
return
5454
}
55-
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT)
55+
oldCert, _, err := provisioner.ExtractSSHPOPCert(body.OTT)
5656
if err != nil {
5757
WriteError(w, errs.InternalServerError(err))
5858
}

0 commit comments

Comments
 (0)