Skip to content

Commit 3658cc4

Browse files
cmaglieCopilot
andauthored
Implementaion of TCP/IP Network RPC API
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 6f0bc96 commit 3658cc4

File tree

4 files changed

+502
-6
lines changed

4 files changed

+502
-6
lines changed

cmd/router/main.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import (
1414
"time"
1515

1616
"github.com/arduino/router/msgpackrouter"
17+
"github.com/arduino/router/msgpackrpc"
18+
networkapi "github.com/arduino/router/network-api"
1719

1820
"github.com/arduino/go-paths-helper"
1921
"github.com/spf13/cobra"
@@ -107,13 +109,16 @@ func startRouter(cfg Config) error {
107109
// Run router
108110
router := msgpackrouter.New()
109111

112+
// Register TCP network API methods
113+
networkapi.Register(router)
114+
110115
// Open serial port if specified
111116
if cfg.SerialPortAddr != "" {
112117
var serialLock sync.Mutex
113118
var serialOpened = sync.NewCond(&serialLock)
114119
var serialClosed = sync.NewCond(&serialLock)
115120
var serialCloseSignal = make(chan struct{})
116-
err := router.RegisterMethod("$/serial/open", func(ctx context.Context, params []any) (result any, err any) {
121+
err := router.RegisterMethod("$/serial/open", func(ctx context.Context, _ *msgpackrpc.Connection, params []any) (result any, err any) {
117122
if len(params) != 1 {
118123
return nil, []any{1, "Invalid number of parameters"}
119124
}
@@ -134,7 +139,7 @@ func startRouter(cfg Config) error {
134139
return true, nil
135140
})
136141
f.Assert(err == nil, "Failed to register $/serial/open method")
137-
err = router.RegisterMethod("$/serial/close", func(ctx context.Context, params []any) (result any, err any) {
142+
err = router.RegisterMethod("$/serial/close", func(ctx context.Context, _ *msgpackrpc.Connection, params []any) (result any, err any) {
138143
if len(params) != 1 {
139144
return nil, []any{1, "Invalid number of parameters"}
140145
}

cmd/router/msgpackrouter/router.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ import (
1212
"github.com/arduino/router/msgpackrpc"
1313
)
1414

15-
type RouterRequestHandler func(ctx context.Context, params []any) (result any, err any)
15+
type RouterRequestHandler func(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (result any, err any)
1616

1717
type Router struct {
1818
routesLock sync.Mutex
1919
routes map[string]*msgpackrpc.Connection
2020
routesInternal map[string]RouterRequestHandler
2121
}
2222

23-
func New() Router {
24-
return Router{
23+
func New() *Router {
24+
return &Router{
2525
routes: make(map[string]*msgpackrpc.Connection),
2626
routesInternal: make(map[string]RouterRequestHandler),
2727
}
@@ -92,7 +92,7 @@ func (r *Router) connectionLoop(conn io.ReadWriteCloser) {
9292
// Check if the method is an internal method
9393
if handler, ok := r.routesInternal[method]; ok {
9494
// Call the internal method handler
95-
return handler(ctx, params)
95+
return handler(ctx, msgpackconn, params)
9696
}
9797

9898
// Check if the method is registered
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
package networkapi
2+
3+
import (
4+
"context"
5+
"crypto/tls"
6+
"crypto/x509"
7+
"errors"
8+
"fmt"
9+
"net"
10+
"os"
11+
"sync"
12+
"sync/atomic"
13+
"time"
14+
15+
"github.com/arduino/router/msgpackrouter"
16+
"github.com/arduino/router/msgpackrpc"
17+
)
18+
19+
// Register the Network API methods
20+
func Register(router *msgpackrouter.Router) {
21+
_ = router.RegisterMethod("tcp/connect", tcpConnect)
22+
23+
_ = router.RegisterMethod("tcp/listen", tcpListen)
24+
_ = router.RegisterMethod("tcp/accept", tcpAccept)
25+
26+
_ = router.RegisterMethod("tcp/read", tcpRead)
27+
_ = router.RegisterMethod("tcp/write", tcpWrite)
28+
_ = router.RegisterMethod("tcp/close", tcpClose)
29+
30+
_ = router.RegisterMethod("tcp/connectSSL", tcpConnectSSL)
31+
32+
}
33+
34+
var lock sync.RWMutex
35+
var liveConnections = make(map[uint]net.Conn)
36+
var liveListeners = make(map[uint]net.Listener)
37+
var nextConnectionID atomic.Uint32
38+
39+
// takeLockAndGenerateNextID generates a new unique ID for a connection or listener.
40+
// It locks the global lock to ensure thread safety and checks for existing IDs.
41+
// It returns the new ID and a function to unlock the global lock.
42+
func takeLockAndGenerateNextID() (newID uint, unlock func()) {
43+
lock.Lock()
44+
for {
45+
id := uint(nextConnectionID.Add(1))
46+
_, exists1 := liveConnections[id]
47+
_, exists2 := liveListeners[id]
48+
if !exists1 && !exists2 {
49+
return id, func() {
50+
lock.Unlock()
51+
}
52+
}
53+
}
54+
}
55+
56+
func tcpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
57+
if len(params) != 1 {
58+
return nil, []any{1, "Invalid number of parameters, expected server address"}
59+
}
60+
serverAddr, ok := params[0].(string)
61+
if !ok {
62+
return nil, []any{1, "Invalid parameter type, expected string for server address"}
63+
}
64+
conn, err := net.Dial("tcp", serverAddr)
65+
if err != nil {
66+
return nil, []any{2, "Failed to connect to server: " + err.Error()}
67+
}
68+
69+
// Successfully connected to the server
70+
71+
id, unlock := takeLockAndGenerateNextID()
72+
liveConnections[id] = conn
73+
unlock()
74+
return id, nil
75+
}
76+
77+
func tcpListen(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
78+
if len(params) != 1 {
79+
return nil, []any{1, "Invalid number of parameters, expected listen address"}
80+
}
81+
listenAddr, ok := params[0].(string)
82+
if !ok {
83+
return nil, []any{1, "Invalid parameter type, expected string for listen address"}
84+
}
85+
86+
listener, err := net.Listen("tcp", listenAddr)
87+
if err != nil {
88+
return nil, []any{2, "Failed to start listening on address: " + err.Error()}
89+
}
90+
91+
id, unlock := takeLockAndGenerateNextID()
92+
liveListeners[id] = listener
93+
unlock()
94+
return id, nil
95+
}
96+
97+
func tcpAccept(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
98+
if len(params) != 1 {
99+
return nil, []any{1, "Invalid number of parameters, expected listener ID"}
100+
}
101+
listenerID, ok := msgpackrpc.ToUint(params[0])
102+
if !ok {
103+
return nil, []any{1, "Invalid parameter type, expected int for listener ID"}
104+
}
105+
106+
lock.RLock()
107+
listener, exists := liveListeners[listenerID]
108+
lock.RUnlock()
109+
110+
if !exists {
111+
return nil, []any{2, fmt.Sprintf("Listener not found for ID: %d", listenerID)}
112+
}
113+
114+
conn, err := listener.Accept()
115+
if err != nil {
116+
return nil, []any{3, "Failed to accept connection: " + err.Error()}
117+
}
118+
119+
// Successfully accepted a connection
120+
121+
connID, unlock := takeLockAndGenerateNextID()
122+
liveConnections[connID] = conn
123+
unlock()
124+
return connID, nil
125+
}
126+
127+
func tcpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
128+
if len(params) != 1 {
129+
return nil, []any{1, "Invalid number of parameters, expected connection ID"}
130+
}
131+
id, ok := msgpackrpc.ToUint(params[0])
132+
if !ok {
133+
return nil, []any{1, "Invalid parameter type, expected int for connection ID"}
134+
}
135+
136+
lock.Lock()
137+
conn, existsConn := liveConnections[id]
138+
listener, existsListener := liveListeners[id]
139+
if existsConn {
140+
delete(liveConnections, id)
141+
}
142+
if existsListener {
143+
delete(liveListeners, id)
144+
}
145+
lock.Unlock()
146+
147+
if !existsConn && !existsListener {
148+
return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)}
149+
}
150+
151+
// Close the connection or listener if it exists
152+
// We do not return an error if the close operation fails, as it is not critical,
153+
// but we only log the error for debugging purposes.
154+
if existsConn {
155+
if err := conn.Close(); err != nil {
156+
return err.Error(), nil
157+
}
158+
}
159+
if existsListener {
160+
if err := listener.Close(); err != nil {
161+
return err.Error(), nil
162+
}
163+
}
164+
165+
return "", nil
166+
}
167+
168+
func tcpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
169+
if len(params) != 2 {
170+
return nil, []any{1, "Invalid number of parameters, expected (connection ID, max bytes to read)"}
171+
}
172+
id, ok := msgpackrpc.ToUint(params[0])
173+
if !ok {
174+
return nil, []any{1, "Invalid parameter type, expected int for connection ID"}
175+
}
176+
lock.RLock()
177+
conn, ok := liveConnections[id]
178+
lock.RUnlock()
179+
if !ok {
180+
return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)}
181+
}
182+
maxBytes, ok := msgpackrpc.ToUint(params[1])
183+
if !ok {
184+
return nil, []any{1, "Invalid parameter type, expected int for max bytes to read"}
185+
}
186+
187+
buffer := make([]byte, maxBytes)
188+
// It seems that the only way to make a non-blocking read is to set a read deadline.
189+
// BTW setting the read deadline to time.Now() will always returns an empty (zero bytes)
190+
// read, so we set it to a very short duration in the future.
191+
if err := conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil {
192+
return nil, []any{3, "Failed to set read timeout: " + err.Error()}
193+
}
194+
n, err := conn.Read(buffer)
195+
if errors.Is(err, os.ErrDeadlineExceeded) {
196+
// timeout
197+
} else if err != nil {
198+
return nil, []any{3, "Failed to read from connection: " + err.Error()}
199+
}
200+
201+
return buffer[:n], nil
202+
}
203+
204+
func tcpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
205+
if len(params) != 2 {
206+
return nil, []any{1, "Invalid number of parameters, expected (connection ID, data to write)"}
207+
}
208+
id, ok := msgpackrpc.ToUint(params[0])
209+
if !ok {
210+
return nil, []any{1, "Invalid parameter type, expected int for connection ID"}
211+
}
212+
lock.RLock()
213+
conn, ok := liveConnections[id]
214+
lock.RUnlock()
215+
if !ok {
216+
return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)}
217+
}
218+
data, ok := params[1].([]byte)
219+
if !ok {
220+
if dataStr, ok := params[1].(string); ok {
221+
data = []byte(dataStr)
222+
} else {
223+
// If data is not []byte or string, return an error
224+
return nil, []any{1, "Invalid parameter type, expected []byte or string for data to write"}
225+
}
226+
}
227+
228+
n, err := conn.Write(data)
229+
if err != nil {
230+
return nil, []any{3, "Failed to write to connection: " + err.Error()}
231+
}
232+
233+
return n, nil
234+
}
235+
236+
func tcpConnectSSL(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
237+
n := len(params)
238+
if n < 1 || n > 2 {
239+
return nil, []any{1, "Invalid number of parameters, expected server address and optional TLS cert"}
240+
}
241+
serverAddr, ok := params[0].(string)
242+
if !ok {
243+
return nil, []any{1, "Invalid parameter type, expected string for server address"}
244+
}
245+
var tlsConfig *tls.Config
246+
if n == 2 {
247+
cert, ok := params[1].(string)
248+
if !ok {
249+
return nil, []any{1, "Invalid parameter type, expected string for TLS cert"}
250+
}
251+
252+
if len(cert) > 0 {
253+
// parse TLS cert in pem format
254+
certs := x509.NewCertPool()
255+
if !certs.AppendCertsFromPEM([]byte(cert)) {
256+
return nil, []any{1, "Failed to parse TLS certificate"}
257+
}
258+
tlsConfig = &tls.Config{
259+
MinVersion: tls.VersionTLS12,
260+
RootCAs: certs,
261+
}
262+
}
263+
}
264+
265+
conn, err := tls.Dial("tcp", serverAddr, tlsConfig)
266+
if err != nil {
267+
return nil, []any{2, "Failed to connect to server: " + err.Error()}
268+
}
269+
270+
// Successfully connected to the server
271+
272+
id, unlock := takeLockAndGenerateNextID()
273+
liveConnections[id] = conn
274+
unlock()
275+
return id, nil
276+
}

0 commit comments

Comments
 (0)