@@ -2,7 +2,6 @@ package main
2
2
3
3
import (
4
4
"bufio"
5
- "bytes"
6
5
"context"
7
6
"encoding/json"
8
7
"errors"
@@ -89,20 +88,18 @@ func (r *Route) Validate(full bool) error {
89
88
}
90
89
91
90
type Proxy struct {
92
- tlsTimeout time.Duration
93
91
logger * Logger
94
92
sniRoutes map [string ]string
95
93
router * Router
96
- proxy * httputil.ReverseProxy
97
94
routesLock sync.RWMutex
95
+ proxy * httputil.ReverseProxy
98
96
}
99
97
100
- func NewProxy (tlsTimeout time. Duration , logLevel LogLevel ) * Proxy {
98
+ func NewProxy (logLevel LogLevel ) * Proxy {
101
99
out := Proxy {
102
- tlsTimeout : tlsTimeout ,
103
- logger : NewLogger ("Proxy" , logLevel ),
104
- sniRoutes : make (map [string ]string ),
105
- router : NewRouter (),
100
+ logger : NewLogger ("Proxy" , logLevel ),
101
+ sniRoutes : make (map [string ]string ),
102
+ router : NewRouter (),
106
103
}
107
104
out .proxy = & httputil.ReverseProxy {
108
105
Director : out .director ,
@@ -122,12 +119,38 @@ func awaitShutdown() {
122
119
}
123
120
}
124
121
125
- func (p * Proxy ) run (httpAddress , tlsAddress , apiURL , apiToken , tlsCert , tlsKey string , isChildProcess bool ) {
122
+ func (p * Proxy ) run (address , tcpAddress , apiURL , apiToken , tlsCert , tlsKey string , isChildProcess bool ) {
126
123
if isChildProcess {
127
124
go awaitShutdown ()
128
125
}
129
- go p .runHTTP (httpAddress , tlsCert , tlsKey )
130
- go p .runTLS (tlsAddress )
126
+
127
+ if tcpAddress == "" {
128
+ tcpAddress = address
129
+ }
130
+
131
+ tlsListener , err := net .Listen ("tcp" , tcpAddress )
132
+ if err != nil {
133
+ p .logger .Errorf ("%s" , err )
134
+ os .Exit (1 )
135
+ }
136
+ p .logger .Infof ("TCP Proxy serving at %s" , tcpAddress )
137
+
138
+ var httpListener net.Listener
139
+ var forwarder * Forwarder = nil
140
+ if address == tcpAddress {
141
+ forwarder = newForwarder (tlsListener )
142
+ httpListener = forwarder
143
+ } else {
144
+ httpListener , err = net .Listen ("tcp" , address )
145
+ if err != nil {
146
+ p .logger .Errorf ("%s" , err )
147
+ os .Exit (1 )
148
+ }
149
+ }
150
+ p .logger .Infof ("HTTP Proxy serving at %s" , address )
151
+
152
+ go p .runHTTP (httpListener , tlsCert , tlsKey )
153
+ go p .runTCP (tlsListener , forwarder )
131
154
p .watchRoutes (apiURL , apiToken , 15 * time .Second )
132
155
}
133
156
@@ -408,110 +431,97 @@ func (p *Proxy) director(req *http.Request) {
408
431
}
409
432
}
410
433
411
- func (p * Proxy ) runHTTP (address , tlsCert , tlsKey string ) {
434
+ func (p * Proxy ) runHTTP (ln net. Listener , tlsCert , tlsKey string ) {
412
435
server := & http.Server {
413
- Addr : address ,
414
436
Handler : p ,
415
437
}
416
438
var err error
417
439
if tlsCert == "" {
418
- p .logger .Infof ("HTTP Proxy serving at http://%s" , address )
419
- err = server .ListenAndServe ()
440
+ err = server .Serve (ln )
420
441
} else {
421
- p .logger .Infof ("HTTP Proxy serving at https://%s" , address )
422
- err = server .ListenAndServeTLS (tlsCert , tlsKey )
442
+ err = server .ServeTLS (ln , tlsCert , tlsKey )
423
443
}
424
444
if err != nil && err != http .ErrServerClosed {
425
445
p .logger .Errorf ("%s" , err )
426
446
os .Exit (1 )
427
447
}
428
448
}
429
449
430
- // TLS proxy implementation
450
+ // TCP proxy implementation
431
451
432
- type tlsConn struct {
433
- inConn * net.TCPConn
434
- outConn * net.TCPConn
435
- sni string
436
- outAddr string
437
- tlsMinor int
452
+ type Forwarder struct {
453
+ net.Listener
454
+ conns chan net.Conn
438
455
}
439
456
440
- type tlsAlert int8
441
-
442
- const (
443
- internalError tlsAlert = 80
444
- unrecognizedName = 112
445
- )
446
-
447
- func (p * Proxy ) sendAlert (c * tlsConn , alert tlsAlert , format string , args ... interface {}) {
448
- p .logger .Debugf (format , args ... )
449
-
450
- alertMsg := []byte {21 , 3 , byte (c .tlsMinor ), 0 , 2 , 2 , byte (alert )}
451
-
452
- if err := c .inConn .SetWriteDeadline (time .Now ().Add (p .tlsTimeout )); err != nil {
453
- p .logger .Debugf ("Error while setting write deadline during abort: %s" , err )
454
- return
457
+ func newForwarder (ln net.Listener ) * Forwarder {
458
+ return & Forwarder {
459
+ Listener : ln ,
460
+ conns : make (chan net.Conn ),
455
461
}
462
+ }
456
463
457
- if _ , err := c .inConn .Write (alertMsg ); err != nil {
458
- p .logger .Debugf ("Error while sending alert: %s" , err )
459
- }
464
+ func (h * Forwarder ) Forward (conn net.Conn ) {
465
+ h .conns <- conn
460
466
}
461
467
462
- func (p * Proxy ) handleConnection (c * tlsConn ) {
463
- defer c .inConn .Close ()
468
+ func (h * Forwarder ) Accept () (net.Conn , error ) {
469
+ conn := <- h .conns
470
+ return conn , nil
471
+ }
464
472
473
+ func (p * Proxy ) handleConnection (inConn * net.TCPConn , forwarder * Forwarder ) {
465
474
var err error
466
475
467
- if err = c .inConn .SetReadDeadline (time .Now ().Add (p .tlsTimeout )); err != nil {
468
- p .sendAlert (c , internalError , "Setting read deadline for handshake: %s" , err )
469
- return
470
- }
471
-
472
- var handshakeBuf bytes.Buffer
473
- c .sni , c .tlsMinor , err = readVerAndSNI (io .TeeReader (c .inConn , & handshakeBuf ))
476
+ sni , isTLS , pInConn , err := readSNI (inConn )
474
477
if err != nil {
475
- p .sendAlert (c , internalError , "Extracting SNI: %s" , err )
478
+ p .logger .Debugf ("Error extracting SNI: %s" , err )
479
+ inConn .Close ()
476
480
return
477
481
}
478
482
479
- if err = c .inConn .SetReadDeadline (time.Time {}); err != nil {
480
- p .sendAlert (c , internalError , "Clearing read deadline for handshake: %s" , err )
481
- return
483
+ const sniPrefix = "daskgateway-"
484
+ if ! isTLS || ! strings .HasPrefix (sni , sniPrefix ) {
485
+ if forwarder != nil {
486
+ forwarder .Forward (pInConn )
487
+ return
488
+ } else {
489
+ p .logger .Debug ("Invalid connection attempt, closing" )
490
+ inConn .Close ()
491
+ return
492
+ }
482
493
}
483
494
495
+ sni = sni [len (sniPrefix ):]
496
+
497
+ defer inConn .Close ()
498
+
484
499
p .routesLock .RLock ()
485
- c . outAddr = p .sniRoutes [c . sni ]
500
+ outAddr : = p .sniRoutes [sni ]
486
501
p .routesLock .RUnlock ()
487
- if c .outAddr == "" {
488
- p .sendAlert (c , unrecognizedName , "SNI %q not found" , c .sni )
502
+
503
+ if outAddr == "" {
504
+ p .logger .Infof ("SNI %q not found" , sni )
489
505
return
490
506
}
491
507
492
- p .logger .Infof ("SNI %q -> %q" , c . sni , c . outAddr )
508
+ p .logger .Infof ("SNI %q -> %q" , sni , outAddr )
493
509
494
- outConn , err := net .DialTimeout ("tcp" , c . outAddr , 10 * time .Second )
510
+ outConn , err := net .DialTimeout ("tcp" , outAddr , 10 * time .Second )
495
511
if err != nil {
496
- p .sendAlert (c , internalError , "Failed to connect to destination %q: %s" , c .outAddr , err )
497
- return
498
- }
499
- c .outConn = outConn .(* net.TCPConn )
500
- defer c .outConn .Close ()
501
-
502
- if _ , err = io .Copy (c .outConn , & handshakeBuf ); err != nil {
503
- p .sendAlert (c , internalError , "Failed to replay handshake to %q: %s" , c .outAddr , err )
512
+ p .logger .Debugf ("Failed to connect to destination %q: %s" , outAddr , err )
504
513
return
505
514
}
515
+ defer outConn .Close ()
506
516
507
517
var wg sync.WaitGroup
508
518
wg .Add (2 )
509
- go p .proxyConnections (& wg , c . inConn , c . outConn )
510
- go p .proxyConnections (& wg , c . outConn , c . inConn )
519
+ go p .proxyConnections (& wg , pInConn , outConn .( * net. TCPConn ) )
520
+ go p .proxyConnections (& wg , outConn .( * net. TCPConn ), pInConn )
511
521
wg .Wait ()
512
522
}
513
523
514
- func (p * Proxy ) proxyConnections (wg * sync.WaitGroup , in , out * net. TCPConn ) {
524
+ func (p * Proxy ) proxyConnections (wg * sync.WaitGroup , in , out tcpConn ) {
515
525
defer wg .Done ()
516
526
if _ , err := io .Copy (in , out ); err != nil {
517
527
p .logger .Debugf ("Error proxying %q -> %q: %s" , in .RemoteAddr (), out .RemoteAddr (), err )
@@ -520,47 +530,38 @@ func (p *Proxy) proxyConnections(wg *sync.WaitGroup, in, out *net.TCPConn) {
520
530
out .CloseWrite ()
521
531
}
522
532
523
- func (p * Proxy ) runTLS (tlsAddress string ) {
524
- p .logger .Infof ("TLS Proxy serving at %s" , tlsAddress )
525
- l , err := net .Listen ("tcp" , tlsAddress )
526
- if err != nil {
527
- p .logger .Errorf ("%s" , err )
528
- os .Exit (1 )
529
- }
533
+ func (p * Proxy ) runTCP (ln net.Listener , forwarder * Forwarder ) {
530
534
for {
531
- c , err := l .Accept ()
535
+ c , err := ln .Accept ()
532
536
if err != nil {
533
537
p .logger .Debugf ("Failed to accept new connection: %s" , err )
534
538
}
535
539
536
- conn := & tlsConn {inConn : c .(* net.TCPConn )}
537
- go p .handleConnection (conn )
540
+ go p .handleConnection (c .(* net.TCPConn ), forwarder )
538
541
}
539
542
}
540
543
541
544
func main () {
542
545
var (
543
- httpAddress string
544
- tlsAddress string
546
+ address string
547
+ tcpAddress string
545
548
apiURL string
546
549
logLevelString string
547
550
isChildProcess bool
548
551
tlsCert string
549
552
tlsKey string
550
- tlsTimeout time.Duration
551
553
)
552
554
553
555
command := flag .NewFlagSet ("dask-gateway-proxy" , flag .ExitOnError )
554
- command .StringVar (& httpAddress , "address" , ":8000" , "HTTP proxy listening address" )
555
- command .StringVar (& tlsAddress , "tls -address" , ":8786 " , "TLS proxy listening address" )
556
+ command .StringVar (& address , "address" , ":8000" , "HTTP proxy listening address" )
557
+ command .StringVar (& tcpAddress , "tcp -address" , "" , "TCP proxy listening address. If empty, `address` will be used. " )
556
558
command .StringVar (& apiURL , "api-url" , "" , "The URL the proxy should watch for updating the routing table" )
557
559
command .StringVar (& logLevelString , "log-level" , "INFO" ,
558
560
"The log level. One of {DEBUG, INFO, WARN, ERROR}" )
559
561
command .BoolVar (& isChildProcess , "is-child-process" , false ,
560
562
"If set, will exit when stdin EOFs. Useful when running as a child process." )
561
563
command .StringVar (& tlsCert , "tls-cert" , "" , "TLS cert to use, if any." )
562
564
command .StringVar (& tlsKey , "tls-key" , "" , "TLS key to use, if any." )
563
- command .DurationVar (& tlsTimeout , "tls-timeout" , 3 * time .Second , "Timeout for TLS Handshake initialization" )
564
565
565
566
command .Parse (os .Args [1 :])
566
567
@@ -582,7 +583,7 @@ func main() {
582
583
panic ("Invalid -api-url: " + err .Error ())
583
584
}
584
585
585
- proxy := NewProxy (tlsTimeout , logLevel )
586
+ proxy := NewProxy (logLevel )
586
587
587
- proxy .run (httpAddress , tlsAddress , apiURL , token , tlsCert , tlsKey , isChildProcess )
588
+ proxy .run (address , tcpAddress , apiURL , token , tlsCert , tlsKey , isChildProcess )
588
589
}
0 commit comments