Skip to content

Commit c76dad8

Browse files
committedFeb 8, 2024
Improve tests for CRL HTTP handler
1 parent 69f5f8d commit c76dad8

File tree

3 files changed

+105
-40
lines changed

3 files changed

+105
-40
lines changed
 

‎api/api_test.go

-39
Original file line numberDiff line numberDiff line change
@@ -789,45 +789,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) (
789789
return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err
790790
}
791791

792-
func Test_CRLGeneration(t *testing.T) {
793-
tests := []struct {
794-
name string
795-
err error
796-
statusCode int
797-
expected []byte
798-
}{
799-
{"empty", nil, http.StatusOK, nil},
800-
}
801-
802-
chiCtx := chi.NewRouteContext()
803-
req := httptest.NewRequest("GET", "http://example.com/crl", http.NoBody)
804-
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
805-
806-
for _, tt := range tests {
807-
t.Run(tt.name, func(t *testing.T) {
808-
mockMustAuthority(t, &mockAuthority{ret1: tt.expected, err: tt.err})
809-
w := httptest.NewRecorder()
810-
CRL(w, req)
811-
res := w.Result()
812-
813-
if res.StatusCode != tt.statusCode {
814-
t.Errorf("caHandler.CRL StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
815-
}
816-
817-
body, err := io.ReadAll(res.Body)
818-
res.Body.Close()
819-
if err != nil {
820-
t.Errorf("caHandler.Root unexpected error = %v", err)
821-
}
822-
if tt.statusCode == 200 {
823-
if !bytes.Equal(bytes.TrimSpace(body), tt.expected) {
824-
t.Errorf("caHandler.Root CRL = %s, wants %s", body, tt.expected)
825-
}
826-
}
827-
})
828-
}
829-
}
830-
831792
func Test_caHandler_Route(t *testing.T) {
832793
type fields struct {
833794
Authority Authority

‎api/crl.go

+12-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"time"
77

88
"github.com/smallstep/certificates/api/render"
9+
"github.com/smallstep/certificates/errs"
910
)
1011

1112
// CRL is an HTTP handler that returns the current CRL in DER or PEM format
@@ -16,7 +17,17 @@ func CRL(w http.ResponseWriter, r *http.Request) {
1617
return
1718
}
1819

19-
w.Header().Add("Expires", crlInfo.ExpiresAt.Format(time.RFC1123))
20+
if crlInfo == nil {
21+
render.Error(w, errs.New(http.StatusInternalServerError, "no CRL available"))
22+
return
23+
}
24+
25+
expires := crlInfo.ExpiresAt
26+
if expires.IsZero() {
27+
expires = time.Now()
28+
}
29+
30+
w.Header().Add("Expires", expires.Format(time.RFC1123))
2031

2132
_, formatAsPEM := r.URL.Query()["pem"]
2233
if formatAsPEM {

‎api/crl_test.go

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package api
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/pem"
7+
"io"
8+
"net/http"
9+
"net/http/httptest"
10+
"testing"
11+
"time"
12+
13+
"github.com/go-chi/chi/v5"
14+
"github.com/pkg/errors"
15+
"github.com/smallstep/certificates/authority"
16+
"github.com/smallstep/certificates/errs"
17+
"github.com/stretchr/testify/assert"
18+
"github.com/stretchr/testify/require"
19+
)
20+
21+
func Test_CRL(t *testing.T) {
22+
data := []byte{1, 2, 3, 4}
23+
pemData := pem.EncodeToMemory(&pem.Block{
24+
Type: "X509 CRL",
25+
Bytes: data,
26+
})
27+
pemData = bytes.TrimSpace(pemData)
28+
emptyPEMData := pem.EncodeToMemory(&pem.Block{
29+
Type: "X509 CRL",
30+
Bytes: nil,
31+
})
32+
emptyPEMData = bytes.TrimSpace(emptyPEMData)
33+
tests := []struct {
34+
name string
35+
url string
36+
err error
37+
statusCode int
38+
crlInfo *authority.CertificateRevocationListInfo
39+
expectedBody []byte
40+
expectedHeaders http.Header
41+
expectedErrorJSON string
42+
}{
43+
{"ok", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, data, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.der"`}}, ""},
44+
{"ok/pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: data}, pemData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""},
45+
{"ok/empty", "http://example.com/crl", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, nil, http.Header{"Content-Type": []string{"application/pkix-crl"}, "Content-Disposition": []string{`attachment; filename="crl.der"`}}, ""},
46+
{"ok/empty-pem", "http://example.com/crl?pem=true", nil, http.StatusOK, &authority.CertificateRevocationListInfo{Data: nil}, emptyPEMData, http.Header{"Content-Type": []string{"application/x-pem-file"}, "Content-Disposition": []string{`attachment; filename="crl.pem"`}}, ""},
47+
{"fail/internal", "http://example.com/crl", errs.Wrap(http.StatusInternalServerError, errors.New("failure"), "authority.GetCertificateRevocationList"), http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"The certificate authority encountered an Internal Server Error. Please see the certificate authority logs for more info."}`},
48+
{"fail/nil", "http://example.com/crl", nil, http.StatusInternalServerError, nil, nil, http.Header{}, `{"status":500,"message":"no CRL available"}`},
49+
}
50+
51+
for _, tt := range tests {
52+
t.Run(tt.name, func(t *testing.T) {
53+
mockMustAuthority(t, &mockAuthority{ret1: tt.crlInfo, err: tt.err})
54+
55+
chiCtx := chi.NewRouteContext()
56+
req := httptest.NewRequest("GET", tt.url, http.NoBody)
57+
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
58+
w := httptest.NewRecorder()
59+
CRL(w, req)
60+
res := w.Result()
61+
62+
assert.Equal(t, tt.statusCode, res.StatusCode)
63+
64+
body, err := io.ReadAll(res.Body)
65+
res.Body.Close()
66+
require.NoError(t, err)
67+
68+
if tt.statusCode >= 300 {
69+
assert.JSONEq(t, tt.expectedErrorJSON, string(bytes.TrimSpace(body)))
70+
return
71+
}
72+
73+
// check expected header values
74+
for _, h := range []string{"content-type", "content-disposition"} {
75+
v := tt.expectedHeaders.Get(h)
76+
require.NotEmpty(t, v)
77+
78+
actual := res.Header.Get(h)
79+
assert.Equal(t, v, actual)
80+
}
81+
82+
// check expires header value
83+
assert.NotEmpty(t, res.Header.Get("expires"))
84+
t1, err := time.Parse(time.RFC1123, res.Header.Get("expires"))
85+
if assert.NoError(t, err) {
86+
assert.False(t, t1.IsZero())
87+
}
88+
89+
// check body contents
90+
assert.Equal(t, tt.expectedBody, bytes.TrimSpace(body))
91+
})
92+
}
93+
}

0 commit comments

Comments
 (0)