Skip to content

Commit 7b6a3ea

Browse files
committed
Add client methods for provisioning endpoints.
1 parent 378166a commit 7b6a3ea

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

ca/client.go

+37
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,43 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
286286
return &sign, nil
287287
}
288288

289+
// Provisioners performs the provisioners request to the CA and returns the
290+
// api.ProvisionersResponse struct with a map of provisioners.
291+
func (c *Client) Provisioners() (*api.ProvisionersResponse, error) {
292+
u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners"})
293+
resp, err := c.client.Get(u.String())
294+
if err != nil {
295+
return nil, errors.Wrapf(err, "client GET %s failed", u)
296+
}
297+
if resp.StatusCode >= 400 {
298+
return nil, readError(resp.Body)
299+
}
300+
var provisioners api.ProvisionersResponse
301+
if err := readJSON(resp.Body, &provisioners); err != nil {
302+
return nil, errors.Wrapf(err, "error reading %s", u)
303+
}
304+
return &provisioners, nil
305+
}
306+
307+
// ProvisionerKey performs the request to the CA to get the encrypted key for
308+
// the given provisioner kid and returns the api.ProvisionerKeyResponse struct
309+
// with the encrypted key.
310+
func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) {
311+
u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"})
312+
resp, err := c.client.Get(u.String())
313+
if err != nil {
314+
return nil, errors.Wrapf(err, "client GET %s failed", u)
315+
}
316+
if resp.StatusCode >= 400 {
317+
return nil, readError(resp.Body)
318+
}
319+
var key api.ProvisionerKeyResponse
320+
if err := readJSON(resp.Body, &key); err != nil {
321+
return nil, errors.Wrapf(err, "error reading %s", u)
322+
}
323+
return &key, nil
324+
}
325+
289326
// CreateSignRequest is a helper function that given an x509 OTT returns a
290327
// simple but secure sign request as well as the private key used.
291328
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {

ca/client_test.go

+120
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"time"
1414

1515
"github.com/smallstep/ca-component/api"
16+
"github.com/smallstep/cli/jose"
1617
)
1718

1819
const (
@@ -386,3 +387,122 @@ func TestClient_Renew(t *testing.T) {
386387
})
387388
}
388389
}
390+
391+
func TestClient_Provisioners(t *testing.T) {
392+
ok := &api.ProvisionersResponse{
393+
Provisioners: map[string]*jose.JSONWebKeySet{},
394+
}
395+
internalServerError := api.InternalServerError(fmt.Errorf("Internal Server Error"))
396+
397+
tests := []struct {
398+
name string
399+
response interface{}
400+
responseCode int
401+
wantErr bool
402+
}{
403+
{"ok", ok, 200, false},
404+
{"fail", internalServerError, 500, true},
405+
}
406+
407+
srv := httptest.NewServer(nil)
408+
defer srv.Close()
409+
410+
for _, tt := range tests {
411+
t.Run(tt.name, func(t *testing.T) {
412+
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
413+
if err != nil {
414+
t.Errorf("NewClient() error = %v", err)
415+
return
416+
}
417+
418+
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
419+
expected := "/provisioners"
420+
if req.RequestURI != expected {
421+
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
422+
}
423+
w.WriteHeader(tt.responseCode)
424+
api.JSON(w, tt.response)
425+
})
426+
427+
got, err := c.Provisioners()
428+
if (err != nil) != tt.wantErr {
429+
t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr)
430+
return
431+
}
432+
433+
switch {
434+
case err != nil:
435+
if got != nil {
436+
t.Errorf("Client.Provisioners() = %v, want nil", got)
437+
}
438+
if !reflect.DeepEqual(err, tt.response) {
439+
t.Errorf("Client.Provisioners() error = %v, want %v", err, tt.response)
440+
}
441+
default:
442+
if !reflect.DeepEqual(got, tt.response) {
443+
t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response)
444+
}
445+
}
446+
})
447+
}
448+
}
449+
450+
func TestClient_ProvisionerKey(t *testing.T) {
451+
ok := &api.ProvisionerKeyResponse{
452+
Key: "an encrypted key",
453+
}
454+
notFound := api.NotFound(fmt.Errorf("Not Found"))
455+
456+
tests := []struct {
457+
name string
458+
kid string
459+
response interface{}
460+
responseCode int
461+
wantErr bool
462+
}{
463+
{"ok", "kid", ok, 200, false},
464+
{"fail", "invalid", notFound, 500, true},
465+
}
466+
467+
srv := httptest.NewServer(nil)
468+
defer srv.Close()
469+
470+
for _, tt := range tests {
471+
t.Run(tt.name, func(t *testing.T) {
472+
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
473+
if err != nil {
474+
t.Errorf("NewClient() error = %v", err)
475+
return
476+
}
477+
478+
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
479+
expected := "/provisioners/" + tt.kid + "/encrypted-key"
480+
if req.RequestURI != expected {
481+
t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected)
482+
}
483+
w.WriteHeader(tt.responseCode)
484+
api.JSON(w, tt.response)
485+
})
486+
487+
got, err := c.ProvisionerKey(tt.kid)
488+
if (err != nil) != tt.wantErr {
489+
t.Errorf("Client.ProvisionerKey() error = %v, wantErr %v", err, tt.wantErr)
490+
return
491+
}
492+
493+
switch {
494+
case err != nil:
495+
if got != nil {
496+
t.Errorf("Client.ProvisionerKey() = %v, want nil", got)
497+
}
498+
if !reflect.DeepEqual(err, tt.response) {
499+
t.Errorf("Client.ProvisionerKey() error = %v, want %v", err, tt.response)
500+
}
501+
default:
502+
if !reflect.DeepEqual(got, tt.response) {
503+
t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response)
504+
}
505+
}
506+
})
507+
}
508+
}

0 commit comments

Comments
 (0)