Skip to content

Commit d67f01e

Browse files
authored
Udp network API (part 2) (#14)
* Split udp/write operation into a 3-step beginPacket/write/endPacket * Renamed udp/awaitRead -> udp/awaitPacket * Added udp/dropPacket * Enforce awaitPacket behaviour on unit tests * Updated tests * Fix typo
1 parent 3a4b461 commit d67f01e

File tree

2 files changed

+153
-31
lines changed

2 files changed

+153
-31
lines changed

network-api/network-api.go

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ func Register(router *msgpackrouter.Router) {
4747
_ = router.RegisterMethod("tcp/connectSSL", tcpConnectSSL)
4848

4949
_ = router.RegisterMethod("udp/connect", udpConnect)
50+
_ = router.RegisterMethod("udp/beginPacket", udpBeginPacket)
5051
_ = router.RegisterMethod("udp/write", udpWrite)
51-
_ = router.RegisterMethod("udp/awaitRead", udpAwaitRead)
52+
_ = router.RegisterMethod("udp/endPacket", udpEndPacket)
53+
_ = router.RegisterMethod("udp/awaitPacket", udpAwaitPacket)
5254
_ = router.RegisterMethod("udp/read", udpRead)
55+
_ = router.RegisterMethod("udp/dropPacket", udpDropPacket)
5356
_ = router.RegisterMethod("udp/close", udpClose)
5457
}
5558

@@ -58,6 +61,8 @@ var liveConnections = make(map[uint]net.Conn)
5861
var liveListeners = make(map[uint]net.Listener)
5962
var liveUdpConnections = make(map[uint]net.PacketConn)
6063
var udpReadBuffers = make(map[uint][]byte)
64+
var udpWriteTargets = make(map[uint]*net.UDPAddr)
65+
var udpWriteBuffers = make(map[uint][]byte)
6166
var nextConnectionID atomic.Uint32
6267

6368
// takeLockAndGenerateNextID generates a new unique ID for a connection or listener.
@@ -375,9 +380,9 @@ func udpConnect(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (
375380
return id, nil
376381
}
377382

378-
func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
379-
if len(params) != 4 {
380-
return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port, payload"}
383+
func udpBeginPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
384+
if len(params) != 3 {
385+
return nil, []any{1, "Invalid number of parameters, expected udpConnId, dest address, dest port"}
381386
}
382387
id, ok := msgpackrpc.ToUint(params[0])
383388
if !ok {
@@ -391,9 +396,33 @@ func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
391396
if !ok {
392397
return nil, []any{1, "Invalid parameter type, expected uint16 for server port"}
393398
}
394-
data, ok := params[3].([]byte)
399+
400+
lock.RLock()
401+
defer lock.RUnlock()
402+
if _, ok := liveUdpConnections[id]; !ok {
403+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
404+
}
405+
targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
406+
addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching
407+
if err != nil {
408+
return nil, []any{3, "Failed to resolve target address: " + err.Error()}
409+
}
410+
udpWriteTargets[id] = addr
411+
udpWriteBuffers[id] = nil
412+
return true, nil
413+
}
414+
415+
func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
416+
if len(params) != 2 {
417+
return nil, []any{1, "Invalid number of parameters, expected udpConnId, payload"}
418+
}
419+
id, ok := msgpackrpc.ToUint(params[0])
395420
if !ok {
396-
if dataStr, ok := params[3].(string); ok {
421+
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
422+
}
423+
data, ok := params[1].([]byte)
424+
if !ok {
425+
if dataStr, ok := params[1].(string); ok {
397426
data = []byte(dataStr)
398427
} else {
399428
// If data is not []byte or string, return an error
@@ -402,25 +431,52 @@ func udpWrite(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_r
402431
}
403432

404433
lock.RLock()
405-
udpConn, ok := liveUdpConnections[id]
434+
udpBuffer, ok := udpWriteBuffers[id]
435+
if ok {
436+
udpWriteBuffers[id] = append(udpBuffer, data...)
437+
}
406438
lock.RUnlock()
407439
if !ok {
408440
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
409441
}
442+
return len(data), nil
443+
}
410444

411-
targetAddr := net.JoinHostPort(targetIP, fmt.Sprintf("%d", targetPort))
412-
addr, err := net.ResolveUDPAddr("udp", targetAddr) // TODO: This is inefficient, implement some caching
413-
if err != nil {
414-
return nil, []any{3, "Failed to resolve target address: " + err.Error()}
445+
func udpEndPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
446+
if len(params) != 1 {
447+
return nil, []any{1, "Invalid number of parameters, expected expected udpConnId"}
448+
}
449+
id, buffExists := msgpackrpc.ToUint(params[0])
450+
if !buffExists {
451+
return nil, []any{1, "Invalid parameter type, expected int for UDP connection ID"}
452+
}
453+
454+
var udpBuffer []byte
455+
var udpAddr *net.UDPAddr
456+
lock.RLock()
457+
udpConn, connExists := liveUdpConnections[id]
458+
if connExists {
459+
udpBuffer, buffExists = udpWriteBuffers[id]
460+
udpAddr = udpWriteTargets[id]
461+
delete(udpWriteBuffers, id)
462+
delete(udpWriteTargets, id)
463+
}
464+
lock.RUnlock()
465+
if !connExists {
466+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
415467
}
416-
if n, err := udpConn.WriteTo(data, addr); err != nil {
468+
if !buffExists {
469+
return nil, []any{3, fmt.Sprintf("No UDP packet begun for ID: %d", id)}
470+
}
471+
472+
if n, err := udpConn.WriteTo(udpBuffer, udpAddr); err != nil {
417473
return nil, []any{4, "Failed to write to UDP connection: " + err.Error()}
418474
} else {
419475
return n, nil
420476
}
421477
}
422478

423-
func udpAwaitRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
479+
func udpAwaitPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
424480
if len(params) != 1 && len(params) != 2 {
425481
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"}
426482
}
@@ -472,6 +528,24 @@ func udpAwaitRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any)
472528
return []any{n, host, port}, nil
473529
}
474530

531+
func udpDropPacket(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
532+
if len(params) != 1 && len(params) != 2 {
533+
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID[, optional timeout in ms])"}
534+
}
535+
id, ok := msgpackrpc.ToUint(params[0])
536+
if !ok {
537+
return nil, []any{1, "Invalid parameter type, expected uint for UDP connection ID"}
538+
}
539+
540+
lock.RLock()
541+
delete(udpReadBuffers, id)
542+
lock.RUnlock()
543+
if !ok {
544+
return nil, []any{2, fmt.Sprintf("UDP connection not found for ID: %d", id)}
545+
}
546+
return true, nil
547+
}
548+
475549
func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_result any, _err any) {
476550
if len(params) != 2 && len(params) != 3 {
477551
return nil, []any{1, "Invalid number of parameters, expected (UDP connection ID, max bytes to read)"}
@@ -494,7 +568,7 @@ func udpRead(ctx context.Context, rpc *msgpackrpc.Connection, params []any) (_re
494568
udpReadBuffers[id] = buffer[maxBytes:]
495569
n = maxBytes
496570
} else {
497-
udpReadBuffers[id] = nil
571+
delete(udpReadBuffers, id)
498572
}
499573
}
500574
lock.Unlock()

network-api/network-api_test.go

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,18 @@ func TestUDPNetworkAPI(t *testing.T) {
248248
require.NotEqual(t, conn1, conn2)
249249

250250
{
251-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Hello")})
251+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900})
252+
require.Nil(t, err)
253+
require.True(t, res.(bool))
254+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Hello")})
255+
require.Nil(t, err)
256+
require.Equal(t, 5, res)
257+
res, err = udpEndPacket(ctx, nil, []any{conn1})
252258
require.Nil(t, err)
253259
require.Equal(t, 5, res)
254260
}
255261
{
256-
res, err := udpAwaitRead(ctx, nil, []any{conn2})
262+
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
257263
require.Nil(t, err)
258264
require.Equal(t, []any{5, "127.0.0.1", 9800}, res)
259265

@@ -262,26 +268,44 @@ func TestUDPNetworkAPI(t *testing.T) {
262268
require.Equal(t, []uint8("Hello"), res2)
263269
}
264270
{
265-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("One")})
271+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900})
272+
require.Nil(t, err)
273+
require.True(t, res.(bool))
274+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("On")})
275+
require.Nil(t, err)
276+
require.Equal(t, 2, res)
277+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("e")})
278+
require.Nil(t, err)
279+
require.Equal(t, 1, res)
280+
res, err = udpEndPacket(ctx, nil, []any{conn1})
266281
require.Nil(t, err)
267282
require.Equal(t, 3, res)
268283
}
269284
{
270-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9900, []byte("Two")})
285+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9900})
286+
require.Nil(t, err)
287+
require.True(t, res.(bool))
288+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Two")})
289+
require.Nil(t, err)
290+
require.Equal(t, 3, res)
291+
res, err = udpEndPacket(ctx, nil, []any{conn1})
271292
require.Nil(t, err)
272293
require.Equal(t, 3, res)
273294
}
274295
{
275-
res, err := udpAwaitRead(ctx, nil, []any{conn2})
296+
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
276297
require.Nil(t, err)
277298
require.Equal(t, []any{3, "127.0.0.1", 9800}, res)
278299

279-
res2, err := udpRead(ctx, nil, []any{conn2, 100})
300+
// A partial read of a packet is allowed
301+
res2, err := udpRead(ctx, nil, []any{conn2, 2})
280302
require.Nil(t, err)
281-
require.Equal(t, []uint8("One"), res2)
303+
require.Equal(t, []uint8("On"), res2)
282304
}
283305
{
284-
res, err := udpAwaitRead(ctx, nil, []any{conn2})
306+
// Even if the previous packet was only partially read,
307+
// the next packet can be received
308+
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
285309
require.Nil(t, err)
286310
require.Equal(t, []any{3, "127.0.0.1", 9800}, res)
287311

@@ -311,12 +335,18 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
311335
require.NotEqual(t, conn1, conn2)
312336

313337
{
314-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9901, []byte("Hello")})
338+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901})
339+
require.Nil(t, err)
340+
require.True(t, res.(bool))
341+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Hello")})
342+
require.Nil(t, err)
343+
require.Equal(t, 5, res)
344+
res, err = udpEndPacket(ctx, nil, []any{conn1})
315345
require.Nil(t, err)
316346
require.Equal(t, 5, res)
317347
}
318348
{
319-
res, err := udpAwaitRead(ctx, nil, []any{conn2})
349+
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
320350
require.Nil(t, err)
321351
require.Equal(t, 5, res.([]any)[0])
322352

@@ -329,17 +359,29 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
329359
require.Equal(t, []uint8("llo"), res2)
330360
}
331361
{
332-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9901, []byte("One")})
362+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901})
363+
require.Nil(t, err)
364+
require.True(t, res.(bool))
365+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("One")})
366+
require.Nil(t, err)
367+
require.Equal(t, 3, res)
368+
res, err = udpEndPacket(ctx, nil, []any{conn1})
333369
require.Nil(t, err)
334370
require.Equal(t, 3, res)
335371
}
336372
{
337-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9901, []byte("Two")})
373+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901})
374+
require.Nil(t, err)
375+
require.True(t, res.(bool))
376+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Two")})
377+
require.Nil(t, err)
378+
require.Equal(t, 3, res)
379+
res, err = udpEndPacket(ctx, nil, []any{conn1})
338380
require.Nil(t, err)
339381
require.Equal(t, 3, res)
340382
}
341383
{
342-
res, err := udpAwaitRead(ctx, nil, []any{conn2})
384+
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
343385
require.Nil(t, err)
344386
require.Equal(t, 3, res.([]any)[0])
345387

@@ -348,7 +390,7 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
348390
require.Equal(t, []uint8("One"), res2)
349391
}
350392
{
351-
res, err := udpAwaitRead(ctx, nil, []any{conn2})
393+
res, err := udpAwaitPacket(ctx, nil, []any{conn2})
352394
require.Nil(t, err)
353395
require.Equal(t, 3, res.([]any)[0])
354396

@@ -360,19 +402,25 @@ func TestUDPNetworkUnboundClientAPI(t *testing.T) {
360402
// Check timeouts
361403
go func() {
362404
time.Sleep(200 * time.Millisecond)
363-
res, err := udpWrite(ctx, nil, []any{conn1, "127.0.0.1", 9901, []byte("Three")})
405+
res, err := udpBeginPacket(ctx, nil, []any{conn1, "127.0.0.1", 9901})
406+
require.Nil(t, err)
407+
require.True(t, res.(bool))
408+
res, err = udpWrite(ctx, nil, []any{conn1, []byte("Three")})
409+
require.Nil(t, err)
410+
require.Equal(t, 5, res)
411+
res, err = udpEndPacket(ctx, nil, []any{conn1})
364412
require.Nil(t, err)
365413
require.Equal(t, 5, res)
366414
}()
367415
{
368416
start := time.Now()
369-
res, err := udpAwaitRead(ctx, nil, []any{conn2, 10})
417+
res, err := udpAwaitPacket(ctx, nil, []any{conn2, 10})
370418
require.Less(t, time.Since(start), 20*time.Millisecond)
371419
require.Equal(t, []any{5, "Timeout"}, err)
372420
require.Nil(t, res)
373421
}
374422
{
375-
res, err := udpAwaitRead(ctx, nil, []any{conn2, 0})
423+
res, err := udpAwaitPacket(ctx, nil, []any{conn2, 0})
376424
require.Nil(t, err)
377425
require.Equal(t, 5, res.([]any)[0])
378426

0 commit comments

Comments
 (0)