|
| 1 | +package netssh |
| 2 | + |
| 3 | +import ( |
| 4 | + "net" |
| 5 | + "fmt" |
| 6 | + "context" |
| 7 | + "io" |
| 8 | + "bytes" |
| 9 | + "github.com/problame/go-rwccmd" |
| 10 | +) |
| 11 | + |
| 12 | +type Endpoint struct { |
| 13 | + Host string |
| 14 | + User string |
| 15 | + Port uint16 |
| 16 | + IdentityFile string |
| 17 | + SSHCommand string |
| 18 | + Options []string |
| 19 | +} |
| 20 | + |
| 21 | +func (e Endpoint) CmdArgs() (cmd string, args []string, env []string) { |
| 22 | + |
| 23 | + if e.SSHCommand != "" { |
| 24 | + cmd = e.SSHCommand |
| 25 | + } else { |
| 26 | + cmd = "ssh" |
| 27 | + } |
| 28 | + |
| 29 | + args = make([]string, 0, 2*len(e.Options)+4) |
| 30 | + args = append(args, |
| 31 | + "-p", fmt.Sprintf("%d", e.Port), |
| 32 | + "-q", |
| 33 | + "-i", e.IdentityFile, |
| 34 | + "-o", "BatchMode=yes", |
| 35 | + ) |
| 36 | + for _, option := range e.Options { |
| 37 | + args = append(args, "-o", option) |
| 38 | + } |
| 39 | + args = append(args, fmt.Sprintf("%s@%s", e.User, e.Host)) |
| 40 | + |
| 41 | + env = []string{} |
| 42 | + |
| 43 | + return |
| 44 | +} |
| 45 | + |
| 46 | +// FIXME: should conform to net.Conn one day, but deadlines as required by net.Conn are complicated: |
| 47 | +// it requires to keep the connection open when the deadline is exceeded, but rwcconn.Cmd does not provide Deadlines |
| 48 | +// for good reason, see their docs for details. |
| 49 | +type SSHConn struct { |
| 50 | + c *rwccmd.Cmd |
| 51 | +} |
| 52 | + |
| 53 | +const go_network string = "SSH" |
| 54 | + |
| 55 | +type addr struct { |
| 56 | + pid int |
| 57 | +} |
| 58 | + |
| 59 | +func (a addr) Network() string { |
| 60 | + return go_network |
| 61 | +} |
| 62 | + |
| 63 | +func (a addr) String() string { |
| 64 | + return fmt.Sprintf("pid=%d", a.pid) |
| 65 | +} |
| 66 | + |
| 67 | +func (conn *SSHConn) LocalAddr() net.Addr { |
| 68 | + return addr{conn.c.Pid()} |
| 69 | +} |
| 70 | + |
| 71 | +func (conn *SSHConn) RemoteAddr() net.Addr { |
| 72 | + return addr{conn.c.Pid()} |
| 73 | +} |
| 74 | + |
| 75 | +func (conn *SSHConn) Read(p []byte) (int, error) { |
| 76 | + return conn.c.Read(p) |
| 77 | +} |
| 78 | + |
| 79 | +func (conn *SSHConn) Write(p []byte) (int, error) { |
| 80 | + return conn.c.Write(p) |
| 81 | +} |
| 82 | + |
| 83 | +func (conn *SSHConn) Close() (error) { |
| 84 | + return conn.c.Close() |
| 85 | +} |
| 86 | + |
| 87 | +// Use at your own risk... |
| 88 | +func (conn *SSHConn) Cmd() *rwccmd.Cmd { |
| 89 | + return conn.c |
| 90 | +} |
| 91 | + |
| 92 | +const bannerMessageLen = 31 |
| 93 | +var messages = make(map[string][]byte) |
| 94 | +func mustMessage(str string) []byte { |
| 95 | + if len(str) > bannerMessageLen { |
| 96 | + panic("message length must be smaller than bannerMessageLen") |
| 97 | + } |
| 98 | + if _, ok := messages[str]; ok { |
| 99 | + panic("duplicate message") |
| 100 | + } |
| 101 | + var buf bytes.Buffer |
| 102 | + n, _ := buf.WriteString(str) |
| 103 | + if n != len(str) { |
| 104 | + panic("message must only contain ascii / 8-bit chars") |
| 105 | + } |
| 106 | + buf.Write(bytes.Repeat([]byte{0}, bannerMessageLen-n)) |
| 107 | + return buf.Bytes() |
| 108 | +} |
| 109 | +var banner_msg = mustMessage("SSHCON_HELO") |
| 110 | +var proxy_error_msg = mustMessage("SSHCON_PROXY_ERROR") |
| 111 | +var begin_msg = mustMessage("SSHCON_BEGIN") |
| 112 | + |
| 113 | +func Dial(ctx context.Context, endpoint Endpoint) (*SSHConn , error) { |
| 114 | + |
| 115 | + sshCmd, sshArgs, sshEnv := endpoint.CmdArgs() |
| 116 | + cmd, err := rwccmd.CommandContext(ctx, sshCmd, sshArgs, sshEnv) |
| 117 | + if err != nil { |
| 118 | + return nil, err |
| 119 | + } |
| 120 | + if err = cmd.Start(); err != nil { |
| 121 | + return nil, err |
| 122 | + } |
| 123 | + |
| 124 | + confErrChan := make(chan error) |
| 125 | + go func() { |
| 126 | + var buf bytes.Buffer |
| 127 | + if _, err := io.CopyN(&buf, cmd, int64(len(banner_msg))); err != nil { |
| 128 | + confErrChan <- fmt.Errorf("error reading banner: %s", err) |
| 129 | + return |
| 130 | + } |
| 131 | + resp := buf.Bytes() |
| 132 | + switch { |
| 133 | + case bytes.Equal(resp, banner_msg): |
| 134 | + break |
| 135 | + case bytes.Equal(resp, proxy_error_msg): |
| 136 | + confErrChan <- fmt.Errorf("proxy error, check remote configuration") |
| 137 | + return |
| 138 | + default: |
| 139 | + confErrChan <- fmt.Errorf("unknown banner message: %v", resp) |
| 140 | + return |
| 141 | + } |
| 142 | + buf.Reset() |
| 143 | + buf.Write(begin_msg) |
| 144 | + if _, err := io.Copy(cmd, &buf); err != nil { |
| 145 | + confErrChan <- fmt.Errorf("error sending begin message: %s", err) |
| 146 | + return |
| 147 | + } |
| 148 | + close(confErrChan) |
| 149 | + }() |
| 150 | + |
| 151 | + select { |
| 152 | + case <-ctx.Done(): |
| 153 | + return nil, ctx.Err() |
| 154 | + case err := <-confErrChan: |
| 155 | + if err != nil { |
| 156 | + return nil, err |
| 157 | + } |
| 158 | + } |
| 159 | + |
| 160 | + return &SSHConn{cmd}, err |
| 161 | +} |
0 commit comments