Skip to content

Commit fa41633

Browse files
committed
Add context to tests.
1 parent c49a9d5 commit fa41633

File tree

5 files changed

+61
-61
lines changed

5 files changed

+61
-61
lines changed

Diff for: api/api_test.go

+37-37
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,6 @@ type mockAuthority struct {
550550
getTLSOptions func() *tlsutil.TLSOptions
551551
root func(shasum string) (*x509.Certificate, error)
552552
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
553-
signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
554-
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
555553
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
556554
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
557555
loadProvisionerByID func(provID string) (provisioner.Interface, error)
@@ -560,14 +558,16 @@ type mockAuthority struct {
560558
getEncryptedKey func(kid string) (string, error)
561559
getRoots func() ([]*x509.Certificate, error)
562560
getFederation func() ([]*x509.Certificate, error)
563-
renewSSH func(cert *ssh.Certificate) (*ssh.Certificate, error)
564-
rekeySSH func(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
565-
getSSHHosts func(*x509.Certificate) ([]sshutil.Host, error)
566-
getSSHRoots func() (*authority.SSHKeys, error)
567-
getSSHFederation func() (*authority.SSHKeys, error)
568-
getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error)
561+
signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
562+
signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
563+
renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
564+
rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
565+
getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
566+
getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error)
567+
getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error)
568+
getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
569569
checkSSHHost func(ctx context.Context, principal, token string) (bool, error)
570-
getSSHBastion func(user string, hostname string) (*authority.Bastion, error)
570+
getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
571571
version func() authority.Version
572572
}
573573

@@ -604,20 +604,6 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Optio
604604
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
605605
}
606606

607-
func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
608-
if m.signSSH != nil {
609-
return m.signSSH(key, opts, signOpts...)
610-
}
611-
return m.ret1.(*ssh.Certificate), m.err
612-
}
613-
614-
func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
615-
if m.signSSHAddUser != nil {
616-
return m.signSSHAddUser(key, cert)
617-
}
618-
return m.ret1.(*ssh.Certificate), m.err
619-
}
620-
621607
func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
622608
if m.renew != nil {
623609
return m.renew(cert)
@@ -674,44 +660,58 @@ func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) {
674660
return m.ret1.([]*x509.Certificate), m.err
675661
}
676662

677-
func (m *mockAuthority) RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error) {
663+
func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
664+
if m.signSSH != nil {
665+
return m.signSSH(ctx, key, opts, signOpts...)
666+
}
667+
return m.ret1.(*ssh.Certificate), m.err
668+
}
669+
670+
func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
671+
if m.signSSHAddUser != nil {
672+
return m.signSSHAddUser(ctx, key, cert)
673+
}
674+
return m.ret1.(*ssh.Certificate), m.err
675+
}
676+
677+
func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) {
678678
if m.renewSSH != nil {
679-
return m.renewSSH(cert)
679+
return m.renewSSH(ctx, cert)
680680
}
681681
return m.ret1.(*ssh.Certificate), m.err
682682
}
683683

684-
func (m *mockAuthority) RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
684+
func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
685685
if m.rekeySSH != nil {
686-
return m.rekeySSH(cert, key, signOpts...)
686+
return m.rekeySSH(ctx, cert, key, signOpts...)
687687
}
688688
return m.ret1.(*ssh.Certificate), m.err
689689
}
690690

691-
func (m *mockAuthority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) {
691+
func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
692692
if m.getSSHHosts != nil {
693-
return m.getSSHHosts(cert)
693+
return m.getSSHHosts(ctx, cert)
694694
}
695695
return m.ret1.([]sshutil.Host), m.err
696696
}
697697

698-
func (m *mockAuthority) GetSSHRoots() (*authority.SSHKeys, error) {
698+
func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) {
699699
if m.getSSHRoots != nil {
700-
return m.getSSHRoots()
700+
return m.getSSHRoots(ctx)
701701
}
702702
return m.ret1.(*authority.SSHKeys), m.err
703703
}
704704

705-
func (m *mockAuthority) GetSSHFederation() (*authority.SSHKeys, error) {
705+
func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) {
706706
if m.getSSHFederation != nil {
707-
return m.getSSHFederation()
707+
return m.getSSHFederation(ctx)
708708
}
709709
return m.ret1.(*authority.SSHKeys), m.err
710710
}
711711

712-
func (m *mockAuthority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
712+
func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
713713
if m.getSSHConfig != nil {
714-
return m.getSSHConfig(typ, data)
714+
return m.getSSHConfig(ctx, typ, data)
715715
}
716716
return m.ret1.([]templates.Output), m.err
717717
}
@@ -723,9 +723,9 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin
723723
return m.ret1.(bool), m.err
724724
}
725725

726-
func (m *mockAuthority) GetSSHBastion(user string, hostname string) (*authority.Bastion, error) {
726+
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) {
727727
if m.getSSHBastion != nil {
728-
return m.getSSHBastion(user, hostname)
728+
return m.getSSHBastion(ctx, user, hostname)
729729
}
730730
return m.ret1.(*authority.Bastion), m.err
731731
}

Diff for: api/ssh_test.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,10 @@ func Test_caHandler_SSHSign(t *testing.T) {
319319
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
320320
return []provisioner.SignOption{}, tt.authErr
321321
},
322-
signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
322+
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
323323
return tt.signCert, tt.signErr
324324
},
325-
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
325+
signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
326326
return tt.addUserCert, tt.addUserErr
327327
},
328328
sign: func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
@@ -379,7 +379,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
379379
for _, tt := range tests {
380380
t.Run(tt.name, func(t *testing.T) {
381381
h := New(&mockAuthority{
382-
getSSHRoots: func() (*authority.SSHKeys, error) {
382+
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
383383
return tt.keys, tt.keysErr
384384
},
385385
}).(*caHandler)
@@ -433,7 +433,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
433433
for _, tt := range tests {
434434
t.Run(tt.name, func(t *testing.T) {
435435
h := New(&mockAuthority{
436-
getSSHFederation: func() (*authority.SSHKeys, error) {
436+
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
437437
return tt.keys, tt.keysErr
438438
},
439439
}).(*caHandler)
@@ -493,7 +493,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
493493
for _, tt := range tests {
494494
t.Run(tt.name, func(t *testing.T) {
495495
h := New(&mockAuthority{
496-
getSSHConfig: func(typ string, data map[string]string) ([]templates.Output, error) {
496+
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
497497
return tt.output, tt.err
498498
},
499499
}).(*caHandler)
@@ -591,7 +591,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
591591
for _, tt := range tests {
592592
t.Run(tt.name, func(t *testing.T) {
593593
h := New(&mockAuthority{
594-
getSSHHosts: func(*x509.Certificate) ([]sshutil.Host, error) {
594+
getSSHHosts: func(context.Context, *x509.Certificate) ([]sshutil.Host, error) {
595595
return tt.hosts, tt.err
596596
},
597597
}).(*caHandler)
@@ -646,7 +646,7 @@ func Test_caHandler_SSHBastion(t *testing.T) {
646646
for _, tt := range tests {
647647
t.Run(tt.name, func(t *testing.T) {
648648
h := New(&mockAuthority{
649-
getSSHBastion: func(user, hostname string) (*authority.Bastion, error) {
649+
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
650650
return tt.bastion, tt.bastionErr
651651
},
652652
}).(*caHandler)

Diff for: authority/provisioner/oidc_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -485,10 +485,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
485485
assert.FatalError(t, p4.Init(config))
486486
assert.FatalError(t, p5.Init(config))
487487

488-
p4.getIdentityFunc = func(p Interface, email string) (*Identity, error) {
488+
p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
489489
return &Identity{Usernames: []string{"max", "mariano"}}, nil
490490
}
491-
p5.getIdentityFunc = func(p Interface, email string) (*Identity, error) {
491+
p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
492492
return nil, errors.New("force")
493493
}
494494

Diff for: authority/provisioner/provisioner_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func TestDefaultIdentityFunc(t *testing.T) {
9292
for name, get := range tests {
9393
t.Run(name, func(t *testing.T) {
9494
tc := get(t)
95-
identity, err := DefaultIdentityFunc(tc.p, tc.email)
95+
identity, err := DefaultIdentityFunc(context.Background(), tc.p, tc.email)
9696
if err != nil {
9797
if assert.NotNil(t, tc.err) {
9898
assert.Equals(t, tc.err.Error(), err.Error())

Diff for: authority/ssh_test.go

+14-14
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func TestAuthority_SignSSH(t *testing.T) {
153153
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
154154
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
155155

156-
got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...)
156+
got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...)
157157
if (err != nil) != tt.wantErr {
158158
t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr)
159159
return
@@ -242,7 +242,7 @@ func TestAuthority_SignSSHAddUser(t *testing.T) {
242242
AddUserPrincipal: tt.fields.addUserPrincipal,
243243
AddUserCommand: tt.fields.addUserCommand,
244244
}
245-
got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject)
245+
got, err := a.SignSSHAddUser(context.Background(), tt.args.key, tt.args.subject)
246246
if (err != nil) != tt.wantErr {
247247
t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr)
248248
return
@@ -295,7 +295,7 @@ func TestAuthority_GetSSHRoots(t *testing.T) {
295295
a.sshCAUserCerts = tt.fields.sshCAUserCerts
296296
a.sshCAHostCerts = tt.fields.sshCAHostCerts
297297

298-
got, err := a.GetSSHRoots()
298+
got, err := a.GetSSHRoots(context.Background())
299299
if (err != nil) != tt.wantErr {
300300
t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr)
301301
return
@@ -337,7 +337,7 @@ func TestAuthority_GetSSHFederation(t *testing.T) {
337337
a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts
338338
a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts
339339

340-
got, err := a.GetSSHFederation()
340+
got, err := a.GetSSHFederation(context.Background())
341341
if (err != nil) != tt.wantErr {
342342
t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr)
343343
return
@@ -463,7 +463,7 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
463463
a.sshCAUserCertSignKey = tt.fields.userSigner
464464
a.sshCAHostCertSignKey = tt.fields.hostSigner
465465

466-
got, err := a.GetSSHConfig(tt.args.typ, tt.args.data)
466+
got, err := a.GetSSHConfig(context.Background(), tt.args.typ, tt.args.data)
467467
if (err != nil) != tt.wantErr {
468468
t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr)
469469
return
@@ -614,7 +614,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
614614
}
615615
type fields struct {
616616
config *Config
617-
sshBastionFunc func(user, hostname string) (*Bastion, error)
617+
sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error)
618618
}
619619
type args struct {
620620
user string
@@ -630,8 +630,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
630630
{"config", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "host.local"}, bastion, false},
631631
{"nil", fields{&Config{SSH: &SSHConfig{Bastion: nil}}, nil}, args{"user", "host.local"}, nil, false},
632632
{"empty", fields{&Config{SSH: &SSHConfig{Bastion: &Bastion{}}}, nil}, args{"user", "host.local"}, nil, false},
633-
{"func", fields{&Config{}, func(_, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false},
634-
{"func err", fields{&Config{}, func(_, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true},
633+
{"func", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false},
634+
{"func err", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true},
635635
{"error", fields{&Config{SSH: nil}, nil}, args{"user", "host.local"}, nil, true},
636636
}
637637
for _, tt := range tests {
@@ -640,7 +640,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
640640
config: tt.fields.config,
641641
sshBastionFunc: tt.fields.sshBastionFunc,
642642
}
643-
got, err := a.GetSSHBastion(tt.args.user, tt.args.hostname)
643+
got, err := a.GetSSHBastion(context.Background(), tt.args.user, tt.args.hostname)
644644
if (err != nil) != tt.wantErr {
645645
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
646646
return
@@ -659,7 +659,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
659659
a := testAuthority(t)
660660

661661
type test struct {
662-
getHostsFunc func(*x509.Certificate) ([]sshutil.Host, error)
662+
getHostsFunc func(context.Context, *x509.Certificate) ([]sshutil.Host, error)
663663
auth *Authority
664664
cert *x509.Certificate
665665
cmp func(got []sshutil.Host)
@@ -669,7 +669,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
669669
tests := map[string]func(t *testing.T) *test{
670670
"fail/getHostsFunc-fail": func(t *testing.T) *test {
671671
return &test{
672-
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
672+
getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
673673
return nil, errors.New("force")
674674
},
675675
cert: &x509.Certificate{},
@@ -684,7 +684,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
684684
}
685685

686686
return &test{
687-
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
687+
getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
688688
return hosts, nil
689689
},
690690
cert: &x509.Certificate{},
@@ -732,7 +732,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
732732
}
733733
auth.sshGetHostsFunc = tc.getHostsFunc
734734

735-
hosts, err := auth.GetSSHHosts(tc.cert)
735+
hosts, err := auth.GetSSHHosts(context.Background(), tc.cert)
736736
if err != nil {
737737
if assert.NotNil(t, tc.err) {
738738
sc, ok := err.(errs.StatusCoder)
@@ -901,7 +901,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
901901
a.sshCAUserCertSignKey = tc.userSigner
902902
a.sshCAHostCertSignKey = tc.hostSigner
903903

904-
cert, err := auth.RekeySSH(tc.cert, tc.key, tc.signOpts...)
904+
cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...)
905905
if err != nil {
906906
if assert.NotNil(t, tc.err) {
907907
sc, ok := err.(errs.StatusCoder)

0 commit comments

Comments
 (0)