Skip to content

Commit 66aecfb

Browse files
authoredFeb 14, 2020
Support tcp and http proxy on same port (dask#203)
* Support tcp and http proxy on same port Adds support for running both the TCP and HTTP proxies on the same port. This relies on the SNI only being used for requests made by dask itself. If the dask-gateway-server is running behind another proxy that makes use of SNIs, then this method won't work and the proxies will have to use separate ports (which is still supported). Running all on the same point is nice from an administrative end (only need to expose one port, not two), as well as a configuration end (users only need to remember the main address, not the proxy address). * Filter valid SNIs Previously we'd accept any SNI, we now filter to only proxy SNIs that start with `daskgateway-`. This helps prevent accidentally mixing HTTPS and dask comms, since browsers will send an SNI on connect if connecting via HTTPS to a domain with a non-ip hostname.
1 parent 1255d43 commit 66aecfb

File tree

11 files changed

+181
-251
lines changed

11 files changed

+181
-251
lines changed
 

‎dask-gateway-server/dask-gateway-proxy/main.go

+90-89
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package main
22

33
import (
44
"bufio"
5-
"bytes"
65
"context"
76
"encoding/json"
87
"errors"
@@ -89,20 +88,18 @@ func (r *Route) Validate(full bool) error {
8988
}
9089

9190
type Proxy struct {
92-
tlsTimeout time.Duration
9391
logger *Logger
9492
sniRoutes map[string]string
9593
router *Router
96-
proxy *httputil.ReverseProxy
9794
routesLock sync.RWMutex
95+
proxy *httputil.ReverseProxy
9896
}
9997

100-
func NewProxy(tlsTimeout time.Duration, logLevel LogLevel) *Proxy {
98+
func NewProxy(logLevel LogLevel) *Proxy {
10199
out := Proxy{
102-
tlsTimeout: tlsTimeout,
103-
logger: NewLogger("Proxy", logLevel),
104-
sniRoutes: make(map[string]string),
105-
router: NewRouter(),
100+
logger: NewLogger("Proxy", logLevel),
101+
sniRoutes: make(map[string]string),
102+
router: NewRouter(),
106103
}
107104
out.proxy = &httputil.ReverseProxy{
108105
Director: out.director,
@@ -122,12 +119,38 @@ func awaitShutdown() {
122119
}
123120
}
124121

125-
func (p *Proxy) run(httpAddress, tlsAddress, apiURL, apiToken, tlsCert, tlsKey string, isChildProcess bool) {
122+
func (p *Proxy) run(address, tcpAddress, apiURL, apiToken, tlsCert, tlsKey string, isChildProcess bool) {
126123
if isChildProcess {
127124
go awaitShutdown()
128125
}
129-
go p.runHTTP(httpAddress, tlsCert, tlsKey)
130-
go p.runTLS(tlsAddress)
126+
127+
if tcpAddress == "" {
128+
tcpAddress = address
129+
}
130+
131+
tlsListener, err := net.Listen("tcp", tcpAddress)
132+
if err != nil {
133+
p.logger.Errorf("%s", err)
134+
os.Exit(1)
135+
}
136+
p.logger.Infof("TCP Proxy serving at %s", tcpAddress)
137+
138+
var httpListener net.Listener
139+
var forwarder *Forwarder = nil
140+
if address == tcpAddress {
141+
forwarder = newForwarder(tlsListener)
142+
httpListener = forwarder
143+
} else {
144+
httpListener, err = net.Listen("tcp", address)
145+
if err != nil {
146+
p.logger.Errorf("%s", err)
147+
os.Exit(1)
148+
}
149+
}
150+
p.logger.Infof("HTTP Proxy serving at %s", address)
151+
152+
go p.runHTTP(httpListener, tlsCert, tlsKey)
153+
go p.runTCP(tlsListener, forwarder)
131154
p.watchRoutes(apiURL, apiToken, 15*time.Second)
132155
}
133156

@@ -408,110 +431,97 @@ func (p *Proxy) director(req *http.Request) {
408431
}
409432
}
410433

411-
func (p *Proxy) runHTTP(address, tlsCert, tlsKey string) {
434+
func (p *Proxy) runHTTP(ln net.Listener, tlsCert, tlsKey string) {
412435
server := &http.Server{
413-
Addr: address,
414436
Handler: p,
415437
}
416438
var err error
417439
if tlsCert == "" {
418-
p.logger.Infof("HTTP Proxy serving at http://%s", address)
419-
err = server.ListenAndServe()
440+
err = server.Serve(ln)
420441
} else {
421-
p.logger.Infof("HTTP Proxy serving at https://%s", address)
422-
err = server.ListenAndServeTLS(tlsCert, tlsKey)
442+
err = server.ServeTLS(ln, tlsCert, tlsKey)
423443
}
424444
if err != nil && err != http.ErrServerClosed {
425445
p.logger.Errorf("%s", err)
426446
os.Exit(1)
427447
}
428448
}
429449

430-
// TLS proxy implementation
450+
// TCP proxy implementation
431451

432-
type tlsConn struct {
433-
inConn *net.TCPConn
434-
outConn *net.TCPConn
435-
sni string
436-
outAddr string
437-
tlsMinor int
452+
type Forwarder struct {
453+
net.Listener
454+
conns chan net.Conn
438455
}
439456

440-
type tlsAlert int8
441-
442-
const (
443-
internalError tlsAlert = 80
444-
unrecognizedName = 112
445-
)
446-
447-
func (p *Proxy) sendAlert(c *tlsConn, alert tlsAlert, format string, args ...interface{}) {
448-
p.logger.Debugf(format, args...)
449-
450-
alertMsg := []byte{21, 3, byte(c.tlsMinor), 0, 2, 2, byte(alert)}
451-
452-
if err := c.inConn.SetWriteDeadline(time.Now().Add(p.tlsTimeout)); err != nil {
453-
p.logger.Debugf("Error while setting write deadline during abort: %s", err)
454-
return
457+
func newForwarder(ln net.Listener) *Forwarder {
458+
return &Forwarder{
459+
Listener: ln,
460+
conns: make(chan net.Conn),
455461
}
462+
}
456463

457-
if _, err := c.inConn.Write(alertMsg); err != nil {
458-
p.logger.Debugf("Error while sending alert: %s", err)
459-
}
464+
func (h *Forwarder) Forward(conn net.Conn) {
465+
h.conns <- conn
460466
}
461467

462-
func (p *Proxy) handleConnection(c *tlsConn) {
463-
defer c.inConn.Close()
468+
func (h *Forwarder) Accept() (net.Conn, error) {
469+
conn := <-h.conns
470+
return conn, nil
471+
}
464472

473+
func (p *Proxy) handleConnection(inConn *net.TCPConn, forwarder *Forwarder) {
465474
var err error
466475

467-
if err = c.inConn.SetReadDeadline(time.Now().Add(p.tlsTimeout)); err != nil {
468-
p.sendAlert(c, internalError, "Setting read deadline for handshake: %s", err)
469-
return
470-
}
471-
472-
var handshakeBuf bytes.Buffer
473-
c.sni, c.tlsMinor, err = readVerAndSNI(io.TeeReader(c.inConn, &handshakeBuf))
476+
sni, isTLS, pInConn, err := readSNI(inConn)
474477
if err != nil {
475-
p.sendAlert(c, internalError, "Extracting SNI: %s", err)
478+
p.logger.Debugf("Error extracting SNI: %s", err)
479+
inConn.Close()
476480
return
477481
}
478482

479-
if err = c.inConn.SetReadDeadline(time.Time{}); err != nil {
480-
p.sendAlert(c, internalError, "Clearing read deadline for handshake: %s", err)
481-
return
483+
const sniPrefix = "daskgateway-"
484+
if !isTLS || !strings.HasPrefix(sni, sniPrefix) {
485+
if forwarder != nil {
486+
forwarder.Forward(pInConn)
487+
return
488+
} else {
489+
p.logger.Debug("Invalid connection attempt, closing")
490+
inConn.Close()
491+
return
492+
}
482493
}
483494

495+
sni = sni[len(sniPrefix):]
496+
497+
defer inConn.Close()
498+
484499
p.routesLock.RLock()
485-
c.outAddr = p.sniRoutes[c.sni]
500+
outAddr := p.sniRoutes[sni]
486501
p.routesLock.RUnlock()
487-
if c.outAddr == "" {
488-
p.sendAlert(c, unrecognizedName, "SNI %q not found", c.sni)
502+
503+
if outAddr == "" {
504+
p.logger.Infof("SNI %q not found", sni)
489505
return
490506
}
491507

492-
p.logger.Infof("SNI %q -> %q", c.sni, c.outAddr)
508+
p.logger.Infof("SNI %q -> %q", sni, outAddr)
493509

494-
outConn, err := net.DialTimeout("tcp", c.outAddr, 10*time.Second)
510+
outConn, err := net.DialTimeout("tcp", outAddr, 10*time.Second)
495511
if err != nil {
496-
p.sendAlert(c, internalError, "Failed to connect to destination %q: %s", c.outAddr, err)
497-
return
498-
}
499-
c.outConn = outConn.(*net.TCPConn)
500-
defer c.outConn.Close()
501-
502-
if _, err = io.Copy(c.outConn, &handshakeBuf); err != nil {
503-
p.sendAlert(c, internalError, "Failed to replay handshake to %q: %s", c.outAddr, err)
512+
p.logger.Debugf("Failed to connect to destination %q: %s", outAddr, err)
504513
return
505514
}
515+
defer outConn.Close()
506516

507517
var wg sync.WaitGroup
508518
wg.Add(2)
509-
go p.proxyConnections(&wg, c.inConn, c.outConn)
510-
go p.proxyConnections(&wg, c.outConn, c.inConn)
519+
go p.proxyConnections(&wg, pInConn, outConn.(*net.TCPConn))
520+
go p.proxyConnections(&wg, outConn.(*net.TCPConn), pInConn)
511521
wg.Wait()
512522
}
513523

514-
func (p *Proxy) proxyConnections(wg *sync.WaitGroup, in, out *net.TCPConn) {
524+
func (p *Proxy) proxyConnections(wg *sync.WaitGroup, in, out tcpConn) {
515525
defer wg.Done()
516526
if _, err := io.Copy(in, out); err != nil {
517527
p.logger.Debugf("Error proxying %q -> %q: %s", in.RemoteAddr(), out.RemoteAddr(), err)
@@ -520,47 +530,38 @@ func (p *Proxy) proxyConnections(wg *sync.WaitGroup, in, out *net.TCPConn) {
520530
out.CloseWrite()
521531
}
522532

523-
func (p *Proxy) runTLS(tlsAddress string) {
524-
p.logger.Infof("TLS Proxy serving at %s", tlsAddress)
525-
l, err := net.Listen("tcp", tlsAddress)
526-
if err != nil {
527-
p.logger.Errorf("%s", err)
528-
os.Exit(1)
529-
}
533+
func (p *Proxy) runTCP(ln net.Listener, forwarder *Forwarder) {
530534
for {
531-
c, err := l.Accept()
535+
c, err := ln.Accept()
532536
if err != nil {
533537
p.logger.Debugf("Failed to accept new connection: %s", err)
534538
}
535539

536-
conn := &tlsConn{inConn: c.(*net.TCPConn)}
537-
go p.handleConnection(conn)
540+
go p.handleConnection(c.(*net.TCPConn), forwarder)
538541
}
539542
}
540543

541544
func main() {
542545
var (
543-
httpAddress string
544-
tlsAddress string
546+
address string
547+
tcpAddress string
545548
apiURL string
546549
logLevelString string
547550
isChildProcess bool
548551
tlsCert string
549552
tlsKey string
550-
tlsTimeout time.Duration
551553
)
552554

553555
command := flag.NewFlagSet("dask-gateway-proxy", flag.ExitOnError)
554-
command.StringVar(&httpAddress, "address", ":8000", "HTTP proxy listening address")
555-
command.StringVar(&tlsAddress, "tls-address", ":8786", "TLS proxy listening address")
556+
command.StringVar(&address, "address", ":8000", "HTTP proxy listening address")
557+
command.StringVar(&tcpAddress, "tcp-address", "", "TCP proxy listening address. If empty, `address` will be used.")
556558
command.StringVar(&apiURL, "api-url", "", "The URL the proxy should watch for updating the routing table")
557559
command.StringVar(&logLevelString, "log-level", "INFO",
558560
"The log level. One of {DEBUG, INFO, WARN, ERROR}")
559561
command.BoolVar(&isChildProcess, "is-child-process", false,
560562
"If set, will exit when stdin EOFs. Useful when running as a child process.")
561563
command.StringVar(&tlsCert, "tls-cert", "", "TLS cert to use, if any.")
562564
command.StringVar(&tlsKey, "tls-key", "", "TLS key to use, if any.")
563-
command.DurationVar(&tlsTimeout, "tls-timeout", 3*time.Second, "Timeout for TLS Handshake initialization")
564565

565566
command.Parse(os.Args[1:])
566567

@@ -582,7 +583,7 @@ func main() {
582583
panic("Invalid -api-url: " + err.Error())
583584
}
584585

585-
proxy := NewProxy(tlsTimeout, logLevel)
586+
proxy := NewProxy(logLevel)
586587

587-
proxy.run(httpAddress, tlsAddress, apiURL, token, tlsCert, tlsKey, isChildProcess)
588+
proxy.run(address, tcpAddress, apiURL, token, tlsCert, tlsKey, isChildProcess)
588589
}
+55-128
Original file line numberDiff line numberDiff line change
@@ -1,154 +1,81 @@
11
package main
22

33
import (
4-
"encoding/binary"
5-
"errors"
6-
"fmt"
4+
"bufio"
5+
"bytes"
6+
"crypto/tls"
77
"io"
8+
"net"
89
)
910

10-
const extensionID = 0
11-
const hostnameID = 0
12-
13-
// Parses a Vector object (length prefixed bytes) per the TLS spec
14-
func parseVector(buf []byte, lenBytes int) ([]byte, []byte, error) {
15-
if len(buf) < lenBytes {
16-
return nil, nil, errors.New("Not enough space in packet for vector")
17-
}
18-
var l int
19-
for _, b := range buf[:lenBytes] {
20-
l = (l << 8) + int(b)
21-
}
22-
if len(buf) < l+lenBytes {
23-
return nil, nil, errors.New("Not enough space in packet for vector")
24-
}
25-
return buf[lenBytes : l+lenBytes], buf[l+lenBytes:], nil
11+
type tcpConn interface {
12+
net.Conn
13+
CloseWrite() error
14+
CloseRead() error
2615
}
2716

28-
func readSNI(buf []byte) (string, error) {
29-
if len(buf) == 0 {
30-
return "", errors.New("Zero length handshake record")
31-
}
32-
if buf[0] != 1 {
33-
return "", fmt.Errorf("Non-ClientHello handshake record type %d", buf[0])
34-
}
17+
type peekedTCPConn struct {
18+
peeked []byte
19+
*net.TCPConn
20+
}
3521

36-
buf, _, err := parseVector(buf[1:], 3)
37-
if err != nil {
38-
return "", fmt.Errorf("Reading ClientHello: %s", err)
22+
func (c *peekedTCPConn) Read(p []byte) (n int, err error) {
23+
if len(c.peeked) > 0 {
24+
n = copy(p, c.peeked)
25+
c.peeked = c.peeked[n:]
26+
if len(c.peeked) == 0 {
27+
c.peeked = nil
28+
}
29+
return n, nil
3930
}
31+
return c.TCPConn.Read(p)
32+
}
4033

41-
if len(buf) < 34 {
42-
return "", errors.New("ClientHello packet too short")
43-
}
34+
func wrapPeeked(inConn *net.TCPConn, br *bufio.Reader) tcpConn {
35+
peeked, _ := br.Peek(br.Buffered())
36+
return &peekedTCPConn{TCPConn: inConn, peeked: peeked}
37+
}
4438

45-
if buf[0] != 3 || buf[1] < 1 || buf[1] > 3 {
46-
return "", fmt.Errorf("ClientHello has unsupported version %d.%d", buf[0], buf[1])
47-
}
39+
type readonly struct {
40+
r io.Reader
41+
net.Conn
42+
}
4843

49-
// Skip version and random struct
50-
buf = buf[34:]
44+
func (c readonly) Read(p []byte) (int, error) { return c.r.Read(p) }
45+
func (readonly) Write(p []byte) (int, error) { return 0, io.EOF }
5146

52-
vec, buf, err := parseVector(buf, 1)
47+
func readSNI(inConn *net.TCPConn) (string, bool, tcpConn, error) {
48+
br := bufio.NewReader(inConn)
49+
hdr, err := br.Peek(1)
5350
if err != nil {
54-
return "", fmt.Errorf("Reading ClientHello SessionID: %s", err)
55-
}
56-
if len(vec) > 32 {
57-
return "", fmt.Errorf("ClientHello SessionID too long (%db)", len(vec))
51+
return "", false, nil, err
5852
}
5953

60-
vec, buf, err = parseVector(buf, 2)
61-
if err != nil {
62-
return "", fmt.Errorf("Reading ClientHello CipherSuites: %s", err)
63-
}
64-
if len(vec) < 2 || len(vec)%2 != 0 {
65-
return "", fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec))
54+
if hdr[0] != 0x16 {
55+
// Not a TLS handshake
56+
return "", false, wrapPeeked(inConn, br), nil
6657
}
6758

68-
vec, buf, err = parseVector(buf, 1)
59+
const headerLen = 5
60+
hdr, err = br.Peek(headerLen)
6961
if err != nil {
70-
return "", fmt.Errorf("Reading ClientHello CompressionMethods: %s", err)
71-
}
72-
if len(vec) < 1 {
73-
return "", fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec))
74-
}
75-
76-
if len(buf) != 0 {
77-
// Check vector is proper length for remaining msg
78-
vec, buf, err = parseVector(buf, 2)
79-
if err != nil {
80-
return "", fmt.Errorf("Reading ClientHello extensions: %s", err)
81-
}
82-
if len(buf) != 0 {
83-
return "", fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(buf))
84-
}
85-
buf = vec
86-
87-
for len(buf) >= 4 {
88-
typ := binary.BigEndian.Uint16(buf[:2])
89-
vec, buf, err = parseVector(buf[2:], 2)
90-
if err != nil {
91-
return "", fmt.Errorf("Reading ClientHello extension %d: %s", typ, err)
92-
}
93-
if typ == extensionID {
94-
// We found an SNI extension, attempt to extract the hostname
95-
buf, _, err := parseVector(vec, 2)
96-
if err != nil {
97-
return "", err
98-
}
99-
100-
for len(buf) >= 3 {
101-
typ := buf[0]
102-
vec, buf, err = parseVector(buf[1:], 2)
103-
if err != nil {
104-
return "", errors.New("Truncated SNI extension")
105-
}
106-
107-
// This vec is a hostname, return
108-
if typ == hostnameID {
109-
return string(vec), nil
110-
}
111-
}
112-
113-
if len(buf) != 0 {
114-
return "", errors.New("Trailing garbage at end of SNI extension")
115-
}
116-
117-
return "", nil
118-
}
119-
}
62+
return "", false, wrapPeeked(inConn, br), nil
12063
}
121-
return "", errors.New("No SNI found")
122-
}
12364

124-
func readVerAndSNI(r io.Reader) (string, int, error) {
125-
var header struct {
126-
Type uint8
127-
Major, Minor uint8
128-
Length uint16
129-
}
130-
if err := binary.Read(r, binary.BigEndian, &header); err != nil {
131-
return "", 0, fmt.Errorf("Error reading TLS record header: %s", err)
132-
}
133-
134-
if header.Type != 22 {
135-
return "", 0, fmt.Errorf("TLS record is not a handshake")
136-
}
137-
138-
if header.Major != 3 || header.Minor < 1 || header.Minor > 3 {
139-
return "", 0, fmt.Errorf("TLS record has unsupported version %d.%d",
140-
header.Major, header.Minor)
141-
}
142-
143-
if header.Length > 16384 {
144-
return "", 0, errors.New("TLS record is malformed")
65+
recLen := int(hdr[3])<<8 | int(hdr[4])
66+
helloBytes, err := br.Peek(headerLen + recLen)
67+
if err != nil {
68+
return "", true, wrapPeeked(inConn, br), nil
14569
}
14670

147-
buf := make([]byte, header.Length)
148-
if _, err := io.ReadFull(r, buf); err != nil {
149-
return "", 0, err
150-
}
71+
sni := ""
72+
server := tls.Server(readonly{r: bytes.NewReader(helloBytes)}, &tls.Config{
73+
GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
74+
sni = hello.ServerName
75+
return nil, nil
76+
},
77+
})
78+
server.Handshake()
15179

152-
sni, err := readSNI(buf)
153-
return sni, int(header.Minor), err
80+
return sni, true, wrapPeeked(inConn, br), nil
15481
}

‎dask-gateway-server/dask_gateway_server/proxy/core.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ class Proxy(LoggingConfigurable):
5353
config=True,
5454
)
5555

56-
scheduler_address = Unicode(
57-
":8786",
56+
tcp_address = Unicode(
5857
help="""
59-
The address the scheduler proxy should *listen* at.
58+
The address the TCP (scheduler) proxy should *listen* at.
6059
6160
Should be of the form ``{hostname}:{port}``
6261
@@ -65,11 +64,17 @@ class Proxy(LoggingConfigurable):
6564
- ``hostname`` sets the hostname to *listen* at. Set to ``""`` or
6665
``"0.0.0.0"`` to listen on all interfaces.
6766
- ``port`` sets the port to *listen* at.
67+
68+
If not specified, will default to `address`.
6869
""",
6970
config=True,
7071
)
7172

72-
@validate("address", "scheduler_address")
73+
@default("tcp_address")
74+
def _default_tcp_address(self):
75+
return self.address
76+
77+
@validate("address", "tcp_address")
7378
def _validate_addresses(self, proposal):
7479
return normalize_address(proposal.value)
7580

@@ -192,8 +197,8 @@ def get_start_command(self, is_child_process=True):
192197
_PROXY_EXE,
193198
"-address",
194199
self.address,
195-
"-tls-address",
196-
self.scheduler_address,
200+
"-tcp-address",
201+
self.tcp_address,
197202
"-api-url",
198203
self.gateway_url + "/api/routes",
199204
"-log-level",
@@ -232,9 +237,7 @@ def start_proxy_process(self):
232237
"https" if self.tls_cert else "http",
233238
self.address,
234239
)
235-
self.log.info(
236-
"- Scheduler routes listening at tls://%s", self.scheduler_address
237-
)
240+
self.log.info("- Scheduler routes listening at gateway://%s", self.tcp_address)
238241

239242
async def monitor_proxy_process(self):
240243
backoff = default_backoff = 0.5

‎dask-gateway/dask_gateway/client.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,10 @@ class Gateway(object):
244244
address : str, optional
245245
The address to the gateway server.
246246
proxy_address : str, int, optional
247-
The address of the scheduler proxy server. If an int, it's used as the
248-
port, with the host/ip taken from ``address``. Provide a full address
249-
if a different host/ip should be used.
247+
The address of the scheduler proxy server. Defaults to `address` if not
248+
provided. If an int, it's used as the port, with the host/ip taken from
249+
``address``. Provide a full address if a different host/ip should be
250+
used.
250251
auth : GatewayAuth, optional
251252
The authentication method to use.
252253
asynchronous : bool, optional
@@ -277,9 +278,7 @@ def __init__(
277278
if proxy_address is None:
278279
proxy_address = format_template(dask.config.get("gateway.proxy-address"))
279280
if proxy_address is None:
280-
raise ValueError(
281-
"No dask-gateway proxy address provided or found in configuration"
282-
)
281+
proxy_address = address
283282
if isinstance(proxy_address, int):
284283
parsed = urlparse(address)
285284
proxy_netloc = "%s:%d" % (parsed.hostname, proxy_address)

‎dask-gateway/dask_gateway/comm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class GatewayConnector(Connector):
3333
client = TCPClient(resolver=_resolver)
3434

3535
async def connect(self, address, deserialize=True, **connection_args):
36-
ip, port, sni = parse_gateway_address(address)
36+
ip, port, path = parse_gateway_address(address)
37+
sni = "daskgateway-" + path
3738
ctx = connection_args.get("ssl_context")
3839
if not isinstance(ctx, ssl.SSLContext):
3940
raise TypeError(

‎dask-gateway/dask_gateway/gateway.yaml

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ gateway:
99
# If `None` (default), `gateway.address` will be used.
1010
# May be a template string.
1111

12-
proxy-address: 8786 # The full address or port to the dask-gateway
12+
proxy-address: null # The full address or port to the dask-gateway
1313
# scheduler proxy. If a port, the host/ip is taken from
14-
# ``address``. May also be a template string.
14+
# ``address``. If null, defaults to `address`.
15+
# May also be a template string.
1516

1617
auth:
1718
type: basic # The authentication type to use. Options are basic,

‎tests/test_cli.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_proxy_cli(tmpdir, monkeypatch):
4646
text = (
4747
"c.DaskGateway.address = '127.0.0.1:8888'\n"
4848
"c.Proxy.address = '127.0.0.1:8866'\n"
49-
"c.Proxy.scheduler_address = '127.0.0.1:8867'\n"
49+
"c.Proxy.tcp_address = '127.0.0.1:8867'\n"
5050
"c.Proxy.log_level = 'debug'\n"
5151
"c.Proxy.api_token = 'abcde'"
5252
)
@@ -71,7 +71,7 @@ def mock_execle(*args):
7171
"dask-gateway-proxy",
7272
"-address",
7373
"127.0.0.1:8866",
74-
"-tls-address",
74+
"-tcp-address",
7575
"127.0.0.1:8867",
7676
"-api-url",
7777
"http://127.0.0.1:8888/api/routes",

‎tests/test_client.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,8 @@ def test_client_init():
133133
config["gateway"]["proxy-address"] = None
134134

135135
with dask.config.set(config):
136-
# No proxy-address provided
137-
with pytest.raises(ValueError):
138-
Gateway()
136+
g = Gateway()
137+
assert g.proxy_address == "gateway://127.0.0.1:8888"
139138

140139

141140
def test_gateway_addresses_template_environment_vars(monkeypatch):

‎tests/test_db_backend.py

-1
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,6 @@ async def test_gateway_resume_clusters_after_shutdown(tmpdir):
748748
config.LocalTestingBackend.check_timeouts_period = 0.5
749749
config.DaskGateway.address = "127.0.0.1:%d" % random_port()
750750
config.Proxy.address = "127.0.0.1:%d" % random_port()
751-
config.Proxy.scheduler_address = "127.0.0.1:%d" % random_port()
752751

753752
async with temp_gateway(config=config) as g:
754753
async with g.gateway_client() as gateway:

‎tests/test_proxies.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def __init__(self, **kwargs):
2121
self.proxy = Proxy(
2222
address="127.0.0.1:0",
2323
prefix="/foobar",
24-
scheduler_address="127.0.0.1:0",
2524
gateway_address=f"127.0.0.1:{self._port}",
2625
log_level="debug",
2726
proxy_status_period=0.5,
@@ -43,9 +42,10 @@ async def __aexit__(self, *args):
4342
await self.proxy.cleanup()
4443

4544

46-
@pytest.fixture
47-
async def proxy():
48-
async with temp_proxy() as proxy:
45+
@pytest.fixture(params=["separate", "shared"])
46+
async def proxy(request):
47+
kwargs = {"tcp_address": "127.0.0.1:0"} if request.param == "separate" else {}
48+
async with temp_proxy(**kwargs) as proxy:
4949
yield proxy
5050

5151

@@ -140,8 +140,8 @@ async def test_502():
140140
await with_retries(test_502, 5)
141141

142142

143-
@pytest.fixture
144-
async def ca_and_tls_proxy(tmpdir_factory):
143+
@pytest.fixture(params=["separate", "shared"])
144+
async def ca_and_tls_proxy(request, tmpdir_factory):
145145
trustme = pytest.importorskip("trustme")
146146
ca = trustme.CA()
147147
cert = ca.issue_cert("127.0.0.1")
@@ -153,7 +153,8 @@ async def ca_and_tls_proxy(tmpdir_factory):
153153
cert.private_key_pem.write_to_path(tls_key)
154154
cert.cert_chain_pems[0].write_to_path(tls_cert)
155155

156-
async with temp_proxy(tls_key=tls_key, tls_cert=tls_cert) as proxy:
156+
kwargs = {"tcp_address": "127.0.0.1:0"} if request.param == "separate" else {}
157+
async with temp_proxy(tls_key=tls_key, tls_cert=tls_cert, **kwargs) as proxy:
157158
yield ca, proxy
158159

159160

@@ -196,7 +197,7 @@ async def test_fails():
196197
async def test_scheduler_proxy(proxy, cluster_and_security):
197198
cluster, security = cluster_and_security
198199

199-
proxied_addr = f"gateway://{proxy.scheduler_address}/temp"
200+
proxied_addr = f"gateway://{proxy.tcp_address}/temp"
200201

201202
# Add a route
202203
await proxy.add_route(kind="SNI", sni="temp", target=cluster.scheduler_address)

‎tests/utils_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def __init__(self, **kwargs):
4646

4747
c.DaskGateway.address = "127.0.0.1:0"
4848
c.Proxy.address = "127.0.0.1:0"
49-
c.Proxy.scheduler_address = "127.0.0.1:0"
5049
c.DaskGateway.authenticator_class = (
5150
"dask_gateway_server.auth.SimpleAuthenticator"
5251
)
@@ -63,7 +62,7 @@ async def __aenter__(self):
6362
await self.gateway.setup()
6463
await self.gateway.backend.proxy._proxy_contacted
6564
self.address = f"http://{self.gateway.backend.proxy.address}"
66-
self.proxy_address = f"tls://{self.gateway.backend.proxy.scheduler_address}"
65+
self.proxy_address = f"gateway://{self.gateway.backend.proxy.tcp_address}"
6766
return self
6867

6968
async def __aexit__(self, *args):

0 commit comments

Comments
 (0)
Please sign in to comment.