diff --git a/debian/arduino-router/DEBIAN/postinst b/debian/arduino-router/DEBIAN/postinst index e04dfef..cba8391 100755 --- a/debian/arduino-router/DEBIAN/postinst +++ b/debian/arduino-router/DEBIAN/postinst @@ -1,13 +1,5 @@ #!/bin/sh -if [ -d /run/systemd/system ]; then - systemctl daemon-reload -fi - systemctl enable arduino-router systemctl enable arduino-router-serial -if [ -d /run/systemd/system ]; then - systemctl start arduino-router - systemctl start arduino-router-serial -fi diff --git a/debian/arduino-router/DEBIAN/prerm b/debian/arduino-router/DEBIAN/prerm index ae00407..193d5be 100755 --- a/debian/arduino-router/DEBIAN/prerm +++ b/debian/arduino-router/DEBIAN/prerm @@ -3,7 +3,3 @@ systemctl disable arduino-router systemctl disable arduino-router-serial -if [ -d /run/systemd/system ]; then - systemctl stop arduino-router - systemctl stop arduino-router-serial -fi diff --git a/hciapi/hci-api.go b/internal/hciapi/hci-api.go similarity index 99% rename from hciapi/hci-api.go rename to internal/hciapi/hci-api.go index a1a7860..655d5a7 100644 --- a/hciapi/hci-api.go +++ b/internal/hciapi/hci-api.go @@ -26,7 +26,7 @@ import ( "golang.org/x/sys/unix" - "github.com/arduino/arduino-router/msgpackrouter" + "github.com/arduino/arduino-router/internal/msgpackrouter" "github.com/arduino/arduino-router/msgpackrpc" ) diff --git a/monitorapi/monitor-api.go b/internal/monitorapi/monitor-api.go similarity index 98% rename from monitorapi/monitor-api.go rename to internal/monitorapi/monitor-api.go index 1f6f88a..c541c1e 100644 --- a/monitorapi/monitor-api.go +++ b/internal/monitorapi/monitor-api.go @@ -27,7 +27,7 @@ import ( "github.com/djherbis/buffer" "github.com/djherbis/nio/v3" - "github.com/arduino/arduino-router/msgpackrouter" + "github.com/arduino/arduino-router/internal/msgpackrouter" "github.com/arduino/arduino-router/msgpackrpc" ) diff --git a/msgpackrouter/errors.go b/internal/msgpackrouter/errors.go similarity index 100% rename from msgpackrouter/errors.go rename to internal/msgpackrouter/errors.go diff --git a/msgpackrouter/router.go b/internal/msgpackrouter/router.go similarity index 100% rename from msgpackrouter/router.go rename to internal/msgpackrouter/router.go diff --git a/msgpackrouter/router_test.go b/internal/msgpackrouter/router_test.go similarity index 98% rename from msgpackrouter/router_test.go rename to internal/msgpackrouter/router_test.go index 74ef598..83d2711 100644 --- a/msgpackrouter/router_test.go +++ b/internal/msgpackrouter/router_test.go @@ -24,7 +24,7 @@ import ( "testing" "time" - "github.com/arduino/arduino-router/msgpackrouter" + "github.com/arduino/arduino-router/internal/msgpackrouter" "github.com/arduino/arduino-router/msgpackrpc" "github.com/djherbis/buffer" diff --git a/internal/network-api/network-api.go b/internal/network-api/network-api.go new file mode 100644 index 0000000..6a1ead2 --- /dev/null +++ b/internal/network-api/network-api.go @@ -0,0 +1,605 @@ +// This file is part of arduino-router +// +// Copyright 2025 ARDUINO SA (http://www.arduino.cc/) +// +// This software is released under the GNU General Public License version 3, +// which covers the main part of arduino-router +// The terms of this license can be found at: +// https://www.gnu.org/licenses/gpl-3.0.en.html +// +// You can be released from the requirements of the above licenses by purchasing +// a commercial license. Buying such a license is mandatory if you want to +// modify or otherwise use the software for commercial activities involving the +// Arduino software without disclosing the source code of your own applications. +// To purchase a commercial license, send an email to license@arduino.cc. + +package networkapi + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "os" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/arduino/arduino-router/internal/msgpackrouter" + "github.com/arduino/arduino-router/msgpackrpc" +) + +// Register the Network API methods +func Register(router *msgpackrouter.Router) { + _ = router.RegisterMethod("tcp/connect", tcpConnect) + + _ = router.RegisterMethod("tcp/listen", tcpListen) + _ = router.RegisterMethod("tcp/closeListener", tcpCloseListener) + + _ = router.RegisterMethod("tcp/accept", tcpAccept) + _ = router.RegisterMethod("tcp/read", tcpRead) + _ = router.RegisterMethod("tcp/write", tcpWrite) + _ = router.RegisterMethod("tcp/close", tcpClose) + + _ = router.RegisterMethod("tcp/connectSSL", tcpConnectSSL) + + _ = router.RegisterMethod("udp/connect", udpConnect) + _ = router.RegisterMethod("udp/beginPacket", udpBeginPacket) + _ = router.RegisterMethod("udp/write", udpWrite) + _ = router.RegisterMethod("udp/endPacket", udpEndPacket) + _ = router.RegisterMethod("udp/awaitPacket", udpAwaitPacket) + _ = router.RegisterMethod("udp/read", udpRead) + _ = router.RegisterMethod("udp/dropPacket", udpDropPacket) + _ = router.RegisterMethod("udp/close", udpClose) +} + +var lock sync.RWMutex +var liveConnections = make(map[uint]net.Conn) +var liveListeners = make(map[uint]net.Listener) +var liveUdpConnections = make(map[uint]net.PacketConn) +var udpReadBuffers = make(map[uint][]byte) +var udpWriteTargets = make(map[uint]*net.UDPAddr) +var udpWriteBuffers = make(map[uint][]byte) +var nextConnectionID atomic.Uint32 + +// takeLockAndGenerateNextID generates a new unique ID for a connection or listener. +// It locks the global lock to ensure thread safety and checks for existing IDs. +// It returns the new ID and a function to unlock the global lock. +func takeLockAndGenerateNextID() (newID uint, unlock func()) { + lock.Lock() + for { + id := uint(nextConnectionID.Add(1)) + _, exists1 := liveConnections[id] + _, exists2 := liveListeners[id] + if !exists1 && !exists2 { + return id, func() { + lock.Unlock() + } + } + } +} + +func tcpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 2 { + return nil, []any{1, "Invalid number of parameters, expected server address and port"} + } + serverAddr, ok := params[0].(string) + if !ok { + return nil, []any{1, "Invalid parameter type, expected string for server address"} + } + serverPort, ok := msgpackrpc.ToUint(params[1]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected uint16 for server port"} + } + + serverAddr = net.JoinHostPort(serverAddr, strconv.FormatUint(uint64(serverPort), 10)) + + conn, err := net.Dial("tcp", serverAddr) + if err != nil { + return nil, []any{2, "Failed to connect to server: " + err.Error()} + } + + // Successfully connected to the server + + id, unlock := takeLockAndGenerateNextID() + liveConnections[id] = conn + unlock() + return id, nil +} + +func tcpListen(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 2 { + return nil, []any{1, "Invalid number of parameters, expected listen address and port"} + } + listenAddr, ok := params[0].(string) + if !ok { + return nil, []any{1, "Invalid parameter type, expected string for listen address"} + } + listenPort, ok := msgpackrpc.ToUint(params[1]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected uint16 for listen port"} + } + + listenAddr = net.JoinHostPort(listenAddr, strconv.FormatUint(uint64(listenPort), 10)) + + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, []any{2, "Failed to start listening on address: " + err.Error()} + } + + id, unlock := takeLockAndGenerateNextID() + liveListeners[id] = listener + unlock() + return id, nil +} + +func tcpAccept(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 1 { + return nil, []any{1, "Invalid number of parameters, expected listener ID"} + } + listenerID, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected int for listener ID"} + } + + lock.RLock() + listener, exists := liveListeners[listenerID] + lock.RUnlock() + + if !exists { + return nil, []any{2, fmt.Sprintf("Listener not found for ID: %d", listenerID)} + } + + conn, err := listener.Accept() + if err != nil { + return nil, []any{3, "Failed to accept connection: " + err.Error()} + } + + // Successfully accepted a connection + + connID, unlock := takeLockAndGenerateNextID() + liveConnections[connID] = conn + unlock() + return connID, nil +} + +func tcpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 1 { + return nil, []any{1, "Invalid number of parameters, expected connection ID"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected int for connection ID"} + } + + lock.Lock() + conn, existsConn := liveConnections[id] + if existsConn { + delete(liveConnections, id) + } + lock.Unlock() + + if !existsConn { + return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)} + } + + // Close the connection if it exists + // We do not return an error to the caller if the close operation fails, as it is not critical, + // but we only log the error for debugging purposes. + if err := conn.Close(); err != nil { + return err.Error(), nil + } + return "", nil +} + +func tcpCloseListener(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 1 { + return nil, []any{1, "Invalid number of parameters, expected listener ID"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected int for listener ID"} + } + + lock.Lock() + listener, existsListener := liveListeners[id] + if existsListener { + delete(liveListeners, id) + } + lock.Unlock() + + if !existsListener { + return nil, []any{2, fmt.Sprintf("Listener not found for ID: %d", id)} + } + + // Close the listener if it exists + // We do not return an error to the caller if the close operation fails, as it is not critical, + // but we only log the error for debugging purposes. + if err := listener.Close(); err != nil { + return err.Error(), nil + } + return "", nil +} + +func tcpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 2 && len(params) != 3 { + return nil, []any{1, "Invalid number of parameters, expected (connection ID, max bytes to read[, optional timeout in ms])"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected int for connection ID"} + } + lock.RLock() + conn, ok := liveConnections[id] + lock.RUnlock() + if !ok { + return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)} + } + maxBytes, ok := msgpackrpc.ToUint(params[1]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected int for max bytes to read"} + } + var deadline time.Time // default value == no timeout + if len(params) == 2 { + // It seems that there is no way to set a 0 ms timeout (immediate return) on a TCP connection. + // Setting the read deadline to time.Now() will always returns an empty (zero bytes) + // read, so we set it by default to a very short duration in the future (1 ms). + deadline = time.Now().Add(time.Millisecond) + } else if ms, ok := msgpackrpc.ToInt(params[2]); !ok { + return nil, []any{1, "Invalid parameter type, expected int for timeout in ms"} + } else if ms > 0 { + deadline = time.Now().Add(time.Duration(ms) * time.Millisecond) + } + + buffer := make([]byte, maxBytes) + if err := conn.SetReadDeadline(deadline); err != nil { + return nil, []any{3, "Failed to set read timeout: " + err.Error()} + } + n, err := conn.Read(buffer) + if errors.Is(err, os.ErrDeadlineExceeded) { + // timeout + } else if err != nil { + return nil, []any{3, "Failed to read from connection: " + err.Error()} + } + + return buffer[:n], nil +} + +func tcpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 2 { + return nil, []any{1, "Invalid number of parameters, expected (connection ID, data to write)"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected int for connection ID"} + } + lock.RLock() + conn, ok := liveConnections[id] + lock.RUnlock() + if !ok { + return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)} + } + data, ok := params[1].([]byte) + if !ok { + if dataStr, ok := params[1].(string); ok { + data = []byte(dataStr) + } else { + // If data is not []byte or string, return an error + return nil, []any{1, "Invalid parameter type, expected []byte or string for data to write"} + } + } + + n, err := conn.Write(data) + if err != nil { + return nil, []any{3, "Failed to write to connection: " + err.Error()} + } + + return n, nil +} + +func tcpConnectSSL(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + n := len(params) + if n < 1 || n > 3 { + return nil, []any{1, "Invalid number of parameters, expected server address, port and optional TLS cert"} + } + serverAddr, ok := params[0].(string) + if !ok { + return nil, []any{1, "Invalid parameter type, expected string for server address"} + } + serverPort, ok := msgpackrpc.ToUint(params[1]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected uint16 for server port"} + } + + serverAddr = net.JoinHostPort(serverAddr, strconv.FormatUint(uint64(serverPort), 10)) + + var tlsConfig *tls.Config + if n == 3 { + cert, ok := params[2].(string) + if !ok { + return nil, []any{1, "Invalid parameter type, expected string for TLS cert"} + } + + if len(cert) > 0 { + // parse TLS cert in pem format + certs := x509.NewCertPool() + if !certs.AppendCertsFromPEM([]byte(cert)) { + return nil, []any{1, "Failed to parse TLS certificate"} + } + tlsConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + RootCAs: certs, + } + } + } + + conn, err := tls.Dial("tcp", serverAddr, tlsConfig) + if err != nil { + return nil, []any{2, "Failed to connect to server: " + err.Error()} + } + + // Successfully connected to the server + + id, unlock := takeLockAndGenerateNextID() + liveConnections[id] = conn + unlock() + return id, nil +} + +func udpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 2 { + return nil, []any{1, "Invalid number of parameters, expected server address and port"} + } + serverAddr, ok := params[0].(string) + if !ok { + return nil, []any{1, "Invalid parameter type, expected string for server address"} + } + serverPort, ok := msgpackrpc.ToUint(params[1]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected uint16 for server port"} + } + + serverAddr = net.JoinHostPort(serverAddr, fmt.Sprintf("%d", serverPort)) + udpAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + return nil, []any{2, "Failed to resolve UDP address: " + err.Error()} + } + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, []any{2, "Failed to connect to server: " + err.Error()} + } + + // Successfully opened UDP channel + + id, unlock := takeLockAndGenerateNextID() + liveUdpConnections[id] = udpConn + unlock() + return id, nil +} + +func udpBeginPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 3 { + return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"} + } + targetIP, ok := params[1].(string) + if !ok { + return nil, []any{1, "Invalid parameter type, expected string for server address"} + } + targetPort, ok := msgpackrpc.ToUint(params[2]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected uint16 for server port"} + } + + lock.RLock() + defer lock.RUnlock() + if _, ok := liveUdpConnections[id]; !ok { + return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)} + } + targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort)) + addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching + if err != nil { + return nil, []any{3, "Failed to resolve target address: " + err.Error()} + } + udpWriteTargets[id] = addr + udpWriteBuffers[id] = nil + return true, nil +} + +func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 2 { + return nil, []any{1, "Invalid number of parameters, expected udpConnId, payload"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"} + } + data, ok := params[1].([]byte) + if !ok { + if dataStr, ok := params[1].(string); ok { + data = []byte(dataStr) + } else { + // If data is not []byte or string, return an error + return nil, []any{1, "Invalid parameter type, expected []byte or string for data to write"} + } + } + + lock.RLock() + udpBuffer, ok := udpWriteBuffers[id] + if ok { + udpWriteBuffers[id] = append(udpBuffer, data...) + } + lock.RUnlock() + if !ok { + return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)} + } + return len(data), nil +} + +func udpEndPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 1 { + return nil, []any{1, "Invalid number of parameters, expected expected udpConnId"} + } + id, buffExists := msgpackrpc.ToUint(params[0]) + if !buffExists { + return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"} + } + + var udpBuffer []byte + var udpAddr *net.UDPAddr + lock.RLock() + udpConn, connExists := liveUdpConnections[id] + if connExists { + udpBuffer, buffExists = udpWriteBuffers[id] + udpAddr = udpWriteTargets[id] + delete(udpWriteBuffers, id) + delete(udpWriteTargets, id) + } + lock.RUnlock() + if !connExists { + return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)} + } + if !buffExists { + return nil, []any{3, fmt.Sprintf("No UDP packet begun for ID: %d", id)} + } + + if n, err := udpConn.WriteTo(udpBuffer, udpAddr); err != nil { + return nil, []any{4, "Failed to write to UDP connection: " + err.Error()} + } else { + return n, nil + } +} + +func udpAwaitPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 1 && len(params) != 2 { + return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"} + } + var deadline time.Time // default value == no timeout + if len(params) == 2 { + if ms, ok := msgpackrpc.ToInt(params[1]); !ok { + return nil, []any{1, "Invalid parameter type, expected int for timeout in ms"} + } else if ms > 0 { + deadline = time.Now().Add(time.Duration(ms) * time.Millisecond) + } + } + + lock.RLock() + udpConn, ok := liveUdpConnections[id] + lock.RUnlock() + if !ok { + return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)} + } + if err := udpConn.SetReadDeadline(deadline); err != nil { + return nil, []any{3, "Failed to set read deadline: " + err.Error()} + } + buffer := make([]byte, 64*1024) // 64 KB buffer + n, addr, err := udpConn.ReadFrom(buffer) + if errors.Is(err, os.ErrDeadlineExceeded) { + // timeout + return nil, []any{5, "Timeout"} + } + if err != nil { + return nil, []any{3, "Failed to read from UDP connection: " + err.Error()} + } + host, portStr, err := net.SplitHostPort(addr.String()) + if err != nil { + // Should never fail, but... + return nil, []any{4, "Failed to parse source address: " + err.Error()} + } + port, err := strconv.Atoi(portStr) + if err != nil { + // Should never fail, but... + return nil, []any{4, "Failed to parse source address: " + err.Error()} + } + + lock.Lock() + udpReadBuffers[id] = buffer[:n] + lock.Unlock() + return []any{n, host, port}, nil +} + +func udpDropPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 1 && len(params) != 2 { + return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"} + } + + lock.RLock() + delete(udpReadBuffers, id) + lock.RUnlock() + if !ok { + return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)} + } + return true, nil +} + +func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 2 && len(params) != 3 { + return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read)"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"} + } + maxBytes, ok := msgpackrpc.ToUint(params[1]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected uint for max bytes to read"} + } + + lock.Lock() + buffer, exists := udpReadBuffers[id] + n := uint(len(buffer)) + if exists { + // keep the remainder of the buffer for the next read + if n > maxBytes { + udpReadBuffers[id] = buffer[maxBytes:] + n = maxBytes + } else { + delete(udpReadBuffers, id) + } + } + lock.Unlock() + + return buffer[:n], nil +} + +func udpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { + if len(params) != 1 { + return nil, []any{1, "Invalid number of parameters, expected UDP connection ID"} + } + id, ok := msgpackrpc.ToUint(params[0]) + if !ok { + return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"} + } + + lock.Lock() + udpConn, existsConn := liveUdpConnections[id] + delete(liveUdpConnections, id) + delete(udpReadBuffers, id) + lock.Unlock() + + if !existsConn { + return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)} + } + + // Close the connection if it exists + // We do not return an error to the caller if the close operation fails, as it is not critical, + // but we only log the error for debugging purposes. + if err := udpConn.Close(); err != nil { + return err.Error(), nil + } + return "", nil +} diff --git a/network-api/network-api_test.go b/internal/network-api/network-api_test.go similarity index 67% rename from network-api/network-api_test.go rename to internal/network-api/network-api_test.go index 41166ff..7f24fd3 100644 --- a/network-api/network-api_test.go +++ b/internal/network-api/network-api_test.go @@ -19,6 +19,7 @@ import ( "fmt" "sync" "testing" + "time" "github.com/arduino/arduino-router/msgpackrpc" @@ -156,7 +157,7 @@ const testCert = "-----BEGIN CERTIFICATE-----\n" + "HAIgNadMPgxv01dy59kCgzehgKzmKdTF0rG1SniYqnkLqPA=\n" + "-----END CERTIFICATE-----\n" -func TestNetworkAPI(t *testing.T) { +func TestTCPNetworkAPI(t *testing.T) { ctx := t.Context() var rpc *msgpackrpc.Connection listID, err := tcpListen(ctx, rpc, []any{"localhost", 9999}) @@ -236,3 +237,206 @@ func TestNetworkAPI(t *testing.T) { wg.Wait() } + +func TestUDPNetworkAPI(t *testing.T) { + ctx := t.Context() + conn1, err := udpConnect(ctx, nil, []any{"0.0.0.0", 9800}) + require.Nil(t, err) + + conn2, err := udpConnect(ctx, nil, []any{"0.0.0.0", 9900}) + require.Nil(t, err) + require.NotEqual(t, conn1, conn2) + + { + res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900}) + require.Nil(t, err) + require.True(t, res.(bool)) + res, err = udpWrite(ctx, nil, []any{conn1, []byte("Hello")}) + require.Nil(t, err) + require.Equal(t, 5, res) + res, err = udpEndPacket(ctx, nil, []any{conn1}) + require.Nil(t, err) + require.Equal(t, 5, res) + } + { + res, err := udpAwaitPacket(ctx, nil, []any{conn2}) + require.Nil(t, err) + require.Equal(t, []any{5, "127.0.0.1", 9800}, res) + + res2, err := udpRead(ctx, nil, []any{conn2, 100}) + require.Nil(t, err) + require.Equal(t, []uint8("Hello"), res2) + } + { + res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900}) + require.Nil(t, err) + require.True(t, res.(bool)) + res, err = udpWrite(ctx, nil, []any{conn1, []byte("On")}) + require.Nil(t, err) + require.Equal(t, 2, res) + res, err = udpWrite(ctx, nil, []any{conn1, []byte("e")}) + require.Nil(t, err) + require.Equal(t, 1, res) + res, err = udpEndPacket(ctx, nil, []any{conn1}) + require.Nil(t, err) + require.Equal(t, 3, res) + } + { + res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900}) + require.Nil(t, err) + require.True(t, res.(bool)) + res, err = udpWrite(ctx, nil, []any{conn1, []byte("Two")}) + require.Nil(t, err) + require.Equal(t, 3, res) + res, err = udpEndPacket(ctx, nil, []any{conn1}) + require.Nil(t, err) + require.Equal(t, 3, res) + } + { + res, err := udpAwaitPacket(ctx, nil, []any{conn2}) + require.Nil(t, err) + require.Equal(t, []any{3, "127.0.0.1", 9800}, res) + + // A partial read of a packet is allowed + res2, err := udpRead(ctx, nil, []any{conn2, 2}) + require.Nil(t, err) + require.Equal(t, []uint8("On"), res2) + } + { + // Even if the previous packet was only partially read, + // the next packet can be received + res, err := udpAwaitPacket(ctx, nil, []any{conn2}) + require.Nil(t, err) + require.Equal(t, []any{3, "127.0.0.1", 9800}, res) + + res2, err := udpRead(ctx, nil, []any{conn2, 100}) + require.Nil(t, err) + require.Equal(t, []uint8("Two"), res2) + } + { + res, err := udpClose(ctx, nil, []any{conn1}) + require.Nil(t, err) + require.Equal(t, "", res) + } + { + res, err := udpClose(ctx, nil, []any{conn2}) + require.Nil(t, err) + require.Equal(t, "", res) + } +} + +func TestUDPNetworkUnboundClientAPI(t *testing.T) { + ctx := t.Context() + conn1, err := udpConnect(ctx, nil, []any{"", 0}) + require.Nil(t, err) + + conn2, err := udpConnect(ctx, nil, []any{"0.0.0.0", 9901}) + require.Nil(t, err) + require.NotEqual(t, conn1, conn2) + + { + res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901}) + require.Nil(t, err) + require.True(t, res.(bool)) + res, err = udpWrite(ctx, nil, []any{conn1, []byte("Hello")}) + require.Nil(t, err) + require.Equal(t, 5, res) + res, err = udpEndPacket(ctx, nil, []any{conn1}) + require.Nil(t, err) + require.Equal(t, 5, res) + } + { + res, err := udpAwaitPacket(ctx, nil, []any{conn2}) + require.Nil(t, err) + require.Equal(t, 5, res.([]any)[0]) + + res2, err := udpRead(ctx, nil, []any{conn2, 2}) + require.Nil(t, err) + require.Equal(t, []uint8("He"), res2) + + res2, err = udpRead(ctx, nil, []any{conn2, 20}) + require.Nil(t, err) + require.Equal(t, []uint8("llo"), res2) + } + { + res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901}) + require.Nil(t, err) + require.True(t, res.(bool)) + res, err = udpWrite(ctx, nil, []any{conn1, []byte("One")}) + require.Nil(t, err) + require.Equal(t, 3, res) + res, err = udpEndPacket(ctx, nil, []any{conn1}) + require.Nil(t, err) + require.Equal(t, 3, res) + } + { + res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901}) + require.Nil(t, err) + require.True(t, res.(bool)) + res, err = udpWrite(ctx, nil, []any{conn1, []byte("Two")}) + require.Nil(t, err) + require.Equal(t, 3, res) + res, err = udpEndPacket(ctx, nil, []any{conn1}) + require.Nil(t, err) + require.Equal(t, 3, res) + } + { + res, err := udpAwaitPacket(ctx, nil, []any{conn2}) + require.Nil(t, err) + require.Equal(t, 3, res.([]any)[0]) + + res2, err := udpRead(ctx, nil, []any{conn2, 100}) + require.Nil(t, err) + require.Equal(t, []uint8("One"), res2) + } + { + res, err := udpAwaitPacket(ctx, nil, []any{conn2}) + require.Nil(t, err) + require.Equal(t, 3, res.([]any)[0]) + + res2, err := udpRead(ctx, nil, []any{conn2, 100}) + require.Nil(t, err) + require.Equal(t, []uint8("Two"), res2) + } + + // Check timeouts + go func() { + time.Sleep(200 * time.Millisecond) + res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901}) + require.Nil(t, err) + require.True(t, res.(bool)) + res, err = udpWrite(ctx, nil, []any{conn1, []byte("Three")}) + require.Nil(t, err) + require.Equal(t, 5, res) + res, err = udpEndPacket(ctx, nil, []any{conn1}) + require.Nil(t, err) + require.Equal(t, 5, res) + }() + { + start := time.Now() + res, err := udpAwaitPacket(ctx, nil, []any{conn2, 10}) + require.Less(t, time.Since(start), 20*time.Millisecond) + require.Equal(t, []any{5, "Timeout"}, err) + require.Nil(t, res) + } + { + res, err := udpAwaitPacket(ctx, nil, []any{conn2, 0}) + require.Nil(t, err) + require.Equal(t, 5, res.([]any)[0]) + + res2, err := udpRead(ctx, nil, []any{conn2, 100, 0}) + require.Nil(t, err) + require.Equal(t, []uint8("Three"), res2) + } + + { + res, err := udpClose(ctx, nil, []any{conn1}) + require.Nil(t, err) + require.Equal(t, "", res) + } + { + res, err := udpClose(ctx, nil, []any{conn2}) + require.Nil(t, err) + require.Equal(t, "", res) + } +} diff --git a/main.go b/main.go index 9cba6de..afdbea4 100644 --- a/main.go +++ b/main.go @@ -29,11 +29,11 @@ import ( "syscall" "time" - "github.com/arduino/arduino-router/hciapi" - "github.com/arduino/arduino-router/monitorapi" - "github.com/arduino/arduino-router/msgpackrouter" + "github.com/arduino/arduino-router/internal/hciapi" + "github.com/arduino/arduino-router/internal/monitorapi" + "github.com/arduino/arduino-router/internal/msgpackrouter" + networkapi "github.com/arduino/arduino-router/internal/network-api" "github.com/arduino/arduino-router/msgpackrpc" - networkapi "github.com/arduino/arduino-router/network-api" "github.com/spf13/cobra" "go.bug.st/f" @@ -56,15 +56,10 @@ type Config struct { func main() { var cfg Config var verbose bool - var printVersion bool cmd := &cobra.Command{ Use: "arduino-router", Long: "Arduino router for msgpack RPC service protocol", Run: func(cmd *cobra.Command, args []string) { - if printVersion { - fmt.Println("Arduino router v" + Version) - return - } if verbose { cfg.LogLevel = slog.LevelDebug } else { @@ -80,12 +75,20 @@ func main() { }, } cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Enable verbose logging") - cmd.Flags().BoolVar(&printVersion, "version", false, "Print version information") cmd.Flags().StringVarP(&cfg.ListenTCPAddr, "listen-port", "l", "", "Listening port for RPC services") cmd.Flags().StringVarP(&cfg.ListenUnixAddr, "unix-port", "u", "/var/run/arduino-router.sock", "Listening port for RPC services") cmd.Flags().StringVarP(&cfg.SerialPortAddr, "serial-port", "p", "", "Serial port address") cmd.Flags().IntVarP(&cfg.SerialBaudRate, "serial-baudrate", "b", 115200, "Serial port baud rate") cmd.Flags().StringVarP(&cfg.MonitorPortAddr, "monitor-port", "m", "127.0.0.1:7500", "Listening port for MCU monitor proxy") + + cmd.AddCommand(&cobra.Command{ + Use: "version", + Long: "Print version information", + Run: func(cmd *cobra.Command, args []string) { + fmt.Println("Arduino Router " + Version) + }, + }) + if err := cmd.Execute(); err != nil { slog.Error("Error executing command.", "error", err) } @@ -160,6 +163,13 @@ func startRouter(cfg Config) error { // Register HCI API methods hciapi.Register(router) + // Register monitor version API methods + if err := router.RegisterMethod("$/version", func(_ context.Context, _ *msgpackrpc.Connection, _ []any) (any, any) { + return Version, nil + }); err != nil { + slog.Error("Failed to register version API", "err", err) + } + // Register monitor API methods if err := monitorapi.Register(router, cfg.MonitorPortAddr); err != nil { slog.Error("Failed to register monitor API", "err", err) diff --git a/network-api/network-api.go b/network-api/network-api.go deleted file mode 100644 index 3411ce6..0000000 --- a/network-api/network-api.go +++ /dev/null @@ -1,330 +0,0 @@ -// This file is part of arduino-router -// -// Copyright 2025 ARDUINO SA (http://www.arduino.cc/) -// -// This software is released under the GNU General Public License version 3, -// which covers the main part of arduino-router -// The terms of this license can be found at: -// https://www.gnu.org/licenses/gpl-3.0.en.html -// -// You can be released from the requirements of the above licenses by purchasing -// a commercial license. Buying such a license is mandatory if you want to -// modify or otherwise use the software for commercial activities involving the -// Arduino software without disclosing the source code of your own applications. -// To purchase a commercial license, send an email to license@arduino.cc. - -package networkapi - -import ( - "context" - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "net" - "os" - "strconv" - "sync" - "sync/atomic" - "time" - - "github.com/arduino/arduino-router/msgpackrouter" - "github.com/arduino/arduino-router/msgpackrpc" -) - -// Register the Network API methods -func Register(router *msgpackrouter.Router) { - _ = router.RegisterMethod("tcp/connect", tcpConnect) - - _ = router.RegisterMethod("tcp/listen", tcpListen) - _ = router.RegisterMethod("tcp/closeListener", tcpCloseListener) - - _ = router.RegisterMethod("tcp/accept", tcpAccept) - _ = router.RegisterMethod("tcp/read", tcpRead) - _ = router.RegisterMethod("tcp/write", tcpWrite) - _ = router.RegisterMethod("tcp/close", tcpClose) - - _ = router.RegisterMethod("tcp/connectSSL", tcpConnectSSL) - -} - -var lock sync.RWMutex -var liveConnections = make(map[uint]net.Conn) -var liveListeners = make(map[uint]net.Listener) -var nextConnectionID atomic.Uint32 - -// takeLockAndGenerateNextID generates a new unique ID for a connection or listener. -// It locks the global lock to ensure thread safety and checks for existing IDs. -// It returns the new ID and a function to unlock the global lock. -func takeLockAndGenerateNextID() (newID uint, unlock func()) { - lock.Lock() - for { - id := uint(nextConnectionID.Add(1)) - _, exists1 := liveConnections[id] - _, exists2 := liveListeners[id] - if !exists1 && !exists2 { - return id, func() { - lock.Unlock() - } - } - } -} - -func tcpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { - if len(params) != 2 { - return nil, []any{1, "Invalid number of parameters, expected server address and port"} - } - serverAddr, ok := params[0].(string) - if !ok { - return nil, []any{1, "Invalid parameter type, expected string for server address"} - } - serverPort, ok := msgpackrpc.ToUint(params[1]) - if !ok { - return nil, []any{1, "Invalid parameter type, expected uint16 for server port"} - } - - serverAddr = net.JoinHostPort(serverAddr, strconv.FormatUint(uint64(serverPort), 10)) - - conn, err := net.Dial("tcp", serverAddr) - if err != nil { - return nil, []any{2, "Failed to connect to server: " + err.Error()} - } - - // Successfully connected to the server - - id, unlock := takeLockAndGenerateNextID() - liveConnections[id] = conn - unlock() - return id, nil -} - -func tcpListen(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { - if len(params) != 2 { - return nil, []any{1, "Invalid number of parameters, expected listen address and port"} - } - listenAddr, ok := params[0].(string) - if !ok { - return nil, []any{1, "Invalid parameter type, expected string for listen address"} - } - listenPort, ok := msgpackrpc.ToUint(params[1]) - if !ok { - return nil, []any{1, "Invalid parameter type, expected uint16 for listen port"} - } - - listenAddr = net.JoinHostPort(listenAddr, strconv.FormatUint(uint64(listenPort), 10)) - - listener, err := net.Listen("tcp", listenAddr) - if err != nil { - return nil, []any{2, "Failed to start listening on address: " + err.Error()} - } - - id, unlock := takeLockAndGenerateNextID() - liveListeners[id] = listener - unlock() - return id, nil -} - -func tcpAccept(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { - if len(params) != 1 { - return nil, []any{1, "Invalid number of parameters, expected listener ID"} - } - listenerID, ok := msgpackrpc.ToUint(params[0]) - if !ok { - return nil, []any{1, "Invalid parameter type, expected int for listener ID"} - } - - lock.RLock() - listener, exists := liveListeners[listenerID] - lock.RUnlock() - - if !exists { - return nil, []any{2, fmt.Sprintf("Listener not found for ID: %d", listenerID)} - } - - conn, err := listener.Accept() - if err != nil { - return nil, []any{3, "Failed to accept connection: " + err.Error()} - } - - // Successfully accepted a connection - - connID, unlock := takeLockAndGenerateNextID() - liveConnections[connID] = conn - unlock() - return connID, nil -} - -func tcpClose(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { - if len(params) != 1 { - return nil, []any{1, "Invalid number of parameters, expected connection ID"} - } - id, ok := msgpackrpc.ToUint(params[0]) - if !ok { - return nil, []any{1, "Invalid parameter type, expected int for connection ID"} - } - - lock.Lock() - conn, existsConn := liveConnections[id] - if existsConn { - delete(liveConnections, id) - } - lock.Unlock() - - if !existsConn { - return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)} - } - - // Close the connection if it exists - // We do not return an error to the caller if the close operation fails, as it is not critical, - // but we only log the error for debugging purposes. - if err := conn.Close(); err != nil { - return err.Error(), nil - } - return "", nil -} - -func tcpCloseListener(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { - if len(params) != 1 { - return nil, []any{1, "Invalid number of parameters, expected listener ID"} - } - id, ok := msgpackrpc.ToUint(params[0]) - if !ok { - return nil, []any{1, "Invalid parameter type, expected int for listener ID"} - } - - lock.Lock() - listener, existsListener := liveListeners[id] - if existsListener { - delete(liveListeners, id) - } - lock.Unlock() - - if !existsListener { - return nil, []any{2, fmt.Sprintf("Listener not found for ID: %d", id)} - } - - // Close the listener if it exists - // We do not return an error to the caller if the close operation fails, as it is not critical, - // but we only log the error for debugging purposes. - if err := listener.Close(); err != nil { - return err.Error(), nil - } - return "", nil -} - -func tcpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { - if len(params) != 2 { - return nil, []any{1, "Invalid number of parameters, expected (connection ID, max bytes to read)"} - } - id, ok := msgpackrpc.ToUint(params[0]) - if !ok { - return nil, []any{1, "Invalid parameter type, expected int for connection ID"} - } - lock.RLock() - conn, ok := liveConnections[id] - lock.RUnlock() - if !ok { - return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)} - } - maxBytes, ok := msgpackrpc.ToUint(params[1]) - if !ok { - return nil, []any{1, "Invalid parameter type, expected int for max bytes to read"} - } - - buffer := make([]byte, maxBytes) - // It seems that the only way to make a non-blocking read is to set a read deadline. - // BTW setting the read deadline to time.Now() will always returns an empty (zero bytes) - // read, so we set it to a very short duration in the future. - if err := conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { - return nil, []any{3, "Failed to set read timeout: " + err.Error()} - } - n, err := conn.Read(buffer) - if errors.Is(err, os.ErrDeadlineExceeded) { - // timeout - } else if err != nil { - return nil, []any{3, "Failed to read from connection: " + err.Error()} - } - - return buffer[:n], nil -} - -func tcpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { - if len(params) != 2 { - return nil, []any{1, "Invalid number of parameters, expected (connection ID, data to write)"} - } - id, ok := msgpackrpc.ToUint(params[0]) - if !ok { - return nil, []any{1, "Invalid parameter type, expected int for connection ID"} - } - lock.RLock() - conn, ok := liveConnections[id] - lock.RUnlock() - if !ok { - return nil, []any{2, fmt.Sprintf("Connection not found for ID: %d", id)} - } - data, ok := params[1].([]byte) - if !ok { - if dataStr, ok := params[1].(string); ok { - data = []byte(dataStr) - } else { - // If data is not []byte or string, return an error - return nil, []any{1, "Invalid parameter type, expected []byte or string for data to write"} - } - } - - n, err := conn.Write(data) - if err != nil { - return nil, []any{3, "Failed to write to connection: " + err.Error()} - } - - return n, nil -} - -func tcpConnectSSL(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) { - n := len(params) - if n < 1 || n > 3 { - return nil, []any{1, "Invalid number of parameters, expected server address, port and optional TLS cert"} - } - serverAddr, ok := params[0].(string) - if !ok { - return nil, []any{1, "Invalid parameter type, expected string for server address"} - } - serverPort, ok := msgpackrpc.ToUint(params[1]) - if !ok { - return nil, []any{1, "Invalid parameter type, expected uint16 for server port"} - } - - serverAddr = net.JoinHostPort(serverAddr, strconv.FormatUint(uint64(serverPort), 10)) - - var tlsConfig *tls.Config - if n == 3 { - cert, ok := params[2].(string) - if !ok { - return nil, []any{1, "Invalid parameter type, expected string for TLS cert"} - } - - if len(cert) > 0 { - // parse TLS cert in pem format - certs := x509.NewCertPool() - if !certs.AppendCertsFromPEM([]byte(cert)) { - return nil, []any{1, "Failed to parse TLS certificate"} - } - tlsConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - RootCAs: certs, - } - } - } - - conn, err := tls.Dial("tcp", serverAddr, tlsConfig) - if err != nil { - return nil, []any{2, "Failed to connect to server: " + err.Error()} - } - - // Successfully connected to the server - - id, unlock := takeLockAndGenerateNextID() - liveConnections[id] = conn - unlock() - return id, nil -}