From 72d9da99543da81dfcc1a2a0b7f32ce45d6f41a7 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Sat, 7 Apr 2018 15:38:14 +0900 Subject: [PATCH 1/4] Implement Connector and DriverContext interface --- connector.go | 139 ++++++++++++++++++++++++++++++++++++++++++ driver.go | 111 ++------------------------------- driver_go1.10.go | 45 ++++++++++++++ driver_go1.10_test.go | 17 ++++++ 4 files changed, 206 insertions(+), 106 deletions(-) create mode 100644 connector.go create mode 100644 driver_go1.10.go create mode 100644 driver_go1.10_test.go diff --git a/connector.go b/connector.go new file mode 100644 index 000000000..876ad976d --- /dev/null +++ b/connector.go @@ -0,0 +1,139 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" + "net" +) + +type connector struct { + cfg *Config // immutable private copy. +} + +// Connect implements driver.Connector interface. +// Connect returns a connection to the database. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + cfg: c.cfg, + } + mc.parseTime = mc.cfg.ParseTime + + // Connect to Server + // TODO: needs RegisterDialContext + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + mc.netConn, err = dial(mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + } + if err != nil { + return nil, err + } + + // Enable TCP Keepalives on TCP connections + if tc, ok := mc.netConn.(*net.TCPConn); ok { + if err := tc.SetKeepAlive(true); err != nil { + // Don't send COM_QUIT before handshake. + mc.netConn.Close() + mc.netConn = nil + return nil, err + } + } + + // Call startWatcher for context support (From Go 1.8) + mc.startWatcher() + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + mc.buf = newBuffer(mc.netConn) + + // Set I/O timeouts + mc.buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout + + // Reading Handshake Initialization Packet + authData, plugin, err := mc.readHandshakePacket() + if err != nil { + mc.cleanup() + return nil, err + } + + if plugin == "" { + plugin = defaultAuthPlugin + } + + // Send Client Authentication Packet + authResp, addNUL, err := mc.auth(authData, plugin) + if err != nil { + // try the default auth plugin, if using the requested plugin failed + errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + plugin = defaultAuthPlugin + authResp, addNUL, err = mc.auth(authData, plugin) + if err != nil { + mc.cleanup() + return nil, err + } + } + if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { + mc.cleanup() + return nil, err + } + + // Handle response to auth packet, switch methods if possible + if err = mc.handleAuthResult(authData, plugin); err != nil { + // Authentication failed and MySQL has already closed the connection + // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // Do not send COM_QUIT, just cleanup and return the error. + mc.cleanup() + return nil, err + } + + if mc.cfg.MaxAllowedPacket > 0 { + mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket + } else { + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxAllowedPacket = stringToInt(maxap) - 1 + } + if mc.maxAllowedPacket < maxPacketSize { + mc.maxWriteSize = mc.maxAllowedPacket + } + + // Handle DSN Params + err = mc.handleParams() + if err != nil { + mc.Close() + return nil, err + } + + return mc, nil +} + +// Driver implements driver.Connector interface. +// Driver returns &MySQLDriver{}. +func (c *connector) Driver() driver.Driver { + return &MySQLDriver{} +} diff --git a/driver.go b/driver.go index 9f4967087..48fb82ba7 100644 --- a/driver.go +++ b/driver.go @@ -17,6 +17,7 @@ package mysql import ( + "context" "database/sql" "database/sql/driver" "net" @@ -52,116 +53,14 @@ func RegisterDial(net string, dial DialFunc) { // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how // the DSN string is formatted func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { - var err error - - // New mysqlConn - mc := &mysqlConn{ - maxAllowedPacket: maxPacketSize, - maxWriteSize: maxPacketSize - 1, - closech: make(chan struct{}), - } - mc.cfg, err = ParseDSN(dsn) - if err != nil { - return nil, err - } - mc.parseTime = mc.cfg.ParseTime - - // Connect to Server - dialsLock.RLock() - dial, ok := dials[mc.cfg.Net] - dialsLock.RUnlock() - if ok { - mc.netConn, err = dial(mc.cfg.Addr) - } else { - nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) - } - if err != nil { - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { - errLog.Print("net.Error from Dial()': ", nerr.Error()) - return nil, driver.ErrBadConn - } - return nil, err - } - - // Enable TCP Keepalives on TCP connections - if tc, ok := mc.netConn.(*net.TCPConn); ok { - if err := tc.SetKeepAlive(true); err != nil { - // Don't send COM_QUIT before handshake. - mc.netConn.Close() - mc.netConn = nil - return nil, err - } - } - - // Call startWatcher for context support (From Go 1.8) - mc.startWatcher() - - mc.buf = newBuffer(mc.netConn) - - // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout - - // Reading Handshake Initialization Packet - authData, plugin, err := mc.readHandshakePacket() + cfg, err := ParseDSN(dsn) if err != nil { - mc.cleanup() return nil, err } - if plugin == "" { - plugin = defaultAuthPlugin + c := &connector{ + cfg: cfg, } - - // Send Client Authentication Packet - authResp, err := mc.auth(authData, plugin) - if err != nil { - // try the default auth plugin, if using the requested plugin failed - errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) - plugin = defaultAuthPlugin - authResp, err = mc.auth(authData, plugin) - if err != nil { - mc.cleanup() - return nil, err - } - } - if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { - mc.cleanup() - return nil, err - } - - // Handle response to auth packet, switch methods if possible - if err = mc.handleAuthResult(authData, plugin); err != nil { - // Authentication failed and MySQL has already closed the connection - // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). - // Do not send COM_QUIT, just cleanup and return the error. - mc.cleanup() - return nil, err - } - - if mc.cfg.MaxAllowedPacket > 0 { - mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket - } else { - // Get max allowed packet size - maxap, err := mc.getSystemVar("max_allowed_packet") - if err != nil { - mc.Close() - return nil, err - } - mc.maxAllowedPacket = stringToInt(maxap) - 1 - } - if mc.maxAllowedPacket < maxPacketSize { - mc.maxWriteSize = mc.maxAllowedPacket - } - - // Handle DSN Params - err = mc.handleParams() - if err != nil { - mc.Close() - return nil, err - } - - return mc, nil + return c.Connect(context.Background()) } func init() { diff --git a/driver_go1.10.go b/driver_go1.10.go new file mode 100644 index 000000000..b0ee134bd --- /dev/null +++ b/driver_go1.10.go @@ -0,0 +1,45 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10 + +package mysql + +import ( + "crypto/rsa" + "database/sql/driver" + "math/big" +) + +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) driver.Connector { + copyCfg := *cfg + copyCfg.tls = cfg.tls.Clone() + copyCfg.Params = make(map[string]string, len(cfg.Params)) + for k, v := range cfg.Params { + copyCfg.Params[k] = v + } + if cfg.pubKey != nil { + copyCfg.pubKey = &rsa.PublicKey{ + N: new(big.Int).Set(cfg.pubKey.N), + E: cfg.pubKey.E, + } + } + return &connector{cfg: ©Cfg} +} + +// OpenConnector implements driver.DriverContext. +func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{ + cfg: cfg, + }, nil +} diff --git a/driver_go1.10_test.go b/driver_go1.10_test.go new file mode 100644 index 000000000..a9d784c52 --- /dev/null +++ b/driver_go1.10_test.go @@ -0,0 +1,17 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10 + +package mysql + +import ( + "database/sql/driver" +) + +var _ driver.DriverContext = &MySQLDriver{} From 3ce299164ae9be24dec82eb10ae27704f9ca2ed4 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Tue, 26 Mar 2019 16:13:05 +0100 Subject: [PATCH 2/4] connector: Rebase on top of master --- appengine.go | 7 +++- connector.go | 14 ++++--- driver.go | 27 ++++++++++--- driver_go1.10.go => driver_go110.go | 22 ++++------ driver_go1.10_test.go => driver_go110_test.go | 0 driver_test.go | 7 ++-- dsn.go | 21 ++++++++++ dsn_test.go | 40 +++++++++++++++++++ 8 files changed, 109 insertions(+), 29 deletions(-) rename driver_go1.10.go => driver_go110.go (63%) rename driver_go1.10_test.go => driver_go110_test.go (100%) diff --git a/appengine.go b/appengine.go index be41f2ee6..44c0fd7c7 100644 --- a/appengine.go +++ b/appengine.go @@ -11,9 +11,14 @@ package mysql import ( + "context" + "google.golang.org/appengine/cloudsql" ) func init() { - RegisterDial("cloudsql", cloudsql.Dial) + RegisterDialContext("cloudsql", func(_ context.Context, instance addr) (net.Conn, error) { + // XXX: the cloudsql driver still does not export a Context-aware dialer. + return cloudsql.Dial(instance) + }) } diff --git a/connector.go b/connector.go index 876ad976d..5aaaba43e 100644 --- a/connector.go +++ b/connector.go @@ -33,17 +33,21 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.parseTime = mc.cfg.ParseTime // Connect to Server - // TODO: needs RegisterDialContext dialsLock.RLock() dial, ok := dials[mc.cfg.Net] dialsLock.RUnlock() if ok { - mc.netConn, err = dial(mc.cfg.Addr) + mc.netConn, err = dial(ctx, mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout} mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) } + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + errLog.Print("net.Error from Dial()': ", nerr.Error()) + return nil, driver.ErrBadConn + } return nil, err } @@ -82,18 +86,18 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { // try the default auth plugin, if using the requested plugin failed errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin - authResp, addNUL, err = mc.auth(authData, plugin) + authResp, err = mc.auth(authData, plugin) if err != nil { mc.cleanup() return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { mc.cleanup() return nil, err } diff --git a/driver.go b/driver.go index 48fb82ba7..1f9decf80 100644 --- a/driver.go +++ b/driver.go @@ -30,25 +30,42 @@ type MySQLDriver struct{} // DialFunc is a function which can be used to establish the network connection. // Custom dial functions must be registered with RegisterDial +// +// Deprecated: users should register a DialContextFunc instead type DialFunc func(addr string) (net.Conn, error) +// DialContextFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDialContext +type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error) + var ( dialsLock sync.RWMutex - dials map[string]DialFunc + dials map[string]DialContextFunc ) -// RegisterDial registers a custom dial function. It can then be used by the +// RegisterDialContext registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. -// addr is passed as a parameter to the dial function. -func RegisterDial(net string, dial DialFunc) { +// The current context for the connection and its address is passed to the dial function. +func RegisterDialContext(net string, dial DialContextFunc) { dialsLock.Lock() defer dialsLock.Unlock() if dials == nil { - dials = make(map[string]DialFunc) + dials = make(map[string]DialContextFunc) } dials[net] = dial } +// RegisterDial registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// addr is passed as a parameter to the dial function. +// +// Deprecated: users should call RegisterDialContext instead +func RegisterDial(network string, dial DialFunc) { + RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) { + return dial(addr) + }) +} + // Open new Connection. // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how // the DSN string is formatted diff --git a/driver_go1.10.go b/driver_go110.go similarity index 63% rename from driver_go1.10.go rename to driver_go110.go index b0ee134bd..eb5a8fe9b 100644 --- a/driver_go1.10.go +++ b/driver_go110.go @@ -11,26 +11,18 @@ package mysql import ( - "crypto/rsa" "database/sql/driver" - "math/big" ) // NewConnector returns new driver.Connector. -func NewConnector(cfg *Config) driver.Connector { - copyCfg := *cfg - copyCfg.tls = cfg.tls.Clone() - copyCfg.Params = make(map[string]string, len(cfg.Params)) - for k, v := range cfg.Params { - copyCfg.Params[k] = v - } - if cfg.pubKey != nil { - copyCfg.pubKey = &rsa.PublicKey{ - N: new(big.Int).Set(cfg.pubKey.N), - E: cfg.pubKey.E, - } +func NewConnector(cfg *Config) (driver.Connector, error) { + cfg = cfg.Clone() + // normalize the contents of cfg so calls to NewConnector have the same + // behavior as MySQLDriver.OpenConnector + if err := cfg.normalize(); err != nil { + return nil, err } - return &connector{cfg: ©Cfg} + return &connector{cfg: cfg}, nil } // OpenConnector implements driver.DriverContext. diff --git a/driver_go1.10_test.go b/driver_go110_test.go similarity index 100% rename from driver_go1.10_test.go rename to driver_go110_test.go diff --git a/driver_test.go b/driver_test.go index 9c3d286ce..b45a81eb1 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1846,7 +1846,7 @@ func TestConcurrent(t *testing.T) { } func testDialError(t *testing.T, dialErr error, expectErr error) { - RegisterDial("mydial", func(addr string) (net.Conn, error) { + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { return nil, dialErr }) @@ -1884,8 +1884,9 @@ func TestCustomDial(t *testing.T) { } // our custom dial function which justs wraps net.Dial here - RegisterDial("mydial", func(addr string) (net.Conn, error) { - return net.Dial(prot, addr) + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, prot, addr) }) db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) diff --git a/dsn.go b/dsn.go index b9134722e..6e19ab717 100644 --- a/dsn.go +++ b/dsn.go @@ -14,6 +14,7 @@ import ( "crypto/tls" "errors" "fmt" + "math/big" "net" "net/url" "sort" @@ -72,6 +73,26 @@ func NewConfig() *Config { } } +func (cfg *Config) Clone() *Config { + cp := *cfg + if cp.tls != nil { + cp.tls = cfg.tls.Clone() + } + if len(cp.Params) > 0 { + cp.Params = make(map[string]string, len(cfg.Params)) + for k, v := range cfg.Params { + cp.Params[k] = v + } + } + if cfg.pubKey != nil { + cp.pubKey = &rsa.PublicKey{ + N: new(big.Int).Set(cfg.pubKey.N), + E: cfg.pubKey.E, + } + } + return &cp +} + func (cfg *Config) normalize() error { if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { return errInvalidDSNUnsafeCollation diff --git a/dsn_test.go b/dsn_test.go index 1cd095496..280fdf61f 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -318,6 +318,46 @@ func TestParamsAreSorted(t *testing.T) { } } +func TestCloneConfig(t *testing.T) { + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey" + cfg, err := ParseDSN(dsn) + if err != nil { + t.Fatal(err.Error()) + } + + cfg2 := cfg.Clone() + if cfg == cfg2 { + t.Errorf("Config.Clone did not create a separate config struct") + } + + if cfg2.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + cfg2.tls.ServerName = "example2.com" + if cfg.tls.ServerName == cfg2.tls.ServerName { + t.Errorf("changed cfg.tls.Server name should not propagate to original Config") + } + + if _, ok := cfg2.Params["foobar"]; !ok { + t.Errorf("cloned Config is missing custom params") + } + + delete(cfg2.Params, "foobar") + + if _, ok := cfg.Params["foobar"]; !ok { + t.Errorf("custom params in cloned Config should not propagate to original Config") + } + + if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) { + t.Errorf("public key in Config should be identical") + } +} + func BenchmarkParseDSN(b *testing.B) { b.ReportAllocs() From 42cb96eece294e59efc2dabd8cb8e641ad545655 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Thu, 4 Apr 2019 12:31:53 +0200 Subject: [PATCH 3/4] test: Implement TestConnectorObeysDialTimeouts --- driver_go110_test.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/driver_go110_test.go b/driver_go110_test.go index a9d784c52..2e18a693a 100644 --- a/driver_go110_test.go +++ b/driver_go110_test.go @@ -11,7 +11,37 @@ package mysql import ( + "context" + "database/sql" "database/sql/driver" + "fmt" + "net" + "testing" ) var _ driver.DriverContext = &MySQLDriver{} + +type dialCtxKey struct{} + +func TestConnectorObeysDialTimeouts(t *testing.T) { + RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + if !ctx.Value(dialCtxKey{}).(bool) { + return nil, fmt.Errorf("test error: query context is not propagated to our dialer") + } + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + ctx := context.WithValue(context.Background(), dialCtxKey{}, true) + + _, err = db.ExecContext(ctx, "DO 1") + if err != nil { + t.Fatal(err) + } +} From a87822a385d9b633b555949e37449f8d19e49f33 Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Thu, 4 Apr 2019 13:05:47 +0200 Subject: [PATCH 4/4] test: Implement more timeout tests for connector --- driver_go110_test.go | 90 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/driver_go110_test.go b/driver_go110_test.go index 2e18a693a..19a0e5956 100644 --- a/driver_go110_test.go +++ b/driver_go110_test.go @@ -17,6 +17,7 @@ import ( "fmt" "net" "testing" + "time" ) var _ driver.DriverContext = &MySQLDriver{} @@ -24,6 +25,10 @@ var _ driver.DriverContext = &MySQLDriver{} type dialCtxKey struct{} func TestConnectorObeysDialTimeouts(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) { var d net.Dialer if !ctx.Value(dialCtxKey{}).(bool) { @@ -45,3 +50,88 @@ func TestConnectorObeysDialTimeouts(t *testing.T) { t.Fatal(err) } } + +func configForTests(t *testing.T) *Config { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mycnf := NewConfig() + mycnf.User = user + mycnf.Passwd = pass + mycnf.Addr = addr + mycnf.Net = prot + mycnf.DBName = dbname + return mycnf +} + +func TestNewConnector(t *testing.T) { + mycnf := configForTests(t) + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + db := sql.OpenDB(conn) + defer db.Close() + + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} + +type slowConnection struct { + net.Conn + slowdown time.Duration +} + +func (sc *slowConnection) Read(b []byte) (int, error) { + time.Sleep(sc.slowdown) + return sc.Conn.Read(b) +} + +type connectorHijack struct { + driver.Connector + connErr error +} + +func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) { + var conn driver.Conn + conn, cw.connErr = cw.Connector.Connect(ctx) + return conn, cw.connErr +} + +func TestConnectorTimeoutsDuringOpen(t *testing.T) { + RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, prot, addr) + if err != nil { + return nil, err + } + return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil + }) + + mycnf := configForTests(t) + mycnf.Net = "slowconn" + + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + hijack := &connectorHijack{Connector: conn} + + db := sql.OpenDB(hijack) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = db.ExecContext(ctx, "DO 1") + if err != context.DeadlineExceeded { + t.Fatalf("ExecContext should have timed out") + } + if hijack.connErr != context.DeadlineExceeded { + t.Fatalf("(*Connector).Connect should have timed out") + } +}