|
8 | 8 | "fmt" |
9 | 9 | "net" |
10 | 10 | "os" |
| 11 | + "strconv" |
11 | 12 | "sync" |
12 | 13 | "sync/atomic" |
13 | 14 | "time" |
@@ -54,13 +55,20 @@ func takeLockAndGenerateNextID() (newID uint, unlock func()) { |
54 | 55 | } |
55 | 56 |
|
56 | 57 | func tcpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { |
57 | | - if len(params) != 1 { |
58 | | - return nil, []any{1, "Invalid number of parameters, expected server address"} |
| 58 | + if len(params) != 2 { |
| 59 | + return nil, []any{1, "Invalid number of parameters, expected server address and port"} |
59 | 60 | } |
60 | 61 | serverAddr, ok := params[0].(string) |
61 | 62 | if !ok { |
62 | 63 | return nil, []any{1, "Invalid parameter type, expected string for server address"} |
63 | 64 | } |
| 65 | + serverPort, ok := msgpackrpc.ToUint(params[1]) |
| 66 | + if !ok { |
| 67 | + return nil, []any{1, "Invalid parameter type, expected uint16 for server port"} |
| 68 | + } |
| 69 | + |
| 70 | + serverAddr = net.JoinHostPort(serverAddr, strconv.FormatUint(uint64(serverPort), 10)) |
| 71 | + |
64 | 72 | conn, err := net.Dial("tcp", serverAddr) |
65 | 73 | if err != nil { |
66 | 74 | return nil, []any{2, "Failed to connect to server: " + err.Error()} |
@@ -235,16 +243,23 @@ func tcpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r |
235 | 243 |
|
236 | 244 | func tcpConnectSSL(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { |
237 | 245 | n := len(params) |
238 | | - if n < 1 || n > 2 { |
239 | | - return nil, []any{1, "Invalid number of parameters, expected server address and optional TLS cert"} |
| 246 | + if n < 1 || n > 3 { |
| 247 | + return nil, []any{1, "Invalid number of parameters, expected server address, port and optional TLS cert"} |
240 | 248 | } |
241 | 249 | serverAddr, ok := params[0].(string) |
242 | 250 | if !ok { |
243 | 251 | return nil, []any{1, "Invalid parameter type, expected string for server address"} |
244 | 252 | } |
| 253 | + serverPort, ok := msgpackrpc.ToUint(params[1]) |
| 254 | + if !ok { |
| 255 | + return nil, []any{1, "Invalid parameter type, expected uint16 for server port"} |
| 256 | + } |
| 257 | + |
| 258 | + serverAddr = net.JoinHostPort(serverAddr, strconv.FormatUint(uint64(serverPort), 10)) |
| 259 | + |
245 | 260 | var tlsConfig *tls.Config |
246 | | - if n == 2 { |
247 | | - cert, ok := params[1].(string) |
| 261 | + if n == 3 { |
| 262 | + cert, ok := params[2].(string) |
248 | 263 | if !ok { |
249 | 264 | return nil, []any{1, "Invalid parameter type, expected string for TLS cert"} |
250 | 265 | } |
|
0 commit comments