Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions internal/msgpackrouter/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,26 @@ type Router struct {
routesLock sync.Mutex
routes map[string]*msgpackrpc.Connection
routesInternal map[string]RouterRequestHandler
sendQueueSize int
}

func New() *Router {
func New(perConnSendQueueSize int) *Router {
return &Router{
routes: make(map[string]*msgpackrpc.Connection),
routesInternal: make(map[string]RouterRequestHandler),
sendQueueSize: perConnSendQueueSize,
}
}

// SetSendQueueSize sets the size of the send queue for each connection.
// Only new connections will be affected by this change, existing connections
// will keep their current send queue size.
func (r *Router) SetSendQueueSize(size int) {
r.routesLock.Lock()
defer r.routesLock.Unlock()
r.sendQueueSize = size
}

func (r *Router) Accept(conn io.ReadWriteCloser) <-chan struct{} {
res := make(chan struct{})
go func() {
Expand Down Expand Up @@ -70,7 +81,7 @@ func (r *Router) connectionLoop(conn io.ReadWriteCloser) {
defer conn.Close()

var msgpackconn *msgpackrpc.Connection
msgpackconn = msgpackrpc.NewConnection(conn, conn,
msgpackconn = msgpackrpc.NewConnectionWithMaxWorkers(conn, conn,
func(ctx context.Context, _ msgpackrpc.FunctionLogger, method string, params []any) (_result any, _err any) {
// This handler is called when a request is received from the client
slog.Debug("Received request", "method", method, "params", params)
Expand Down Expand Up @@ -149,6 +160,7 @@ func (r *Router) connectionLoop(conn io.ReadWriteCloser) {
}
slog.Error("Error in connection", "err", err)
},
r.sendQueueSize,
)

msgpackconn.Run()
Expand Down
54 changes: 53 additions & 1 deletion internal/msgpackrouter/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestBasicRouterFunctionality(t *testing.T) {
})
go cl2.Run()

router := msgpackrouter.New()
router := msgpackrouter.New(0)
router.Accept(ch1b)
router.Accept(ch2b)

Expand Down Expand Up @@ -160,3 +160,55 @@ func TestBasicRouterFunctionality(t *testing.T) {
require.Contains(t, cl1Notifications.String(), "notification: ping [b 14 true true]")
cl1NotificationsMux.Unlock()
}

func TestMessageForwarderCongestionControl(t *testing.T) {
// Test parameters
queueSize := 5
msgLatency := 100 * time.Millisecond
// Run a batch of 20 requests, and expect them to take more than 400 ms
// in total because the router should throttle requests in batch of 5.
batchSize := queueSize * 4
expectedLatency := msgLatency * time.Duration(batchSize/queueSize)

// Make a client that simulates a slow response
ch1a, ch1b := newFullPipe()
cl1 := msgpackrpc.NewConnection(ch1a, ch1a, func(ctx context.Context, logger msgpackrpc.FunctionLogger, method string, params []any) (_result any, _err any) {
time.Sleep(msgLatency)
return true, nil
}, nil, nil)
go cl1.Run()

// Make a second client to send requests, without any delay
ch2a, ch2b := newFullPipe()
cl2 := msgpackrpc.NewConnection(ch2a, ch2a, nil, nil, nil)
go cl2.Run()

// Setup router
router := msgpackrouter.New(queueSize) // max 5 pending messages per connection
router.Accept(ch1b)
router.Accept(ch2b)

{
// Register a method on the first client
result, reqErr, err := cl1.SendRequest(context.Background(), "$/register", []any{"test"})
require.Equal(t, true, result)
require.Nil(t, reqErr)
require.NoError(t, err)
}

// Run batch of requests from cl2 to cl1
start := time.Now()
var wg sync.WaitGroup
for range batchSize {
wg.Go(func() {
_, _, err := cl2.SendRequest(t.Context(), "test", []any{})
require.NoError(t, err)
})
}
wg.Wait()
elapsed := time.Since(start)

// Check that the elapsed time is greater than expectedLatency
fmt.Println("Elapsed time for requests:", elapsed)
require.Greater(t, elapsed, expectedLatency, "Expected elapsed time to be greater than %s", expectedLatency)
}
17 changes: 9 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ var Version string = "0.0.0-dev"

// Server configuration
type Config struct {
LogLevel slog.Level
ListenTCPAddr string
ListenUnixAddr string
SerialPortAddr string
SerialBaudRate int
MonitorPortAddr string
LogLevel slog.Level
ListenTCPAddr string
ListenUnixAddr string
SerialPortAddr string
SerialBaudRate int
MonitorPortAddr string
MaxPendingRequestsPerClient int
}

func main() {
Expand Down Expand Up @@ -80,7 +81,7 @@ func main() {
cmd.Flags().StringVarP(&cfg.SerialPortAddr, "serial-port", "p", "", "Serial port address")
cmd.Flags().IntVarP(&cfg.SerialBaudRate, "serial-baudrate", "b", 115200, "Serial port baud rate")
cmd.Flags().StringVarP(&cfg.MonitorPortAddr, "monitor-port", "m", "127.0.0.1:7500", "Listening port for MCU monitor proxy")

cmd.Flags().IntVarP(&cfg.MaxPendingRequestsPerClient, "max-pending-requests", "", 25, "Maximum number of pending requests per client connection")
cmd.AddCommand(&cobra.Command{
Use: "version",
Long: "Print version information",
Expand Down Expand Up @@ -155,7 +156,7 @@ func startRouter(cfg Config) error {
}

// Run router
router := msgpackrouter.New()
router := msgpackrouter.New(cfg.MaxPendingRequestsPerClient)

// Register TCP network API methods
networkapi.Register(router)
Expand Down
33 changes: 29 additions & 4 deletions msgpackrpc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ type Connection struct {
activeOutRequests map[MessageID]*outRequest
activeOutRequestsMutex sync.Mutex
lastOutRequestsIndex atomic.Uint32

workerSlots chan bool
}

type inRequest struct {
Expand Down Expand Up @@ -79,8 +81,14 @@ type NotificationHandler func(logger FunctionLogger, method string, params []any
// sending a request or notification.
type ErrorHandler func(error)

// NewConnection starts a new
// NewConnection creates a new MessagePack-RPC Connection handler.
func NewConnection(in io.ReadCloser, out io.WriteCloser, requestHandler RequestHandler, notificationHandler NotificationHandler, errorHandler ErrorHandler) *Connection {
return NewConnectionWithMaxWorkers(in, out, requestHandler, notificationHandler, errorHandler, 0)
}

// NewConnectionWithMaxWorkers creates a new MessagePack-RPC Connection handler
// with a specified maximum number of worker goroutines to handle incoming requests.
func NewConnectionWithMaxWorkers(in io.ReadCloser, out io.WriteCloser, requestHandler RequestHandler, notificationHandler NotificationHandler, errorHandler ErrorHandler, maxWorkers int) *Connection {
outEncoder := msgpack.NewEncoder(out)
outEncoder.UseCompactInts(true)
if requestHandler == nil {
Expand Down Expand Up @@ -109,9 +117,24 @@ func NewConnection(in io.ReadCloser, out io.WriteCloser, requestHandler RequestH
activeOutRequests: map[MessageID]*outRequest{},
logger: NullLogger{},
}
if maxWorkers > 0 {
conn.workerSlots = make(chan bool, maxWorkers)
}
return conn
}

func (c *Connection) startWorker(cb func()) {
if c.workerSlots == nil {
go cb()
return
}
c.workerSlots <- true
go func() {
defer func() { <-c.workerSlots }()
cb()
}()
}

func (c *Connection) SetLogger(l Logger) {
c.loggerMutex.Lock()
c.logger = l
Expand Down Expand Up @@ -215,7 +238,7 @@ func (c *Connection) handleIncomingRequest(id MessageID, method string, params [
logger := c.logger.LogIncomingRequest(id, method, params)
c.loggerMutex.Unlock()

go func() {
c.startWorker(func() {
reqResult, reqError := c.requestHandler(ctx, logger, method, params)

var existing *inRequest
Expand All @@ -238,7 +261,7 @@ func (c *Connection) handleIncomingRequest(id MessageID, method string, params [
c.errorHandler(fmt.Errorf("error sending response: %w", err))
c.Close()
}
}()
})
}

func (c *Connection) handleIncomingNotification(method string, params []any) {
Expand All @@ -261,7 +284,9 @@ func (c *Connection) handleIncomingNotification(method string, params []any) {
logger := c.logger.LogIncomingNotification(method, params)
c.loggerMutex.Unlock()

go c.notificationHandler(logger, method, params)
c.startWorker(func() {
c.notificationHandler(logger, method, params)
})
}

func (c *Connection) handleIncomingResponse(id MessageID, reqError any, reqResult any) {
Expand Down