Skip to content

Commit d940ab7

Browse files
committed
Add getSSHHosts injection func
1 parent 414a94b commit d940ab7

File tree

4 files changed

+33
-12
lines changed

4 files changed

+33
-12
lines changed

Diff for: api/ssh.go

+13-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type SSHAuthority interface {
2323
GetSSHFederation() (*authority.SSHKeys, error)
2424
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
2525
CheckSSHHost(principal string) (bool, error)
26-
GetSSHHosts() ([]string, error)
26+
GetSSHHosts(user string) ([]string, error)
2727
GetSSHBastion(user string, hostname string) (*authority.Bastion, error)
2828
}
2929

@@ -406,7 +406,18 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
406406

407407
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
408408
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
409-
hosts, err := h.Authority.GetSSHHosts()
409+
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
410+
WriteError(w, BadRequest(errors.New("missing peer certificate")))
411+
return
412+
}
413+
414+
cert := r.TLS.PeerCertificates[0]
415+
email := cert.EmailAddresses[0]
416+
if len(email) == 0 {
417+
WriteError(w, BadRequest(errors.New("client certificate missing email SAN")))
418+
return
419+
}
420+
hosts, err := h.Authority.GetSSHHosts(email)
410421
if err != nil {
411422
WriteError(w, InternalServerError(err))
412423
return

Diff for: authority/authority.go

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type Authority struct {
4141
initOnce bool
4242
// Custom functions
4343
sshBastionFunc func(user, hostname string) (*Bastion, error)
44+
sshGetHostsFunc func(user string) ([]string, error)
4445
getIdentityFunc provisioner.GetIdentityFunc
4546
}
4647

Diff for: authority/options.go

+12-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ func WithDatabase(db db.AuthDB) Option {
1616
}
1717
}
1818

19+
// WithGetIdentityFunc sets a custom function to retrieve the identity from
20+
// an external resource.
21+
func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
22+
return func(a *Authority) {
23+
a.getIdentityFunc = fn
24+
}
25+
}
26+
1927
// WithSSHBastionFunc sets a custom function to get the bastion for a
2028
// given user-host pair.
2129
func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
@@ -24,10 +32,10 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
2432
}
2533
}
2634

27-
// WithGetIdentityFunc sets a custom function to retrieve the identity from
28-
// an external resource.
29-
func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
35+
// WithSSHGetHosts sets a custom function to get the bastion for a
36+
// given user-host pair.
37+
func WithSSHGetHosts(fn func(user string) ([]string, error)) Option {
3038
return func(a *Authority) {
31-
a.getIdentityFunc = fn
39+
a.sshGetHostsFunc = fn
3240
}
3341
}

Diff for: authority/ssh.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -673,13 +673,14 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) {
673673
}
674674

675675
// GetSSHHosts returns a list of valid host principals.
676-
func (a *Authority) GetSSHHosts() ([]string, error) {
677-
ps, err := a.db.GetSSHHostPrincipals()
678-
if err != nil {
679-
return nil, err
676+
func (a *Authority) GetSSHHosts(email string) ([]string, error) {
677+
if a.sshBastionFunc != nil {
678+
return a.sshGetHostsFunc(email)
679+
}
680+
return nil, &apiError{
681+
err: errors.New("getSSHHosts is not configured"),
682+
code: http.StatusNotFound,
680683
}
681-
682-
return ps, nil
683684
}
684685

685686
func (a *Authority) getAddUserPrincipal() (cmd string) {

0 commit comments

Comments
 (0)