@@ -20,6 +20,7 @@ import (
2020 "regexp"
2121 "runtime"
2222 "strings"
23+ "sync"
2324 "testing"
2425 "time"
2526
@@ -1318,9 +1319,6 @@ func TestSSH(t *testing.T) {
13181319
13191320 tmpdir := tempDirUnixSocket (t )
13201321 localSock := filepath .Join (tmpdir , "local.sock" )
1321- l , err := net .Listen ("unix" , localSock )
1322- require .NoError (t , err )
1323- defer l .Close ()
13241322 remoteSock := filepath .Join (tmpdir , "remote.sock" )
13251323
13261324 inv , root := clitest .New (t ,
@@ -1332,23 +1330,62 @@ func TestSSH(t *testing.T) {
13321330 clitest .SetupConfig (t , client , root )
13331331 pty := ptytest .New (t ).Attach (inv )
13341332 inv .Stderr = pty .Output ()
1335- cmdDone := tGo (t , func () {
1336- err := inv .WithContext (ctx ).Run ()
1337- assert .NoError (t , err , "ssh command failed" )
1338- })
13391333
1340- // Wait for the prompt or any output really to indicate the command has
1341- // started and accepting input on stdin.
1334+ w := clitest .StartWithWaiter (t , inv .WithContext (ctx ))
1335+ defer w .Wait () // We don't care about any exit error (exit code 255: SSH connection ended unexpectedly).
1336+
1337+ // Since something was output, it should be safe to write input.
1338+ // This could show a prompt or "running startup scripts", so it's
1339+ // not indicative of the SSH connection being ready.
13421340 _ = pty .Peek (ctx , 1 )
13431341
1344- // This needs to support most shells on Linux or macOS
1345- // We can't include exactly what's expected in the input, as that will always be matched
1346- pty .WriteLine (fmt .Sprintf (`echo "results: $(netstat -an | grep %s | wc -l | tr -d ' ')"` , remoteSock ))
1347- pty .ExpectMatchContext (ctx , "results: 1" )
1342+ // Ensure the SSH connection is ready by testing the shell
1343+ // input/output.
1344+ pty .WriteLine ("echo ping' 'pong" )
1345+ pty .ExpectMatchContext (ctx , "ping pong" )
1346+
1347+ // Start the listener on the "local machine".
1348+ l , err := net .Listen ("unix" , localSock )
1349+ require .NoError (t , err )
1350+ defer l .Close ()
1351+ testutil .Go (t , func () {
1352+ var wg sync.WaitGroup
1353+ defer wg .Wait ()
1354+ for {
1355+ fd , err := l .Accept ()
1356+ if err != nil {
1357+ if ! errors .Is (err , net .ErrClosed ) {
1358+ assert .NoError (t , err , "listener accept failed" )
1359+ }
1360+ return
1361+ }
1362+
1363+ wg .Add (1 )
1364+ go func () {
1365+ defer wg .Done ()
1366+ defer fd .Close ()
1367+ agentssh .Bicopy (ctx , fd , fd )
1368+ }()
1369+ }
1370+ })
1371+
1372+ // Dial the forwarded socket on the "remote machine".
1373+ d := & net.Dialer {}
1374+ fd , err := d .DialContext (ctx , "unix" , remoteSock )
1375+ require .NoError (t , err )
1376+ defer fd .Close ()
1377+
1378+ // Ping / pong to ensure the socket is working.
1379+ _ , err = fd .Write ([]byte ("hello world" ))
1380+ require .NoError (t , err )
1381+
1382+ buf := make ([]byte , 11 )
1383+ _ , err = fd .Read (buf )
1384+ require .NoError (t , err )
1385+ require .Equal (t , "hello world" , string (buf ))
13481386
13491387 // And we're done.
13501388 pty .WriteLine ("exit" )
1351- <- cmdDone
13521389 })
13531390
13541391 // Test that we can forward a local unix socket to a remote unix socket and
@@ -1377,6 +1414,8 @@ func TestSSH(t *testing.T) {
13771414 require .NoError (t , err )
13781415 defer l .Close ()
13791416 testutil .Go (t , func () {
1417+ var wg sync.WaitGroup
1418+ defer wg .Wait ()
13801419 for {
13811420 fd , err := l .Accept ()
13821421 if err != nil {
@@ -1386,10 +1425,12 @@ func TestSSH(t *testing.T) {
13861425 return
13871426 }
13881427
1389- testutil .Go (t , func () {
1428+ wg .Add (1 )
1429+ go func () {
1430+ defer wg .Done ()
13901431 defer fd .Close ()
13911432 agentssh .Bicopy (ctx , fd , fd )
1392- })
1433+ }( )
13931434 }
13941435 })
13951436
@@ -1522,6 +1563,8 @@ func TestSSH(t *testing.T) {
15221563 require .NoError (t , err )
15231564 defer l .Close () //nolint:revive // Defer is fine in this loop, we only run it twice.
15241565 testutil .Go (t , func () {
1566+ var wg sync.WaitGroup
1567+ defer wg .Wait ()
15251568 for {
15261569 fd , err := l .Accept ()
15271570 if err != nil {
@@ -1531,10 +1574,12 @@ func TestSSH(t *testing.T) {
15311574 return
15321575 }
15331576
1534- testutil .Go (t , func () {
1577+ wg .Add (1 )
1578+ go func () {
1579+ defer wg .Done ()
15351580 defer fd .Close ()
15361581 agentssh .Bicopy (ctx , fd , fd )
1537- })
1582+ }( )
15381583 }
15391584 })
15401585
0 commit comments