Skip to content

Commit a6b8e65

Browse files
committed
Retrieve the authority from the context in api methods.
1 parent 900a640 commit a6b8e65

File tree

9 files changed

+121
-90
lines changed

9 files changed

+121
-90
lines changed

api/api.go

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@ type Authority interface {
5252
Version() authority.Version
5353
}
5454

55+
var errAuthority = errors.New("authority is not in context")
56+
57+
func mustAuthority(ctx context.Context) Authority {
58+
a, ok := authority.FromContext(ctx)
59+
if !ok {
60+
panic(errAuthority)
61+
}
62+
return a
63+
}
64+
5565
// TimeDuration is an alias of provisioner.TimeDuration
5666
type TimeDuration = provisioner.TimeDuration
5767

@@ -251,58 +261,58 @@ func New(auth Authority) RouterHandler {
251261
}
252262

253263
func (h *caHandler) Route(r Router) {
254-
r.MethodFunc("GET", "/version", h.Version)
255-
r.MethodFunc("GET", "/health", h.Health)
256-
r.MethodFunc("GET", "/root/{sha}", h.Root)
257-
r.MethodFunc("POST", "/sign", h.Sign)
258-
r.MethodFunc("POST", "/renew", h.Renew)
259-
r.MethodFunc("POST", "/rekey", h.Rekey)
260-
r.MethodFunc("POST", "/revoke", h.Revoke)
261-
r.MethodFunc("GET", "/provisioners", h.Provisioners)
262-
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey)
263-
r.MethodFunc("GET", "/roots", h.Roots)
264-
r.MethodFunc("GET", "/roots.pem", h.RootsPEM)
265-
r.MethodFunc("GET", "/federation", h.Federation)
264+
r.MethodFunc("GET", "/version", Version)
265+
r.MethodFunc("GET", "/health", Health)
266+
r.MethodFunc("GET", "/root/{sha}", Root)
267+
r.MethodFunc("POST", "/sign", Sign)
268+
r.MethodFunc("POST", "/renew", Renew)
269+
r.MethodFunc("POST", "/rekey", Rekey)
270+
r.MethodFunc("POST", "/revoke", Revoke)
271+
r.MethodFunc("GET", "/provisioners", Provisioners)
272+
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
273+
r.MethodFunc("GET", "/roots", Roots)
274+
r.MethodFunc("GET", "/roots.pem", RootsPEM)
275+
r.MethodFunc("GET", "/federation", Federation)
266276
// SSH CA
267-
r.MethodFunc("POST", "/ssh/sign", h.SSHSign)
268-
r.MethodFunc("POST", "/ssh/renew", h.SSHRenew)
269-
r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke)
270-
r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey)
271-
r.MethodFunc("GET", "/ssh/roots", h.SSHRoots)
272-
r.MethodFunc("GET", "/ssh/federation", h.SSHFederation)
273-
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
274-
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
275-
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
276-
r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts)
277-
r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion)
277+
r.MethodFunc("POST", "/ssh/sign", SSHSign)
278+
r.MethodFunc("POST", "/ssh/renew", SSHRenew)
279+
r.MethodFunc("POST", "/ssh/revoke", SSHRevoke)
280+
r.MethodFunc("POST", "/ssh/rekey", SSHRekey)
281+
r.MethodFunc("GET", "/ssh/roots", SSHRoots)
282+
r.MethodFunc("GET", "/ssh/federation", SSHFederation)
283+
r.MethodFunc("POST", "/ssh/config", SSHConfig)
284+
r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig)
285+
r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost)
286+
r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts)
287+
r.MethodFunc("POST", "/ssh/bastion", SSHBastion)
278288

279289
// For compatibility with old code:
280-
r.MethodFunc("POST", "/re-sign", h.Renew)
281-
r.MethodFunc("POST", "/sign-ssh", h.SSHSign)
282-
r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts)
290+
r.MethodFunc("POST", "/re-sign", Renew)
291+
r.MethodFunc("POST", "/sign-ssh", SSHSign)
292+
r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts)
283293
}
284294

285295
// Version is an HTTP handler that returns the version of the server.
286-
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
287-
v := h.Authority.Version()
296+
func Version(w http.ResponseWriter, r *http.Request) {
297+
v := mustAuthority(r.Context()).Version()
288298
render.JSON(w, VersionResponse{
289299
Version: v.Version,
290300
RequireClientAuthentication: v.RequireClientAuthentication,
291301
})
292302
}
293303

294304
// Health is an HTTP handler that returns the status of the server.
295-
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
305+
func Health(w http.ResponseWriter, r *http.Request) {
296306
render.JSON(w, HealthResponse{Status: "ok"})
297307
}
298308

299309
// Root is an HTTP handler that using the SHA256 from the URL, returns the root
300310
// certificate for the given SHA256.
301-
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
311+
func Root(w http.ResponseWriter, r *http.Request) {
302312
sha := chi.URLParam(r, "sha")
303313
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
304314
// Load root certificate with the
305-
cert, err := h.Authority.Root(sum)
315+
cert, err := mustAuthority(r.Context()).Root(sum)
306316
if err != nil {
307317
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
308318
return
@@ -320,38 +330,40 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
320330
}
321331

322332
// Provisioners returns the list of provisioners configured in the authority.
323-
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
333+
func Provisioners(w http.ResponseWriter, r *http.Request) {
324334
cursor, limit, err := ParseCursor(r)
325335
if err != nil {
326336
render.Error(w, err)
327337
return
328338
}
329339

330-
p, next, err := h.Authority.GetProvisioners(cursor, limit)
340+
p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
331341
if err != nil {
332342
render.Error(w, errs.InternalServerErr(err))
333343
return
334344
}
345+
335346
render.JSON(w, &ProvisionersResponse{
336347
Provisioners: p,
337348
NextCursor: next,
338349
})
339350
}
340351

341352
// ProvisionerKey returns the encrypted key of a provisioner by it's key id.
342-
func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
353+
func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
343354
kid := chi.URLParam(r, "kid")
344-
key, err := h.Authority.GetEncryptedKey(kid)
355+
key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
345356
if err != nil {
346357
render.Error(w, errs.NotFoundErr(err))
347358
return
348359
}
360+
349361
render.JSON(w, &ProvisionerKeyResponse{key})
350362
}
351363

352364
// Roots returns all the root certificates for the CA.
353-
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
354-
roots, err := h.Authority.GetRoots()
365+
func Roots(w http.ResponseWriter, r *http.Request) {
366+
roots, err := mustAuthority(r.Context()).GetRoots()
355367
if err != nil {
356368
render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
357369
return
@@ -368,8 +380,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
368380
}
369381

370382
// RootsPEM returns all the root certificates for the CA in PEM format.
371-
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
372-
roots, err := h.Authority.GetRoots()
383+
func RootsPEM(w http.ResponseWriter, r *http.Request) {
384+
roots, err := mustAuthority(r.Context()).GetRoots()
373385
if err != nil {
374386
render.Error(w, errs.InternalServerErr(err))
375387
return
@@ -391,8 +403,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
391403
}
392404

393405
// Federation returns all the public certificates in the federation.
394-
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
395-
federated, err := h.Authority.GetFederation()
406+
func Federation(w http.ResponseWriter, r *http.Request) {
407+
federated, err := mustAuthority(r.Context()).GetFederation()
396408
if err != nil {
397409
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
398410
return

api/rekey.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error {
2727
}
2828

2929
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
30-
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
30+
func Rekey(w http.ResponseWriter, r *http.Request) {
3131
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
3232
render.Error(w, errs.BadRequest("missing client certificate"))
3333
return
@@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
4444
return
4545
}
4646

47-
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
47+
a := mustAuthority(r.Context())
48+
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
4849
if err != nil {
4950
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
5051
return
@@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
6061
ServerPEM: certChainPEM[0],
6162
CaPEM: caPEM,
6263
CertChainPEM: certChainPEM,
63-
TLSOptions: h.Authority.GetTLSOptions(),
64+
TLSOptions: a.GetTLSOptions(),
6465
}, http.StatusCreated)
6566
}

api/renew.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ const (
1616

1717
// Renew uses the information of certificate in the TLS connection to create a
1818
// new one.
19-
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
20-
cert, err := h.getPeerCertificate(r)
19+
func Renew(w http.ResponseWriter, r *http.Request) {
20+
cert, err := getPeerCertificate(r)
2121
if err != nil {
2222
render.Error(w, err)
2323
return
2424
}
2525

26-
certChain, err := h.Authority.Renew(cert)
26+
a := mustAuthority(r.Context())
27+
certChain, err := a.Renew(cert)
2728
if err != nil {
2829
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
2930
return
@@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
3940
ServerPEM: certChainPEM[0],
4041
CaPEM: caPEM,
4142
CertChainPEM: certChainPEM,
42-
TLSOptions: h.Authority.GetTLSOptions(),
43+
TLSOptions: a.GetTLSOptions(),
4344
}, http.StatusCreated)
4445
}
4546

46-
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
47+
func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
4748
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
4849
return r.TLS.PeerCertificates[0], nil
4950
}
5051
if s := r.Header.Get(authorizationHeader); s != "" {
5152
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
52-
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
53+
ctx := r.Context()
54+
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
5355
}
5456
}
5557
return nil, errs.BadRequest("missing client certificate")

api/revoke.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package api
22

33
import (
4-
"context"
54
"net/http"
65

76
"golang.org/x/crypto/ocsp"
@@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) {
4948
// NOTE: currently only Passive revocation is supported.
5049
//
5150
// TODO: Add CRL and OCSP support.
52-
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
51+
func Revoke(w http.ResponseWriter, r *http.Request) {
5352
var body RevokeRequest
5453
if err := read.JSON(r.Body, &body); err != nil {
5554
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
6867
PassiveOnly: body.Passive,
6968
}
7069

71-
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod)
70+
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod)
71+
a := mustAuthority(ctx)
72+
7273
// A token indicates that we are using the api via a provisioner token,
7374
// otherwise it is assumed that the certificate is revoking itself over mTLS.
7475
if len(body.OTT) > 0 {
7576
logOtt(w, body.OTT)
76-
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
77+
if _, err := a.Authorize(ctx, body.OTT); err != nil {
7778
render.Error(w, errs.UnauthorizedErr(err))
7879
return
7980
}
@@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
9899
opts.MTLS = true
99100
}
100101

101-
if err := h.Authority.Revoke(ctx, opts); err != nil {
102+
if err := a.Revoke(ctx, opts); err != nil {
102103
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
103104
return
104105
}

api/sign.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type SignResponse struct {
4949
// Sign is an HTTP handler that reads a certificate request and an
5050
// one-time-token (ott) from the body and creates a new certificate with the
5151
// information in the certificate request.
52-
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
52+
func Sign(w http.ResponseWriter, r *http.Request) {
5353
var body SignRequest
5454
if err := read.JSON(r.Body, &body); err != nil {
5555
render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@@ -68,13 +68,14 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
6868
TemplateData: body.TemplateData,
6969
}
7070

71-
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
71+
a := mustAuthority(r.Context())
72+
signOpts, err := a.AuthorizeSign(body.OTT)
7273
if err != nil {
7374
render.Error(w, errs.UnauthorizedErr(err))
7475
return
7576
}
7677

77-
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
78+
certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
7879
if err != nil {
7980
render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
8081
return
@@ -89,6 +90,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
8990
ServerPEM: certChainPEM[0],
9091
CaPEM: caPEM,
9192
CertChainPEM: certChainPEM,
92-
TLSOptions: h.Authority.GetTLSOptions(),
93+
TLSOptions: a.GetTLSOptions(),
9394
}, http.StatusCreated)
9495
}

0 commit comments

Comments
 (0)