@@ -10,6 +10,8 @@ import (
1010 "sync"
1111 "syscall"
1212 "time"
13+
14+ "github.com/zrepl/zrepl/util/circlog"
1315)
1416
1517type Endpoint struct {
@@ -168,7 +170,7 @@ func (conn *SSHConn) shutdownProcess() *shutdownResult {
168170 conn .shutdownResult = & shutdownResult {waitErr }
169171 case <- timeout .C :
170172 conn .cmdCancel ()
171- waitErr := <- wait // reuse existing Wait invocation, must not call twice
173+ waitErr := <- wait // reuse existing Wait invocation, must not call twice
172174 conn .shutdownResult = & shutdownResult {waitErr }
173175 }
174176 return conn .shutdownResult
@@ -220,11 +222,6 @@ type SSHError struct {
220222// Error() will try to present a one-line error message unless ssh stderr output is longer than one line
221223func (e * SSHError ) Error () string {
222224
223- if e .RWCError == io .EOF {
224- // rwccmd returns io.EOF on exit status 0, but we do not expect ssh to do that
225- return fmt .Sprintf ("ssh exited unexpectedly with exit status 0" )
226- }
227-
228225 exitErr , ok := e .RWCError .(* exec.ExitError )
229226 if ! ok {
230227 return fmt .Sprintf ("ssh: %s" , e .RWCError )
@@ -286,18 +283,31 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
286283 if err != nil {
287284 return nil , err
288285 }
289- // stderr is required for *exec.ExitErr
286+
287+ stderrBuf , err := circlog .NewCircularLog (1 << 15 )
288+ if err != nil {
289+ panic (err ) // wrong API usage
290+ }
291+ cmd .Stderr = stderrBuf
290292
291293 if err = cmd .Start (); err != nil {
292294 return nil , err
293295 }
296+ cmdWaitErrOrIOErr := func (ioErr error , what string ) * SSHError {
297+ werr := cmd .Wait ()
298+ if werr , ok := werr .(* exec.ExitError ); ok {
299+ werr .Stderr = []byte (stderrBuf .String ())
300+ return & SSHError {werr , what }
301+ }
302+ return & SSHError {ioErr , what }
303+ }
294304
295305 confErrChan := make (chan error , 1 )
296306 go func () {
297307 defer close (confErrChan )
298308 var buf bytes.Buffer
299309 if _ , err := io .CopyN (& buf , stdout , int64 (len (banner_msg ))); err != nil {
300- confErrChan <- & SSHError { err , "read banner" }
310+ confErrChan <- cmdWaitErrOrIOErr ( err , "read banner" )
301311 return
302312 }
303313 resp := buf .Bytes ()
@@ -314,7 +324,7 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
314324 buf .Reset ()
315325 buf .Write (begin_msg )
316326 if _ , err := io .Copy (stdin , & buf ); err != nil {
317- confErrChan <- & SSHError { err , "send begin message" }
327+ confErrChan <- cmdWaitErrOrIOErr ( err , "send begin message" )
318328 return
319329 }
320330 }()
@@ -332,6 +342,9 @@ func Dial(dialCtx context.Context, endpoint Endpoint) (*SSHConn, error) {
332342 for _ = range confErrChan {
333343 }
334344
345+ // TODO collect stderr in this case
346+ // can probably extend *SSHError for this but need to implement net.Error
347+
335348 return nil , dialCtx .Err ()
336349
337350 case err := <- confErrChan :
0 commit comments