From ad9fa14acdcf7d0533e7fbe58728f3d216213ade Mon Sep 17 00:00:00 2001 From: Thomas Posch <55388669+thopos@users.noreply.github.com> Date: Wed, 13 Apr 2022 09:25:45 +0200 Subject: [PATCH 001/123] Add SQLState to MySQLError (#1321) Report SQLState in MySQLError to allow library users to distinguish user-defined from client / server errors. --- AUTHORS | 1 + errors.go | 9 +++++++-- errors_test.go | 6 +++--- packets.go | 11 ++++++----- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/AUTHORS b/AUTHORS index 876b2964a..50b9593f0 100644 --- a/AUTHORS +++ b/AUTHORS @@ -110,6 +110,7 @@ Ziheng Lyu Barracuda Networks, Inc. Counting Ltd. DigitalOcean Inc. +dyves labs AG Facebook Inc. GitHub Inc. Google Inc. diff --git a/errors.go b/errors.go index 92cc9a361..7c037e7d6 100644 --- a/errors.go +++ b/errors.go @@ -56,11 +56,16 @@ func SetLogger(logger Logger) error { // MySQLError is an error type which represents a single MySQL error type MySQLError struct { - Number uint16 - Message string + Number uint16 + SQLState [5]byte + Message string } func (me *MySQLError) Error() string { + if me.SQLState != [5]byte{} { + return fmt.Sprintf("Error %d (%s): %s", me.Number, me.SQLState, me.Message) + } + return fmt.Sprintf("Error %d: %s", me.Number, me.Message) } diff --git a/errors_test.go b/errors_test.go index 3a1aef74d..43213f98e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -43,13 +43,13 @@ func TestErrorsStrictIgnoreNotes(t *testing.T) { } func TestMySQLErrIs(t *testing.T) { - infraErr := &MySQLError{1234, "the server is on fire"} - otherInfraErr := &MySQLError{1234, "the datacenter is flooded"} + infraErr := &MySQLError{Number: 1234, Message: "the server is on fire"} + otherInfraErr := &MySQLError{Number: 1234, Message: "the datacenter is flooded"} if !errors.Is(infraErr, otherInfraErr) { t.Errorf("expected errors to be the same: %+v %+v", infraErr, otherInfraErr) } - differentCodeErr := &MySQLError{5678, "the server is on fire"} + differentCodeErr := &MySQLError{Number: 5678, Message: "the server is on fire"} if errors.Is(infraErr, differentCodeErr) { t.Fatalf("expected errors to be different: %+v %+v", infraErr, differentCodeErr) } diff --git a/packets.go b/packets.go index ab30601ae..003584c25 100644 --- a/packets.go +++ b/packets.go @@ -587,19 +587,20 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { return driver.ErrBadConn } + me := &MySQLError{Number: errno} + pos := 3 // SQL State [optional: # + 5bytes string] if data[3] == 0x23 { - //sqlstate := string(data[4 : 4+5]) + copy(me.SQLState[:], data[4:4+5]) pos = 9 } // Error Message [string] - return &MySQLError{ - Number: errno, - Message: string(data[pos:]), - } + me.Message = string(data[pos:]) + + return me } func readStatus(b []byte) statusFlag { From 0c62bb2791485d4260371bcc6017321de93b2430 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Fri, 19 Aug 2022 16:29:14 +0900 Subject: [PATCH 002/123] Go1.19 is released (#1350) --- .github/workflows/test.yml | 3 +- atomic_bool.go | 19 ++++++++++ atomic_bool_go118.go | 47 +++++++++++++++++++++++++ atomic_bool_test.go | 71 +++++++++++++++++++++++++++++++++++++ auth.go | 35 +++++++++--------- collations.go | 3 +- conncheck.go | 1 + conncheck_dummy.go | 1 + conncheck_test.go | 1 + connection.go | 22 ++++++------ connection_test.go | 2 +- driver.go | 6 ++-- fuzz.go | 1 + infile.go | 28 +++++++-------- nulltime.go | 18 +++++----- statement.go | 8 ++--- transaction.go | 4 +-- utils.go | 72 +++++++++++--------------------------- utils_test.go | 60 ------------------------------- 19 files changed, 226 insertions(+), 176 deletions(-) create mode 100644 atomic_bool.go create mode 100644 atomic_bool_go118.go create mode 100644 atomic_bool_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f8c472832..b558eba28 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,8 +22,9 @@ jobs: import json go = [ # Keep the most recent production release at the top - '1.18', + '1.19', # Older production releases + '1.18', '1.17', '1.16', '1.15', diff --git a/atomic_bool.go b/atomic_bool.go new file mode 100644 index 000000000..1b7e19f3e --- /dev/null +++ b/atomic_bool.go @@ -0,0 +1,19 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build go1.19 +// +build go1.19 + +package mysql + +import "sync/atomic" + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +type atomicBool = atomic.Bool diff --git a/atomic_bool_go118.go b/atomic_bool_go118.go new file mode 100644 index 000000000..2e9a7f0b6 --- /dev/null +++ b/atomic_bool_go118.go @@ -0,0 +1,47 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !go1.19 +// +build !go1.19 + +package mysql + +import "sync/atomic" + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +// atomicBool is an implementation of atomic.Bool for older version of Go. +// it is a wrapper around uint32 for usage as a boolean value with +// atomic access. +type atomicBool struct { + _ noCopy + value uint32 +} + +// Load returns whether the current boolean value is true +func (ab *atomicBool) Load() bool { + return atomic.LoadUint32(&ab.value) > 0 +} + +// Store sets the value of the bool regardless of the previous value +func (ab *atomicBool) Store(value bool) { + if value { + atomic.StoreUint32(&ab.value, 1) + } else { + atomic.StoreUint32(&ab.value, 0) + } +} + +// Swap sets the value of the bool and returns the old value. +func (ab *atomicBool) Swap(value bool) bool { + if value { + return atomic.SwapUint32(&ab.value, 1) > 0 + } + return atomic.SwapUint32(&ab.value, 0) > 0 +} diff --git a/atomic_bool_test.go b/atomic_bool_test.go new file mode 100644 index 000000000..a3b4ea0e8 --- /dev/null +++ b/atomic_bool_test.go @@ -0,0 +1,71 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !go1.19 +// +build !go1.19 + +package mysql + +import ( + "testing" +) + +func TestAtomicBool(t *testing.T) { + var ab atomicBool + if ab.Load() { + t.Fatal("Expected value to be false") + } + + ab.Store(true) + if ab.value != 1 { + t.Fatal("Set(true) did not set value to 1") + } + if !ab.Load() { + t.Fatal("Expected value to be true") + } + + ab.Store(true) + if !ab.Load() { + t.Fatal("Expected value to be true") + } + + ab.Store(false) + if ab.value != 0 { + t.Fatal("Set(false) did not set value to 0") + } + if ab.Load() { + t.Fatal("Expected value to be false") + } + + ab.Store(false) + if ab.Load() { + t.Fatal("Expected value to be false") + } + if ab.Swap(false) { + t.Fatal("Expected the old value to be false") + } + if ab.Swap(true) { + t.Fatal("Expected the old value to be false") + } + if !ab.Load() { + t.Fatal("Expected value to be true") + } + + ab.Store(true) + if !ab.Load() { + t.Fatal("Expected value to be true") + } + if !ab.Swap(true) { + t.Fatal("Expected the old value to be true") + } + if !ab.Swap(false) { + t.Fatal("Expected the old value to be true") + } + if ab.Load() { + t.Fatal("Expected value to be false") + } +} diff --git a/auth.go b/auth.go index a25353429..26f8723f5 100644 --- a/auth.go +++ b/auth.go @@ -33,27 +33,26 @@ var ( // Note: The provided rsa.PublicKey instance is exclusively owned by the driver // after registering it and may not be modified. // -// data, err := ioutil.ReadFile("mykey.pem") -// if err != nil { -// log.Fatal(err) -// } +// data, err := ioutil.ReadFile("mykey.pem") +// if err != nil { +// log.Fatal(err) +// } // -// block, _ := pem.Decode(data) -// if block == nil || block.Type != "PUBLIC KEY" { -// log.Fatal("failed to decode PEM block containing public key") -// } +// block, _ := pem.Decode(data) +// if block == nil || block.Type != "PUBLIC KEY" { +// log.Fatal("failed to decode PEM block containing public key") +// } // -// pub, err := x509.ParsePKIXPublicKey(block.Bytes) -// if err != nil { -// log.Fatal(err) -// } -// -// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { -// mysql.RegisterServerPubKey("mykey", rsaPubKey) -// } else { -// log.Fatal("not a RSA public key") -// } +// pub, err := x509.ParsePKIXPublicKey(block.Bytes) +// if err != nil { +// log.Fatal(err) +// } // +// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { +// mysql.RegisterServerPubKey("mykey", rsaPubKey) +// } else { +// log.Fatal("not a RSA public key") +// } func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { serverPubKeyLock.Lock() if serverPubKeyRegistry == nil { diff --git a/collations.go b/collations.go index 326a9f7fa..295bfbe52 100644 --- a/collations.go +++ b/collations.go @@ -13,7 +13,8 @@ const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: -// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID +// +// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID // // Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255. // diff --git a/conncheck.go b/conncheck.go index 024eb2858..0ea721720 100644 --- a/conncheck.go +++ b/conncheck.go @@ -6,6 +6,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos // +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos package mysql diff --git a/conncheck_dummy.go b/conncheck_dummy.go index ea7fb607a..a56c138f2 100644 --- a/conncheck_dummy.go +++ b/conncheck_dummy.go @@ -6,6 +6,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build !linux && !darwin && !dragonfly && !freebsd && !netbsd && !openbsd && !solaris && !illumos // +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos package mysql diff --git a/conncheck_test.go b/conncheck_test.go index 53995517b..f7e025680 100644 --- a/conncheck_test.go +++ b/conncheck_test.go @@ -6,6 +6,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd || solaris || illumos // +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos package mysql diff --git a/connection.go b/connection.go index 835f89729..9539077cb 100644 --- a/connection.go +++ b/connection.go @@ -104,7 +104,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { } func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -123,7 +123,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent - if !mc.closed.IsSet() { + if !mc.closed.Load() { err = mc.writeCommandPacket(comQuit) } @@ -137,7 +137,7 @@ func (mc *mysqlConn) Close() (err error) { // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { - if !mc.closed.TrySet(true) { + if mc.closed.Swap(true) { return } @@ -152,7 +152,7 @@ func (mc *mysqlConn) cleanup() { } func (mc *mysqlConn) error() error { - if mc.closed.IsSet() { + if mc.closed.Load() { if err := mc.canceled.Value(); err != nil { return err } @@ -162,7 +162,7 @@ func (mc *mysqlConn) error() error { } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -295,7 +295,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -356,7 +356,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro } func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -450,7 +450,7 @@ func (mc *mysqlConn) finish() { // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { - if mc.closed.IsSet() { + if mc.closed.Load() { errLog.Print(ErrInvalidConn) return driver.ErrBadConn } @@ -469,7 +469,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { // BeginTx implements driver.ConnBeginTx interface func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - if mc.closed.IsSet() { + if mc.closed.Load() { return nil, driver.ErrBadConn } @@ -636,7 +636,7 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { // ResetSession implements driver.SessionResetter. // (From Go 1.10) func (mc *mysqlConn) ResetSession(ctx context.Context) error { - if mc.closed.IsSet() { + if mc.closed.Load() { return driver.ErrBadConn } mc.reset = true @@ -646,5 +646,5 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { // IsValid implements driver.Validator interface // (From Go 1.15) func (mc *mysqlConn) IsValid() bool { - return !mc.closed.IsSet() + return !mc.closed.Load() } diff --git a/connection_test.go b/connection_test.go index a6d677308..b6764a2f6 100644 --- a/connection_test.go +++ b/connection_test.go @@ -147,7 +147,7 @@ func TestCleanCancel(t *testing.T) { t.Errorf("expected context.Canceled, got %#v", err) } - if mc.closed.IsSet() { + if mc.closed.Load() { t.Error("expected mc is not closed, closed actually") } diff --git a/driver.go b/driver.go index c1bdf1199..ad7aec215 100644 --- a/driver.go +++ b/driver.go @@ -8,10 +8,10 @@ // // The driver should be used via the database/sql package: // -// import "database/sql" -// import _ "github.com/go-sql-driver/mysql" +// import "database/sql" +// import _ "github.com/go-sql-driver/mysql" // -// db, err := sql.Open("mysql", "user:password@/dbname") +// db, err := sql.Open("mysql", "user:password@/dbname") // // See https://github.com/go-sql-driver/mysql#usage for details package mysql diff --git a/fuzz.go b/fuzz.go index fa75adf6a..3a4ec25a9 100644 --- a/fuzz.go +++ b/fuzz.go @@ -6,6 +6,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build gofuzz // +build gofuzz package mysql diff --git a/infile.go b/infile.go index e6323aea4..3279dcffd 100644 --- a/infile.go +++ b/infile.go @@ -28,12 +28,11 @@ var ( // Alternatively you can allow the use of all local files with // the DSN parameter 'allowAllFiles=true' // -// filePath := "/home/gopher/data.csv" -// mysql.RegisterLocalFile(filePath) -// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") -// if err != nil { -// ... -// +// filePath := "/home/gopher/data.csv" +// mysql.RegisterLocalFile(filePath) +// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") +// if err != nil { +// ... func RegisterLocalFile(filePath string) { fileRegisterLock.Lock() // lazy map init @@ -58,15 +57,14 @@ func DeregisterLocalFile(filePath string) { // If the handler returns a io.ReadCloser Close() is called when the // request is finished. // -// mysql.RegisterReaderHandler("data", func() io.Reader { -// var csvReader io.Reader // Some Reader that returns CSV data -// ... // Open Reader here -// return csvReader -// }) -// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") -// if err != nil { -// ... -// +// mysql.RegisterReaderHandler("data", func() io.Reader { +// var csvReader io.Reader // Some Reader that returns CSV data +// ... // Open Reader here +// return csvReader +// }) +// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") +// if err != nil { +// ... func RegisterReaderHandler(name string, handler func() io.Reader) { readerRegisterLock.Lock() // lazy map init diff --git a/nulltime.go b/nulltime.go index 17af92ddc..36c8a42c5 100644 --- a/nulltime.go +++ b/nulltime.go @@ -19,16 +19,16 @@ import ( // NullTime implements the Scanner interface so // it can be used as a scan destination: // -// var nt NullTime -// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) -// ... -// if nt.Valid { -// // use nt.Time -// } else { -// // NULL value -// } +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } // -// This NullTime implementation is not driver-specific +// # This NullTime implementation is not driver-specific // // Deprecated: NullTime doesn't honor the loc DSN parameter. // NullTime.Scan interprets a time as UTC, not the loc DSN parameter. diff --git a/statement.go b/statement.go index 18a3ae498..10ece8bd6 100644 --- a/statement.go +++ b/statement.go @@ -23,7 +23,7 @@ type mysqlStmt struct { } func (stmt *mysqlStmt) Close() error { - if stmt.mc == nil || stmt.mc.closed.IsSet() { + if stmt.mc == nil || stmt.mc.closed.Load() { // driver.Stmt.Close can be called more than once, thus this function // has to be idempotent. // See also Issue #450 and golang/go#16019. @@ -50,7 +50,7 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.closed.IsSet() { + if stmt.mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -98,7 +98,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { } func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { - if stmt.mc.closed.IsSet() { + if stmt.mc.closed.Load() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -157,7 +157,7 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if driver.IsValue(sv) { return sv, nil } - // A value returend from the Valuer interface can be "a type handled by + // A value returned from the Valuer interface can be "a type handled by // a database driver's NamedValueChecker interface" so we should accept // uint64 here as well. if u, ok := sv.(uint64); ok { diff --git a/transaction.go b/transaction.go index 417d72793..4a4b61001 100644 --- a/transaction.go +++ b/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.closed.IsSet() { + if tx.mc == nil || tx.mc.closed.Load() { return ErrInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.closed.IsSet() { + if tx.mc == nil || tx.mc.closed.Load() { return ErrInvalidConn } err = tx.mc.exec("ROLLBACK") diff --git a/utils.go b/utils.go index 5a024aa0a..60f1a91c6 100644 --- a/utils.go +++ b/utils.go @@ -35,26 +35,25 @@ var ( // Note: The provided tls.Config is exclusively owned by the driver after // registering it. // -// rootCertPool := x509.NewCertPool() -// pem, err := ioutil.ReadFile("/path/ca-cert.pem") -// if err != nil { -// log.Fatal(err) -// } -// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { -// log.Fatal("Failed to append PEM.") -// } -// clientCert := make([]tls.Certificate, 0, 1) -// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") -// if err != nil { -// log.Fatal(err) -// } -// clientCert = append(clientCert, certs) -// mysql.RegisterTLSConfig("custom", &tls.Config{ -// RootCAs: rootCertPool, -// Certificates: clientCert, -// }) -// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") -// +// rootCertPool := x509.NewCertPool() +// pem, err := ioutil.ReadFile("/path/ca-cert.pem") +// if err != nil { +// log.Fatal(err) +// } +// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { +// log.Fatal("Failed to append PEM.") +// } +// clientCert := make([]tls.Certificate, 0, 1) +// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") +// if err != nil { +// log.Fatal(err) +// } +// clientCert = append(clientCert, certs) +// mysql.RegisterTLSConfig("custom", &tls.Config{ +// RootCAs: rootCertPool, +// Certificates: clientCert, +// }) +// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") func RegisterTLSConfig(key string, config *tls.Config) error { if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { return fmt.Errorf("key '%s' is reserved", key) @@ -796,39 +795,10 @@ func (*noCopy) Lock() {} // https://github.com/golang/go/issues/26165 func (*noCopy) Unlock() {} -// atomicBool is a wrapper around uint32 for usage as a boolean value with -// atomic access. -type atomicBool struct { - _noCopy noCopy - value uint32 -} - -// IsSet returns whether the current boolean value is true -func (ab *atomicBool) IsSet() bool { - return atomic.LoadUint32(&ab.value) > 0 -} - -// Set sets the value of the bool regardless of the previous value -func (ab *atomicBool) Set(value bool) { - if value { - atomic.StoreUint32(&ab.value, 1) - } else { - atomic.StoreUint32(&ab.value, 0) - } -} - -// TrySet sets the value of the bool and returns whether the value changed -func (ab *atomicBool) TrySet(value bool) bool { - if value { - return atomic.SwapUint32(&ab.value, 1) == 0 - } - return atomic.SwapUint32(&ab.value, 0) > 0 -} - // atomicError is a wrapper for atomically accessed error values type atomicError struct { - _noCopy noCopy - value atomic.Value + _ noCopy + value atomic.Value } // Set sets the error value regardless of the previous value. diff --git a/utils_test.go b/utils_test.go index b0069251e..8296ac2aa 100644 --- a/utils_test.go +++ b/utils_test.go @@ -173,66 +173,6 @@ func TestEscapeQuotes(t *testing.T) { expect("foo\"bar", "foo\"bar") // not affected } -func TestAtomicBool(t *testing.T) { - var ab atomicBool - if ab.IsSet() { - t.Fatal("Expected value to be false") - } - - ab.Set(true) - if ab.value != 1 { - t.Fatal("Set(true) did not set value to 1") - } - if !ab.IsSet() { - t.Fatal("Expected value to be true") - } - - ab.Set(true) - if !ab.IsSet() { - t.Fatal("Expected value to be true") - } - - ab.Set(false) - if ab.value != 0 { - t.Fatal("Set(false) did not set value to 0") - } - if ab.IsSet() { - t.Fatal("Expected value to be false") - } - - ab.Set(false) - if ab.IsSet() { - t.Fatal("Expected value to be false") - } - if ab.TrySet(false) { - t.Fatal("Expected TrySet(false) to fail") - } - if !ab.TrySet(true) { - t.Fatal("Expected TrySet(true) to succeed") - } - if !ab.IsSet() { - t.Fatal("Expected value to be true") - } - - ab.Set(true) - if !ab.IsSet() { - t.Fatal("Expected value to be true") - } - if ab.TrySet(true) { - t.Fatal("Expected TrySet(true) to fail") - } - if !ab.TrySet(false) { - t.Fatal("Expected TrySet(false) to succeed") - } - if ab.IsSet() { - t.Fatal("Expected value to be false") - } - - // we've "tested" them ¯\_(ツ)_/¯ - ab._noCopy.Lock() - defer ab._noCopy.Unlock() -} - func TestAtomicError(t *testing.T) { var ae atomicError if ae.Value() != nil { From a477f69f3c2abaf4646680bdc3a65d5172a6566e Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sat, 20 Aug 2022 00:19:54 +0900 Subject: [PATCH 003/123] fix: benchmarkExecContext is unused (#1351) --- benchmark_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark_test.go b/benchmark_test.go index 1030ddc52..97ed781f8 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -314,7 +314,7 @@ func BenchmarkExecContext(b *testing.B) { defer db.Close() for _, p := range []int{1, 2, 3, 4} { b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { - benchmarkQueryContext(b, db, p) + benchmarkExecContext(b, db, p) }) } } From 803c0e06f2b703d30b18e168e7349ffc66e7fa86 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Mon, 7 Nov 2022 21:26:59 +0900 Subject: [PATCH 004/123] migrate set-output to the environment file (#1368) --- .github/workflows/test.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b558eba28..703203258 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,6 +20,7 @@ jobs: id: set-matrix run: | import json + import os go = [ # Keep the most recent production release at the top '1.19', @@ -55,7 +56,8 @@ jobs: 'include': includes } output = json.dumps(matrix, separators=(',', ':')) - print('::set-output name=matrix::{0}'.format(output)) + with open(os.environ["GITHUB_OUTPUT"], 'a', encoding="utf-8") as f: + f.write('matrix={0}\n'.format(output)) shell: python test: needs: list From 05bed834d054b8361595c6544146567d70713dc1 Mon Sep 17 00:00:00 2001 From: "lgtm-com[bot]" <43144390+lgtm-com[bot]@users.noreply.github.com> Date: Thu, 10 Nov 2022 06:19:39 +0900 Subject: [PATCH 005/123] Add CodeQL workflow for GitHub code scanning (#1369) Co-authored-by: LGTM Migrator --- .github/workflows/codeql.yml | 41 ++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 .github/workflows/codeql.yml diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000..d9d29a8b7 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,41 @@ +name: "CodeQL" + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + schedule: + - cron: "18 19 * * 1" + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ go ] + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + queries: +security-and-quality + + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{ matrix.language }}" From fa1e4ed592daa59bcd70003263b5fc72e3de0137 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 10 Nov 2022 15:01:30 +0900 Subject: [PATCH 006/123] Fix parsing 0 year. (#1257) Fix #1252 --- utils.go | 10 ---------- utils_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/utils.go b/utils.go index 60f1a91c6..15dbd8d16 100644 --- a/utils.go +++ b/utils.go @@ -117,10 +117,6 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { if err != nil { return time.Time{}, err } - if year <= 0 { - year = 1 - } - if b[4] != '-' { return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4]) } @@ -129,9 +125,6 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { if err != nil { return time.Time{}, err } - if m <= 0 { - m = 1 - } month := time.Month(m) if b[7] != '-' { @@ -142,9 +135,6 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { if err != nil { return time.Time{}, err } - if day <= 0 { - day = 1 - } if len(b) == 10 { return time.Date(year, month, day, 0, 0, 0, 0, loc), nil } diff --git a/utils_test.go b/utils_test.go index 8296ac2aa..4e5fc3cb7 100644 --- a/utils_test.go +++ b/utils_test.go @@ -380,6 +380,33 @@ func TestParseDateTime(t *testing.T) { } } +func TestInvalidDateTime(t *testing.T) { + cases := []struct { + name string + str string + want time.Time + }{ + { + name: "parse datetime without day", + str: "0000-00-00 21:30:45", + want: time.Date(0, 0, 0, 21, 30, 45, 0, time.UTC), + }, + } + + for _, cc := range cases { + t.Run(cc.name, func(t *testing.T) { + got, err := parseDateTime([]byte(cc.str), time.UTC) + if err != nil { + t.Fatal(err) + } + + if !cc.want.Equal(got) { + t.Fatalf("want: %v, but got %v", cc.want, got) + } + }) + } +} + func TestParseDateTimeFail(t *testing.T) { cases := []struct { name string From 41dd159e6ec9afad00d2b90144bbc083ea860db1 Mon Sep 17 00:00:00 2001 From: lance6716 Date: Mon, 28 Nov 2022 14:26:20 +0800 Subject: [PATCH 007/123] Add `AllowFallbackToPlaintext` and `TLS` to config (#1370) --- AUTHORS | 1 + README.md | 11 ++++++++ auth.go | 4 +-- auth_test.go | 10 ++++---- dsn.go | 72 ++++++++++++++++++++++++++++++++-------------------- dsn_test.go | 47 +++++++++++++++++----------------- packets.go | 12 ++++----- 7 files changed, 94 insertions(+), 63 deletions(-) diff --git a/AUTHORS b/AUTHORS index 50b9593f0..051327519 100644 --- a/AUTHORS +++ b/AUTHORS @@ -61,6 +61,7 @@ Kamil Dziedzic Kei Kamikawa Kevin Malachowski Kieron Woodhouse +Lance Tian Lennart Rudolph Leonardo YongUk Kim Linh Tran Tuan diff --git a/README.md b/README.md index ded6e3b16..25de2e5aa 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,17 @@ Default: false `allowCleartextPasswords=true` allows using the [cleartext client side plugin](https://dev.mysql.com/doc/en/cleartext-pluggable-authentication.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network. + +##### `allowFallbackToPlaintext` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +`allowFallbackToPlaintext=true` acts like a `--ssl-mode=PREFERRED` MySQL client as described in [Command Options for Connecting to the Server](https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode) + ##### `allowNativePasswords` ``` diff --git a/auth.go b/auth.go index 26f8723f5..1ff203e57 100644 --- a/auth.go +++ b/auth.go @@ -275,7 +275,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { } // unlike caching_sha2_password, sha256_password does not accept // cleartext password on unix transport. - if mc.cfg.tls != nil { + if mc.cfg.TLS != nil { // write cleartext auth packet return append([]byte(mc.cfg.Passwd), 0), nil } @@ -351,7 +351,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } case cachingSha2PasswordPerformFullAuthentication: - if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + if mc.cfg.TLS != nil || mc.cfg.Net == "unix" { // write cleartext auth packet err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) if err != nil { diff --git a/auth_test.go b/auth_test.go index 3bce7fe22..3ce0ea6e0 100644 --- a/auth_test.go +++ b/auth_test.go @@ -291,7 +291,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { // Hack to make the caching_sha2_password plugin believe that the connection // is secure - mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} // check written auth response authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 @@ -663,7 +663,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // hack to make the caching_sha2_password plugin believe that the connection // is secure - mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, 62, 94, 83, 80, 52, 85} @@ -676,7 +676,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { } // unset TLS config to prevent the actual establishment of a TLS wrapper - mc.cfg.tls = nil + mc.cfg.TLS = nil err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { @@ -866,7 +866,7 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { // Hack to make the caching_sha2_password plugin believe that the connection // is secure - mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} // auth switch request conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, @@ -1299,7 +1299,7 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { // Hack to make the caching_sha2_password plugin believe that the connection // is secure - mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + mc.cfg.TLS = &tls.Config{InsecureSkipVerify: true} // auth switch request conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, diff --git a/dsn.go b/dsn.go index a306d66a3..4b71aaab0 100644 --- a/dsn.go +++ b/dsn.go @@ -46,22 +46,23 @@ type Config struct { ServerPubKey string // Server public key name pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name - tls *tls.Config // TLS configuration + TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig Timeout time.Duration // Dial timeout ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout - AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE - AllowCleartextPasswords bool // Allows the cleartext client side plugin - AllowNativePasswords bool // Allows the native password authentication method - AllowOldPasswords bool // Allows the old insecure password method - CheckConnLiveness bool // Check connections for liveness before using them - ClientFoundRows bool // Return number of matching rows instead of rows changed - ColumnsWithAlias bool // Prepend table alias to column names - InterpolateParams bool // Interpolate placeholders into query string - MultiStatements bool // Allow multiple statements in one query - ParseTime bool // Parse time values to time.Time - RejectReadOnly bool // Reject read-only connections + AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE + AllowCleartextPasswords bool // Allows the cleartext client side plugin + AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS + AllowNativePasswords bool // Allows the native password authentication method + AllowOldPasswords bool // Allows the old insecure password method + CheckConnLiveness bool // Check connections for liveness before using them + ClientFoundRows bool // Return number of matching rows instead of rows changed + ColumnsWithAlias bool // Prepend table alias to column names + InterpolateParams bool // Interpolate placeholders into query string + MultiStatements bool // Allow multiple statements in one query + ParseTime bool // Parse time values to time.Time + RejectReadOnly bool // Reject read-only connections } // NewConfig creates a new Config and sets default values. @@ -77,8 +78,8 @@ func NewConfig() *Config { func (cfg *Config) Clone() *Config { cp := *cfg - if cp.tls != nil { - cp.tls = cfg.tls.Clone() + if cp.TLS != nil { + cp.TLS = cfg.TLS.Clone() } if len(cp.Params) > 0 { cp.Params = make(map[string]string, len(cfg.Params)) @@ -119,24 +120,29 @@ func (cfg *Config) normalize() error { cfg.Addr = ensureHavePort(cfg.Addr) } - switch cfg.TLSConfig { - case "false", "": - // don't set anything - case "true": - cfg.tls = &tls.Config{} - case "skip-verify", "preferred": - cfg.tls = &tls.Config{InsecureSkipVerify: true} - default: - cfg.tls = getTLSConfigClone(cfg.TLSConfig) - if cfg.tls == nil { - return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) + if cfg.TLS == nil { + switch cfg.TLSConfig { + case "false", "": + // don't set anything + case "true": + cfg.TLS = &tls.Config{} + case "skip-verify": + cfg.TLS = &tls.Config{InsecureSkipVerify: true} + case "preferred": + cfg.TLS = &tls.Config{InsecureSkipVerify: true} + cfg.AllowFallbackToPlaintext = true + default: + cfg.TLS = getTLSConfigClone(cfg.TLSConfig) + if cfg.TLS == nil { + return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) + } } } - if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + if cfg.TLS != nil && cfg.TLS.ServerName == "" && !cfg.TLS.InsecureSkipVerify { host, _, err := net.SplitHostPort(cfg.Addr) if err == nil { - cfg.tls.ServerName = host + cfg.TLS.ServerName = host } } @@ -204,6 +210,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true") } + if cfg.AllowFallbackToPlaintext { + writeDSNParam(&buf, &hasParam, "allowFallbackToPlaintext", "true") + } + if !cfg.AllowNativePasswords { writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false") } @@ -391,6 +401,14 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // Allow fallback to unencrypted connection if server does not support TLS + case "allowFallbackToPlaintext": + var isBool bool + cfg.AllowFallbackToPlaintext, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + // Use native password authentication case "allowNativePasswords": var isBool bool diff --git a/dsn_test.go b/dsn_test.go index fc6eea9c8..41a6a29fa 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -42,8 +42,8 @@ var testDSNs = []struct { "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, }, { - "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false}, + "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0&allowFallbackToPlaintext=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, @@ -82,7 +82,7 @@ func TestDSNParser(t *testing.T) { } // pointer not static - cfg.tls = nil + cfg.TLS = nil if !reflect.DeepEqual(cfg, tst.out) { t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) @@ -100,6 +100,7 @@ func TestDSNParserInvalid(t *testing.T) { "User:pass@tcp(1.2.3.4:3306)", // no trailing slash "net()/", // unknown default addr "user:pass@tcp(127.0.0.1:3306)/db/name", // invalid dbname + "user:password@/dbname?allowFallbackToPlaintext=PREFERRED", // wrong bool flag //"/dbname?arg=/some/unescaped/path", } @@ -118,7 +119,7 @@ func TestDSNReformat(t *testing.T) { t.Error(err.Error()) continue } - cfg1.tls = nil // pointer not static + cfg1.TLS = nil // pointer not static res1 := fmt.Sprintf("%+v", cfg1) dsn2 := cfg1.FormatDSN() @@ -127,7 +128,7 @@ func TestDSNReformat(t *testing.T) { t.Error(err.Error()) continue } - cfg2.tls = nil // pointer not static + cfg2.TLS = nil // pointer not static res2 := fmt.Sprintf("%+v", cfg2) if res1 != res2 { @@ -203,7 +204,7 @@ func TestDSNWithCustomTLS(t *testing.T) { if err != nil { t.Error(err.Error()) - } else if cfg.tls.ServerName != name { + } else if cfg.TLS.ServerName != name { t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) } @@ -214,7 +215,7 @@ func TestDSNWithCustomTLS(t *testing.T) { if err != nil { t.Error(err.Error()) - } else if cfg.tls.ServerName != name { + } else if cfg.TLS.ServerName != name { t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) } else if tlsCfg.ServerName != "" { t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst) @@ -229,11 +230,11 @@ func TestDSNTLSConfig(t *testing.T) { if err != nil { t.Error(err.Error()) } - if cfg.tls == nil { + if cfg.TLS == nil { t.Error("cfg.tls should not be nil") } - if cfg.tls.ServerName != expectedServerName { - t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + if cfg.TLS.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.TLS.ServerName) } dsn = "tcp(example.com)/?tls=true" @@ -241,11 +242,11 @@ func TestDSNTLSConfig(t *testing.T) { if err != nil { t.Error(err.Error()) } - if cfg.tls == nil { + if cfg.TLS == nil { t.Error("cfg.tls should not be nil") } - if cfg.tls.ServerName != expectedServerName { - t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + if cfg.TLS.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.TLS.ServerName) } } @@ -262,7 +263,7 @@ func TestDSNWithCustomTLSQueryEscape(t *testing.T) { if err != nil { t.Error(err.Error()) - } else if cfg.tls.ServerName != name { + } else if cfg.TLS.ServerName != name { t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn) } } @@ -335,12 +336,12 @@ func TestCloneConfig(t *testing.T) { t.Errorf("Config.Clone did not create a separate config struct") } - if cfg2.tls.ServerName != expectedServerName { - t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + if cfg2.TLS.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.TLS.ServerName) } - cfg2.tls.ServerName = "example2.com" - if cfg.tls.ServerName == cfg2.tls.ServerName { + cfg2.TLS.ServerName = "example2.com" + if cfg.TLS.ServerName == cfg2.TLS.ServerName { t.Errorf("changed cfg.tls.Server name should not propagate to original Config") } @@ -384,20 +385,20 @@ func TestNormalizeTLSConfig(t *testing.T) { cfg.normalize() - if cfg.tls == nil { + if cfg.TLS == nil { if tc.want != nil { t.Fatal("wanted a tls config but got nil instead") } return } - if cfg.tls.ServerName != tc.want.ServerName { + if cfg.TLS.ServerName != tc.want.ServerName { t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')", - tc.want.ServerName, cfg.tls.ServerName) + tc.want.ServerName, cfg.TLS.ServerName) } - if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify { + if cfg.TLS.InsecureSkipVerify != tc.want.InsecureSkipVerify { t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)", - tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify) + tc.want.InsecureSkipVerify, cfg.TLS.InsecureSkipVerify) } }) } diff --git a/packets.go b/packets.go index 003584c25..ee05c95a8 100644 --- a/packets.go +++ b/packets.go @@ -222,9 +222,9 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if mc.flags&clientProtocol41 == 0 { return nil, "", ErrOldProtocol } - if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - if mc.cfg.TLSConfig == "preferred" { - mc.cfg.tls = nil + if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { + if mc.cfg.AllowFallbackToPlaintext { + mc.cfg.TLS = nil } else { return nil, "", ErrNoTLS } @@ -292,7 +292,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // To enable TLS / SSL - if mc.cfg.tls != nil { + if mc.cfg.TLS != nil { clientFlags |= clientSSL } @@ -356,14 +356,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // SSL Connection Request Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest - if mc.cfg.tls != nil { + if mc.cfg.TLS != nil { // Send TLS / SSL request packet if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { return err } // Switch to TLS - tlsConn := tls.Client(mc.netConn, mc.cfg.tls) + tlsConn := tls.Client(mc.netConn, mc.cfg.TLS) if err := tlsConn.Handshake(); err != nil { return err } From 5cee457661043566c72c86b89aadbab7b88cce7a Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Fri, 2 Dec 2022 20:40:24 +0900 Subject: [PATCH 008/123] update changelog for Version 1.7 (#1376) --- CHANGELOG.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72a738ed5..77024a820 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,24 @@ +## Version 1.7 (2022-11-29) + +Changes: + + - Drop support of Go 1.12 (#1211) + - Refactoring `(*textRows).readRow` in a more clear way (#1230) + - util: Reduce boundary check in escape functions. (#1316) + - enhancement for mysqlConn handleAuthResult (#1250) + +New Features: + + - support Is comparison on MySQLError (#1210) + - return unsigned in database type name when necessary (#1238) + - Add API to express like a --ssl-mode=PREFERRED MySQL client (#1370) + - Add SQLState to MySQLError (#1321) + +Bugfixes: + + - Fix parsing 0 year. (#1257) + + ## Version 1.6 (2021-04-01) Changes: From 4591e42e65cf483147a7c7a4f4cfeac81b21c917 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Fri, 9 Dec 2022 20:51:20 +0900 Subject: [PATCH 009/123] bump actions/checkout@v3 and actions/setup-go@v3 (#1375) * bump actions/checkout@v3 and actions/setup-go@v3 * enable cache of actions/setup-go * Revert "enable cache of actions/setup-go" I don't know why, but some jobs fail with "Could not get cache folder paths". This reverts commit 185228e0e3110b182759332193ebd75ed7054477. --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 703203258..f5ba6b99c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,8 +66,8 @@ jobs: fail-fast: false matrix: ${{ fromJSON(needs.list.outputs.matrix) }} steps: - - uses: actions/checkout@v2 - - uses: actions/setup-go@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - uses: shogo82148/actions-setup-mysql@v1 From af380e92cd245f1815fa16f42ea65472a7cb06ee Mon Sep 17 00:00:00 2001 From: Samantha Date: Wed, 8 Mar 2023 03:16:29 -0500 Subject: [PATCH 010/123] Use SET syntax as specified in the MySQL documentation (#1402) --- AUTHORS | 1 + connection.go | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index 051327519..6f7041c7a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -83,6 +83,7 @@ Reed Allman Richard Wilkes Robert Russell Runrioter Wung +Samantha Frank Santhosh Kumar Tekuri Sho Iizuka Sho Ikeda diff --git a/connection.go b/connection.go index 9539077cb..947a883e3 100644 --- a/connection.go +++ b/connection.go @@ -71,10 +71,10 @@ func (mc *mysqlConn) handleParams() (err error) { cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1)) cmdSet.WriteString("SET ") } else { - cmdSet.WriteByte(',') + cmdSet.WriteString(", ") } cmdSet.WriteString(param) - cmdSet.WriteByte('=') + cmdSet.WriteString(" = ") cmdSet.WriteString(val) } } From d83ecdc268ff92fa198c0bf64356a1479cc83438 Mon Sep 17 00:00:00 2001 From: Phil Porada Date: Wed, 29 Mar 2023 21:34:18 -0400 Subject: [PATCH 011/123] Add go1.20 and mariadb10.11 to the testing matrix (#1403) * Add go1.20 and mariadb10.11 to the testing matrix * Use latest upstream actions-setup-mysql which has support for mariadb 10.11 * Update authors file --- .github/workflows/test.yml | 6 ++++-- AUTHORS | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f5ba6b99c..d45ed0fa9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,8 +23,9 @@ jobs: import os go = [ # Keep the most recent production release at the top - '1.19', + '1.20', # Older production releases + '1.19', '1.18', '1.17', '1.16', @@ -36,6 +37,7 @@ jobs: '8.0', '5.7', '5.6', + 'mariadb-10.11', 'mariadb-10.6', 'mariadb-10.5', 'mariadb-10.4', @@ -70,7 +72,7 @@ jobs: - uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - - uses: shogo82148/actions-setup-mysql@v1 + - uses: shogo82148/actions-setup-mysql@v1.15.0 with: mysql-version: ${{ matrix.mysql }} user: ${{ env.MYSQL_TEST_USER }} diff --git a/AUTHORS b/AUTHORS index 6f7041c7a..fb1478c3b 100644 --- a/AUTHORS +++ b/AUTHORS @@ -78,6 +78,7 @@ Olivier Mengué oscarzhao Paul Bonser Peter Schultz +Phil Porada Rebecca Chin Reed Allman Richard Wilkes From f0e16c6977aae7045c058989971467759e470e99 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 14 Apr 2023 19:00:15 +0900 Subject: [PATCH 012/123] Increase default maxAllowedPacket size. (#1411) 64MiB is same to MySQL 8.0. --- README.md | 2 +- const.go | 2 +- errors.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 25de2e5aa..252bbefdf 100644 --- a/README.md +++ b/README.md @@ -282,7 +282,7 @@ Please keep in mind, that param values must be [url.QueryEscape](https://golang. ##### `maxAllowedPacket` ``` Type: decimal number -Default: 4194304 +Default: 64*1024*1024 ``` Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. diff --git a/const.go b/const.go index b1e6b85ef..64e2bced6 100644 --- a/const.go +++ b/const.go @@ -10,7 +10,7 @@ package mysql const ( defaultAuthPlugin = "mysql_native_password" - defaultMaxAllowedPacket = 4 << 20 // 4 MiB + defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" diff --git a/errors.go b/errors.go index 7c037e7d6..ff9a8f088 100644 --- a/errors.go +++ b/errors.go @@ -27,7 +27,7 @@ var ( ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") ErrPktSync = errors.New("commands out of sync. You can't run this command now") ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") - ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") + ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the `Config.MaxAllowedPacket`") ErrBusyBuffer = errors.New("busy buffer") // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. From faedeff6d3187aad8fa1d08a206d082f8c371656 Mon Sep 17 00:00:00 2001 From: Simon J Mudd Date: Sat, 15 Apr 2023 15:38:33 +0200 Subject: [PATCH 013/123] Correct maxAllowedPacket default value mentioned in docs to match the new setting (#1412) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 252bbefdf..3b5d229aa 100644 --- a/README.md +++ b/README.md @@ -285,7 +285,7 @@ Type: decimal number Default: 64*1024*1024 ``` -Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. +Max packet size allowed in bytes. The default value is 64 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*. ##### `multiStatements` From 8503110d880a06b863c5a93c8214068c71f16e9d Mon Sep 17 00:00:00 2001 From: cui fliter Date: Tue, 25 Apr 2023 13:46:24 +0800 Subject: [PATCH 014/123] fix some comments (#1417) Signed-off-by: cui fliter --- driver_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/driver_test.go b/driver_test.go index 4850498d0..a1c776728 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2703,7 +2703,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { if err := row.Scan(&v); err != nil { dbt.Fatal(err) } - // Because writer transaction wasn't commited yet, it should be available + // Because writer transaction wasn't committed yet, it should be available if v != 0 { dbt.Errorf("expected val to be 0, got %d", v) } @@ -2717,7 +2717,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { if err := row.Scan(&v); err != nil { dbt.Fatal(err) } - // Data written by writer transaction is already commited, it should be selectable + // Data written by writer transaction is already committed, it should be selectable if v != 1 { dbt.Errorf("expected val to be 1, got %d", v) } From f20b2863636093e5fbf1481b59bdaff3b0fbb779 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 25 Apr 2023 19:02:15 +0900 Subject: [PATCH 015/123] Update changelog for version 1.7.1 (#1418) --- CHANGELOG.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 77024a820..5166e4adb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,16 @@ +## Version 1.7.1 (2023-04-25) + +Changes: + + - bump actions/checkout@v3 and actions/setup-go@v3 (#1375) + - Add go1.20 and mariadb10.11 to the testing matrix (#1403) + - Increase default maxAllowedPacket size. (#1411) + +Bugfixes: + + - Use SET syntax as specified in the MySQL documentation (#1402) + + ## Version 1.7 (2022-11-29) Changes: From aa0194dbeccdb9e79d5775f0a8903c3cdbb4e753 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 25 Apr 2023 20:36:57 +0900 Subject: [PATCH 016/123] Drop Go 1.13-17 support (#1420) Start v1.8 development --- .github/workflows/test.yml | 11 +++-------- README.md | 4 ++-- go.mod | 2 +- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d45ed0fa9..cd474767b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,11 +27,6 @@ jobs: # Older production releases '1.19', '1.18', - '1.17', - '1.16', - '1.15', - '1.14', - '1.13', ] mysql = [ '8.0', @@ -47,7 +42,7 @@ jobs: includes = [] # Go versions compatibility check for v in go[1:]: - includes.append({'os': 'ubuntu-latest', 'go': v, 'mysql': mysql[0]}) + includes.append({'os': 'ubuntu-latest', 'go': v, 'mysql': mysql[0]}) matrix = { # OS vs MySQL versions @@ -69,10 +64,10 @@ jobs: matrix: ${{ fromJSON(needs.list.outputs.matrix) }} steps: - uses: actions/checkout@v3 - - uses: actions/setup-go@v3 + - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - - uses: shogo82148/actions-setup-mysql@v1.15.0 + - uses: shogo82148/actions-setup-mysql@v1.16.0 with: mysql-version: ${{ matrix.mysql }} user: ${{ env.MYSQL_TEST_USER }} diff --git a/README.md b/README.md index 3b5d229aa..5a242e9d7 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,8 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Optional placeholder interpolation ## Requirements - * Go 1.13 or higher. We aim to support the 3 latest versions of Go. - * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) + * Go 1.18 or higher. We aim to support the 3 latest versions of Go. + * MySQL (5.6+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) --------------------------------------- diff --git a/go.mod b/go.mod index 251110478..77bbb8dbf 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/go-sql-driver/mysql -go 1.13 +go 1.18 From cffc85ce9efe406a98c1d82749a237cc0338a8b2 Mon Sep 17 00:00:00 2001 From: Evil Puncker Date: Tue, 25 Apr 2023 19:10:42 -0300 Subject: [PATCH 017/123] Reduced allocation on connection.go (#1421) reduces allocations when there is only one param because current calculation is off by 2 --- connection.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connection.go b/connection.go index 947a883e3..0aeef207b 100644 --- a/connection.go +++ b/connection.go @@ -68,7 +68,7 @@ func (mc *mysqlConn) handleParams() (err error) { default: if cmdSet.Len() == 0 { // Heuristic: 29 chars for each other key=value to reduce reallocations - cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1)) + cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1)) cmdSet.WriteString("SET ") } else { cmdSet.WriteString(", ") From fbfb3f6a34bd0d4e73e1569831e054ec36b38ce9 Mon Sep 17 00:00:00 2001 From: jypelle <52546084+jypelle@users.noreply.github.com> Date: Mon, 1 May 2023 17:52:55 +0200 Subject: [PATCH 018/123] Adding DeregisterDialContext (#1422) Co-authored-by: jypelle --- AUTHORS | 1 + driver.go | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/AUTHORS b/AUTHORS index fb1478c3b..ea9b96789 100644 --- a/AUTHORS +++ b/AUTHORS @@ -47,6 +47,7 @@ INADA Naoki Jacek Szwec James Harr Janek Vedock +Jean-Yves Pellé Jeff Hodges Jeffrey Charles Jerome Meyer diff --git a/driver.go b/driver.go index ad7aec215..8b0c3ec0a 100644 --- a/driver.go +++ b/driver.go @@ -55,6 +55,17 @@ func RegisterDialContext(net string, dial DialContextFunc) { dials[net] = dial } +// DeregisterDialContext removes the custom dial function registered with the given net. +func DeregisterDialContext(net string) { + dialsLock.Lock() + defer dialsLock.Unlock() + if dials != nil { + if _, ok := dials[net]; ok { + delete(dials, net) + } + } +} + // RegisterDial registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. // addr is passed as a parameter to the dial function. From 191a7c4c519ef60cf3e8656fde8728eee9194308 Mon Sep 17 00:00:00 2001 From: frozenbonito Date: Thu, 4 May 2023 23:30:22 +0900 Subject: [PATCH 019/123] Make logger configurable per Connector (#1408) --- AUTHORS | 1 + auth.go | 2 +- connection.go | 16 ++++++++-------- connection_test.go | 1 + connector.go | 2 +- driver_test.go | 2 +- dsn.go | 6 ++++++ dsn_test.go | 34 +++++++++++++++++----------------- errors.go | 12 +++++++++--- errors_test.go | 6 +++--- packets.go | 26 +++++++++++++------------- packets_test.go | 1 + statement.go | 4 ++-- 13 files changed, 64 insertions(+), 49 deletions(-) diff --git a/AUTHORS b/AUTHORS index ea9b96789..129ca665a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -96,6 +96,7 @@ Stan Putrya Stanley Gunawan Steven Hartland Tan Jinhua <312841925 at qq.com> +Tetsuro Aoki Thomas Wodarek Tim Ruffles Tom Jenkinson diff --git a/auth.go b/auth.go index 1ff203e57..b591e7b8a 100644 --- a/auth.go +++ b/auth.go @@ -291,7 +291,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { return enc, err default: - errLog.Print("unknown auth plugin:", plugin) + mc.cfg.Logger.Print("unknown auth plugin:", plugin) return nil, ErrUnknownPlugin } } diff --git a/connection.go b/connection.go index 0aeef207b..a7da9e7e2 100644 --- a/connection.go +++ b/connection.go @@ -105,7 +105,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } var q string @@ -147,7 +147,7 @@ func (mc *mysqlConn) cleanup() { return } if err := mc.netConn.Close(); err != nil { - errLog.Print(err) + mc.cfg.Logger.Print(err) } } @@ -163,14 +163,14 @@ func (mc *mysqlConn) error() error { func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. - errLog.Print(err) + mc.cfg.Logger.Print(err) return nil, driver.ErrBadConn } @@ -204,7 +204,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf, err := mc.buf.takeCompleteBuffer() if err != nil { // can not take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return "", ErrInvalidConn } buf = buf[:0] @@ -296,7 +296,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -357,7 +357,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -451,7 +451,7 @@ func (mc *mysqlConn) finish() { // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { if mc.closed.Load() { - errLog.Print(ErrInvalidConn) + mc.cfg.Logger.Print(ErrInvalidConn) return driver.ErrBadConn } diff --git a/connection_test.go b/connection_test.go index b6764a2f6..98c985ae1 100644 --- a/connection_test.go +++ b/connection_test.go @@ -179,6 +179,7 @@ func TestPingErrInvalidConn(t *testing.T) { buf: newBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), + cfg: NewConfig(), } err := ms.Ping(context.Background()) diff --git a/connector.go b/connector.go index d567b4e4f..a5c988e13 100644 --- a/connector.go +++ b/connector.go @@ -92,7 +92,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { authResp, err := mc.auth(authData, plugin) if err != nil { // try the default auth plugin, if using the requested plugin failed - errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + c.cfg.Logger.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin authResp, err = mc.auth(authData, plugin) if err != nil { diff --git a/driver_test.go b/driver_test.go index a1c776728..1741a13ef 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1995,7 +1995,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) { func TestUnixSocketAuthFail(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { // Save the current logger so we can restore it. - oldLogger := errLog + oldLogger := defaultLogger // Set a new logger so we can capture its output. buffer := bytes.NewBuffer(make([]byte, 0, 64)) diff --git a/dsn.go b/dsn.go index 4b71aaab0..ded459c94 100644 --- a/dsn.go +++ b/dsn.go @@ -50,6 +50,7 @@ type Config struct { Timeout time.Duration // Dial timeout ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout + Logger Logger // Logger AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -71,6 +72,7 @@ func NewConfig() *Config { Collation: defaultCollation, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, + Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, } @@ -153,6 +155,10 @@ func (cfg *Config) normalize() error { } } + if cfg.Logger == nil { + cfg.Logger = defaultLogger + } + return nil } diff --git a/dsn_test.go b/dsn_test.go index 41a6a29fa..cb97d557e 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -22,55 +22,55 @@ var testDSNs = []struct { out *Config }{{ "username:password@protocol(address)/dbname?param=value", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, }, { "user@unix(/path/to/socket)/dbname?charset=utf8", - &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, }, { "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, Logger: defaultLogger, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, }, { "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0&allowFallbackToPlaintext=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, Logger: defaultLogger, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", - &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "@/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:p@/ssword@/", - &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "unix/?arg=%2Fsome%2Fpath.ext", - &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "tcp(127.0.0.1)/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "tcp(de:ad:be:ef::ca:fe)/dbname", - &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, } diff --git a/errors.go b/errors.go index ff9a8f088..5680b6c05 100644 --- a/errors.go +++ b/errors.go @@ -37,20 +37,26 @@ var ( errBadConnNoWrite = errors.New("bad connection") ) -var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) +var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) // Logger is used to log critical error messages. type Logger interface { Print(v ...interface{}) } -// SetLogger is used to set the logger for critical errors. +// NopLogger is a nop implementation of the Logger interface. +type NopLogger struct{} + +// Print implements Logger interface. +func (nl *NopLogger) Print(_ ...interface{}) {} + +// SetLogger is used to set the default logger for critical errors. // The initial logger is os.Stderr. func SetLogger(logger Logger) error { if logger == nil { return errors.New("logger is nil") } - errLog = logger + defaultLogger = logger return nil } diff --git a/errors_test.go b/errors_test.go index 43213f98e..53d634454 100644 --- a/errors_test.go +++ b/errors_test.go @@ -16,9 +16,9 @@ import ( ) func TestErrorsSetLogger(t *testing.T) { - previous := errLog + previous := defaultLogger defer func() { - errLog = previous + defaultLogger = previous }() // set up logger @@ -28,7 +28,7 @@ func TestErrorsSetLogger(t *testing.T) { // print SetLogger(logger) - errLog.Print("test") + defaultLogger.Print("test") // check result if actual := buffer.String(); actual != expected { diff --git a/packets.go b/packets.go index ee05c95a8..8fd67997b 100644 --- a/packets.go +++ b/packets.go @@ -34,7 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - errLog.Print(err) + mc.cfg.Logger.Print(err) mc.Close() return nil, ErrInvalidConn } @@ -56,7 +56,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if pktLen == 0 { // there was no previous packet if prevData == nil { - errLog.Print(ErrMalformPkt) + mc.cfg.Logger.Print(ErrMalformPkt) mc.Close() return nil, ErrInvalidConn } @@ -70,7 +70,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - errLog.Print(err) + mc.cfg.Logger.Print(err) mc.Close() return nil, ErrInvalidConn } @@ -119,7 +119,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { } } if err != nil { - errLog.Print("closing bad idle connection: ", err) + mc.cfg.Logger.Print("closing bad idle connection: ", err) mc.Close() return driver.ErrBadConn } @@ -161,7 +161,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) mc.cleanup() - errLog.Print(ErrMalformPkt) + mc.cfg.Logger.Print(ErrMalformPkt) } else { if cerr := mc.canceled.Value(); cerr != nil { return cerr @@ -171,7 +171,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { return errBadConnNoWrite } mc.cleanup() - errLog.Print(err) + mc.cfg.Logger.Print(err) } return ErrInvalidConn } @@ -322,7 +322,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data, err := mc.buf.takeSmallBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -404,7 +404,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { data, err := mc.buf.takeSmallBuffer(pktLen) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -424,7 +424,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -443,7 +443,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -464,7 +464,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -938,7 +938,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } @@ -1137,7 +1137,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) if err = mc.buf.store(data); err != nil { - errLog.Print(err) + mc.cfg.Logger.Print(err) return errBadConnNoWrite } } diff --git a/packets_test.go b/packets_test.go index b61e4dbf7..cacec1c68 100644 --- a/packets_test.go +++ b/packets_test.go @@ -265,6 +265,7 @@ func TestReadPacketFail(t *testing.T) { mc := &mysqlConn{ buf: newBuffer(conn), closech: make(chan struct{}), + cfg: NewConfig(), } // illegal empty (stand-alone) packet diff --git a/statement.go b/statement.go index 10ece8bd6..d8b3975a5 100644 --- a/statement.go +++ b/statement.go @@ -51,7 +51,7 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.mc.closed.Load() { - errLog.Print(ErrInvalidConn) + stmt.mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command @@ -99,7 +99,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if stmt.mc.closed.Load() { - errLog.Print(ErrInvalidConn) + stmt.mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command From 0b40aee005dafcce42833210b672c1f0930008aa Mon Sep 17 00:00:00 2001 From: wayyoungboy <35394786+wayyoungboy@users.noreply.github.com> Date: Sat, 6 May 2023 17:07:41 +0800 Subject: [PATCH 020/123] Avoid panic in TestRowsColumnTypes (#1426) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * optimized the execution flow of the TestRowsColumnTypes unit test * Update driver_test.go --------- Co-authored-by: 渠磊 Co-authored-by: Inada Naoki --- driver_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/driver_test.go b/driver_test.go index 1741a13ef..d24488a82 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2945,7 +2945,10 @@ func TestRowsColumnTypes(t *testing.T) { continue } } - + // Avoid panic caused by nil scantype. + if t.Failed() { + return + } values := make([]interface{}, len(tt)) for i := range values { values[i] = reflect.New(types[i]).Interface() From 736b6faabe4947c9a0a7fef6407839dc72114011 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 7 May 2023 20:24:09 +0900 Subject: [PATCH 021/123] Stop `ColumnTypeScanType()` from returning `sql.RawBytes` (#1424) ColumnTypeScanType() returns []byte, string, or sql.NullString. It returned sql.RawBytes but it was dangoerous. Fixes #1423 --- driver_test.go | 67 +++++++++++++++++++++++++------------------------- fields.go | 48 +++++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/driver_test.go b/driver_test.go index d24488a82..50c617274 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2778,13 +2778,18 @@ func TestRowsColumnTypes(t *testing.T) { nd1 := sql.NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} nd2 := sql.NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} ndNULL := sql.NullTime{Time: time.Time{}, Valid: false} - rbNULL := sql.RawBytes(nil) - rb0 := sql.RawBytes("0") - rb42 := sql.RawBytes("42") - rbTest := sql.RawBytes("Test") - rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00 - rbx0 := sql.RawBytes("\x00") - rbx42 := sql.RawBytes("\x42") + bNULL := []byte(nil) + nsNULL := sql.NullString{String: "", Valid: false} + // Helper function to build NullString from string literal. + ns := func(s string) sql.NullString { return sql.NullString{String: s, Valid: true} } + ns0 := ns("0") + b0 := []byte("0") + b42 := []byte("42") + nsTest := ns("Test") + bTest := []byte("Test") + b0pad4 := []byte("0\x00\x00\x00") // BINARY right-pads values with 0x00 + bx0 := []byte("\x00") + bx42 := []byte("\x42") var columns = []struct { name string @@ -2797,7 +2802,7 @@ func TestRowsColumnTypes(t *testing.T) { valuesIn [3]string valuesOut [3]interface{} }{ - {"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}}, + {"bit8null", "BIT(8)", "BIT", scanTypeBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{bx0, bNULL, bx42}}, {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, @@ -2817,24 +2822,24 @@ func TestRowsColumnTypes(t *testing.T) { {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}}, {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, - {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}}, - {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}}, - {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}}, - {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}}, - {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}}, - {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, - {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}}, - {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, - {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, - {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeString, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{"0.000000", "13.370000", "1234.123456"}}, + {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeNullString, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{ns("0.000000"), nsNULL, ns("1234.123456")}}, + {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeString, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{"0.0000", "13.3700", "1234.1235"}}, + {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeNullString, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{ns("0.0000"), nsNULL, ns("1234.1235")}}, + {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeString, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{"0", "13", "-12345"}}, + {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeNullString, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{ns0, nsNULL, ns("-12345")}}, + {"char25null", "CHAR(25)", "CHAR", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}}, + {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}}, + {"binary4null", "BINARY(4)", "BINARY", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0pad4, bNULL, bTest}}, + {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}}, + {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0, bNULL, bTest}}, + {"tinytextnull", "TINYTEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}}, + {"blobnull", "BLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0, bNULL, bTest}}, + {"textnull", "TEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}}, + {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}}, + {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}}, + {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}}, + {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}}, {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, @@ -2959,14 +2964,10 @@ func TestRowsColumnTypes(t *testing.T) { if err != nil { t.Fatalf("failed to scan values in %v", err) } - for j := range values { - value := reflect.ValueOf(values[j]).Elem().Interface() + for j, value := range values { + value := reflect.ValueOf(value).Elem().Interface() if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { - if columns[j].scanType == scanTypeRawBytes { - t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) - } else { - t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) - } + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) } } i++ diff --git a/fields.go b/fields.go index e0654a83d..18c23e0cb 100644 --- a/fields.go +++ b/fields.go @@ -110,21 +110,23 @@ func (mf *mysqlField) typeDatabaseName() string { } var ( - scanTypeFloat32 = reflect.TypeOf(float32(0)) - scanTypeFloat64 = reflect.TypeOf(float64(0)) - scanTypeInt8 = reflect.TypeOf(int8(0)) - scanTypeInt16 = reflect.TypeOf(int16(0)) - scanTypeInt32 = reflect.TypeOf(int32(0)) - scanTypeInt64 = reflect.TypeOf(int64(0)) - scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) - scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) - scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) - scanTypeUint8 = reflect.TypeOf(uint8(0)) - scanTypeUint16 = reflect.TypeOf(uint16(0)) - scanTypeUint32 = reflect.TypeOf(uint32(0)) - scanTypeUint64 = reflect.TypeOf(uint64(0)) - scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) - scanTypeUnknown = reflect.TypeOf(new(interface{})) + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeString = reflect.TypeOf("") + scanTypeNullString = reflect.TypeOf(sql.NullString{}) + scanTypeBytes = reflect.TypeOf([]byte{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) ) type mysqlField struct { @@ -187,12 +189,18 @@ func (mf *mysqlField) scanType() reflect.Type { } return scanTypeNullFloat + case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, + fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry: + if mf.charSet == 63 /* binary */ { + return scanTypeBytes + } + fallthrough case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, - fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, - fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, - fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, - fieldTypeTime: - return scanTypeRawBytes + fieldTypeEnum, fieldTypeSet, fieldTypeJSON, fieldTypeTime: + if mf.flags&flagNotNULL != 0 { + return scanTypeString + } + return scanTypeNullString case fieldTypeDate, fieldTypeNewDate, fieldTypeTimestamp, fieldTypeDateTime: From 081308f66228fdc51224614d1cf414c918cc1596 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 7 May 2023 20:25:21 +0900 Subject: [PATCH 022/123] Add benchmark to receive massive rows. (#1415) --- benchmark_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/benchmark_test.go b/benchmark_test.go index 97ed781f8..fc70df60d 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -372,3 +372,59 @@ func BenchmarkQueryRawBytes(b *testing.B) { }) } } + +// BenchmarkReceiveMassiveRows measures performance of receiving large number of rows. +func BenchmarkReceiveMassiveRows(b *testing.B) { + // Setup -- prepare 10000 rows. + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val TEXT)") + defer db.Close() + + sval := strings.Repeat("x", 50) + stmt, err := db.Prepare(`INSERT INTO foo (id, val) VALUES (?, ?)` + strings.Repeat(",(?,?)", 99)) + if err != nil { + b.Errorf("failed to prepare query: %v", err) + return + } + for i := 0; i < 10000; i += 100 { + args := make([]any, 200) + for j := 0; j < 100; j++ { + args[j*2] = i + j + args[j*2+1] = sval + } + _, err := stmt.Exec(args...) + if err != nil { + b.Error(err) + return + } + } + stmt.Close() + + // Use b.Run() to skip expensive setup. + b.Run("query", func(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + rows, err := db.Query(`SELECT id, val FROM foo`) + if err != nil { + b.Errorf("failed to select: %v", err) + return + } + for rows.Next() { + var i int + var s sql.RawBytes + err = rows.Scan(&i, &s) + if err != nil { + b.Errorf("failed to scan: %v", err) + _ = rows.Close() + return + } + } + if err = rows.Err(); err != nil { + b.Errorf("failed to read rows: %v", err) + } + _ = rows.Close() + } + }) +} From a841e816042356288f94f7c5a586d83040cb63ea Mon Sep 17 00:00:00 2001 From: Evan Elias Date: Wed, 17 May 2023 14:28:03 -0400 Subject: [PATCH 023/123] Fix ColumnType.DatabaseTypeName for mediumint unsigned (#1428) --- AUTHORS | 1 + README.md | 2 +- driver_test.go | 1 + fields.go | 3 +++ 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 129ca665a..24dc43652 100644 --- a/AUTHORS +++ b/AUTHORS @@ -33,6 +33,7 @@ Dave Protasowski DisposaBoy Egor Smolyakov Erwan Martin +Evan Elias Evan Shaw Frederick Mayle Gustavo Kristic diff --git a/README.md b/README.md index 5a242e9d7..ddb5cefc7 100644 --- a/README.md +++ b/README.md @@ -465,7 +465,7 @@ user:password@/ The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively. ## `ColumnType` Support -This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `BIGINT`. +This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported. All Unsigned database type names will be returned `UNSIGNED ` with `INT`, `TINYINT`, `SMALLINT`, `MEDIUMINT`, `BIGINT`. ## `context.Context` Support Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. diff --git a/driver_test.go b/driver_test.go index 50c617274..118c0d7ba 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2816,6 +2816,7 @@ func TestRowsColumnTypes(t *testing.T) { {"tinyuint", "TINYINT UNSIGNED NOT NULL", "UNSIGNED TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, {"smalluint", "SMALLINT UNSIGNED NOT NULL", "UNSIGNED SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, {"biguint", "BIGINT UNSIGNED NOT NULL", "UNSIGNED BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"mediumuint", "MEDIUMINT UNSIGNED NOT NULL", "UNSIGNED MEDIUMINT", scanTypeUint32, false, 0, 0, [3]string{"0", "16777215", "42"}, [3]interface{}{uint32(0), uint32(16777215), uint32(42)}}, {"uint13", "INT(13) UNSIGNED NOT NULL", "UNSIGNED INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, diff --git a/fields.go b/fields.go index 18c23e0cb..ae709363f 100644 --- a/fields.go +++ b/fields.go @@ -37,6 +37,9 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeGeometry: return "GEOMETRY" case fieldTypeInt24: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED MEDIUMINT" + } return "MEDIUMINT" case fieldTypeJSON: return "JSON" From 72e78ee26806a26405ee462c4cf82406f094a143 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 19 May 2023 23:04:35 +0900 Subject: [PATCH 024/123] README: Update multistatement (#1431) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ddb5cefc7..ad7ca718e 100644 --- a/README.md +++ b/README.md @@ -295,9 +295,9 @@ Valid Values: true, false Default: false ``` -Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded. +Allow multiple statements in one query. This can be used to bach multiple queries. Use [Rows.NextResultSet()](https://pkg.go.dev/database/sql#Rows.NextResultSet) to get result of the second and subsequent queries. -When `multiStatements` is used, `?` parameters must only be used in the first statement. +When `multiStatements` is used, `?` parameters must only be used in the first statement. [interpolateParams](#interpolateparams) can be used to avoid this limitation unless prepared statement is used explicitly. ##### `parseTime` From 924f8336da7226f4cd4bfac575d394ffa20aacb4 Mon Sep 17 00:00:00 2001 From: Daemonxiao <35677990+Daemonxiao@users.noreply.github.com> Date: Wed, 24 May 2023 00:44:19 +0800 Subject: [PATCH 025/123] Send connection attributes (#1389) Co-authored-by: Inada Naoki --- .github/workflows/test.yml | 1 + README.md | 9 ++++++++ connection.go | 1 + connector.go | 46 ++++++++++++++++++++++++++++++++++++- connector_test.go | 9 +++++--- const.go | 12 ++++++++++ driver.go | 11 ++++----- driver_test.go | 47 ++++++++++++++++++++++++++++++++++++++ dsn.go | 40 ++++++++++++++++++-------------- packets.go | 13 +++++++++++ packets_test.go | 7 +++++- utils.go | 5 ++++ 12 files changed, 173 insertions(+), 28 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cd474767b..b2ab5e82a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -79,6 +79,7 @@ jobs: ; TestConcurrent fails if max_connections is too large max_connections=50 local_infile=1 + performance_schema=on - name: setup database run: | mysql --user 'root' --host '127.0.0.1' -e 'create database gotest;' diff --git a/README.md b/README.md index ad7ca718e..5935afd0c 100644 --- a/README.md +++ b/README.md @@ -393,6 +393,15 @@ Default: 0 I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. +##### `connectionAttributes` + +``` +Type: comma-delimited string of user-defined "key:value" pairs +Valid Values: (:,:,...) +Default: none +``` + +[Connection attributes](https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html) are key-value pairs that application programs can pass to the server at connect time. ##### System Variables diff --git a/connection.go b/connection.go index a7da9e7e2..67cea1fcb 100644 --- a/connection.go +++ b/connection.go @@ -27,6 +27,7 @@ type mysqlConn struct { affectedRows uint64 insertId uint64 cfg *Config + connector *connector maxAllowedPacket int maxWriteSize int writeTimeout time.Duration diff --git a/connector.go b/connector.go index a5c988e13..6acf3dd50 100644 --- a/connector.go +++ b/connector.go @@ -11,11 +11,54 @@ package mysql import ( "context" "database/sql/driver" + "fmt" "net" + "os" + "strconv" + "strings" ) type connector struct { - cfg *Config // immutable private copy. + cfg *Config // immutable private copy. + encodedAttributes string // Encoded connection attributes. +} + +func encodeConnectionAttributes(textAttributes string) string { + connAttrsBuf := make([]byte, 0, 251) + + // default connection attributes + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid())) + + // user-defined connection attributes + for _, connAttr := range strings.Split(textAttributes, ",") { + attr := strings.SplitN(connAttr, ":", 2) + if len(attr) != 2 { + continue + } + for _, v := range attr { + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v) + } + } + + return string(connAttrsBuf) +} + +func newConnector(cfg *Config) (*connector, error) { + encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes) + if len(encodedAttributes) > 250 { + return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes) + } + return &connector{ + cfg: cfg, + encodedAttributes: encodedAttributes, + }, nil } // Connect implements driver.Connector interface. @@ -29,6 +72,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), cfg: c.cfg, + connector: c, } mc.parseTime = mc.cfg.ParseTime diff --git a/connector_test.go b/connector_test.go index 976903c5b..bedb44ce2 100644 --- a/connector_test.go +++ b/connector_test.go @@ -8,13 +8,16 @@ import ( ) func TestConnectorReturnsTimeout(t *testing.T) { - connector := &connector{&Config{ + connector, err := newConnector(&Config{ Net: "tcp", Addr: "1.1.1.1:1234", Timeout: 10 * time.Millisecond, - }} + }) + if err != nil { + t.Fatal(err) + } - _, err := connector.Connect(context.Background()) + _, err = connector.Connect(context.Background()) if err == nil { t.Fatal("error expected") } diff --git a/const.go b/const.go index 64e2bced6..0f2621a6f 100644 --- a/const.go +++ b/const.go @@ -8,12 +8,24 @@ package mysql +import "runtime" + const ( defaultAuthPlugin = "mysql_native_password" defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" + + // Connection attributes + // See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available + connAttrClientName = "_client_name" + connAttrClientNameValue = "Go-MySQL-Driver" + connAttrOS = "_os" + connAttrOSValue = runtime.GOOS + connAttrPlatform = "_platform" + connAttrPlatformValue = runtime.GOARCH + connAttrPid = "_pid" ) // MySQL constants documentation: diff --git a/driver.go b/driver.go index 8b0c3ec0a..c19e04207 100644 --- a/driver.go +++ b/driver.go @@ -85,8 +85,9 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - c := &connector{ - cfg: cfg, + c, err := newConnector(cfg) + if err != nil { + return nil, err } return c.Connect(context.Background()) } @@ -103,7 +104,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) { if err := cfg.normalize(); err != nil { return nil, err } - return &connector{cfg: cfg}, nil + return newConnector(cfg) } // OpenConnector implements driver.DriverContext. @@ -112,7 +113,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { if err != nil { return nil, err } - return &connector{ - cfg: cfg, - }, nil + return newConnector(cfg) } diff --git a/driver_test.go b/driver_test.go index 118c0d7ba..7c25aa905 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3214,3 +3214,50 @@ func TestConnectorTimeoutsWatchCancel(t *testing.T) { t.Errorf("connection not closed") } } + +func TestConnectionAttributes(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + attr1 := "attr1" + value1 := "value1" + attr2 := "foo" + value2 := "boo" + dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2) + + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + + var attrValue string + queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?" + rows := dbt.mustQuery(queryString, connAttrClientName) + if rows.Next() { + rows.Scan(&attrValue) + if attrValue != connAttrClientNameValue { + dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue) + } + } else { + dbt.Errorf("no data") + } + rows.Close() + + rows = dbt.mustQuery(queryString, attr2) + if rows.Next() { + rows.Scan(&attrValue) + if attrValue != value2 { + dbt.Errorf("expected %q, got %q", value2, attrValue) + } + } else { + dbt.Errorf("no data") + } + rows.Close() +} diff --git a/dsn.go b/dsn.go index ded459c94..7c788517c 100644 --- a/dsn.go +++ b/dsn.go @@ -34,23 +34,24 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { - User string // Username - Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) - DBName string // Database name - Params map[string]string // Connection parameters - Collation string // Connection collation - Loc *time.Location // Location for time.Time values - MaxAllowedPacket int // Max packet size allowed - ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key - TLSConfig string // TLS configuration name - TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig - Timeout time.Duration // Dial timeout - ReadTimeout time.Duration // I/O read timeout - WriteTimeout time.Duration // I/O write timeout - Logger Logger // Logger + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + Logger Logger // Logger AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin @@ -560,6 +561,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } + + // Connection attributes + case "connectionAttributes": + cfg.ConnectionAttributes = value + default: // lazy init if cfg.Params == nil { diff --git a/packets.go b/packets.go index 8fd67997b..d6a11fd21 100644 --- a/packets.go +++ b/packets.go @@ -285,6 +285,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientLocalFiles | clientPluginAuth | clientMultiResults | + clientConnectAttrs | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { @@ -318,6 +319,13 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pktLen += n + 1 } + // 1 byte to store length of all key-values + // NOTE: Actually, this is length encoded integer. + // But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer + // doesn't support buffer size more than 4096 bytes. + // TODO(methane): Rewrite buffer management. + pktLen += 1 + len(mc.connector.encodedAttributes) + // Calculate packet length and get buffer with that size data, err := mc.buf.takeSmallBuffer(pktLen + 4) if err != nil { @@ -394,6 +402,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data[pos] = 0x00 pos++ + // Connection Attributes + data[pos] = byte(len(mc.connector.encodedAttributes)) + pos++ + pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + // Send Auth packet return mc.writePacket(data[:pos]) } diff --git a/packets_test.go b/packets_test.go index cacec1c68..f429087e9 100644 --- a/packets_test.go +++ b/packets_test.go @@ -96,9 +96,14 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) + connector, err := newConnector(NewConfig()) + if err != nil { + panic(err) + } mc := &mysqlConn{ buf: newBuffer(conn), - cfg: NewConfig(), + cfg: connector.cfg, + connector: connector, netConn: conn, closech: make(chan struct{}), maxAllowedPacket: defaultMaxAllowedPacket, diff --git a/utils.go b/utils.go index 15dbd8d16..753ebd65c 100644 --- a/utils.go +++ b/utils.go @@ -616,6 +616,11 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) } +func appendLengthEncodedString(b []byte, s string) []byte { + b = appendLengthEncodedInteger(b, uint64(len(s))) + return append(b, s...) +} + // reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. // If cap(buf) is not enough, reallocate new buffer. func reserveBuffer(buf []byte, appendSize int) []byte { From d3e4fe64aaa1e99a19f711233dc682f2114ffbfd Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 25 May 2023 23:49:33 +0900 Subject: [PATCH 026/123] Use PathEscape for dbname in DSN. (#1432) Support for slashes in database names via url escape codes. On the other hand, '%' in DSN is now treated as percent-encoding. Co-authored-by: Brian Hendriks --- AUTHORS | 2 ++ README.md | 6 +++++ dsn.go | 8 +++++-- dsn_test.go | 66 ++++++++++++++++++++++++++++++----------------------- 4 files changed, 51 insertions(+), 31 deletions(-) diff --git a/AUTHORS b/AUTHORS index 24dc43652..7e4fac5a1 100644 --- a/AUTHORS +++ b/AUTHORS @@ -110,6 +110,7 @@ Xuehong Chan Zhenye Xie Zhixin Wen Ziheng Lyu +Brian Hendriks # Organizations @@ -127,3 +128,4 @@ Percona LLC Pivotal Inc. Stripe Inc. Zendesk Inc. +Dolthub Inc. diff --git a/README.md b/README.md index 5935afd0c..156aaa965 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,12 @@ This has the same effect as an empty DSN string: ``` +`dbname` is escaped by [PathEscape()]()https://pkg.go.dev/net/url#PathEscape) since v1.8.0. If your database name is `dbname/withslash`, it becomes: + +``` +/dbname%2Fwithslash +``` + Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct. #### Password diff --git a/dsn.go b/dsn.go index 7c788517c..3a6537e6c 100644 --- a/dsn.go +++ b/dsn.go @@ -203,7 +203,7 @@ func (cfg *Config) FormatDSN() string { // /dbname buf.WriteByte('/') - buf.WriteString(cfg.DBName) + buf.WriteString(url.PathEscape(cfg.DBName)) // [?param1=value1&...¶mN=valueN] hasParam := false @@ -365,7 +365,11 @@ func ParseDSN(dsn string) (cfg *Config, err error) { break } } - cfg.DBName = dsn[i+1 : j] + + dbname := dsn[i+1 : j] + if cfg.DBName, err = url.PathUnescape(dbname); err != nil { + return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err) + } break } diff --git a/dsn_test.go b/dsn_test.go index cb97d557e..8b623df01 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -50,6 +50,9 @@ var testDSNs = []struct { }, { "/dbname", &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "/dbname%2Fwithslash", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname/withslash", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "@/", &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, @@ -76,17 +79,20 @@ var testDSNs = []struct { func TestDSNParser(t *testing.T) { for i, tst := range testDSNs { - cfg, err := ParseDSN(tst.in) - if err != nil { - t.Error(err.Error()) - } + t.Run(tst.in, func(t *testing.T) { + cfg, err := ParseDSN(tst.in) + if err != nil { + t.Error(err.Error()) + return + } - // pointer not static - cfg.TLS = nil + // pointer not static + cfg.TLS = nil - if !reflect.DeepEqual(cfg, tst.out) { - t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) - } + if !reflect.DeepEqual(cfg, tst.out) { + t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) + } + }) } } @@ -113,27 +119,29 @@ func TestDSNParserInvalid(t *testing.T) { func TestDSNReformat(t *testing.T) { for i, tst := range testDSNs { - dsn1 := tst.in - cfg1, err := ParseDSN(dsn1) - if err != nil { - t.Error(err.Error()) - continue - } - cfg1.TLS = nil // pointer not static - res1 := fmt.Sprintf("%+v", cfg1) - - dsn2 := cfg1.FormatDSN() - cfg2, err := ParseDSN(dsn2) - if err != nil { - t.Error(err.Error()) - continue - } - cfg2.TLS = nil // pointer not static - res2 := fmt.Sprintf("%+v", cfg2) + t.Run(tst.in, func(t *testing.T) { + dsn1 := tst.in + cfg1, err := ParseDSN(dsn1) + if err != nil { + t.Error(err.Error()) + return + } + cfg1.TLS = nil // pointer not static + res1 := fmt.Sprintf("%+v", cfg1) - if res1 != res2 { - t.Errorf("%d. %q does not match %q", i, res2, res1) - } + dsn2 := cfg1.FormatDSN() + cfg2, err := ParseDSN(dsn2) + if err != nil { + t.Error(err.Error()) + return + } + cfg2.TLS = nil // pointer not static + res2 := fmt.Sprintf("%+v", cfg2) + + if res1 != res2 { + t.Errorf("%d. %q does not match %q", i, res2, res1) + } + }) } } From 7b4d7eb08bc4e705373ad835b2384df28676fb2f Mon Sep 17 00:00:00 2001 From: uji <49834542+uji@users.noreply.github.com> Date: Fri, 26 May 2023 11:32:30 +0900 Subject: [PATCH 027/123] all: replace ioutil pkg to new package (#1438) --- auth.go | 2 +- driver_test.go | 3 +-- utils.go | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/auth.go b/auth.go index b591e7b8a..e758e6d00 100644 --- a/auth.go +++ b/auth.go @@ -33,7 +33,7 @@ var ( // Note: The provided rsa.PublicKey instance is exclusively owned by the driver // after registering it and may not be modified. // -// data, err := ioutil.ReadFile("mykey.pem") +// data, err := os.ReadFile("mykey.pem") // if err != nil { // log.Fatal(err) // } diff --git a/driver_test.go b/driver_test.go index 7c25aa905..abf91a486 100644 --- a/driver_test.go +++ b/driver_test.go @@ -17,7 +17,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "log" "math" "net" @@ -1245,7 +1244,7 @@ func TestLoadData(t *testing.T) { dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") // Local File - file, err := ioutil.TempFile("", "gotest") + file, err := os.CreateTemp("", "gotest") defer os.Remove(file.Name()) if err != nil { dbt.Fatal(err) diff --git a/utils.go b/utils.go index 753ebd65c..a24197b93 100644 --- a/utils.go +++ b/utils.go @@ -36,7 +36,7 @@ var ( // registering it. // // rootCertPool := x509.NewCertPool() -// pem, err := ioutil.ReadFile("/path/ca-cert.pem") +// pem, err := os.ReadFile("/path/ca-cert.pem") // if err != nil { // log.Fatal(err) // } From 7b22099c7ea60190ef92f953ee62263a1808bd4b Mon Sep 17 00:00:00 2001 From: guangwu Date: Fri, 26 May 2023 12:52:11 +0800 Subject: [PATCH 028/123] code optimization (#1439) --- driver.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/driver.go b/driver.go index c19e04207..0ed8fa1c5 100644 --- a/driver.go +++ b/driver.go @@ -60,9 +60,7 @@ func DeregisterDialContext(net string) { dialsLock.Lock() defer dialsLock.Unlock() if dials != nil { - if _, ok := dials[net]; ok { - delete(dials, net) - } + delete(dials, net) } } From 99976f4f587dd1a26900f6dd91ca96f6e3e2f724 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 28 May 2023 01:43:28 +0900 Subject: [PATCH 029/123] Use `SET NAMES charset COLLATE collation`. (#1437) --- README.md | 10 ++++++---- connection.go | 9 +++++++-- dsn.go | 5 ++--- dsn_test.go | 34 +++++++++++++++++----------------- packets.go | 11 +++++++---- 5 files changed, 39 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 156aaa965..d747a7446 100644 --- a/README.md +++ b/README.md @@ -202,8 +202,7 @@ Default: none Sets the charset used for client-server interaction (`"SET NAMES "`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). -Usage of the `charset` parameter is discouraged because it issues additional queries to the server. -Unless you need the fallback behavior, please use `collation` instead. +See also [Unicode Support](#unicode-support). ##### `checkConnLiveness` @@ -232,6 +231,7 @@ The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You s Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)). +See also [Unicode Support](#unicode-support). ##### `clientFoundRows` @@ -511,9 +511,11 @@ However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` v ### Unicode support Since version 1.5 Go-MySQL-Driver automatically uses the collation ` utf8mb4_general_ci` by default. -Other collations / charsets can be set using the [`collation`](#collation) DSN parameter. +Other charsets / collations can be set using the [`charset`](#charset) or [`collation`](#collation) DSN parameter. -Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAMES utf8`) to the DSN to enable proper UTF-8 support. This is not necessary anymore. The [`collation`](#collation) parameter should be preferred to set another collation / charset than the default. +- When only the `charset` is specified, the `SET NAMES ` query is sent and the server's default collation is used. +- When both the `charset` and `collation` are specified, the `SET NAMES COLLATE ` query is sent. +- When only the `collation` is specified, the collation is specified in the protocol handshake and the `SET NAMES` query is not sent. This can save one roundtrip, but note that the server may ignore the specified collation silently and use the server's default charset/collation instead. See http://dev.mysql.com/doc/refman/8.0/en/charset-unicode.html for more details on MySQL's Unicode support. diff --git a/connection.go b/connection.go index 67cea1fcb..14a972b40 100644 --- a/connection.go +++ b/connection.go @@ -49,14 +49,19 @@ type mysqlConn struct { // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder + for param, val := range mc.cfg.Params { switch param { // Charset: character_set_connection, character_set_client, character_set_results case "charset": charsets := strings.Split(val, ",") - for i := range charsets { + for _, cs := range charsets { // ignore errors here - a charset may not exist - err = mc.exec("SET NAMES " + charsets[i]) + if mc.cfg.Collation != "" { + err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation) + } else { + err = mc.exec("SET NAMES " + cs) + } if err == nil { break } diff --git a/dsn.go b/dsn.go index 3a6537e6c..693aa4e5a 100644 --- a/dsn.go +++ b/dsn.go @@ -70,7 +70,6 @@ type Config struct { // NewConfig creates a new Config and sets default values. func NewConfig() *Config { return &Config{ - Collation: defaultCollation, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, @@ -100,7 +99,7 @@ func (cfg *Config) Clone() *Config { } func (cfg *Config) normalize() error { - if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + if cfg.InterpolateParams && cfg.Collation != "" && unsafeCollations[cfg.Collation] { return errInvalidDSNUnsafeCollation } @@ -237,7 +236,7 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "clientFoundRows", "true") } - if col := cfg.Collation; col != defaultCollation && len(col) > 0 { + if col := cfg.Collation; col != "" { writeDSNParam(&buf, &hasParam, "collation", col) } diff --git a/dsn_test.go b/dsn_test.go index 8b623df01..a729d0ef8 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -22,58 +22,58 @@ var testDSNs = []struct { out *Config }{{ "username:password@protocol(address)/dbname?param=value", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true}, }, { "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", - &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, }, { "user@unix(/path/to/socket)/dbname?charset=utf8", - &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, }, { "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, Logger: defaultLogger, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, }, { "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0&allowFallbackToPlaintext=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, Logger: defaultLogger, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: 0, Logger: defaultLogger, AllowFallbackToPlaintext: true, AllowNativePasswords: false, CheckConnLiveness: false}, }, { "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", - &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/dbname%2Fwithslash", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname/withslash", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname/withslash", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "@/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "/", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:p@/ssword@/", - &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "unix/?arg=%2Fsome%2Fpath.ext", - &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "tcp(127.0.0.1)/dbname", - &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "tcp(de:ad:be:ef::ca:fe)/dbname", - &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, } diff --git a/packets.go b/packets.go index d6a11fd21..c10072c94 100644 --- a/packets.go +++ b/packets.go @@ -14,7 +14,6 @@ import ( "database/sql/driver" "encoding/binary" "encoding/json" - "errors" "fmt" "io" "math" @@ -346,14 +345,18 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data[10] = 0x00 data[11] = 0x00 - // Charset [1 byte] + // Collation ID [1 byte] + cname := mc.cfg.Collation + if cname == "" { + cname = defaultCollation + } var found bool - data[12], found = collations[mc.cfg.Collation] + data[12], found = collations[cname] if !found { // Note possibility for false negatives: // could be triggered although the collation is valid if the // collations map does not contain entries the server supports. - return errors.New("unknown collation") + return fmt.Errorf("unknown collation: %q", cname) } // Filler [23 bytes] (all 0x00) From f43effaa7c9271606b37b04a6235e5f7ed37c3e0 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 28 May 2023 02:07:46 +0900 Subject: [PATCH 030/123] Reduce map lookup in ColumnTypeDatabaseTypeName. (#1436) --- collations.go | 2 +- fields.go | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/collations.go b/collations.go index 295bfbe52..1cdf97b67 100644 --- a/collations.go +++ b/collations.go @@ -9,7 +9,7 @@ package mysql const defaultCollation = "utf8mb4_general_ci" -const binaryCollation = "binary" +const binaryCollationID = 63 // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: diff --git a/fields.go b/fields.go index ae709363f..30f31cbfb 100644 --- a/fields.go +++ b/fields.go @@ -18,7 +18,7 @@ func (mf *mysqlField) typeDatabaseName() string { case fieldTypeBit: return "BIT" case fieldTypeBLOB: - if mf.charSet != collations[binaryCollation] { + if mf.charSet != binaryCollationID { return "TEXT" } return "BLOB" @@ -49,7 +49,7 @@ func (mf *mysqlField) typeDatabaseName() string { } return "INT" case fieldTypeLongBLOB: - if mf.charSet != collations[binaryCollation] { + if mf.charSet != binaryCollationID { return "LONGTEXT" } return "LONGBLOB" @@ -59,7 +59,7 @@ func (mf *mysqlField) typeDatabaseName() string { } return "BIGINT" case fieldTypeMediumBLOB: - if mf.charSet != collations[binaryCollation] { + if mf.charSet != binaryCollationID { return "MEDIUMTEXT" } return "MEDIUMBLOB" @@ -77,7 +77,7 @@ func (mf *mysqlField) typeDatabaseName() string { } return "SMALLINT" case fieldTypeString: - if mf.charSet == collations[binaryCollation] { + if mf.charSet == binaryCollationID { return "BINARY" } return "CHAR" @@ -91,17 +91,17 @@ func (mf *mysqlField) typeDatabaseName() string { } return "TINYINT" case fieldTypeTinyBLOB: - if mf.charSet != collations[binaryCollation] { + if mf.charSet != binaryCollationID { return "TINYTEXT" } return "TINYBLOB" case fieldTypeVarChar: - if mf.charSet == collations[binaryCollation] { + if mf.charSet == binaryCollationID { return "VARBINARY" } return "VARCHAR" case fieldTypeVarString: - if mf.charSet == collations[binaryCollation] { + if mf.charSet == binaryCollationID { return "VARBINARY" } return "VARCHAR" @@ -194,7 +194,7 @@ func (mf *mysqlField) scanType() reflect.Type { case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry: - if mf.charSet == 63 /* binary */ { + if mf.charSet == binaryCollationID { return scanTypeBytes } fallthrough From 397e2f5323e1c03bc4513d6c9ab345dfd47108cd Mon Sep 17 00:00:00 2001 From: Matthew Herrmann <47012945+mherr-google@users.noreply.github.com> Date: Mon, 29 May 2023 13:33:49 +1000 Subject: [PATCH 031/123] Exec() now provides access to status of multiple statements. (#1309) It now reports the last inserted ID and affected row count for all statements, not just the last one. This is useful to execute batches of statements such as UPDATE with minimal roundtrips. Co-authored-by: Inada Naoki --- README.md | 16 +++++++ auth.go | 6 +-- connection.go | 29 ++++++------- driver_test.go | 112 +++++++++++++++++++++++++++++++++++++++++++++++++ infile.go | 8 ++-- packets.go | 77 ++++++++++++++++++++++++++++------ result.go | 37 ++++++++++++++-- rows.go | 7 +++- statement.go | 17 ++++---- 9 files changed, 259 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index d747a7446..4eade6853 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,22 @@ Allow multiple statements in one query. This can be used to bach multiple querie When `multiStatements` is used, `?` parameters must only be used in the first statement. [interpolateParams](#interpolateparams) can be used to avoid this limitation unless prepared statement is used explicitly. +It's possible to access the last inserted ID and number of affected rows for multiple statements by using `sql.Conn.Raw()` and the `mysql.Result`. For example: + +```go +conn, _ := db.Conn(ctx) +conn.Raw(func(conn interface{}) error { + ex := conn.(driver.Execer) + res, err := ex.Exec(` + UPDATE point SET x = 1 WHERE y = 2; + UPDATE point SET x = 2 WHERE y = 3; + `, nil) + // Both slices have 2 elements. + log.Print(res.(mysql.Result).AllRowsAffected()) + log.Print(res.(mysql.Result).AllLastInsertIds()) +}) +``` + ##### `parseTime` ``` diff --git a/auth.go b/auth.go index e758e6d00..f6b157a12 100644 --- a/auth.go +++ b/auth.go @@ -346,7 +346,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case 1: switch authData[0] { case cachingSha2PasswordFastAuthSuccess: - if err = mc.readResultOK(); err == nil { + if err = mc.resultUnchanged().readResultOK(); err == nil { return nil // auth successful } @@ -397,7 +397,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { return err } } - return mc.readResultOK() + return mc.resultUnchanged().readResultOK() default: return ErrMalformPkt @@ -426,7 +426,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { if err != nil { return err } - return mc.readResultOK() + return mc.resultUnchanged().readResultOK() } default: diff --git a/connection.go b/connection.go index 14a972b40..631a1dc24 100644 --- a/connection.go +++ b/connection.go @@ -23,9 +23,8 @@ import ( type mysqlConn struct { buf buffer netConn net.Conn - rawConn net.Conn // underlying connection when netConn is TLS connection. - affectedRows uint64 - insertId uint64 + rawConn net.Conn // underlying connection when netConn is TLS connection. + result mysqlResult // managed by clearResult() and handleOkPacket(). cfg *Config connector *connector maxAllowedPacket int @@ -155,6 +154,7 @@ func (mc *mysqlConn) cleanup() { if err := mc.netConn.Close(); err != nil { mc.cfg.Logger.Print(err) } + mc.clearResult() } func (mc *mysqlConn) error() error { @@ -316,28 +316,25 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err } query = prepared } - mc.affectedRows = 0 - mc.insertId = 0 err := mc.exec(query) if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, err + copied := mc.result + return &copied, err } return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { + handleOk := mc.clearResult() // Send command if err := mc.writeCommandPacketStr(comQuery, query); err != nil { return mc.markBadConn(err) } // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket() if err != nil { return err } @@ -354,7 +351,7 @@ func (mc *mysqlConn) exec(query string) error { } } - return mc.discardResults() + return handleOk.discardResults() } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { @@ -362,6 +359,8 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro } func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + handleOk := mc.clearResult() + if mc.closed.Load() { mc.cfg.Logger.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -382,7 +381,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) if err == nil { // Read Result var resLen int - resLen, err = mc.readResultSetHeaderPacket() + resLen, err = handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc @@ -410,12 +409,13 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command + handleOk := mc.clearResult() if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { return nil, err } // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc @@ -466,11 +466,12 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { } defer mc.finish() + handleOk := mc.clearResult() if err = mc.writeCommandPacket(comPing); err != nil { return mc.markBadConn(err) } - return mc.readResultOK() + return handleOk.readResultOK() } // BeginTx implements driver.ConnBeginTx interface diff --git a/driver_test.go b/driver_test.go index abf91a486..cd94c434e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2154,11 +2154,51 @@ func TestRejectReadOnly(t *testing.T) { } func TestPing(t *testing.T) { + ctx := context.Background() runTests(t, dsn, func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { dbt.fail("Ping", "Ping", err) } }) + + runTests(t, dsn, func(dbt *DBTest) { + conn, err := dbt.db.Conn(ctx) + if err != nil { + dbt.fail("db", "Conn", err) + } + + // Check that affectedRows and insertIds are cleared after each call. + conn.Raw(func(conn interface{}) error { + c := conn.(*mysqlConn) + + // Issue a query that sets affectedRows and insertIds. + q, err := c.Query(`SELECT 1`, nil) + if err != nil { + dbt.fail("Conn", "Query", err) + } + if got, want := c.result.affectedRows, []int64{0}; !reflect.DeepEqual(got, want) { + dbt.Fatalf("bad affectedRows: got %v, want=%v", got, want) + } + if got, want := c.result.insertIds, []int64{0}; !reflect.DeepEqual(got, want) { + dbt.Fatalf("bad insertIds: got %v, want=%v", got, want) + } + q.Close() + + // Verify that Ping() clears both fields. + for i := 0; i < 2; i++ { + if err := c.Ping(ctx); err != nil { + dbt.fail("Pinger", "Ping", err) + } + if got, want := c.result.affectedRows, []int64(nil); !reflect.DeepEqual(got, want) { + t.Errorf("bad affectedRows: got %v, want=%v", got, want) + } + if got, want := c.result.insertIds, []int64(nil); !reflect.DeepEqual(got, want) { + t.Errorf("bad affectedRows: got %v, want=%v", got, want) + } + } + return nil + }) + }) } // See Issue #799 @@ -2378,6 +2418,42 @@ func TestMultiResultSetNoSelect(t *testing.T) { }) } +func TestExecMultipleResults(t *testing.T) { + ctx := context.Background() + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + CREATE TABLE test ( + id INT NOT NULL AUTO_INCREMENT, + value VARCHAR(255), + PRIMARY KEY (id) + )`) + conn, err := dbt.db.Conn(ctx) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + conn.Raw(func(conn interface{}) error { + ex := conn.(driver.Execer) + res, err := ex.Exec(` + INSERT INTO test (value) VALUES ('a'), ('b'); + INSERT INTO test (value) VALUES ('c'), ('d'), ('e'); + `, nil) + if err != nil { + t.Fatalf("insert statements failed: %v", err) + } + mres := res.(Result) + if got, want := mres.AllRowsAffected(), []int64{2, 3}; !reflect.DeepEqual(got, want) { + t.Errorf("bad AllRowsAffected: got %v, want=%v", got, want) + } + // For INSERTs containing multiple rows, LAST_INSERT_ID() returns the + // first inserted ID, not the last. + if got, want := mres.AllLastInsertIds(), []int64{1, 3}; !reflect.DeepEqual(got, want) { + t.Errorf("bad AllLastInsertIds: got %v, want %v", got, want) + } + return nil + }) + }) +} + // tests if rows are set in a proper state if some results were ignored before // calling rows.NextResultSet. func TestSkipResults(t *testing.T) { @@ -2399,6 +2475,42 @@ func TestSkipResults(t *testing.T) { }) } +func TestQueryMultipleResults(t *testing.T) { + ctx := context.Background() + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + CREATE TABLE test ( + id INT NOT NULL AUTO_INCREMENT, + value VARCHAR(255), + PRIMARY KEY (id) + )`) + conn, err := dbt.db.Conn(ctx) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + conn.Raw(func(conn interface{}) error { + qr := conn.(driver.Queryer) + + c := conn.(*mysqlConn) + + // Demonstrate that repeated queries reset the affectedRows + for i := 0; i < 2; i++ { + _, err := qr.Query(` + INSERT INTO test (value) VALUES ('a'), ('b'); + INSERT INTO test (value) VALUES ('c'), ('d'), ('e'); + `, nil) + if err != nil { + t.Fatalf("insert statements failed: %v", err) + } + if got, want := c.result.affectedRows, []int64{2, 3}; !reflect.DeepEqual(got, want) { + t.Errorf("bad affectedRows: got %v, want=%v", got, want) + } + } + return nil + }) + }) +} + func TestPingContext(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/infile.go b/infile.go index 3279dcffd..cfd41914e 100644 --- a/infile.go +++ b/infile.go @@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) { const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP -func (mc *mysqlConn) handleInFileRequest(name string) (err error) { +func (mc *okHandler) handleInFileRequest(name string) (err error) { var rdr io.Reader var data []byte packetSize := defaultPacketSize @@ -154,7 +154,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { for err == nil { n, err = rdr.Read(data[4:]) if n > 0 { - if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + if ioErr := mc.conn().writePacket(data[:4+n]); ioErr != nil { return ioErr } } @@ -168,7 +168,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { if data == nil { data = make([]byte, 4) } - if ioErr := mc.writePacket(data[:4]); ioErr != nil { + if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { return ioErr } @@ -177,6 +177,6 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { return mc.readResultOK() } - mc.readPacket() + mc.conn().readPacket() return err } diff --git a/packets.go b/packets.go index c10072c94..1a7f2c376 100644 --- a/packets.go +++ b/packets.go @@ -511,7 +511,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { switch data[0] { case iOK: - return nil, "", mc.handleOkPacket(data) + // resultUnchanged, since auth happens before any queries or + // commands have been executed. + return nil, "", mc.resultUnchanged().handleOkPacket(data) case iAuthMoreData: return data[1:], "", err @@ -535,8 +537,8 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { } // Returns error if Packet is not an 'Result OK'-Packet -func (mc *mysqlConn) readResultOK() error { - data, err := mc.readPacket() +func (mc *okHandler) readResultOK() error { + data, err := mc.conn().readPacket() if err != nil { return err } @@ -544,13 +546,17 @@ func (mc *mysqlConn) readResultOK() error { if data[0] == iOK { return mc.handleOkPacket(data) } - return mc.handleErrorPacket(data) + return mc.conn().handleErrorPacket(data) } // Result Set Header Packet // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset -func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { - data, err := mc.readPacket() +func (mc *okHandler) readResultSetHeaderPacket() (int, error) { + // handleOkPacket replaces both values; other cases leave the values unchanged. + mc.result.affectedRows = append(mc.result.affectedRows, 0) + mc.result.insertIds = append(mc.result.insertIds, 0) + + data, err := mc.conn().readPacket() if err == nil { switch data[0] { @@ -558,7 +564,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { return 0, mc.handleOkPacket(data) case iERR: - return 0, mc.handleErrorPacket(data) + return 0, mc.conn().handleErrorPacket(data) case iLocalInFile: return 0, mc.handleInFileRequest(string(data[1:])) @@ -623,18 +629,61 @@ func readStatus(b []byte) statusFlag { return statusFlag(b[0]) | statusFlag(b[1])<<8 } +// Returns an instance of okHandler for codepaths where mysqlConn.result doesn't +// need to be cleared first (e.g. during authentication, or while additional +// resultsets are being fetched.) +func (mc *mysqlConn) resultUnchanged() *okHandler { + return (*okHandler)(mc) +} + +// okHandler represents the state of the connection when mysqlConn.result has +// been prepared for processing of OK packets. +// +// To correctly populate mysqlConn.result (updated by handleOkPacket()), all +// callpaths must either: +// +// 1. first clear it using clearResult(), or +// 2. confirm that they don't need to (by calling resultUnchanged()). +// +// Both return an instance of type *okHandler. +type okHandler mysqlConn + +// Exposees the underlying type's methods. +func (mc *okHandler) conn() *mysqlConn { + return (*mysqlConn)(mc) +} + +// clearResult clears the connection's stored affectedRows and insertIds +// fields. +// +// It returns a handler that can process OK responses. +func (mc *mysqlConn) clearResult() *okHandler { + mc.result = mysqlResult{} + return (*okHandler)(mc) +} + // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet -func (mc *mysqlConn) handleOkPacket(data []byte) error { +func (mc *okHandler) handleOkPacket(data []byte) error { var n, m int + var affectedRows, insertId uint64 // 0x00 [1 byte] // Affected rows [Length Coded Binary] - mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + affectedRows, _, n = readLengthEncodedInteger(data[1:]) // Insert id [Length Coded Binary] - mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + insertId, _, m = readLengthEncodedInteger(data[1+n:]) + + // Update for the current statement result (only used by + // readResultSetHeaderPacket). + if len(mc.result.affectedRows) > 0 { + mc.result.affectedRows[len(mc.result.affectedRows)-1] = int64(affectedRows) + } + if len(mc.result.insertIds) > 0 { + mc.result.insertIds[len(mc.result.insertIds)-1] = int64(insertId) + } // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) @@ -1165,7 +1214,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { return mc.writePacket(data) } -func (mc *mysqlConn) discardResults() error { +// For each remaining resultset in the stream, discards its rows and updates +// mc.affectedRows and mc.insertIds. +func (mc *okHandler) discardResults() error { for mc.status&statusMoreResultsExists != 0 { resLen, err := mc.readResultSetHeaderPacket() if err != nil { @@ -1173,11 +1224,11 @@ func (mc *mysqlConn) discardResults() error { } if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.conn().readUntilEOF(); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.conn().readUntilEOF(); err != nil { return err } } diff --git a/result.go b/result.go index c6438d034..36a432e81 100644 --- a/result.go +++ b/result.go @@ -8,15 +8,44 @@ package mysql +import "database/sql/driver" + +// Result exposes data not available through *connection.Result. +// +// This is accessible by executing statements using sql.Conn.Raw() and +// downcasting the returned result: +// +// res, err := rawConn.Exec(...) +// res.(mysql.Result).AllRowsAffected() +// +type Result interface { + driver.Result + // AllRowsAffected returns a slice containing the affected rows for each + // executed statement. + AllRowsAffected() []int64 + // AllLastInsertIds returns a slice containing the last inserted ID for each + // executed statement. + AllLastInsertIds() []int64 +} + type mysqlResult struct { - affectedRows int64 - insertId int64 + // One entry in both slices is created for every executed statement result. + affectedRows []int64 + insertIds []int64 } func (res *mysqlResult) LastInsertId() (int64, error) { - return res.insertId, nil + return res.insertIds[len(res.insertIds)-1], nil } func (res *mysqlResult) RowsAffected() (int64, error) { - return res.affectedRows, nil + return res.affectedRows[len(res.affectedRows)-1], nil +} + +func (res *mysqlResult) AllLastInsertIds() []int64 { + return append([]int64{}, res.insertIds...) // defensive copy +} + +func (res *mysqlResult) AllRowsAffected() []int64 { + return append([]int64{}, res.affectedRows...) // defensive copy } diff --git a/rows.go b/rows.go index 888bdb5f0..63d0ed2d5 100644 --- a/rows.go +++ b/rows.go @@ -123,7 +123,8 @@ func (rows *mysqlRows) Close() (err error) { err = mc.readUntilEOF() } if err == nil { - if err = mc.discardResults(); err != nil { + handleOk := mc.clearResult() + if err = handleOk.discardResults(); err != nil { return err } } @@ -160,7 +161,9 @@ func (rows *mysqlRows) nextResultSet() (int, error) { return 0, io.EOF } rows.rs = resultSet{} - return rows.mc.readResultSetHeaderPacket() + // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to + // nextResultSet. + return rows.mc.resultUnchanged().readResultSetHeaderPacket() } func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { diff --git a/statement.go b/statement.go index d8b3975a5..31e7799c4 100644 --- a/statement.go +++ b/statement.go @@ -61,12 +61,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } mc := stmt.mc - - mc.affectedRows = 0 - mc.insertId = 0 + handleOk := stmt.mc.clearResult() // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -83,14 +81,12 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } } - if err := mc.discardResults(); err != nil { + if err := handleOk.discardResults(); err != nil { return nil, err } - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, nil + copied := mc.result + return &copied, nil } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { @@ -111,7 +107,8 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { mc := stmt.mc // Read Result - resLen, err := mc.readResultSetHeaderPacket() + handleOk := stmt.mc.clearResult() + resLen, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } From 8365b948403b6a9d0724518c2f722e09d4561794 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Mengu=C3=A9?= Date: Fri, 2 Jun 2023 17:28:47 +0200 Subject: [PATCH 032/123] doc: add link to NewConnector from FormatDSN (#1442) Advise to use NewConnector instead of FormatDSN because roundtripping is known to not work well. See https://github.com/go-sql-driver/mysql/issues/1410#issuecomment-1510866931 --- dsn.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dsn.go b/dsn.go index 693aa4e5a..380ca9570 100644 --- a/dsn.go +++ b/dsn.go @@ -177,6 +177,8 @@ func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) { // FormatDSN formats the given Config into a DSN string which can be passed to // the driver. +// +// Note: use [NewConnector] and [database/sql.OpenDB] to open a connection from a [*Config]. func (cfg *Config) FormatDSN() string { var buf bytes.Buffer From 65ed3c5d4007ad7ea74c33e78b953b82a9ed80ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Mengu=C3=A9?= Date: Fri, 2 Jun 2023 17:30:26 +0200 Subject: [PATCH 033/123] Add fuzz test for FormatDSN (#1444) Run (go 1.18+): go test -fuzz FuzzFormatDSN Note: invalid host:addr values are currently ignored as they are known to break (ParseDSN doesn't strictly check address format). --- dsn_fuzz_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 dsn_fuzz_test.go diff --git a/dsn_fuzz_test.go b/dsn_fuzz_test.go new file mode 100644 index 000000000..04c56ad45 --- /dev/null +++ b/dsn_fuzz_test.go @@ -0,0 +1,47 @@ +//go:build go1.18 +// +build go1.18 + +package mysql + +import ( + "net" + "testing" +) + +func FuzzFormatDSN(f *testing.F) { + for _, test := range testDSNs { // See dsn_test.go + f.Add(test.in) + } + + f.Fuzz(func(t *testing.T, dsn1 string) { + // Do not waste resources + if len(dsn1) > 1000 { + t.Skip("ignore: too long") + } + + cfg1, err := ParseDSN(dsn1) + if err != nil { + t.Skipf("invalid DSN: %v", err) + } + + dsn2 := cfg1.FormatDSN() + if dsn2 == dsn1 { + return + } + + // Skip known cases of bad config that are not strictly checked by ParseDSN + if _, _, err := net.SplitHostPort(cfg1.Addr); err != nil { + t.Skipf("invalid addr %q: %v", cfg1.Addr, err) + } + + cfg2, err := ParseDSN(dsn2) + if err != nil { + t.Fatalf("%q rewritten as %q: %v", dsn1, dsn2, err) + } + + dsn3 := cfg2.FormatDSN() + if dsn3 != dsn2 { + t.Errorf("%q rewritten as %q", dsn2, dsn3) + } + }) +} From cf948e4a9df2e97c1b0e3d068a52b5e2d53485a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Mengu=C3=A9?= Date: Tue, 13 Jun 2023 06:24:06 +0200 Subject: [PATCH 034/123] TestDSNReformat: add more roundtrip checks (#1443) Add more roundtrip checks for ParseDSN/FormatDSN. --- dsn_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/dsn_test.go b/dsn_test.go index a729d0ef8..be50102de 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -130,6 +130,11 @@ func TestDSNReformat(t *testing.T) { res1 := fmt.Sprintf("%+v", cfg1) dsn2 := cfg1.FormatDSN() + if dsn2 != dsn1 { + // Just log + t.Logf("%d. %q reformated as %q", i, dsn1, dsn2) + } + cfg2, err := ParseDSN(dsn2) if err != nil { t.Error(err.Error()) @@ -141,6 +146,11 @@ func TestDSNReformat(t *testing.T) { if res1 != res2 { t.Errorf("%d. %q does not match %q", i, res2, res1) } + + dsn3 := cfg2.FormatDSN() + if dsn3 != dsn2 { + t.Errorf("%d. %q does not match %q", i, dsn2, dsn3) + } }) } } From 943264b76442d87ceea460ae7745208c8143f098 Mon Sep 17 00:00:00 2001 From: Achille Date: Mon, 12 Jun 2023 23:39:30 -0700 Subject: [PATCH 035/123] ignore errors returned by SetKeepAlive (#1448) Signed-off-by: Achille Roussel --- connector.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/connector.go b/connector.go index 6acf3dd50..7e0b16734 100644 --- a/connector.go +++ b/connector.go @@ -100,10 +100,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { // Enable TCP Keepalives on TCP connections if tc, ok := mc.netConn.(*net.TCPConn); ok { if err := tc.SetKeepAlive(true); err != nil { - // Don't send COM_QUIT before handshake. - mc.netConn.Close() - mc.netConn = nil - return nil, err + c.cfg.Logger.Print(err) } } From 564dee9b80ffc1e406b8b91e2215d29919730ae2 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 16 Jun 2023 10:33:09 +0900 Subject: [PATCH 036/123] CI: use staticcheck (#1449) --- .github/workflows/test.yml | 8 ++++++++ auth.go | 2 +- driver_test.go | 11 +++++++---- errors.go | 2 +- infile.go | 4 ++-- nulltime.go | 2 +- 6 files changed, 20 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b2ab5e82a..3122c0e17 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,6 +11,14 @@ env: MYSQL_TEST_CONCURRENT: 1 jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: dominikh/staticcheck-action@v1.3.0 + with: + version: "2023.1.3" + list: runs-on: ubuntu-latest outputs: diff --git a/auth.go b/auth.go index f6b157a12..d2ab0103d 100644 --- a/auth.go +++ b/auth.go @@ -382,7 +382,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { // parse public key block, rest := pem.Decode(data[1:]) if block == nil { - return fmt.Errorf("No Pem data found, data: %s", rest) + return fmt.Errorf("no pem data found, data: %s", rest) } pkix, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { diff --git a/driver_test.go b/driver_test.go index cd94c434e..c937b8416 100644 --- a/driver_test.go +++ b/driver_test.go @@ -346,8 +346,8 @@ func TestMultiQuery(t *testing.T) { rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") if rows.Next() { rows.Scan(&out) - if 5 != out { - dbt.Errorf("5 != %d", out) + if out != 5 { + dbt.Errorf("expected 5, got %d", out) } if rows.Next() { @@ -1293,7 +1293,7 @@ func TestLoadData(t *testing.T) { _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") if err == nil { dbt.Fatal("load non-existent Reader didn't fail") - } else if err.Error() != "Reader 'doesnotexist' is not registered" { + } else if err.Error() != "reader 'doesnotexist' is not registered" { dbt.Fatal(err.Error()) } }) @@ -1401,6 +1401,7 @@ func TestReuseClosedConnection(t *testing.T) { if err != nil { t.Fatalf("error preparing statement: %s", err.Error()) } + //lint:ignore SA1019 this is a test _, err = stmt.Exec(nil) if err != nil { t.Fatalf("error executing statement: %s", err.Error()) @@ -1415,6 +1416,7 @@ func TestReuseClosedConnection(t *testing.T) { t.Errorf("panic after reusing a closed connection: %v", err) } }() + //lint:ignore SA1019 this is a test _, err = stmt.Exec(nil) if err != nil && err != driver.ErrBadConn { t.Errorf("unexpected error '%s', expected '%s'", @@ -2432,6 +2434,7 @@ func TestExecMultipleResults(t *testing.T) { t.Fatalf("failed to connect: %v", err) } conn.Raw(func(conn interface{}) error { + //lint:ignore SA1019 this is a test ex := conn.(driver.Execer) res, err := ex.Exec(` INSERT INTO test (value) VALUES ('a'), ('b'); @@ -2489,8 +2492,8 @@ func TestQueryMultipleResults(t *testing.T) { t.Fatalf("failed to connect: %v", err) } conn.Raw(func(conn interface{}) error { + //lint:ignore SA1019 this is a test qr := conn.(driver.Queryer) - c := conn.(*mysqlConn) // Demonstrate that repeated queries reset the affectedRows diff --git a/errors.go b/errors.go index 5680b6c05..a9a3060c9 100644 --- a/errors.go +++ b/errors.go @@ -21,7 +21,7 @@ var ( ErrMalformPkt = errors.New("malformed packet") ErrNoTLS = errors.New("TLS requested but server does not support TLS") ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") - ErrNativePassword = errors.New("this user requires mysql native password authentication.") + ErrNativePassword = errors.New("this user requires mysql native password authentication") ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") ErrUnknownPlugin = errors.New("this authentication plugin is not supported") ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") diff --git a/infile.go b/infile.go index cfd41914e..0c8af9f11 100644 --- a/infile.go +++ b/infile.go @@ -116,10 +116,10 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { defer deferredClose(&err, cl) } } else { - err = fmt.Errorf("Reader '%s' is ", name) + err = fmt.Errorf("reader '%s' is ", name) } } else { - err = fmt.Errorf("Reader '%s' is not registered", name) + err = fmt.Errorf("reader '%s' is not registered", name) } } else { // File name = strings.Trim(name, `"`) diff --git a/nulltime.go b/nulltime.go index 36c8a42c5..7d381d5c2 100644 --- a/nulltime.go +++ b/nulltime.go @@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) (err error) { } nt.Valid = false - return fmt.Errorf("Can't convert %T to time.Time", value) + return fmt.Errorf("can't convert %T to time.Time", value) } // Value implements the driver Valuer interface. From 5d4a83127cf18cadc447807c320666de5367cc4d Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 3 Jul 2023 15:50:22 +0900 Subject: [PATCH 037/123] Parse numbers on text protocol too (#1452) --- driver_test.go | 87 ++++++++++++++++++++++++++++++++++---------------- packets.go | 39 +++++++++++++++++----- 2 files changed, 90 insertions(+), 36 deletions(-) diff --git a/driver_test.go b/driver_test.go index c937b8416..2748870b7 100644 --- a/driver_test.go +++ b/driver_test.go @@ -148,29 +148,18 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { defer db2.Close() } - dsn3 := dsn + "&multiStatements=true" - var db3 *sql.DB - if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { - db3, err = sql.Open("mysql", dsn3) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } - defer db3.Close() - } - - dbt := &DBTest{t, db} - dbt2 := &DBTest{t, db2} - dbt3 := &DBTest{t, db3} for _, test := range tests { - test(dbt) - dbt.db.Exec("DROP TABLE IF EXISTS test") + t.Run("default", func(t *testing.T) { + dbt := &DBTest{t, db} + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + }) if db2 != nil { - test(dbt2) - dbt2.db.Exec("DROP TABLE IF EXISTS test") - } - if db3 != nil { - test(dbt3) - dbt3.db.Exec("DROP TABLE IF EXISTS test") + t.Run("interpolateParams", func(t *testing.T) { + dbt2 := &DBTest{t, db2} + test(dbt2) + dbt2.db.Exec("DROP TABLE IF EXISTS test") + }) } } } @@ -316,6 +305,48 @@ func TestCRUD(t *testing.T) { }) } +// TestNumbers test that selecting numeric columns. +// Both of textRows and binaryRows should return same type and value. +func TestNumbersToAny(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE `test` (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " + + "i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE)") + dbt.mustExec("INSERT INTO `test` VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5)") + + // Use binaryRows for intarpolateParams=false and textRows for intarpolateParams=true. + rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64 FROM `test` WHERE id=?", 1) + if !rows.Next() { + dbt.Fatal("no data") + } + var b, i8, i16, i32, i64, f32, f64 any + err := rows.Scan(&b, &i8, &i16, &i32, &i64, &f32, &f64) + if err != nil { + dbt.Fatal(err) + } + if b.(int64) != 1 { + dbt.Errorf("b != 1") + } + if i8.(int64) != 127 { + dbt.Errorf("i8 != 127") + } + if i16.(int64) != 32767 { + dbt.Errorf("i16 != 32767") + } + if i32.(int64) != 2147483647 { + dbt.Errorf("i32 != 2147483647") + } + if i64.(int64) != 9223372036854775807 { + dbt.Errorf("i64 != 9223372036854775807") + } + if f32.(float32) != 1.25 { + dbt.Errorf("f32 != 1.25") + } + if f64.(float64) != 2.5 { + dbt.Errorf("f64 != 2.5") + } + }) +} + func TestMultiQuery(t *testing.T) { runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { // Create Table @@ -1808,13 +1839,13 @@ func TestConcurrent(t *testing.T) { } runTests(t, dsn, func(dbt *DBTest) { - var version string - if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil { - dbt.Fatalf("%s", err.Error()) - } - if strings.Contains(strings.ToLower(version), "mariadb") { - t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`) - } + // var version string + // if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil { + // dbt.Fatal(err) + // } + // if strings.Contains(strings.ToLower(version), "mariadb") { + // t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`) + // } var max int err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) diff --git a/packets.go b/packets.go index 1a7f2c376..66635c55b 100644 --- a/packets.go +++ b/packets.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "math" + "strconv" "time" ) @@ -834,7 +835,8 @@ func (rows *textRows) readRow(dest []driver.Value) error { for i := range dest { // Read bytes and convert to string - dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + var buf []byte + buf, isNull, n, err = readLengthEncodedString(data[pos:]) pos += n if err != nil { @@ -846,19 +848,40 @@ func (rows *textRows) readRow(dest []driver.Value) error { continue } - if !mc.parseTime { - continue - } - - // Parse time field switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: - if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil { - return err + if mc.parseTime { + dest[i], err = parseDateTime(buf, mc.cfg.Loc) + } else { + dest[i] = buf + } + + case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong: + dest[i], err = strconv.ParseInt(string(buf), 10, 32) + + case fieldTypeLongLong: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + dest[i], err = strconv.ParseUint(string(buf), 10, 64) + } else { + dest[i], err = strconv.ParseInt(string(buf), 10, 64) } + + case fieldTypeFloat: + var d float64 + d, err = strconv.ParseFloat(string(buf), 32) + dest[i] = float32(d) + + case fieldTypeDouble: + dest[i], err = strconv.ParseFloat(string(buf), 64) + + default: + dest[i] = buf + } + if err != nil { + return err } } From 0b18dac46f7f10d00411ab6fb10b8d6e4522c2d9 Mon Sep 17 00:00:00 2001 From: Daemonxiao <35677990+Daemonxiao@users.noreply.github.com> Date: Thu, 13 Jul 2023 16:52:35 +0800 Subject: [PATCH 038/123] Add Daemonxiao to AUTHORS (#1459) --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index 7e4fac5a1..29e08b0ca 100644 --- a/AUTHORS +++ b/AUTHORS @@ -26,6 +26,7 @@ Carlos Nieto Chris Kirkland Chris Moos Craig Wilson +Daemonxiao <735462752 at qq.com> Daniel Montoya Daniel Nichter Daniël van Eeden From 2c81c69ebe815b611383d18002074e073bed745a Mon Sep 17 00:00:00 2001 From: i7a7467 <61368544+i7a7467@users.noreply.github.com> Date: Thu, 3 Aug 2023 15:51:54 +0900 Subject: [PATCH 039/123] update docs link about load data local (#1468) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4eade6853..6ef19966c 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ Default: false ``` `allowAllFiles=true` disables the file allowlist for `LOAD DATA LOCAL INFILE` and allows *all* files. -[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html) +[*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local) ##### `allowCleartextPasswords` @@ -509,7 +509,7 @@ For this feature you need direct access to the package. Therefore you must chang import "github.com/go-sql-driver/mysql" ``` -Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). +Files must be explicitly allowed by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the allowlist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-local)). To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore. From e503d8d2c01d622d312e4b044fc2c19948d4663f Mon Sep 17 00:00:00 2001 From: Netzer7 <58796038+Netzer7@users.noreply.github.com> Date: Mon, 7 Aug 2023 00:34:14 -0700 Subject: [PATCH 040/123] Update README.md (#1464) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6ef19966c..18fcc0276 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,7 @@ Valid Values: Default: none ``` -Sets the charset used for client-server interaction (`"SET NAMES "`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). +Sets the charset used for client-server interaction (`"SET NAMES "`). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset fails. This enables for example support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`). See also [Unicode Support](#unicode-support). From 7cf548287682c36ebce3b7966f2693d58094bd5a Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Wed, 9 Aug 2023 20:35:39 +0900 Subject: [PATCH 041/123] add Go 1.21 and MySQL 8.1 to the build matrix (#1472) * add Go 1.21 and MySQL 8.1 to the build matrix * bump shogo82148/actions-setup-mysql v1.21.0 --- .github/workflows/test.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3122c0e17..b25c9e389 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,12 +31,14 @@ jobs: import os go = [ # Keep the most recent production release at the top - '1.20', + '1.21', # Older production releases + '1.20', '1.19', '1.18', ] mysql = [ + '8.1', '8.0', '5.7', '5.6', @@ -75,7 +77,7 @@ jobs: - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - - uses: shogo82148/actions-setup-mysql@v1.16.0 + - uses: shogo82148/actions-setup-mysql@v1.21.0 with: mysql-version: ${{ matrix.mysql }} user: ${{ env.MYSQL_TEST_USER }} From 43e9bef05581335f84d246aba6211af1b5133aae Mon Sep 17 00:00:00 2001 From: Pyry Kontio Date: Sat, 2 Sep 2023 03:35:23 +0900 Subject: [PATCH 042/123] Improve DSN docstsrings (#1475) --- dsn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsn.go b/dsn.go index 380ca9570..f5b184e3f 100644 --- a/dsn.go +++ b/dsn.go @@ -36,8 +36,8 @@ var ( type Config struct { User string // Username Passwd string // Password (requires User) - Net string // Network type - Addr string // Network address (requires Net) + Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") + Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix") DBName string // Database name Params map[string]string // Connection parameters ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs From 78e0387dba9f2894f3ee6004b98c49b9b11bf367 Mon Sep 17 00:00:00 2001 From: ShenFeng312 <49786112+ShenFeng312@users.noreply.github.com> Date: Wed, 20 Sep 2023 11:55:24 +0800 Subject: [PATCH 043/123] packet: remove length check (#1481) Fix #1478 --- packets.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/packets.go b/packets.go index 66635c55b..4e27004aa 100644 --- a/packets.go +++ b/packets.go @@ -572,12 +572,9 @@ func (mc *okHandler) readResultSetHeaderPacket() (int, error) { } // column count - num, _, n := readLengthEncodedInteger(data) - if n-len(data) == 0 { - return int(num), nil - } - - return 0, ErrMalformPkt + num, _, _ := readLengthEncodedInteger(data) + // ignore remaining data in the packet. see #1478. + return int(num), nil } return 0, err } From 22e750b046938b5c13375da56a5f85ae9ce10e0b Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 28 Sep 2023 20:16:32 +0900 Subject: [PATCH 044/123] README: fix markup error (#1480) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 18fcc0276..9257c1fd2 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ This has the same effect as an empty DSN string: ``` -`dbname` is escaped by [PathEscape()]()https://pkg.go.dev/net/url#PathEscape) since v1.8.0. If your database name is `dbname/withslash`, it becomes: +`dbname` is escaped by [PathEscape()](https://pkg.go.dev/net/url#PathEscape) since v1.8.0. If your database name is `dbname/withslash`, it becomes: ``` /dbname%2Fwithslash From 19171b59bf90e6bf7a5bdf979e5e24a84b328b8a Mon Sep 17 00:00:00 2001 From: Oliver Bone Date: Sat, 30 Sep 2023 20:33:48 +0100 Subject: [PATCH 045/123] Close connection on ErrPktSync and ErrPktSyncMul (#1473) An `ErrPktSync` or `ErrPktSyncMul` error always means that a packet header has been read, but since the sequence ID was not correct then the packet payload has not been read. This results in the connection being left in a broken state, since any future operations will always result in a "busy buffer" error. Keeping such connections alive leads to them being repeatedly returned to the pool in this state, which can in turn result in a large number of failures due to these "busy buffer" errors. This commit fixes this problem by simply closing the connection before returning either `ErrPktSync` or `ErrPktSyncMul`. This ensures that the connection won't be returned to the pool, preventing it from causing any further errors. --- AUTHORS | 1 + packets.go | 1 + packets_test.go | 52 ++++++++++++++++++++++++++----------------------- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/AUTHORS b/AUTHORS index 29e08b0ca..dec27daca 100644 --- a/AUTHORS +++ b/AUTHORS @@ -77,6 +77,7 @@ Maciej Zimnoch Michael Woolnough Nathanial Murphy Nicola Peduzzi +Oliver Bone Olivier Mengué oscarzhao Paul Bonser diff --git a/packets.go b/packets.go index 4e27004aa..0994d41a3 100644 --- a/packets.go +++ b/packets.go @@ -44,6 +44,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // check packet sync [8 bit] if data[3] != mc.sequence { + mc.Close() if data[3] > mc.sequence { return nil, ErrPktSyncMul } diff --git a/packets_test.go b/packets_test.go index f429087e9..56c455188 100644 --- a/packets_test.go +++ b/packets_test.go @@ -133,30 +133,34 @@ func TestReadPacketSingleByte(t *testing.T) { } func TestReadPacketWrongSequenceID(t *testing.T) { - conn := new(mockConn) - mc := &mysqlConn{ - buf: newBuffer(conn), - } - - // too low sequence id - conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} - conn.maxReads = 1 - mc.sequence = 1 - _, err := mc.readPacket() - if err != ErrPktSync { - t.Errorf("expected ErrPktSync, got %v", err) - } - - // reset - conn.reads = 0 - mc.sequence = 0 - mc.buf = newBuffer(conn) - - // too high sequence id - conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} - _, err = mc.readPacket() - if err != ErrPktSyncMul { - t.Errorf("expected ErrPktSyncMul, got %v", err) + for _, testCase := range []struct { + ClientSequenceID byte + ServerSequenceID byte + ExpectedErr error + }{ + { + ClientSequenceID: 1, + ServerSequenceID: 0, + ExpectedErr: ErrPktSync, + }, + { + ClientSequenceID: 0, + ServerSequenceID: 0x42, + ExpectedErr: ErrPktSyncMul, + }, + } { + conn, mc := newRWMockConn(testCase.ClientSequenceID) + + conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff} + _, err := mc.readPacket() + if err != testCase.ExpectedErr { + t.Errorf("expected %v, got %v", testCase.ExpectedErr, err) + } + + // connection should not be returned to the pool in this state + if mc.IsValid() { + t.Errorf("expected IsValid() to be false") + } } } From e5a2abc9cca895ca44570b171ff1f2f976d5921d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Skytt=C3=A4?= Date: Wed, 4 Oct 2023 21:24:11 +0300 Subject: [PATCH 046/123] Spelling, grammar, and link fixes (#1485) --- CHANGELOG.md | 6 +++--- README.md | 2 +- auth.go | 4 ++-- driver_test.go | 10 +++++----- dsn_test.go | 2 +- packets.go | 6 +++--- packets_test.go | 4 ++-- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5166e4adb..213215c8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -162,7 +162,7 @@ New Features: - Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249) - Support for returning table alias on Columns() (#289, #359, #382) - - Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318, #490) + - Placeholder interpolation, can be activated with the DSN parameter `interpolateParams=true` (#309, #318, #490) - Support for uint64 parameters with high bit set (#332, #345) - Cleartext authentication plugin support (#327) - Exported ParseDSN function and the Config struct (#403, #419, #429) @@ -206,7 +206,7 @@ Changes: - Also exported the MySQLWarning type - mysqlConn.Close returns the first error encountered instead of ignoring all errors - writePacket() automatically writes the packet size to the header - - readPacket() uses an iterative approach instead of the recursive approach to merge splitted packets + - readPacket() uses an iterative approach instead of the recursive approach to merge split packets New Features: @@ -254,7 +254,7 @@ Bugfixes: - Fixed MySQL 4.1 support: MySQL 4.1 sends packets with lengths which differ from the specification - Convert to DB timezone when inserting `time.Time` - - Splitted packets (more than 16MB) are now merged correctly + - Split packets (more than 16MB) are now merged correctly - Fixed false positive `io.EOF` errors when the data was fully read - Avoid panics on reuse of closed connections - Fixed empty string producing false nil values diff --git a/README.md b/README.md index 9257c1fd2..fff8969f3 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ Passwords can consist of any character. Escaping is **not** necessary. #### Protocol See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available. -In general you should use an Unix domain socket if available and TCP otherwise for best performance. +In general you should use a Unix domain socket if available and TCP otherwise for best performance. #### Address For TCP and UDP networks, addresses have the form `host[:port]`. diff --git a/auth.go b/auth.go index d2ab0103d..bab282bd2 100644 --- a/auth.go +++ b/auth.go @@ -338,7 +338,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { switch plugin { - // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + // https://dev.mysql.com/blog-archive/preparing-your-community-connector-for-mysql-8-part-2-sha256/ case "caching_sha2_password": switch len(authData) { case 0: @@ -376,7 +376,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { } if data[0] != iAuthMoreData { - return fmt.Errorf("unexpect resp from server for caching_sha2_password perform full authentication") + return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication") } // parse public key diff --git a/driver_test.go b/driver_test.go index 2748870b7..dd3d73141 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1198,7 +1198,7 @@ func TestLongData(t *testing.T) { dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) } if rows.Next() { - dbt.Error("LONGBLOB: unexpexted row") + dbt.Error("LONGBLOB: unexpected row") } } else { dbt.Fatalf("LONGBLOB: no data") @@ -1217,7 +1217,7 @@ func TestLongData(t *testing.T) { dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) } if rows.Next() { - dbt.Error("LONGBLOB: unexpexted row") + dbt.Error("LONGBLOB: unexpected row") } } else { if err = rows.Err(); err != nil { @@ -1293,7 +1293,7 @@ func TestLoadData(t *testing.T) { dbt.Fatalf("unexpected row count: got %d, want 0", count) } - // Then fille File with data and try to load it + // Then fill File with data and try to load it file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") file.Close() dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) @@ -1899,7 +1899,7 @@ func TestConcurrent(t *testing.T) { }(i) } - // wait until all conections are open + // wait until all connections are open wg.Wait() if fatalError != "" { @@ -1948,7 +1948,7 @@ func TestCustomDial(t *testing.T) { t.Skipf("MySQL server not running on %s", netAddr) } - // our custom dial function which justs wraps net.Dial here + // our custom dial function which just wraps net.Dial here RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { var d net.Dialer return d.DialContext(ctx, prot, addr) diff --git a/dsn_test.go b/dsn_test.go index be50102de..8a6a0c10e 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -132,7 +132,7 @@ func TestDSNReformat(t *testing.T) { dsn2 := cfg1.FormatDSN() if dsn2 != dsn1 { // Just log - t.Logf("%d. %q reformated as %q", i, dsn1, dsn2) + t.Logf("%d. %q reformatted as %q", i, dsn1, dsn2) } cfg2, err := ParseDSN(dsn2) diff --git a/packets.go b/packets.go index 0994d41a3..a1aaf20ee 100644 --- a/packets.go +++ b/packets.go @@ -240,7 +240,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // reserved (all [00]) [10 bytes] pos += 1 + 2 + 2 + 1 + 10 - // second part of the password cipher [mininum 13 bytes], + // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) // // The web documentation is ambiguous about the length. However, @@ -538,7 +538,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { } } -// Returns error if Packet is not an 'Result OK'-Packet +// Returns error if Packet is not a 'Result OK'-Packet func (mc *okHandler) readResultOK() error { data, err := mc.conn().readPacket() if err != nil { @@ -647,7 +647,7 @@ func (mc *mysqlConn) resultUnchanged() *okHandler { // Both return an instance of type *okHandler. type okHandler mysqlConn -// Exposees the underlying type's methods. +// Exposes the underlying type's methods. func (mc *okHandler) conn() *mysqlConn { return (*mysqlConn)(mc) } diff --git a/packets_test.go b/packets_test.go index 56c455188..e86ec5848 100644 --- a/packets_test.go +++ b/packets_test.go @@ -188,7 +188,7 @@ func TestReadPacketSplit(t *testing.T) { data[4] = 0x11 data[maxPacketSize+3] = 0x22 - // 2nd packet has payload length 0 and squence id 1 + // 2nd packet has payload length 0 and sequence id 1 // 00 00 00 01 data[pkt2ofs+3] = 0x01 @@ -220,7 +220,7 @@ func TestReadPacketSplit(t *testing.T) { data[pkt2ofs+4] = 0x33 data[pkt2ofs+maxPacketSize+3] = 0x44 - // 3rd packet has payload length 0 and squence id 2 + // 3rd packet has payload length 0 and sequence id 2 // 00 00 00 02 data[pkt3ofs+3] = 0x02 From 37980127edfb00edd1ba2eb397a33fdea2828828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Skytt=C3=A4?= Date: Thu, 5 Oct 2023 11:44:35 +0300 Subject: [PATCH 047/123] use strings.Cut (#1486) --- connector.go | 9 ++++----- dsn.go | 8 ++++---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/connector.go b/connector.go index 7e0b16734..ba3be71e7 100644 --- a/connector.go +++ b/connector.go @@ -38,13 +38,12 @@ func encodeConnectionAttributes(textAttributes string) string { // user-defined connection attributes for _, connAttr := range strings.Split(textAttributes, ",") { - attr := strings.SplitN(connAttr, ":", 2) - if len(attr) != 2 { + k, v, found := strings.Cut(connAttr, ":") + if !found { continue } - for _, v := range attr { - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v) - } + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, k) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v) } return string(connAttrsBuf) diff --git a/dsn.go b/dsn.go index f5b184e3f..50c7ec413 100644 --- a/dsn.go +++ b/dsn.go @@ -390,13 +390,13 @@ func ParseDSN(dsn string) (cfg *Config, err error) { // Values must be url.QueryEscape'ed func parseDSNParams(cfg *Config, params string) (err error) { for _, v := range strings.Split(params, "&") { - param := strings.SplitN(v, "=", 2) - if len(param) != 2 { + key, value, found := strings.Cut(v, "=") + if !found { continue } // cfg params - switch value := param[1]; param[0] { + switch key { // Disable INFILE allowlist / enable all files case "allowAllFiles": var isBool bool @@ -577,7 +577,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { cfg.Params = make(map[string]string) } - if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { + if cfg.Params[key], err = url.QueryUnescape(value); err != nil { return } } From 5f74bcbcf0550e74cf0ac0170e5dd9f87683a355 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 9 Oct 2023 18:44:08 +0900 Subject: [PATCH 048/123] move stale connection check to ResetSession() (#1496) When ResetSession was added, it was called when the connection is put into the pool. Thet is why we had only set `mc.reset` flag on ResetSession(). In Go 1.15, this behavior was changed. (golang/go@971f8a2) ResetSession is called when the connection is checked out from the pool. So we can call checkConnLiveness() directly from ResetSession. --- connection.go | 27 +++++++++++++++++++++++++-- packets.go | 28 ---------------------------- 2 files changed, 25 insertions(+), 30 deletions(-) diff --git a/connection.go b/connection.go index 631a1dc24..660b2b0e0 100644 --- a/connection.go +++ b/connection.go @@ -34,7 +34,6 @@ type mysqlConn struct { status statusFlag sequence uint8 parseTime bool - reset bool // set when the Go SQL package calls ResetSession // for context support (Go 1.8+) watching bool @@ -646,7 +645,31 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { if mc.closed.Load() { return driver.ErrBadConn } - mc.reset = true + + // Perform a stale connection check. We only perform this check for + // the first query on a connection that has been checked out of the + // connection pool: a fresh connection from the pool is more likely + // to be stale, and it has not performed any previous writes that + // could cause data corruption, so it's safe to return ErrBadConn + // if the check fails. + if mc.cfg.CheckConnLiveness { + conn := mc.netConn + if mc.rawConn != nil { + conn = mc.rawConn + } + var err error + if mc.cfg.ReadTimeout != 0 { + err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout)) + } + if err == nil { + err = connCheck(conn) + } + if err != nil { + mc.cfg.Logger.Print("closing bad idle connection: ", err) + return driver.ErrBadConn + } + } + return nil } diff --git a/packets.go b/packets.go index a1aaf20ee..0127232ee 100644 --- a/packets.go +++ b/packets.go @@ -98,34 +98,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { return ErrPktTooLarge } - // Perform a stale connection check. We only perform this check for - // the first query on a connection that has been checked out of the - // connection pool: a fresh connection from the pool is more likely - // to be stale, and it has not performed any previous writes that - // could cause data corruption, so it's safe to return ErrBadConn - // if the check fails. - if mc.reset { - mc.reset = false - conn := mc.netConn - if mc.rawConn != nil { - conn = mc.rawConn - } - var err error - if mc.cfg.CheckConnLiveness { - if mc.cfg.ReadTimeout != 0 { - err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout)) - } - if err == nil { - err = connCheck(conn) - } - } - if err != nil { - mc.cfg.Logger.Print("closing bad idle connection: ", err) - mc.Close() - return driver.ErrBadConn - } - } - for { var size int if pktLen >= maxPacketSize { From 9c633df1f62eadfdc840840a0f229ea59cc15c33 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Tue, 10 Oct 2023 17:39:46 +0900 Subject: [PATCH 049/123] fix race condition of TestConcurrent (#1490) * fix race condition of TestConcurrent * run tests with the '-race' option --- .github/workflows/test.yml | 2 +- driver_test.go | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b25c9e389..8e1cb9bc3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,7 +96,7 @@ jobs: - name: test run: | - go test -v '-covermode=count' '-coverprofile=coverage.out' + go test -v '-race' '-covermode=atomic' '-coverprofile=coverage.out' - name: Send coverage uses: shogo82148/actions-goveralls@v1 diff --git a/driver_test.go b/driver_test.go index dd3d73141..74f15c2d2 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1872,7 +1872,6 @@ func TestConcurrent(t *testing.T) { defer wg.Done() tx, err := dbt.db.Begin() - atomic.AddInt32(&remaining, -1) if err != nil { if err.Error() != "Error 1040: Too many connections" { @@ -1882,7 +1881,7 @@ func TestConcurrent(t *testing.T) { } // keep the connection busy until all connections are open - for remaining > 0 { + for atomic.AddInt32(&remaining, -1) > 0 { if _, err = tx.Exec("DO 1"); err != nil { fatalf("error on conn %d: %s", id, err.Error()) return From 278a0b9e6b34ccc52aa213681836a79336714d34 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Tue, 10 Oct 2023 17:48:58 +0900 Subject: [PATCH 050/123] mark fail, mustExec and mustQuery as test helpers (#1488) --- driver_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/driver_test.go b/driver_test.go index 74f15c2d2..f46d38df6 100644 --- a/driver_test.go +++ b/driver_test.go @@ -165,6 +165,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { } func (dbt *DBTest) fail(method, query string, err error) { + dbt.Helper() if len(query) > 300 { query = "[query too large to print]" } @@ -172,6 +173,7 @@ func (dbt *DBTest) fail(method, query string, err error) { } func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { + dbt.Helper() res, err := dbt.db.Exec(query, args...) if err != nil { dbt.fail("exec", query, err) @@ -180,6 +182,7 @@ func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) } func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { + dbt.Helper() rows, err := dbt.db.Query(query, args...) if err != nil { dbt.fail("query", query, err) From 1e6b8d7df47928193f2b1a04b5f7f06907187508 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Mengu=C3=A9?= Date: Thu, 19 Oct 2023 07:33:23 +0200 Subject: [PATCH 051/123] Remove obsolete fuzz.go (#1498) fuzz.go (added in #1097) uses gofuzz. But #1444 added a better fuzzer that uses Go builtin fuzzing. Closes #1445. --- fuzz.go | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 fuzz.go diff --git a/fuzz.go b/fuzz.go deleted file mode 100644 index 3a4ec25a9..000000000 --- a/fuzz.go +++ /dev/null @@ -1,25 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. -// -// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build gofuzz -// +build gofuzz - -package mysql - -import ( - "database/sql" -) - -func Fuzz(data []byte) int { - db, err := sql.Open("mysql", string(data)) - if err != nil { - return 0 - } - db.Close() - return 1 -} From 62c29ce0b1b8f84567de97ca0d32cebd53f05aa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Mengu=C3=A9?= Date: Tue, 24 Oct 2023 10:05:53 +0200 Subject: [PATCH 052/123] Allow to change (or disable) the default driver name for registration (#1499) A link variable now allows to change or disable the name of the driver that is automatically registered with database/sql: Change driver name: go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom" Disable driver registration (set driverName to empty string): go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=" In the same way, a variable overridable at link time is also provided to override the driver name used in the test suite. This allows to run our test suite on another driver. go test "-ldflags=-X github.com/go-sql-driver/mysql.driverNameTest=custom" driverName is propagated to driverNameTest unless driverNameTest is explicitely defined. --- benchmark_test.go | 8 ++++---- driver.go | 8 +++++++- driver_test.go | 28 +++++++++++++++++++--------- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index fc70df60d..a4ecc0a63 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -48,7 +48,7 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { func initDB(b *testing.B, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) for _, query := range queries { if _, err := db.Exec(query); err != nil { b.Fatalf("error on %q: %v", query, err) @@ -105,7 +105,7 @@ func BenchmarkExec(b *testing.B) { tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) db.SetMaxIdleConns(concurrencyLevel) defer db.Close() @@ -151,7 +151,7 @@ func BenchmarkRoundtripTxt(b *testing.B) { sampleString := string(sample) b.ReportAllocs() tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() b.StartTimer() var result string @@ -184,7 +184,7 @@ func BenchmarkRoundtripBin(b *testing.B) { sample, min, max := initRoundtripBenchmarks() b.ReportAllocs() tb := (*TB)(b) - db := tb.checkDB(sql.Open("mysql", dsn)) + db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() stmt := tb.checkStmt(db.Prepare("SELECT ?")) defer stmt.Close() diff --git a/driver.go b/driver.go index 0ed8fa1c5..45528b920 100644 --- a/driver.go +++ b/driver.go @@ -90,8 +90,14 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return c.Connect(context.Background()) } +// This variable can be replaced with -ldflags like below: +// go build "-ldflags=-X github.com/go-sql-driver/mysql.driverName=custom" +var driverName = "mysql" + func init() { - sql.Register("mysql", &MySQLDriver{}) + if driverName != "" { + sql.Register(driverName, &MySQLDriver{}) + } } // NewConnector returns new driver.Connector. diff --git a/driver_test.go b/driver_test.go index f46d38df6..13e07e753 100644 --- a/driver_test.go +++ b/driver_test.go @@ -31,6 +31,16 @@ import ( "time" ) +// This variable can be replaced with -ldflags like below: +// go test "-ldflags=-X github.com/go-sql-driver/mysql.driverNameTest=custom" +var driverNameTest string + +func init() { + if driverNameTest == "" { + driverNameTest = driverName + } +} + // Ensure that all the driver interfaces are implemented var ( _ driver.Rows = &binaryRows{} @@ -111,7 +121,7 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT dsn += "&multiStatements=true" var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { - db, err = sql.Open("mysql", dsn) + db, err = sql.Open(driverNameTest, dsn) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -130,7 +140,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { t.Skipf("MySQL server not running on %s", netAddr) } - db, err := sql.Open("mysql", dsn) + db, err := sql.Open(driverNameTest, dsn) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -141,7 +151,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { dsn2 := dsn + "&interpolateParams=true" var db2 *sql.DB if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { - db2, err = sql.Open("mysql", dsn2) + db2, err = sql.Open(driverNameTest, dsn2) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -1917,7 +1927,7 @@ func testDialError(t *testing.T, dialErr error, expectErr error) { return nil, dialErr }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -1956,7 +1966,7 @@ func TestCustomDial(t *testing.T) { return d.DialContext(ctx, prot, addr) }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -2054,7 +2064,7 @@ func TestUnixSocketAuthFail(t *testing.T) { } t.Logf("socket: %s", socket) badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname) - db, err := sql.Open("mysql", badDSN) + db, err := sql.Open(driverNameTest, badDSN) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -2243,7 +2253,7 @@ func TestEmptyPassword(t *testing.T) { } dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname) - db, err := sql.Open("mysql", dsn) + db, err := sql.Open(driverNameTest, dsn) if err == nil { defer db.Close() err = db.Ping() @@ -3210,7 +3220,7 @@ func TestConnectorObeysDialTimeouts(t *testing.T) { return d.DialContext(ctx, prot, addr) }) - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + db, err := sql.Open(driverNameTest, fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } @@ -3375,7 +3385,7 @@ func TestConnectionAttributes(t *testing.T) { var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { - db, err = sql.Open("mysql", dsn) + db, err = sql.Open(driverNameTest, dsn) if err != nil { t.Fatalf("error connecting: %s", err.Error()) } From c175348d98a9a245462ade75c6fde69424eb6fd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Mengu=C3=A9?= Date: Tue, 24 Oct 2023 10:08:26 +0200 Subject: [PATCH 053/123] testing: expose testing.TB in DBTest instead of full *testing.T (#1500) Reduce the methods exposed by DBTest to the subset of testing.T exposed in the testing.TB interface. --- driver_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/driver_test.go b/driver_test.go index 13e07e753..f256011a7 100644 --- a/driver_test.go +++ b/driver_test.go @@ -92,7 +92,7 @@ func init() { } type DBTest struct { - *testing.T + testing.TB db *sql.DB } From 18b74e415dc148b486af13faa300fdefe26e484f Mon Sep 17 00:00:00 2001 From: Vaibhav Panvalkar <42548559+panvalkar1994@users.noreply.github.com> Date: Tue, 7 Nov 2023 17:27:05 +0530 Subject: [PATCH 054/123] symbol removed from installation command (#1510) Co-authored-by: panvalkar1994 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index fff8969f3..ac79890a7 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac ## Installation Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: ```bash -$ go get -u github.com/go-sql-driver/mysql +go get -u github.com/go-sql-driver/mysql ``` Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. From b2e2ccbf16565d9706a2ffe77aafb21fb545a8d5 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Tue, 14 Nov 2023 19:17:17 +0800 Subject: [PATCH 055/123] QueryUnescape DSN ConnectionAttribute value (#1470) --- AUTHORS | 2 ++ driver_test.go | 18 +++++++++++++++--- dsn.go | 6 +++++- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/AUTHORS b/AUTHORS index dec27daca..c84293100 100644 --- a/AUTHORS +++ b/AUTHORS @@ -109,6 +109,7 @@ Xiangyu Hu Xiaobing Jiang Xiuming Chen Xuehong Chan +Zhang Xiang Zhenye Xie Zhixin Wen Ziheng Lyu @@ -127,6 +128,7 @@ InfoSum Ltd. Keybase Inc. Multiplay Ltd. Percona LLC +PingCAP Inc. Pivotal Inc. Stripe Inc. Zendesk Inc. diff --git a/driver_test.go b/driver_test.go index f256011a7..8c02f6d1c 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3379,9 +3379,10 @@ func TestConnectionAttributes(t *testing.T) { attr1 := "attr1" value1 := "value1" - attr2 := "foo" - value2 := "boo" - dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2) + attr2 := "fo/o" + value2 := "bo/o" + dsn += "&connectionAttributes=" + url.QueryEscape(fmt.Sprintf("%s:%s,%s:%s", attr1, value1, attr2, value2)) + var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { @@ -3407,6 +3408,17 @@ func TestConnectionAttributes(t *testing.T) { } rows.Close() + rows = dbt.mustQuery(queryString, attr1) + if rows.Next() { + rows.Scan(&attrValue) + if attrValue != value1 { + dbt.Errorf("expected %q, got %q", value1, attrValue) + } + } else { + dbt.Errorf("no data") + } + rows.Close() + rows = dbt.mustQuery(queryString, attr2) if rows.Next() { rows.Scan(&attrValue) diff --git a/dsn.go b/dsn.go index 50c7ec413..ef0608636 100644 --- a/dsn.go +++ b/dsn.go @@ -569,7 +569,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Connection attributes case "connectionAttributes": - cfg.ConnectionAttributes = value + connectionAttributes, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid connectionAttributes value: %v", err) + } + cfg.ConnectionAttributes = connectionAttributes default: // lazy init From a4c260b40eeb51bd823d8b04d0e0e8d072e56adf Mon Sep 17 00:00:00 2001 From: Aidan <97376271+keeplearning20221@users.noreply.github.com> Date: Wed, 15 Nov 2023 18:40:52 +0800 Subject: [PATCH 056/123] fix hangup when error in multi resultsets (#1462) Fix #1361 Co-authored-by: Inada Naoki --- AUTHORS | 1 + driver_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ rows.go | 8 +++++++- 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index c84293100..c7e159603 100644 --- a/AUTHORS +++ b/AUTHORS @@ -13,6 +13,7 @@ Aaron Hopkins Achille Roussel +Aidan Alex Snast Alexey Palazhchenko Andrew Reid diff --git a/driver_test.go b/driver_test.go index 8c02f6d1c..ab780f04c 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3430,3 +3430,44 @@ func TestConnectionAttributes(t *testing.T) { } rows.Close() } + +func TestErrorInMultiResult(t *testing.T) { + // https://github.com/go-sql-driver/mysql/issues/1361 + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + query := ` +CREATE PROCEDURE test_proc1() +BEGIN + SELECT 1,2; + SELECT 3,4; + SIGNAL SQLSTATE '10000' SET MESSAGE_TEXT = "some error", MYSQL_ERRNO = 10000; +END; +` + runCallCommand(dbt, query, "test_proc1") +} + +func runCallCommand(dbt *DBTest, query, name string) { + dbt.mustExec(fmt.Sprintf("DROP PROCEDURE IF EXISTS %s", name)) + dbt.mustExec(query) + defer dbt.mustExec("DROP PROCEDURE " + name) + rows, err := dbt.db.Query(fmt.Sprintf("CALL %s", name)) + if err != nil { + return + } + defer rows.Close() + + for rows.Next() { + } + for rows.NextResultSet() { + for rows.Next() { + } + } +} diff --git a/rows.go b/rows.go index 63d0ed2d5..81fa6062c 100644 --- a/rows.go +++ b/rows.go @@ -163,7 +163,13 @@ func (rows *mysqlRows) nextResultSet() (int, error) { rows.rs = resultSet{} // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to // nextResultSet. - return rows.mc.resultUnchanged().readResultSetHeaderPacket() + resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() + if err != nil { + // Clean up about multi-results flag + rows.rs.done = true + rows.mc.status = rows.mc.status & (^statusMoreResultsExists) + } + return resLen, err } func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { From 98d72897bab37633105da6dce698ce074fd19995 Mon Sep 17 00:00:00 2001 From: Jason Ng Date: Thu, 23 Nov 2023 21:01:24 +0800 Subject: [PATCH 057/123] Add default connection attribute '_server_host' (#1506) The `_server_host` connection attribute is supported in MariaDB (Connector/C) https://mariadb.com/kb/en/mysql_optionsv/#connection-attribute-options --- AUTHORS | 2 ++ connector.go | 21 +++++++------- connector_test.go | 7 ++--- const.go | 1 + driver.go | 9 ++---- driver_test.go | 71 ++++++++++++++++++++++++----------------------- packets.go | 16 +++++------ packets_test.go | 5 +--- 8 files changed, 64 insertions(+), 68 deletions(-) diff --git a/AUTHORS b/AUTHORS index c7e159603..2caa7d706 100644 --- a/AUTHORS +++ b/AUTHORS @@ -50,6 +50,7 @@ INADA Naoki Jacek Szwec James Harr Janek Vedock +Jason Ng Jean-Yves Pellé Jeff Hodges Jeffrey Charles @@ -131,6 +132,7 @@ Multiplay Ltd. Percona LLC PingCAP Inc. Pivotal Inc. +Shattered Silicon Ltd. Stripe Inc. Zendesk Inc. Dolthub Inc. diff --git a/connector.go b/connector.go index ba3be71e7..3cef7963f 100644 --- a/connector.go +++ b/connector.go @@ -11,7 +11,6 @@ package mysql import ( "context" "database/sql/driver" - "fmt" "net" "os" "strconv" @@ -23,8 +22,8 @@ type connector struct { encodedAttributes string // Encoded connection attributes. } -func encodeConnectionAttributes(textAttributes string) string { - connAttrsBuf := make([]byte, 0, 251) +func encodeConnectionAttributes(cfg *Config) string { + connAttrsBuf := make([]byte, 0) // default connection attributes connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName) @@ -35,9 +34,14 @@ func encodeConnectionAttributes(textAttributes string) string { connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue) connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid) connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid())) + serverHost, _, _ := net.SplitHostPort(cfg.Addr) + if serverHost != "" { + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost) + } // user-defined connection attributes - for _, connAttr := range strings.Split(textAttributes, ",") { + for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") { k, v, found := strings.Cut(connAttr, ":") if !found { continue @@ -49,15 +53,12 @@ func encodeConnectionAttributes(textAttributes string) string { return string(connAttrsBuf) } -func newConnector(cfg *Config) (*connector, error) { - encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes) - if len(encodedAttributes) > 250 { - return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes) - } +func newConnector(cfg *Config) *connector { + encodedAttributes := encodeConnectionAttributes(cfg) return &connector{ cfg: cfg, encodedAttributes: encodedAttributes, - }, nil + } } // Connect implements driver.Connector interface. diff --git a/connector_test.go b/connector_test.go index bedb44ce2..82d8c5989 100644 --- a/connector_test.go +++ b/connector_test.go @@ -8,16 +8,13 @@ import ( ) func TestConnectorReturnsTimeout(t *testing.T) { - connector, err := newConnector(&Config{ + connector := newConnector(&Config{ Net: "tcp", Addr: "1.1.1.1:1234", Timeout: 10 * time.Millisecond, }) - if err != nil { - t.Fatal(err) - } - _, err = connector.Connect(context.Background()) + _, err := connector.Connect(context.Background()) if err == nil { t.Fatal("error expected") } diff --git a/const.go b/const.go index 0f2621a6f..22526e031 100644 --- a/const.go +++ b/const.go @@ -26,6 +26,7 @@ const ( connAttrPlatform = "_platform" connAttrPlatformValue = runtime.GOARCH connAttrPid = "_pid" + connAttrServerHost = "_server_host" ) // MySQL constants documentation: diff --git a/driver.go b/driver.go index 45528b920..105316b81 100644 --- a/driver.go +++ b/driver.go @@ -83,10 +83,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - c, err := newConnector(cfg) - if err != nil { - return nil, err - } + c := newConnector(cfg) return c.Connect(context.Background()) } @@ -108,7 +105,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) { if err := cfg.normalize(); err != nil { return nil, err } - return newConnector(cfg) + return newConnector(cfg), nil } // OpenConnector implements driver.DriverContext. @@ -117,5 +114,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { if err != nil { return nil, err } - return newConnector(cfg) + return newConnector(cfg), nil } diff --git a/driver_test.go b/driver_test.go index ab780f04c..efbff1792 100644 --- a/driver_test.go +++ b/driver_test.go @@ -24,6 +24,7 @@ import ( "os" "reflect" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -3377,12 +3378,30 @@ func TestConnectionAttributes(t *testing.T) { t.Skipf("MySQL server not running on %s", netAddr) } - attr1 := "attr1" - value1 := "value1" - attr2 := "fo/o" - value2 := "bo/o" - dsn += "&connectionAttributes=" + url.QueryEscape(fmt.Sprintf("%s:%s,%s:%s", attr1, value1, attr2, value2)) + defaultAttrs := []string{ + connAttrClientName, + connAttrOS, + connAttrPlatform, + connAttrPid, + connAttrServerHost, + } + host, _, _ := net.SplitHostPort(addr) + defaultAttrValues := []string{ + connAttrClientNameValue, + connAttrOSValue, + connAttrPlatformValue, + strconv.Itoa(os.Getpid()), + host, + } + + customAttrs := []string{"attr1", "fo/o"} + customAttrValues := []string{"value1", "bo/o"} + customAttrStrs := make([]string, len(customAttrs)) + for i := range customAttrs { + customAttrStrs[i] = fmt.Sprintf("%s:%s", customAttrs[i], customAttrValues[i]) + } + dsn += "&connectionAttributes=" + url.QueryEscape(strings.Join(customAttrStrs, ",")) var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { @@ -3395,40 +3414,24 @@ func TestConnectionAttributes(t *testing.T) { dbt := &DBTest{t, db} - var attrValue string - queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?" - rows := dbt.mustQuery(queryString, connAttrClientName) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != connAttrClientNameValue { - dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue) - } - } else { - dbt.Errorf("no data") - } - rows.Close() + queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" + rows := dbt.mustQuery(queryString) + defer rows.Close() - rows = dbt.mustQuery(queryString, attr1) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != value1 { - dbt.Errorf("expected %q, got %q", value1, attrValue) - } - } else { - dbt.Errorf("no data") + rowsMap := make(map[string]string) + for rows.Next() { + var attrName, attrValue string + rows.Scan(&attrName, &attrValue) + rowsMap[attrName] = attrValue } - rows.Close() - rows = dbt.mustQuery(queryString, attr2) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != value2 { - dbt.Errorf("expected %q, got %q", value2, attrValue) + connAttrs := append(append([]string{}, defaultAttrs...), customAttrs...) + expectedAttrValues := append(append([]string{}, defaultAttrValues...), customAttrValues...) + for i := range connAttrs { + if gotValue := rowsMap[connAttrs[i]]; gotValue != expectedAttrValues[i] { + dbt.Errorf("expected %q, got %q", expectedAttrValues[i], gotValue) } - } else { - dbt.Errorf("no data") } - rows.Close() } func TestErrorInMultiResult(t *testing.T) { diff --git a/packets.go b/packets.go index 0127232ee..49e6bb058 100644 --- a/packets.go +++ b/packets.go @@ -292,15 +292,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pktLen += n + 1 } - // 1 byte to store length of all key-values - // NOTE: Actually, this is length encoded integer. - // But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer - // doesn't support buffer size more than 4096 bytes. - // TODO(methane): Rewrite buffer management. - pktLen += 1 + len(mc.connector.encodedAttributes) + // encode length of the connection attributes + var connAttrsLEIBuf [9]byte + connAttrsLen := len(mc.connector.encodedAttributes) + connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) + pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) // Calculate packet length and get buffer with that size - data, err := mc.buf.takeSmallBuffer(pktLen + 4) + data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection mc.cfg.Logger.Print(err) @@ -380,8 +379,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pos++ // Connection Attributes - data[pos] = byte(len(mc.connector.encodedAttributes)) - pos++ + pos += copy(data[pos:], connAttrsLEI) pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) // Send Auth packet diff --git a/packets_test.go b/packets_test.go index e86ec5848..fa4683eab 100644 --- a/packets_test.go +++ b/packets_test.go @@ -96,10 +96,7 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) - connector, err := newConnector(NewConfig()) - if err != nil { - panic(err) - } + connector := newConnector(NewConfig()) mc := &mysqlConn{ buf: newBuffer(conn), cfg: connector.cfg, From d9f43839450e9361c16685ea24f0bce0da1935b7 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 12 Dec 2023 14:21:53 +0900 Subject: [PATCH 058/123] fix fragile test (#1522) --- driver_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/driver_test.go b/driver_test.go index efbff1792..87892a09a 100644 --- a/driver_test.go +++ b/driver_test.go @@ -128,6 +128,8 @@ func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBT } defer db.Close() } + // Previous test may be skipped without dropping the test table + db.Exec("DROP TABLE IF EXISTS test") dbt := &DBTest{t, db} for _, test := range tests { @@ -147,6 +149,7 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { } defer db.Close() + // Previous test may be skipped without dropping the test table db.Exec("DROP TABLE IF EXISTS test") dsn2 := dsn + "&interpolateParams=true" From fc589cbaba22032382488393c72b9b3b5366917c Mon Sep 17 00:00:00 2001 From: Gusted Date: Tue, 12 Dec 2023 10:26:35 +0100 Subject: [PATCH 059/123] Add client_ed25519 authentication (#1518) Implements the necessary client code for [ed25519 authentication](https://mariadb.com/kb/en/authentication-plugin-ed25519/). This patch uses filippo.io/edwards25519 to implement the crypto bits. The standard library `crypto/ed25519` cannot be used as MariaDB chose a scheme that is simply not compatible with what the standard library provides. --- AUTHORS | 1 + auth.go | 47 ++++++++++++++++++++++++++++++++++++++++++++++ auth_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ driver_test.go | 10 +++++----- go.mod | 2 ++ go.sum | 2 ++ 6 files changed, 108 insertions(+), 5 deletions(-) create mode 100644 go.sum diff --git a/AUTHORS b/AUTHORS index 2caa7d706..954e7ac7a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -39,6 +39,7 @@ Evan Elias Evan Shaw Frederick Mayle Gustavo Kristic +Gusted Hajime Nakagami Hanno Braun Henri Yandell diff --git a/auth.go b/auth.go index bab282bd2..658259b24 100644 --- a/auth.go +++ b/auth.go @@ -13,10 +13,13 @@ import ( "crypto/rsa" "crypto/sha1" "crypto/sha256" + "crypto/sha512" "crypto/x509" "encoding/pem" "fmt" "sync" + + "filippo.io/edwards25519" ) // server pub keys registry @@ -225,6 +228,44 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) } +// authEd25519 does ed25519 authentication used by MariaDB. +func authEd25519(scramble []byte, password string) ([]byte, error) { + // Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c + // Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207 + h := sha512.Sum512([]byte(password)) + + s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) + if err != nil { + return nil, err + } + A := (&edwards25519.Point{}).ScalarBaseMult(s) + + mh := sha512.New() + mh.Write(h[32:]) + mh.Write(scramble) + messageDigest := mh.Sum(nil) + r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) + if err != nil { + return nil, err + } + + R := (&edwards25519.Point{}).ScalarBaseMult(r) + + kh := sha512.New() + kh.Write(R.Bytes()) + kh.Write(A.Bytes()) + kh.Write(scramble) + hramDigest := kh.Sum(nil) + k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) + if err != nil { + return nil, err + } + + S := k.MultiplyAdd(k, s, r) + + return append(R.Bytes(), S.Bytes()...), nil +} + func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) if err != nil { @@ -290,6 +331,12 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) return enc, err + case "client_ed25519": + if len(authData) != 32 { + return nil, ErrMalformPkt + } + return authEd25519(authData, mc.cfg.Passwd) + default: mc.cfg.Logger.Print("unknown auth plugin:", plugin) return nil, ErrUnknownPlugin diff --git a/auth_test.go b/auth_test.go index 3ce0ea6e0..8caed1fff 100644 --- a/auth_test.go +++ b/auth_test.go @@ -1328,3 +1328,54 @@ func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { t.Errorf("got unexpected data: %v", conn.written) } } + +// Derived from https://github.com/MariaDB/server/blob/6b2287fff23fbdc362499501c562f01d0d2db52e/plugin/auth_ed25519/ed25519-t.c +func TestEd25519Auth(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "foobar" + + authData := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + plugin := "client_ed25519" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{ + 232, 61, 201, 63, 67, 63, 51, 53, 86, 73, 238, 35, 170, 117, 146, + 214, 26, 17, 35, 9, 8, 132, 245, 141, 48, 99, 66, 58, 36, 228, 48, + 84, 115, 254, 187, 168, 88, 162, 249, 57, 35, 85, 79, 238, 167, 106, + 68, 117, 56, 135, 171, 47, 20, 14, 133, 79, 15, 229, 124, 160, 176, + 100, 138, 14, + } + if writtenAuthRespLen != 64 { + t.Fatalf("expected 64 bytes from client, got %d", writtenAuthRespLen) + } + if !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("auth response did not match expected value:\n%v\n%v", writtenAuthResp, expectedAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} diff --git a/driver_test.go b/driver_test.go index 87892a09a..97fd5a17a 100644 --- a/driver_test.go +++ b/driver_test.go @@ -165,14 +165,14 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { for _, test := range tests { t.Run("default", func(t *testing.T) { dbt := &DBTest{t, db} + defer dbt.db.Exec("DROP TABLE IF EXISTS test") test(dbt) - dbt.db.Exec("DROP TABLE IF EXISTS test") }) if db2 != nil { t.Run("interpolateParams", func(t *testing.T) { dbt2 := &DBTest{t, db2} + defer dbt2.db.Exec("DROP TABLE IF EXISTS test") test(dbt2) - dbt2.db.Exec("DROP TABLE IF EXISTS test") }) } } @@ -3181,14 +3181,14 @@ func TestRawBytesAreNotModified(t *testing.T) { rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) if err != nil { - t.Fatal(err) + dbt.Fatal(err) } var b int var raw sql.RawBytes for rows.Next() { if err := rows.Scan(&b, &raw); err != nil { - t.Fatal(err) + dbt.Fatal(err) } before := string(raw) @@ -3198,7 +3198,7 @@ func TestRawBytesAreNotModified(t *testing.T) { after := string(raw) if before != after { - t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) + dbt.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) } } rows.Close() diff --git a/go.mod b/go.mod index 77bbb8dbf..4629714c0 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/go-sql-driver/mysql go 1.18 + +require filippo.io/edwards25519 v1.1.0 diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..359ca94b4 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= From 2cdf62442f2edb873d1270897d994fc83b78f118 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Wed, 13 Dec 2023 15:21:30 +0900 Subject: [PATCH 060/123] Fix sql.RawBytes corruption issue (#1523) --- driver_test.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/driver_test.go b/driver_test.go index 97fd5a17a..d7359085d 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3183,25 +3183,26 @@ func TestRawBytesAreNotModified(t *testing.T) { if err != nil { dbt.Fatal(err) } + defer rows.Close() var b int var raw sql.RawBytes - for rows.Next() { - if err := rows.Scan(&b, &raw); err != nil { - dbt.Fatal(err) - } + if !rows.Next() { + dbt.Fatal("expected at least one row") + } + if err := rows.Scan(&b, &raw); err != nil { + dbt.Fatal(err) + } - before := string(raw) - // Ensure cancelling the query does not corrupt the contents of `raw` - cancel() - time.Sleep(time.Microsecond * 100) - after := string(raw) + before := string(raw) + // Ensure cancelling the query does not corrupt the contents of `raw` + cancel() + time.Sleep(time.Microsecond * 100) + after := string(raw) - if before != after { - dbt.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) - } + if before != after { + dbt.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) } - rows.Close() }() } }) From d4517c5d905ccd3cc1e750f592edfa88d774d908 Mon Sep 17 00:00:00 2001 From: jennifersp <44716627+jennifersp@users.noreply.github.com> Date: Wed, 13 Dec 2023 00:50:21 -0800 Subject: [PATCH 061/123] Support ENUM and SET type in DatabaseTypeName() (#1520) --- AUTHORS | 5 +++-- driver_test.go | 2 ++ fields.go | 5 +++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index 954e7ac7a..0ada02d86 100644 --- a/AUTHORS +++ b/AUTHORS @@ -21,6 +21,7 @@ Animesh Ray Arne Hormann Ariel Mashraki Asta Xie +Brian Hendriks Bulat Gaifullin Caine Jette Carlos Nieto @@ -55,6 +56,7 @@ Jason Ng Jean-Yves Pellé Jeff Hodges Jeffrey Charles +Jennifer Purevsuren Jerome Meyer Jiajia Zhong Jian Zhen @@ -116,13 +118,13 @@ Zhang Xiang Zhenye Xie Zhixin Wen Ziheng Lyu -Brian Hendriks # Organizations Barracuda Networks, Inc. Counting Ltd. DigitalOcean Inc. +Dolthub Inc. dyves labs AG Facebook Inc. GitHub Inc. @@ -136,4 +138,3 @@ Pivotal Inc. Shattered Silicon Ltd. Stripe Inc. Zendesk Inc. -Dolthub Inc. diff --git a/driver_test.go b/driver_test.go index d7359085d..8ec1be412 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3007,6 +3007,8 @@ func TestRowsColumnTypes(t *testing.T) { {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}}, {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}}, + {"enum", "ENUM('', 'v1', 'v2')", "ENUM", scanTypeNullString, true, 0, 0, [3]string{"''", "'v1'", "'v2'"}, [3]interface{}{ns(""), ns("v1"), ns("v2")}}, + {"set", "set('', 'v1', 'v2')", "SET", scanTypeNullString, true, 0, 0, [3]string{"''", "'v1'", "'v1,v2'"}, [3]interface{}{ns(""), ns("v1"), ns("v1,v2")}}, } schema := "" diff --git a/fields.go b/fields.go index 30f31cbfb..2a397b245 100644 --- a/fields.go +++ b/fields.go @@ -77,6 +77,11 @@ func (mf *mysqlField) typeDatabaseName() string { } return "SMALLINT" case fieldTypeString: + if mf.flags&flagEnum != 0 { + return "ENUM" + } else if mf.flags&flagSet != 0 { + return "SET" + } if mf.charSet == binaryCollationID { return "BINARY" } From 0004702b931d3429afb3e16df444ed80be24d1f4 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Wed, 13 Dec 2023 20:25:41 +0900 Subject: [PATCH 062/123] Parallelize test (#1525) * Refactor test cleanup in driver_test.go * parallelize TestEmptyQuery and TestCRUD * parallelize TestNumbersToAny * parallelize TestInt * parallelize TestFloat32 * parallelize TestFloat64 * parallelize TestFloat64Placeholder * parallelize TestString * parallelize TestRawBytes * parallelize TestRawMessage * parallelize TestValuer * parallelize TestValuerWithValidation * parallelize TestTimestampMicros * parallelize TestNULL * parallelize TestUint64 * parallelize TestLongData * parallelize TestContextCancelExec * parallelize TestPingContext * parallelize TestContextCancelQuery * parallelize TestContextCancelQueryRow * Revert "parallelize TestLongData" This reverts commit a360be7a110bb6372bed8cf7bc467e3c2dae3c66. * parallelize TestContextCancelPrepare * parallelize TestContextCancelStmtExec * parallelize TestContextCancelStmtQuery * parallelize TestContextCancelBegin * parallelize TestContextBeginIsolationLevel * parallelize TestContextBeginReadOnly * parallelize TestValuerWithValueReceiverGivenNilValue * parallelize TestRawBytesAreNotModified * parallelize TestFoundRows * parallelize TestRowsClose * parallelize TestCloseStmtBeforeRows * parallelize TestStmtMultiRows * Revert "parallelize TestRawBytesAreNotModified" This reverts commit 91622f05d44481dd9867eeaaf382da239afe3925. * parallelize TestStaleConnectionChecks * parallelize TestFailingCharset * parallelize TestColumnsWithAlias * parallelize TestRawBytesResultExceedsBuffer * parallelize TestUnixSocketAuthFail * parallelize TestSkipResults * Add parallel flag to go test command * Revert "parallelize TestUnixSocketAuthFail" This reverts commit b3df7bd130a21294a45c3733f1d2541b15582111. --- .github/workflows/test.yml | 2 +- conncheck_test.go | 2 +- driver_test.go | 332 ++++++++++++++++++++++--------------- 3 files changed, 198 insertions(+), 138 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8e1cb9bc3..aae421196 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,7 +96,7 @@ jobs: - name: test run: | - go test -v '-race' '-covermode=atomic' '-coverprofile=coverage.out' + go test -v '-race' '-covermode=atomic' '-coverprofile=coverage.out' -parallel 10 - name: Send coverage uses: shogo82148/actions-goveralls@v1 diff --git a/conncheck_test.go b/conncheck_test.go index f7e025680..6b60cb7d6 100644 --- a/conncheck_test.go +++ b/conncheck_test.go @@ -17,7 +17,7 @@ import ( ) func TestStaleConnectionChecks(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { dbt.mustExec("SET @@SESSION.wait_timeout = 2") if err := dbt.db.Ping(); err != nil { diff --git a/driver_test.go b/driver_test.go index 8ec1be412..6bdb78c78 100644 --- a/driver_test.go +++ b/driver_test.go @@ -11,6 +11,7 @@ package mysql import ( "bytes" "context" + "crypto/rand" "crypto/tls" "database/sql" "database/sql/driver" @@ -149,8 +150,9 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { } defer db.Close() - // Previous test may be skipped without dropping the test table - db.Exec("DROP TABLE IF EXISTS test") + cleanup := func() { + db.Exec("DROP TABLE IF EXISTS test") + } dsn2 := dsn + "&interpolateParams=true" var db2 *sql.DB @@ -163,21 +165,80 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { } for _, test := range tests { + test := test t.Run("default", func(t *testing.T) { dbt := &DBTest{t, db} - defer dbt.db.Exec("DROP TABLE IF EXISTS test") + t.Cleanup(cleanup) test(dbt) }) if db2 != nil { t.Run("interpolateParams", func(t *testing.T) { dbt2 := &DBTest{t, db2} - defer dbt2.db.Exec("DROP TABLE IF EXISTS test") + t.Cleanup(cleanup) test(dbt2) }) } } } +// runTestsParallel runs the tests in parallel with a separate database connection for each test. +func runTestsParallel(t *testing.T, dsn string, tests ...func(dbt *DBTest, tableName string)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + newTableName := func(t *testing.T) string { + t.Helper() + var buf [8]byte + if _, err := rand.Read(buf[:]); err != nil { + t.Fatal(err) + } + return fmt.Sprintf("test_%x", buf[:]) + } + + t.Parallel() + for _, test := range tests { + test := test + + t.Run("default", func(t *testing.T) { + t.Parallel() + + tableName := newTableName(t) + db, err := sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + t.Cleanup(func() { + db.Exec("DROP TABLE IF EXISTS " + tableName) + db.Close() + }) + + dbt := &DBTest{t, db} + test(dbt, tableName) + }) + + dsn2 := dsn + "&interpolateParams=true" + if _, err := ParseDSN(dsn2); err == errInvalidDSNUnsafeCollation { + t.Run("interpolateParams", func(t *testing.T) { + t.Parallel() + + tableName := newTableName(t) + db, err := sql.Open("mysql", dsn2) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + t.Cleanup(func() { + db.Exec("DROP TABLE IF EXISTS " + tableName) + db.Close() + }) + + dbt := &DBTest{t, db} + test(dbt, tableName) + }) + } + } +} + func (dbt *DBTest) fail(method, query string, err error) { dbt.Helper() if len(query) > 300 { @@ -216,7 +277,7 @@ func maybeSkip(t *testing.T, err error, skipErrno uint16) { } func TestEmptyQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { // just a comment, no query rows := dbt.mustQuery("--") defer rows.Close() @@ -228,20 +289,20 @@ func TestEmptyQuery(t *testing.T) { } func TestCRUD(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { // Create Table - dbt.mustExec("CREATE TABLE test (value BOOL)") + dbt.mustExec("CREATE TABLE " + tbl + " (value BOOL)") // Test for unexpected data var out bool - rows := dbt.mustQuery("SELECT * FROM test") + rows := dbt.mustQuery("SELECT * FROM " + tbl) if rows.Next() { dbt.Error("unexpected data in empty table") } rows.Close() // Create Data - res := dbt.mustExec("INSERT INTO test VALUES (1)") + res := dbt.mustExec("INSERT INTO " + tbl + " VALUES (1)") count, err := res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -259,7 +320,7 @@ func TestCRUD(t *testing.T) { } // Read - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if true != out { @@ -275,7 +336,7 @@ func TestCRUD(t *testing.T) { rows.Close() // Update - res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) + res = dbt.mustExec("UPDATE "+tbl+" SET value = ? WHERE value = ?", false, true) count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -285,7 +346,7 @@ func TestCRUD(t *testing.T) { } // Check Update - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if false != out { @@ -301,7 +362,7 @@ func TestCRUD(t *testing.T) { rows.Close() // Delete - res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) + res = dbt.mustExec("DELETE FROM "+tbl+" WHERE value = ?", false) count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -311,7 +372,7 @@ func TestCRUD(t *testing.T) { } // Check for unexpected rows - res = dbt.mustExec("DELETE FROM test") + res = dbt.mustExec("DELETE FROM " + tbl) count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -325,13 +386,13 @@ func TestCRUD(t *testing.T) { // TestNumbers test that selecting numeric columns. // Both of textRows and binaryRows should return same type and value. func TestNumbersToAny(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE `test` (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " + + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " + "i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE)") - dbt.mustExec("INSERT INTO `test` VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5)") + dbt.mustExec("INSERT INTO " + tbl + " VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5)") // Use binaryRows for intarpolateParams=false and textRows for intarpolateParams=true. - rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64 FROM `test` WHERE id=?", 1) + rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64 FROM "+tbl+" WHERE id=?", 1) if !rows.Next() { dbt.Fatal("no data") } @@ -410,7 +471,7 @@ func TestMultiQuery(t *testing.T) { } func TestInt(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} in := int64(42) var out int64 @@ -418,11 +479,11 @@ func TestInt(t *testing.T) { // SIGNED for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in != out { @@ -433,16 +494,16 @@ func TestInt(t *testing.T) { } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } // UNSIGNED ZEROFILL for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + " ZEROFILL)") - dbt.mustExec("INSERT INTO test VALUES (?)", in) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in != out { @@ -453,21 +514,21 @@ func TestInt(t *testing.T) { } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } }) } func TestFloat32(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [2]string{"FLOAT", "DOUBLE"} in := float32(42.23) var out float32 var rows *sql.Rows for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ")") + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in != out { @@ -477,21 +538,21 @@ func TestFloat32(t *testing.T) { dbt.Errorf("%s: no data", v) } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } }) } func TestFloat64(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [2]string{"FLOAT", "DOUBLE"} var expected float64 = 42.23 var out float64 var rows *sql.Rows for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (42.23)") - rows = dbt.mustQuery("SELECT value FROM test") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ")") + dbt.mustExec("INSERT INTO " + tbl + " VALUES (42.23)") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if expected != out { @@ -501,21 +562,21 @@ func TestFloat64(t *testing.T) { dbt.Errorf("%s: no data", v) } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } }) } func TestFloat64Placeholder(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [2]string{"FLOAT", "DOUBLE"} var expected float64 = 42.23 var out float64 var rows *sql.Rows for _, v := range types { - dbt.mustExec("CREATE TABLE test (id int, value " + v + ")") - dbt.mustExec("INSERT INTO test VALUES (1, 42.23)") - rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) + dbt.mustExec("CREATE TABLE " + tbl + " (id int, value " + v + ")") + dbt.mustExec("INSERT INTO " + tbl + " VALUES (1, 42.23)") + rows = dbt.mustQuery("SELECT value FROM "+tbl+" WHERE id = ?", 1) if rows.Next() { rows.Scan(&out) if expected != out { @@ -525,24 +586,24 @@ func TestFloat64Placeholder(t *testing.T) { dbt.Errorf("%s: no data", v) } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } }) } func TestString(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" var out string var rows *sql.Rows for _, v := range types { - dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8") + dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ") CHARACTER SET utf8") - dbt.mustExec("INSERT INTO test VALUES (?)", in) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in != out { @@ -553,11 +614,11 @@ func TestString(t *testing.T) { } rows.Close() - dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("DROP TABLE IF EXISTS " + tbl) } // BLOB - dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + dbt.mustExec("CREATE TABLE " + tbl + " (id int, value BLOB) CHARACTER SET utf8") id := 2 in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + @@ -568,9 +629,9 @@ func TestString(t *testing.T) { "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." - dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?, ?)", id, in) - err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) + err := dbt.db.QueryRow("SELECT value FROM "+tbl+" WHERE id = ?", id).Scan(&out) if err != nil { dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) } else if out != in { @@ -580,7 +641,7 @@ func TestString(t *testing.T) { } func TestRawBytes(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { v1 := []byte("aaa") v2 := []byte("bbb") rows := dbt.mustQuery("SELECT ?, ?", v1, v2) @@ -609,7 +670,7 @@ func TestRawBytes(t *testing.T) { } func TestRawMessage(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { v1 := json.RawMessage("{}") v2 := json.RawMessage("[]") rows := dbt.mustQuery("SELECT ?, ?", v1, v2) @@ -640,14 +701,14 @@ func (tv testValuer) Value() (driver.Value, error) { } func TestValuer(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { in := testValuer{"a_value"} var out string var rows *sql.Rows - dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8") - dbt.mustExec("INSERT INTO test VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM test") + dbt.mustExec("CREATE TABLE " + tbl + " (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM " + tbl) if rows.Next() { rows.Scan(&out) if in.value != out { @@ -657,8 +718,6 @@ func TestValuer(t *testing.T) { dbt.Errorf("Valuer: no data") } rows.Close() - - dbt.mustExec("DROP TABLE IF EXISTS test") }) } @@ -675,15 +734,15 @@ func (tv testValuerWithValidation) Value() (driver.Value, error) { } func TestValuerWithValidation(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { in := testValuerWithValidation{"a_value"} var out string var rows *sql.Rows - dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8") - dbt.mustExec("INSERT INTO testValuer VALUES (?)", in) + dbt.mustExec("CREATE TABLE " + tbl + " (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in) - rows = dbt.mustQuery("SELECT value FROM testValuer") + rows = dbt.mustQuery("SELECT value FROM " + tbl) defer rows.Close() if rows.Next() { @@ -695,19 +754,17 @@ func TestValuerWithValidation(t *testing.T) { dbt.Errorf("Valuer: no data") } - if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil { + if _, err := dbt.db.Exec("INSERT INTO "+tbl+" VALUES (?)", testValuerWithValidation{""}); err == nil { dbt.Errorf("Failed to check valuer error") } - if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil { + if _, err := dbt.db.Exec("INSERT INTO "+tbl+" VALUES (?)", nil); err != nil { dbt.Errorf("Failed to check nil") } - if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil { + if _, err := dbt.db.Exec("INSERT INTO "+tbl+" VALUES (?)", map[string]bool{}); err == nil { dbt.Errorf("Failed to check not valuer") } - - dbt.mustExec("DROP TABLE IF EXISTS testValuer") }) } @@ -941,7 +998,7 @@ func TestTimestampMicros(t *testing.T) { f0 := format[:19] f1 := format[:21] f6 := format[:26] - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { // check if microseconds are supported. // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width // and not precision. @@ -956,7 +1013,7 @@ func TestTimestampMicros(t *testing.T) { return } _, err := dbt.db.Exec(` - CREATE TABLE test ( + CREATE TABLE ` + tbl + ` ( value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `', value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `', value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `' @@ -965,10 +1022,10 @@ func TestTimestampMicros(t *testing.T) { if err != nil { dbt.Error(err) } - defer dbt.mustExec("DROP TABLE IF EXISTS test") - dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6) + defer dbt.mustExec("DROP TABLE IF EXISTS " + tbl) + dbt.mustExec("INSERT INTO "+tbl+" SET value0=?, value1=?, value6=?", f0, f1, f6) var res0, res1, res6 string - rows := dbt.mustQuery("SELECT * FROM test") + rows := dbt.mustQuery("SELECT * FROM " + tbl) defer rows.Close() if !rows.Next() { dbt.Errorf("test contained no selectable values") @@ -990,7 +1047,7 @@ func TestTimestampMicros(t *testing.T) { } func TestNULL(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { nullStmt, err := dbt.db.Prepare("SELECT NULL") if err != nil { dbt.Fatal(err) @@ -1122,12 +1179,12 @@ func TestNULL(t *testing.T) { } // Insert NULL - dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") + dbt.mustExec("CREATE TABLE " + tbl + " (dummmy1 int, value int, dummy2 int)") - dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) + dbt.mustExec("INSERT INTO "+tbl+" VALUES (?, ?, ?)", 1, nil, 2) var out interface{} - rows := dbt.mustQuery("SELECT * FROM test") + rows := dbt.mustQuery("SELECT * FROM " + tbl) defer rows.Close() if rows.Next() { rows.Scan(&out) @@ -1151,7 +1208,7 @@ func TestUint64(t *testing.T) { shigh = int64(uhigh) stop = ^shigh ) - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`) if err != nil { dbt.Fatal(err) @@ -1347,12 +1404,12 @@ func TestLoadData(t *testing.T) { }) } -func TestFoundRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") - dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") +func TestFoundRows1(t *testing.T) { + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO " + tbl + " (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") - res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + res := dbt.mustExec("UPDATE " + tbl + " SET data = 1 WHERE id = 0") count, err := res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -1360,7 +1417,7 @@ func TestFoundRows(t *testing.T) { if count != 2 { dbt.Fatalf("Expected 2 affected rows, got %d", count) } - res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + res = dbt.mustExec("UPDATE " + tbl + " SET data = 1 WHERE id = 1") count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -1369,11 +1426,14 @@ func TestFoundRows(t *testing.T) { dbt.Fatalf("Expected 2 affected rows, got %d", count) } }) - runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") - dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") +} + +func TestFoundRows2(t *testing.T) { + runTestsParallel(t, dsn+"&clientFoundRows=true", func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO " + tbl + " (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") - res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + res := dbt.mustExec("UPDATE " + tbl + " SET data = 1 WHERE id = 0") count, err := res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -1381,7 +1441,7 @@ func TestFoundRows(t *testing.T) { if count != 2 { dbt.Fatalf("Expected 2 matched rows, got %d", count) } - res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + res = dbt.mustExec("UPDATE " + tbl + " SET data = 1 WHERE id = 1") count, err = res.RowsAffected() if err != nil { dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) @@ -1507,7 +1567,7 @@ func TestCharset(t *testing.T) { } func TestFailingCharset(t *testing.T) { - runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + runTestsParallel(t, dsn+"&charset=none", func(dbt *DBTest, _ string) { // run query to really establish connection... _, err := dbt.db.Exec("SELECT 1") if err == nil { @@ -1556,7 +1616,7 @@ func TestCollation(t *testing.T) { } func TestColumnsWithAlias(t *testing.T) { - runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { + runTestsParallel(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest, _ string) { rows := dbt.mustQuery("SELECT 1 AS A") defer rows.Close() cols, _ := rows.Columns() @@ -1580,7 +1640,7 @@ func TestColumnsWithAlias(t *testing.T) { } func TestRawBytesResultExceedsBuffer(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { // defaultBufSize from buffer.go expected := strings.Repeat("abc", defaultBufSize) @@ -1639,7 +1699,7 @@ func TestTimezoneConversion(t *testing.T) { // Special cases func TestRowsClose(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { rows, err := dbt.db.Query("SELECT 1") if err != nil { dbt.Fatal(err) @@ -1664,7 +1724,7 @@ func TestRowsClose(t *testing.T) { // dangling statements // http://code.google.com/p/go/issues/detail?id=3865 func TestCloseStmtBeforeRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { stmt, err := dbt.db.Prepare("SELECT 1") if err != nil { dbt.Fatal(err) @@ -1705,7 +1765,7 @@ func TestCloseStmtBeforeRows(t *testing.T) { // It is valid to have multiple Rows for the same Stmt // http://code.google.com/p/go/issues/detail?id=3734 func TestStmtMultiRows(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") if err != nil { dbt.Fatal(err) @@ -2507,7 +2567,7 @@ func TestExecMultipleResults(t *testing.T) { // tests if rows are set in a proper state if some results were ignored before // calling rows.NextResultSet. func TestSkipResults(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { rows := dbt.mustQuery("SELECT 1, 2") defer rows.Close() @@ -2562,7 +2622,7 @@ func TestQueryMultipleResults(t *testing.T) { } func TestPingContext(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { ctx, cancel := context.WithCancel(context.Background()) cancel() if err := dbt.db.PingContext(ctx); err != context.Canceled { @@ -2572,8 +2632,8 @@ func TestPingContext(t *testing.T) { } func TestContextCancelExec(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) // Delay execution for just a bit until db.ExecContext has begun. @@ -2581,7 +2641,7 @@ func TestContextCancelExec(t *testing.T) { // This query will be canceled. startTime := time.Now() - if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))"); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } if d := time.Since(startTime); d > 500*time.Millisecond { @@ -2593,7 +2653,7 @@ func TestContextCancelExec(t *testing.T) { // Check how many times the query is executed. var v int - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { // TODO: need to kill the query, and v should be 0. @@ -2601,14 +2661,14 @@ func TestContextCancelExec(t *testing.T) { } // Context is already canceled, so error should come before execution. - if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil { + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (1)"); err == nil { dbt.Error("expected error") } else if err.Error() != "context canceled" { dbt.Fatalf("unexpected error: %s", err) } // The second insert query will fail, so the table has no changes. - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { @@ -2618,8 +2678,8 @@ func TestContextCancelExec(t *testing.T) { } func TestContextCancelQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) // Delay execution for just a bit until db.ExecContext has begun. @@ -2627,7 +2687,7 @@ func TestContextCancelQuery(t *testing.T) { // This query will be canceled. startTime := time.Now() - if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))"); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } if d := time.Since(startTime); d > 500*time.Millisecond { @@ -2639,7 +2699,7 @@ func TestContextCancelQuery(t *testing.T) { // Check how many times the query is executed. var v int - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { // TODO: need to kill the query, and v should be 0. @@ -2647,12 +2707,12 @@ func TestContextCancelQuery(t *testing.T) { } // Context is already canceled, so error should come before execution. - if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled { + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO "+tbl+" VALUES (1)"); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } // The second insert query will fail, so the table has no changes. - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { @@ -2662,12 +2722,12 @@ func TestContextCancelQuery(t *testing.T) { } func TestContextCancelQueryRow(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") - dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") + dbt.mustExec("INSERT INTO " + tbl + " VALUES (1), (2), (3)") ctx, cancel := context.WithCancel(context.Background()) - rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test") + rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM "+tbl) if err != nil { dbt.Fatalf("%s", err.Error()) } @@ -2695,7 +2755,7 @@ func TestContextCancelQueryRow(t *testing.T) { } func TestContextCancelPrepare(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { + runTestsParallel(t, dsn, func(dbt *DBTest, _ string) { ctx, cancel := context.WithCancel(context.Background()) cancel() if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled { @@ -2705,10 +2765,10 @@ func TestContextCancelPrepare(t *testing.T) { } func TestContextCancelStmtExec(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) - stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))") if err != nil { dbt.Fatalf("unexpected error: %v", err) } @@ -2730,7 +2790,7 @@ func TestContextCancelStmtExec(t *testing.T) { // Check how many times the query is executed. var v int - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { // TODO: need to kill the query, and v should be 0. @@ -2740,10 +2800,10 @@ func TestContextCancelStmtExec(t *testing.T) { } func TestContextCancelStmtQuery(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) - stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))") if err != nil { dbt.Fatalf("unexpected error: %v", err) } @@ -2765,7 +2825,7 @@ func TestContextCancelStmtQuery(t *testing.T) { // Check how many times the query is executed. var v int - if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM " + tbl).Scan(&v); err != nil { dbt.Fatalf("%s", err.Error()) } if v != 1 { // TODO: need to kill the query, and v should be 0. @@ -2779,8 +2839,8 @@ func TestContextCancelBegin(t *testing.T) { t.Skip(`FIXME: it sometime fails with "expected driver.ErrBadConn, got sql: connection is already closed" on windows and macOS`) } - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) conn, err := dbt.db.Conn(ctx) if err != nil { @@ -2797,7 +2857,7 @@ func TestContextCancelBegin(t *testing.T) { // This query will be canceled. startTime := time.Now() - if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + if _, err := tx.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (SLEEP(1))"); err != context.Canceled { dbt.Errorf("expected context.Canceled, got %v", err) } if d := time.Since(startTime); d > 500*time.Millisecond { @@ -2835,8 +2895,8 @@ func TestContextCancelBegin(t *testing.T) { } func TestContextBeginIsolationLevel(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2854,13 +2914,13 @@ func TestContextBeginIsolationLevel(t *testing.T) { dbt.Fatal(err) } - _, err = tx1.ExecContext(ctx, "INSERT INTO test VALUES (1)") + _, err = tx1.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (1)") if err != nil { dbt.Fatal(err) } var v int - row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tbl) if err := row.Scan(&v); err != nil { dbt.Fatal(err) } @@ -2874,7 +2934,7 @@ func TestContextBeginIsolationLevel(t *testing.T) { dbt.Fatal(err) } - row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tbl) if err := row.Scan(&v); err != nil { dbt.Fatal(err) } @@ -2887,8 +2947,8 @@ func TestContextBeginIsolationLevel(t *testing.T) { } func TestContextBeginReadOnly(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (v INTEGER)") ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2903,14 +2963,14 @@ func TestContextBeginReadOnly(t *testing.T) { } // INSERT queries fail in a READ ONLY transaction. - _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)") + _, err = tx.ExecContext(ctx, "INSERT INTO "+tbl+" VALUES (1)") if _, ok := err.(*MySQLError); !ok { dbt.Errorf("expected MySQLError, got %v", err) } // SELECT queries can be executed. var v int - row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tbl) if err := row.Scan(&v); err != nil { dbt.Fatal(err) } @@ -3147,9 +3207,9 @@ func TestRowsColumnTypes(t *testing.T) { } func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (value VARCHAR(255))") - dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil)) + runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { + dbt.mustExec("CREATE TABLE " + tbl + " (value VARCHAR(255))") + dbt.db.Exec("INSERT INTO "+tbl+" VALUES (?)", (*testValuer)(nil)) // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() }) } From c48c0e7da17e8fc06133e431ce7c10e7a3e94f06 Mon Sep 17 00:00:00 2001 From: shi yuhang <52435083+shiyuhang0@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:47:16 +0800 Subject: [PATCH 063/123] Fix unsigned int overflow (#1530) --- driver_test.go | 15 +++++++++------ packets.go | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/driver_test.go b/driver_test.go index 6bdb78c78..5934caab6 100644 --- a/driver_test.go +++ b/driver_test.go @@ -388,16 +388,16 @@ func TestCRUD(t *testing.T) { func TestNumbersToAny(t *testing.T) { runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) { dbt.mustExec("CREATE TABLE " + tbl + " (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " + - "i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE)") - dbt.mustExec("INSERT INTO " + tbl + " VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5)") + "i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE, iu32 INT UNSIGNED)") + dbt.mustExec("INSERT INTO " + tbl + " VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5, 4294967295)") - // Use binaryRows for intarpolateParams=false and textRows for intarpolateParams=true. - rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64 FROM "+tbl+" WHERE id=?", 1) + // Use binaryRows for interpolateParams=false and textRows for interpolateParams=true. + rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64, iu32 FROM "+tbl+" WHERE id=?", 1) if !rows.Next() { dbt.Fatal("no data") } - var b, i8, i16, i32, i64, f32, f64 any - err := rows.Scan(&b, &i8, &i16, &i32, &i64, &f32, &f64) + var b, i8, i16, i32, i64, f32, f64, iu32 any + err := rows.Scan(&b, &i8, &i16, &i32, &i64, &f32, &f64, &iu32) if err != nil { dbt.Fatal(err) } @@ -422,6 +422,9 @@ func TestNumbersToAny(t *testing.T) { if f64.(float64) != 2.5 { dbt.Errorf("f64 != 2.5") } + if iu32.(int64) != 4294967295 { + dbt.Errorf("iu32 != 4294967295") + } }) } diff --git a/packets.go b/packets.go index 49e6bb058..94b46b10f 100644 --- a/packets.go +++ b/packets.go @@ -828,7 +828,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { } case fieldTypeTiny, fieldTypeShort, fieldTypeInt24, fieldTypeYear, fieldTypeLong: - dest[i], err = strconv.ParseInt(string(buf), 10, 32) + dest[i], err = strconv.ParseInt(string(buf), 10, 64) case fieldTypeLongLong: if rows.rs.columns[i].flags&flagUnsigned != 0 { From 743e263bab87912dfb61789f36c21d9685887c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paulius=20Lo=C5=BEys?= <42966213+PauliusLozys@users.noreply.github.com> Date: Wed, 31 Jan 2024 11:34:24 +0200 Subject: [PATCH 064/123] Introduce `timeTruncate` parameter for `time.Time` arguments (#1541) Co-authored-by: Inada Naoki --- AUTHORS | 1 + README.md | 9 ++++++ connection.go | 2 +- dsn.go | 12 +++++++ dsn_test.go | 3 ++ packets.go | 2 +- utils.go | 6 +++- utils_test.go | 89 ++++++++++++++++++++++++++++++++++++++------------- 8 files changed, 98 insertions(+), 26 deletions(-) diff --git a/AUTHORS b/AUTHORS index 0ada02d86..63ee516e5 100644 --- a/AUTHORS +++ b/AUTHORS @@ -86,6 +86,7 @@ Oliver Bone Olivier Mengué oscarzhao Paul Bonser +Paulius Lozys Peter Schultz Phil Porada Rebecca Chin diff --git a/README.md b/README.md index ac79890a7..018e1dd7c 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,15 @@ Note that this sets the location for time.Time values but does not change MySQL' Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`. +##### `timeTruncate` + +``` +Type: duration +Default: 0 +``` + +[Truncate time values](https://pkg.go.dev/time#Duration.Truncate) to the specified duration. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*. + ##### `maxAllowedPacket` ``` Type: decimal number diff --git a/connection.go b/connection.go index 660b2b0e0..99eb8a808 100644 --- a/connection.go +++ b/connection.go @@ -251,7 +251,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf = append(buf, "'0000-00-00'"...) } else { buf = append(buf, '\'') - buf, err = appendDateTime(buf, v.In(mc.cfg.Loc)) + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate) if err != nil { return "", err } diff --git a/dsn.go b/dsn.go index ef0608636..ce5d85ff0 100644 --- a/dsn.go +++ b/dsn.go @@ -48,6 +48,7 @@ type Config struct { pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig + TimeTruncate time.Duration // Truncate time.Time values to the specified duration Timeout time.Duration // Dial timeout ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout @@ -262,6 +263,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "parseTime", "true") } + if cfg.TimeTruncate > 0 { + writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.TimeTruncate.String()) + } + if cfg.ReadTimeout > 0 { writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String()) } @@ -502,6 +507,13 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // time.Time truncation + case "timeTruncate": + cfg.TimeTruncate, err = time.ParseDuration(value) + if err != nil { + return + } + // I/O read Timeout case "readTimeout": cfg.ReadTimeout, err = time.ParseDuration(value) diff --git a/dsn_test.go b/dsn_test.go index 8a6a0c10e..75cbda700 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -74,6 +74,9 @@ var testDSNs = []struct { }, { "tcp(de:ad:be:ef::ca:fe)/dbname", &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "user:password@/dbname?loc=UTC&timeout=30s&parseTime=true&timeTruncate=1h", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TimeTruncate: time.Hour}, }, } diff --git a/packets.go b/packets.go index 94b46b10f..e5a6e4727 100644 --- a/packets.go +++ b/packets.go @@ -1172,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { b = append(b, "0000-00-00"...) } else { - b, err = appendDateTime(b, v.In(mc.cfg.Loc)) + b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate) if err != nil { return err } diff --git a/utils.go b/utils.go index a24197b93..cda24fe74 100644 --- a/utils.go +++ b/utils.go @@ -265,7 +265,11 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va return nil, fmt.Errorf("invalid DATETIME packet length %d", num) } -func appendDateTime(buf []byte, t time.Time) ([]byte, error) { +func appendDateTime(buf []byte, t time.Time, timeTruncate time.Duration) ([]byte, error) { + if timeTruncate > 0 { + t = t.Truncate(timeTruncate) + } + year, month, day := t.Date() hour, min, sec := t.Clock() nsec := t.Nanosecond() diff --git a/utils_test.go b/utils_test.go index 4e5fc3cb7..80aebddff 100644 --- a/utils_test.go +++ b/utils_test.go @@ -237,8 +237,10 @@ func TestIsolationLevelMapping(t *testing.T) { func TestAppendDateTime(t *testing.T) { tests := []struct { - t time.Time - str string + t time.Time + str string + timeTruncate time.Duration + expectedErr bool }{ { t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC), @@ -276,34 +278,75 @@ func TestAppendDateTime(t *testing.T) { t: time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), str: "0001-01-01", }, + // Truncated time + { + t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC), + str: "1234-05-06", + timeTruncate: time.Second, + }, + { + t: time.Date(4567, 12, 31, 12, 0, 0, 0, time.UTC), + str: "4567-12-31 12:00:00", + timeTruncate: time.Minute, + }, + { + t: time.Date(2020, 5, 30, 12, 34, 0, 0, time.UTC), + str: "2020-05-30 12:34:00", + timeTruncate: 0, + }, + { + t: time.Date(2020, 5, 30, 12, 34, 56, 0, time.UTC), + str: "2020-05-30 12:34:56", + timeTruncate: time.Second, + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123000000, time.UTC), + str: "2020-05-30 22:33:44", + timeTruncate: time.Second, + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456000, time.UTC), + str: "2020-05-30 22:33:44.123", + timeTruncate: time.Millisecond, + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456789, time.UTC), + str: "2020-05-30 22:33:44", + timeTruncate: time.Second, + }, + { + t: time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC), + str: "9999-12-31 23:59:59.999999999", + timeTruncate: 0, + }, + { + t: time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC), + str: "0001-01-01", + timeTruncate: 365 * 24 * time.Hour, + }, + // year out of range + { + t: time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC), + expectedErr: true, + }, + { + t: time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC), + expectedErr: true, + }, } for _, v := range tests { buf := make([]byte, 0, 32) - buf, _ = appendDateTime(buf, v.t) + buf, err := appendDateTime(buf, v.t, v.timeTruncate) + if err != nil { + if !v.expectedErr { + t.Errorf("appendDateTime(%v) returned an errror: %v", v.t, err) + } + continue + } if str := string(buf); str != v.str { t.Errorf("appendDateTime(%v), have: %s, want: %s", v.t, str, v.str) } } - - // year out of range - { - v := time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) - buf := make([]byte, 0, 32) - _, err := appendDateTime(buf, v) - if err == nil { - t.Error("want an error") - return - } - } - { - v := time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC) - buf := make([]byte, 0, 32) - _, err := appendDateTime(buf, v) - if err == nil { - t.Error("want an error") - return - } - } } func TestParseDateTime(t *testing.T) { From f019727e4706bf9c4f60579382f6e72b94bd0305 Mon Sep 17 00:00:00 2001 From: crazycs Date: Mon, 5 Feb 2024 16:57:21 +0800 Subject: [PATCH 065/123] add TiDB support in README.md (#1333) Signed-off-by: crazycs520 Co-authored-by: Inada Naoki --- README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 018e1dd7c..9d0d806ef 100644 --- a/README.md +++ b/README.md @@ -40,8 +40,16 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Optional placeholder interpolation ## Requirements - * Go 1.18 or higher. We aim to support the 3 latest versions of Go. - * MySQL (5.6+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) + +* Go 1.19 or higher. We aim to support the 3 latest versions of Go. +* MySQL (5.7+) and MariaDB (10.3+) are supported. +* [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP. + * Do not ask questions about TiDB in our issue tracker or forum. + * [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang) + * [Forum](https://ask.pingcap.com/) +* go-mysql would work with Percona Server, Google CloudSQL or Sphinx (2.2.3+). + * Maintainers won't support them. Do not expect issues are investigated and resolved by maintainers. + * Investigate issues yourself and please send a pull request to fix it. --------------------------------------- From 097fe6e3ad83bbd7c84debe810aec4c4a533bcaa Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 5 Feb 2024 20:29:00 +0900 Subject: [PATCH 066/123] Update workflows (#1547) --- .github/workflows/codeql.yml | 8 ++++---- .github/workflows/test.yml | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index d9d29a8b7..83a3d6ee8 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -24,18 +24,18 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} queries: +security-and-quality - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 with: category: "/language:${{ matrix.language }}" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aae421196..f5a115802 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,10 +14,10 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: dominikh/staticcheck-action@v1.3.0 with: - version: "2023.1.3" + version: "2023.1.6" list: runs-on: ubuntu-latest @@ -73,11 +73,11 @@ jobs: fail-fast: false matrix: ${{ fromJSON(needs.list.outputs.matrix) }} steps: - - uses: actions/checkout@v3 - - uses: actions/setup-go@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - - uses: shogo82148/actions-setup-mysql@v1.21.0 + - uses: shogo82148/actions-setup-mysql@v1 with: mysql-version: ${{ matrix.mysql }} user: ${{ env.MYSQL_TEST_USER }} From 6964272ffd13a41ad66383cd2ea738fded75ad06 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 7 Mar 2024 00:32:18 +0900 Subject: [PATCH 067/123] Make TimeTruncate functional option (#1552) --- connection.go | 2 +- dsn.go | 47 ++++++++++++++++++++++++++++++++++++++++------- dsn_test.go | 2 +- packets.go | 2 +- result.go | 5 ++--- 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/connection.go b/connection.go index 99eb8a808..c170114fe 100644 --- a/connection.go +++ b/connection.go @@ -251,7 +251,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf = append(buf, "'0000-00-00'"...) } else { buf = append(buf, '\'') - buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate) + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate) if err != nil { return "", err } diff --git a/dsn.go b/dsn.go index ce5d85ff0..d0fbf3bd9 100644 --- a/dsn.go +++ b/dsn.go @@ -34,6 +34,8 @@ var ( // If a new Config is created instead of being parsed from a DSN string, // the NewConfig function should be used, which sets default values. type Config struct { + // non boolean fields + User string // Username Passwd string // Password (requires User) Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp") @@ -45,15 +47,15 @@ type Config struct { Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed ServerPubKey string // Server public key name - pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig - TimeTruncate time.Duration // Truncate time.Time values to the specified duration Timeout time.Duration // Dial timeout ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout Logger Logger // Logger + // boolean fields + AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE AllowCleartextPasswords bool // Allows the cleartext client side plugin AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS @@ -66,17 +68,48 @@ type Config struct { MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections + + // unexported fields. new options should be come here + + pubKey *rsa.PublicKey // Server public key + timeTruncate time.Duration // Truncate time.Time values to the specified duration } +// Functional Options Pattern +// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis +type Option func(*Config) error + // NewConfig creates a new Config and sets default values. func NewConfig() *Config { - return &Config{ + cfg := &Config{ Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, } + + return cfg +} + +// Apply applies the given options to the Config object. +func (c *Config) Apply(opts ...Option) error { + for _, opt := range opts { + err := opt(c) + if err != nil { + return err + } + } + return nil +} + +// TimeTruncate sets the time duration to truncate time.Time values in +// query parameters. +func TimeTruncate(d time.Duration) Option { + return func(cfg *Config) error { + cfg.timeTruncate = d + return nil + } } func (cfg *Config) Clone() *Config { @@ -263,8 +296,8 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "parseTime", "true") } - if cfg.TimeTruncate > 0 { - writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.TimeTruncate.String()) + if cfg.timeTruncate > 0 { + writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String()) } if cfg.ReadTimeout > 0 { @@ -509,9 +542,9 @@ func parseDSNParams(cfg *Config, params string) (err error) { // time.Time truncation case "timeTruncate": - cfg.TimeTruncate, err = time.ParseDuration(value) + cfg.timeTruncate, err = time.ParseDuration(value) if err != nil { - return + return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err) } // I/O read Timeout diff --git a/dsn_test.go b/dsn_test.go index 75cbda700..dd8cd935c 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -76,7 +76,7 @@ var testDSNs = []struct { &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:password@/dbname?loc=UTC&timeout=30s&parseTime=true&timeTruncate=1h", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TimeTruncate: time.Hour}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, timeTruncate: time.Hour}, }, } diff --git a/packets.go b/packets.go index e5a6e4727..3d6e5308c 100644 --- a/packets.go +++ b/packets.go @@ -1172,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if v.IsZero() { b = append(b, "0000-00-00"...) } else { - b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate) + b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate) if err != nil { return err } diff --git a/result.go b/result.go index 36a432e81..d51631468 100644 --- a/result.go +++ b/result.go @@ -15,9 +15,8 @@ import "database/sql/driver" // This is accessible by executing statements using sql.Conn.Raw() and // downcasting the returned result: // -// res, err := rawConn.Exec(...) -// res.(mysql.Result).AllRowsAffected() -// +// res, err := rawConn.Exec(...) +// res.(mysql.Result).AllRowsAffected() type Result interface { driver.Result // AllRowsAffected returns a slice containing the affected rows for each From 33b7747a9144946e50399904d3f27ecc0f96c2b6 Mon Sep 17 00:00:00 2001 From: "Alessandro (Ale) Segala" <43508+ItalyPaleAle@users.noreply.github.com> Date: Sat, 9 Mar 2024 07:57:08 +0100 Subject: [PATCH 068/123] Add BeforeConnect callback to configuration object (#1469) This can be used to alter the connection options for each connection, right before it's established Co-authored-by: Inada Naoki --- AUTHORS | 1 + connector.go | 12 +++++++++++- driver_test.go | 34 ++++++++++++++++++++++++++++++++++ dsn.go | 14 ++++++++++++-- 4 files changed, 58 insertions(+), 3 deletions(-) diff --git a/AUTHORS b/AUTHORS index 63ee516e5..4021b96cc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -132,6 +132,7 @@ GitHub Inc. Google Inc. InfoSum Ltd. Keybase Inc. +Microsoft Corp. Multiplay Ltd. Percona LLC PingCAP Inc. diff --git a/connector.go b/connector.go index 3cef7963f..a0ee62839 100644 --- a/connector.go +++ b/connector.go @@ -66,12 +66,22 @@ func newConnector(cfg *Config) *connector { func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { var err error + // Invoke beforeConnect if present, with a copy of the configuration + cfg := c.cfg + if c.cfg.beforeConnect != nil { + cfg = c.cfg.Clone() + err = c.cfg.beforeConnect(ctx, cfg) + if err != nil { + return nil, err + } + } + // New mysqlConn mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, closech: make(chan struct{}), - cfg: c.cfg, + cfg: cfg, connector: c, } mc.parseTime = mc.cfg.ParseTime diff --git a/driver_test.go b/driver_test.go index 5934caab6..001957244 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2044,6 +2044,40 @@ func TestCustomDial(t *testing.T) { } } +func TestBeforeConnect(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // dbname is set in the BeforeConnect handle + cfg, err := ParseDSN(fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, "_")) + if err != nil { + t.Fatalf("error parsing DSN: %v", err) + } + + cfg.Apply(BeforeConnect(func(ctx context.Context, c *Config) error { + c.DBName = dbname + return nil + })) + + connector, err := NewConnector(cfg) + if err != nil { + t.Fatalf("error creating connector: %v", err) + } + + db := sql.OpenDB(connector) + defer db.Close() + + var connectedDb string + err = db.QueryRow("SELECT DATABASE();").Scan(&connectedDb) + if err != nil { + t.Fatalf("error executing query: %v", err) + } + if connectedDb != dbname { + t.Fatalf("expected to connect to DB %s, but connected to %s instead", dbname, connectedDb) + } +} + func TestSQLInjection(t *testing.T) { createTest := func(arg string) func(dbt *DBTest) { return func(dbt *DBTest) { diff --git a/dsn.go b/dsn.go index d0fbf3bd9..65f5a0242 100644 --- a/dsn.go +++ b/dsn.go @@ -10,6 +10,7 @@ package mysql import ( "bytes" + "context" "crypto/rsa" "crypto/tls" "errors" @@ -71,8 +72,9 @@ type Config struct { // unexported fields. new options should be come here - pubKey *rsa.PublicKey // Server public key - timeTruncate time.Duration // Truncate time.Time values to the specified duration + beforeConnect func(context.Context, *Config) error // Invoked before a connection is established + pubKey *rsa.PublicKey // Server public key + timeTruncate time.Duration // Truncate time.Time values to the specified duration } // Functional Options Pattern @@ -112,6 +114,14 @@ func TimeTruncate(d time.Duration) Option { } } +// BeforeConnect sets the function to be invoked before a connection is established. +func BeforeConnect(fn func(context.Context, *Config) error) Option { + return func(cfg *Config) error { + cfg.beforeConnect = fn + return nil + } +} + func (cfg *Config) Clone() *Config { cp := *cfg if cp.TLS != nil { From 3147497dd6a98708e5ee4da04f2a686b4d7979a7 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 11 Mar 2024 13:44:06 +0900 Subject: [PATCH 069/123] ci: update Go and MySQL versions (#1557) --- .github/workflows/test.yml | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f5a115802..c5b2aa313 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,22 +31,20 @@ jobs: import os go = [ # Keep the most recent production release at the top - '1.21', + '1.22', # Older production releases + '1.21', '1.20', - '1.19', - '1.18', ] mysql = [ - '8.1', '8.0', + '8.3', '5.7', - '5.6', - 'mariadb-10.11', - 'mariadb-10.6', + 'mariadb-11.3', + 'mariadb-11.1', + 'mariadb-10.11', # LTS + 'mariadb-10.6', # LTS 'mariadb-10.5', - 'mariadb-10.4', - 'mariadb-10.3', ] includes = [] @@ -64,7 +62,7 @@ jobs: } output = json.dumps(matrix, separators=(',', ':')) with open(os.environ["GITHUB_OUTPUT"], 'a', encoding="utf-8") as f: - f.write('matrix={0}\n'.format(output)) + print(f"matrix={output}", file=f) shell: python test: needs: list From 8a327a3575a42f7222f6e51263326d5a0eaecab0 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Mon, 11 Mar 2024 23:54:40 +0900 Subject: [PATCH 070/123] Drop support of go1.19 (#1558) * drop support of Go 1.19 * replace atomicBool by atomic.Bool * Update Go and MariaDB versions in README.md --- README.md | 4 +-- atomic_bool.go | 19 ------------ atomic_bool_go118.go | 47 ----------------------------- atomic_bool_test.go | 71 -------------------------------------------- connection.go | 3 +- go.mod | 2 +- 6 files changed, 5 insertions(+), 141 deletions(-) delete mode 100644 atomic_bool.go delete mode 100644 atomic_bool_go118.go delete mode 100644 atomic_bool_test.go diff --git a/README.md b/README.md index 9d0d806ef..c3204ef11 100644 --- a/README.md +++ b/README.md @@ -41,8 +41,8 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac ## Requirements -* Go 1.19 or higher. We aim to support the 3 latest versions of Go. -* MySQL (5.7+) and MariaDB (10.3+) are supported. +* Go 1.20 or higher. We aim to support the 3 latest versions of Go. +* MySQL (5.7+) and MariaDB (10.5+) are supported. * [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP. * Do not ask questions about TiDB in our issue tracker or forum. * [Document](https://docs.pingcap.com/tidb/v6.1/dev-guide-sample-application-golang) diff --git a/atomic_bool.go b/atomic_bool.go deleted file mode 100644 index 1b7e19f3e..000000000 --- a/atomic_bool.go +++ /dev/null @@ -1,19 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. -// -// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. -//go:build go1.19 -// +build go1.19 - -package mysql - -import "sync/atomic" - -/****************************************************************************** -* Sync utils * -******************************************************************************/ - -type atomicBool = atomic.Bool diff --git a/atomic_bool_go118.go b/atomic_bool_go118.go deleted file mode 100644 index 2e9a7f0b6..000000000 --- a/atomic_bool_go118.go +++ /dev/null @@ -1,47 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. -// -// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. -//go:build !go1.19 -// +build !go1.19 - -package mysql - -import "sync/atomic" - -/****************************************************************************** -* Sync utils * -******************************************************************************/ - -// atomicBool is an implementation of atomic.Bool for older version of Go. -// it is a wrapper around uint32 for usage as a boolean value with -// atomic access. -type atomicBool struct { - _ noCopy - value uint32 -} - -// Load returns whether the current boolean value is true -func (ab *atomicBool) Load() bool { - return atomic.LoadUint32(&ab.value) > 0 -} - -// Store sets the value of the bool regardless of the previous value -func (ab *atomicBool) Store(value bool) { - if value { - atomic.StoreUint32(&ab.value, 1) - } else { - atomic.StoreUint32(&ab.value, 0) - } -} - -// Swap sets the value of the bool and returns the old value. -func (ab *atomicBool) Swap(value bool) bool { - if value { - return atomic.SwapUint32(&ab.value, 1) > 0 - } - return atomic.SwapUint32(&ab.value, 0) > 0 -} diff --git a/atomic_bool_test.go b/atomic_bool_test.go deleted file mode 100644 index a3b4ea0e8..000000000 --- a/atomic_bool_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. -// -// Copyright 2022 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. -//go:build !go1.19 -// +build !go1.19 - -package mysql - -import ( - "testing" -) - -func TestAtomicBool(t *testing.T) { - var ab atomicBool - if ab.Load() { - t.Fatal("Expected value to be false") - } - - ab.Store(true) - if ab.value != 1 { - t.Fatal("Set(true) did not set value to 1") - } - if !ab.Load() { - t.Fatal("Expected value to be true") - } - - ab.Store(true) - if !ab.Load() { - t.Fatal("Expected value to be true") - } - - ab.Store(false) - if ab.value != 0 { - t.Fatal("Set(false) did not set value to 0") - } - if ab.Load() { - t.Fatal("Expected value to be false") - } - - ab.Store(false) - if ab.Load() { - t.Fatal("Expected value to be false") - } - if ab.Swap(false) { - t.Fatal("Expected the old value to be false") - } - if ab.Swap(true) { - t.Fatal("Expected the old value to be false") - } - if !ab.Load() { - t.Fatal("Expected value to be true") - } - - ab.Store(true) - if !ab.Load() { - t.Fatal("Expected value to be true") - } - if !ab.Swap(true) { - t.Fatal("Expected the old value to be true") - } - if !ab.Swap(false) { - t.Fatal("Expected the old value to be true") - } - if ab.Load() { - t.Fatal("Expected value to be false") - } -} diff --git a/connection.go b/connection.go index c170114fe..55e42eb18 100644 --- a/connection.go +++ b/connection.go @@ -17,6 +17,7 @@ import ( "net" "strconv" "strings" + "sync/atomic" "time" ) @@ -41,7 +42,7 @@ type mysqlConn struct { closech chan struct{} finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled - closed atomicBool // set when conn is closed, before closech is closed + closed atomic.Bool // set when conn is closed, before closech is closed } // Handles parameters set in DSN after the connection is established diff --git a/go.mod b/go.mod index 4629714c0..2eed53ebb 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/go-sql-driver/mysql -go 1.18 +go 1.20 require filippo.io/edwards25519 v1.1.0 From 35847bed632a869c89234080ebee1e7b78d140e6 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 16 Mar 2024 23:23:22 +0900 Subject: [PATCH 071/123] replace interface{} with any (#1560) --- README.md | 2 +- driver_test.go | 116 +++++++++++++++++++++++----------------------- errors.go | 4 +- fields.go | 2 +- nulltime.go | 2 +- nulltime_test.go | 2 +- statement.go | 2 +- statement_test.go | 4 +- 8 files changed, 67 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index c3204ef11..6c6abf9c4 100644 --- a/README.md +++ b/README.md @@ -326,7 +326,7 @@ It's possible to access the last inserted ID and number of affected rows for mul ```go conn, _ := db.Conn(ctx) -conn.Raw(func(conn interface{}) error { +conn.Raw(func(conn any) error { ex := conn.(driver.Execer) res, err := ex.Exec(` UPDATE point SET x = 1 WHERE y = 2; diff --git a/driver_test.go b/driver_test.go index 001957244..6b52650c2 100644 --- a/driver_test.go +++ b/driver_test.go @@ -247,7 +247,7 @@ func (dbt *DBTest) fail(method, query string, err error) { dbt.Fatalf("error on %s %s: %s", method, query, err.Error()) } -func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { +func (dbt *DBTest) mustExec(query string, args ...any) (res sql.Result) { dbt.Helper() res, err := dbt.db.Exec(query, args...) if err != nil { @@ -256,7 +256,7 @@ func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) return res } -func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { +func (dbt *DBTest) mustQuery(query string, args ...any) (rows *sql.Rows) { dbt.Helper() rows, err := dbt.db.Query(query, args...) if err != nil { @@ -844,7 +844,7 @@ func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) { dbt.Errorf("%s [%s]: %s", dbtype, mode, err) return } - var dst interface{} + var dst any err = rows.Scan(&dst) if err != nil { dbt.Errorf("%s [%s]: %s", dbtype, mode, err) @@ -875,7 +875,7 @@ func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) { t.s, val.Format(tlayout), ) default: - fmt.Printf("%#v\n", []interface{}{dbtype, tlayout, mode, t.s, t.t}) + fmt.Printf("%#v\n", []any{dbtype, tlayout, mode, t.s, t.t}) dbt.Errorf("%s [%s]: unhandled type %T (is '%v')", dbtype, mode, val, val, @@ -1186,7 +1186,7 @@ func TestNULL(t *testing.T) { dbt.mustExec("INSERT INTO "+tbl+" VALUES (?, ?, ?)", 1, nil, 2) - var out interface{} + var out any rows := dbt.mustQuery("SELECT * FROM " + tbl) defer rows.Close() if rows.Next() { @@ -1894,7 +1894,7 @@ func TestPreparedManyCols(t *testing.T) { // create more parameters than fit into the buffer // which will take nil-values - params := make([]interface{}, numParams) + params := make([]any, numParams) rows, err := stmt.Query(params...) if err != nil { dbt.Fatal(err) @@ -1941,7 +1941,7 @@ func TestConcurrent(t *testing.T) { var fatalError string var once sync.Once - fatalf := func(s string, vals ...interface{}) { + fatalf := func(s string, vals ...any) { once.Do(func() { fatalError = fmt.Sprintf(s, vals...) }) @@ -2314,7 +2314,7 @@ func TestPing(t *testing.T) { } // Check that affectedRows and insertIds are cleared after each call. - conn.Raw(func(conn interface{}) error { + conn.Raw(func(conn any) error { c := conn.(*mysqlConn) // Issue a query that sets affectedRows and insertIds. @@ -2577,7 +2577,7 @@ func TestExecMultipleResults(t *testing.T) { if err != nil { t.Fatalf("failed to connect: %v", err) } - conn.Raw(func(conn interface{}) error { + conn.Raw(func(conn any) error { //lint:ignore SA1019 this is a test ex := conn.(driver.Execer) res, err := ex.Exec(` @@ -2635,7 +2635,7 @@ func TestQueryMultipleResults(t *testing.T) { if err != nil { t.Fatalf("failed to connect: %v", err) } - conn.Raw(func(conn interface{}) error { + conn.Raw(func(conn any) error { //lint:ignore SA1019 this is a test qr := conn.(driver.Queryer) c := conn.(*mysqlConn) @@ -3058,54 +3058,54 @@ func TestRowsColumnTypes(t *testing.T) { precision int64 // 0 if not ok scale int64 valuesIn [3]string - valuesOut [3]interface{} + valuesOut [3]any }{ - {"bit8null", "BIT(8)", "BIT", scanTypeBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{bx0, bNULL, bx42}}, - {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, - {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, - {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, - {"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}}, - {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, - {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, - {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}}, - {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}}, - {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, - {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, - {"tinyuint", "TINYINT UNSIGNED NOT NULL", "UNSIGNED TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, - {"smalluint", "SMALLINT UNSIGNED NOT NULL", "UNSIGNED SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, - {"biguint", "BIGINT UNSIGNED NOT NULL", "UNSIGNED BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, - {"mediumuint", "MEDIUMINT UNSIGNED NOT NULL", "UNSIGNED MEDIUMINT", scanTypeUint32, false, 0, 0, [3]string{"0", "16777215", "42"}, [3]interface{}{uint32(0), uint32(16777215), uint32(42)}}, - {"uint13", "INT(13) UNSIGNED NOT NULL", "UNSIGNED INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, - {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, - {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, - {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, - {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}}, - {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, - {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeString, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{"0.000000", "13.370000", "1234.123456"}}, - {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeNullString, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{ns("0.000000"), nsNULL, ns("1234.123456")}}, - {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeString, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{"0.0000", "13.3700", "1234.1235"}}, - {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeNullString, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{ns("0.0000"), nsNULL, ns("1234.1235")}}, - {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeString, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{"0", "13", "-12345"}}, - {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeNullString, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{ns0, nsNULL, ns("-12345")}}, - {"char25null", "CHAR(25)", "CHAR", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}}, - {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}}, - {"binary4null", "BINARY(4)", "BINARY", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0pad4, bNULL, bTest}}, - {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}}, - {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0, bNULL, bTest}}, - {"tinytextnull", "TINYTEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}}, - {"blobnull", "BLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{b0, bNULL, bTest}}, - {"textnull", "TEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{ns0, nsNULL, nsTest}}, - {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}}, - {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}}, - {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{b0, bTest, b42}}, - {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{"0", "Test", "42"}}, - {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, - {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, - {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, - {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}}, - {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}}, - {"enum", "ENUM('', 'v1', 'v2')", "ENUM", scanTypeNullString, true, 0, 0, [3]string{"''", "'v1'", "'v2'"}, [3]interface{}{ns(""), ns("v1"), ns("v2")}}, - {"set", "set('', 'v1', 'v2')", "SET", scanTypeNullString, true, 0, 0, [3]string{"''", "'v1'", "'v1,v2'"}, [3]interface{}{ns(""), ns("v1"), ns("v1,v2")}}, + {"bit8null", "BIT(8)", "BIT", scanTypeBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]any{bx0, bNULL, bx42}}, + {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]any{niNULL, ni1, ni0}}, + {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]any{int8(1), int8(0), int8(0)}}, + {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]any{ni0, niNULL, ni42}}, + {"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]any{int16(0), int16(-32768), int16(32767)}}, + {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]any{ni0, niNULL, ni42}}, + {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]any{ni0, niNULL, ni42}}, + {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]any{int32(0), int32(-1337), int32(42)}}, + {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]any{ni0, ni42, niNULL}}, + {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]any{int64(0), int64(65535), int64(-42)}}, + {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]any{niNULL, ni1, ni42}}, + {"tinyuint", "TINYINT UNSIGNED NOT NULL", "UNSIGNED TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]any{uint8(0), uint8(255), uint8(42)}}, + {"smalluint", "SMALLINT UNSIGNED NOT NULL", "UNSIGNED SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]any{uint16(0), uint16(65535), uint16(42)}}, + {"biguint", "BIGINT UNSIGNED NOT NULL", "UNSIGNED BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]any{uint64(0), uint64(65535), uint64(42)}}, + {"mediumuint", "MEDIUMINT UNSIGNED NOT NULL", "UNSIGNED MEDIUMINT", scanTypeUint32, false, 0, 0, [3]string{"0", "16777215", "42"}, [3]any{uint32(0), uint32(16777215), uint32(42)}}, + {"uint13", "INT(13) UNSIGNED NOT NULL", "UNSIGNED INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]any{uint32(0), uint32(1337), uint32(42)}}, + {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]any{float32(0), float32(42), float32(13.37)}}, + {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]any{nf0, nfNULL, nf1337}}, + {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]any{nf0, nfNULL, nf1337}}, + {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]any{float64(0), float64(42), float64(13.37)}}, + {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]any{nf0, nfNULL, nf1337}}, + {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeString, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]any{"0.000000", "13.370000", "1234.123456"}}, + {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeNullString, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]any{ns("0.000000"), nsNULL, ns("1234.123456")}}, + {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeString, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]any{"0.0000", "13.3700", "1234.1235"}}, + {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeNullString, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]any{ns("0.0000"), nsNULL, ns("1234.1235")}}, + {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeString, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]any{"0", "13", "-12345"}}, + {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeNullString, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]any{ns0, nsNULL, ns("-12345")}}, + {"char25null", "CHAR(25)", "CHAR", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]any{ns0, nsNULL, nsTest}}, + {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]any{"0", "Test", "42"}}, + {"binary4null", "BINARY(4)", "BINARY", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]any{b0pad4, bNULL, bTest}}, + {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]any{b0, bTest, b42}}, + {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]any{b0, bNULL, bTest}}, + {"tinytextnull", "TINYTEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]any{ns0, nsNULL, nsTest}}, + {"blobnull", "BLOB", "BLOB", scanTypeBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]any{b0, bNULL, bTest}}, + {"textnull", "TEXT", "TEXT", scanTypeNullString, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]any{ns0, nsNULL, nsTest}}, + {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]any{b0, bTest, b42}}, + {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]any{"0", "Test", "42"}}, + {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]any{b0, bTest, b42}}, + {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeString, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]any{"0", "Test", "42"}}, + {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]any{nt0, nt0, nt0}}, + {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]any{nt0, nt1, nt2}}, + {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]any{nt0, nt1, nt6}}, + {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]any{nd1, ndNULL, nd2}}, + {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]any{uint16(2006), uint16(2000), uint16(1994)}}, + {"enum", "ENUM('', 'v1', 'v2')", "ENUM", scanTypeNullString, true, 0, 0, [3]string{"''", "'v1'", "'v2'"}, [3]any{ns(""), ns("v1"), ns("v2")}}, + {"set", "set('', 'v1', 'v2')", "SET", scanTypeNullString, true, 0, 0, [3]string{"''", "'v1'", "'v1,v2'"}, [3]any{ns(""), ns("v1"), ns("v1,v2")}}, } schema := "" @@ -3215,7 +3215,7 @@ func TestRowsColumnTypes(t *testing.T) { if t.Failed() { return } - values := make([]interface{}, len(tt)) + values := make([]any, len(tt)) for i := range values { values[i] = reflect.New(types[i]).Interface() } diff --git a/errors.go b/errors.go index a9a3060c9..a7ef88909 100644 --- a/errors.go +++ b/errors.go @@ -41,14 +41,14 @@ var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|lo // Logger is used to log critical error messages. type Logger interface { - Print(v ...interface{}) + Print(v ...any) } // NopLogger is a nop implementation of the Logger interface. type NopLogger struct{} // Print implements Logger interface. -func (nl *NopLogger) Print(_ ...interface{}) {} +func (nl *NopLogger) Print(_ ...any) {} // SetLogger is used to set the default logger for critical errors. // The initial logger is os.Stderr. diff --git a/fields.go b/fields.go index 2a397b245..286084247 100644 --- a/fields.go +++ b/fields.go @@ -134,7 +134,7 @@ var ( scanTypeString = reflect.TypeOf("") scanTypeNullString = reflect.TypeOf(sql.NullString{}) scanTypeBytes = reflect.TypeOf([]byte{}) - scanTypeUnknown = reflect.TypeOf(new(interface{})) + scanTypeUnknown = reflect.TypeOf(new(any)) ) type mysqlField struct { diff --git a/nulltime.go b/nulltime.go index 7d381d5c2..316a48aae 100644 --- a/nulltime.go +++ b/nulltime.go @@ -38,7 +38,7 @@ type NullTime sql.NullTime // Scan implements the Scanner interface. // The value type must be time.Time or string / []byte (formatted time-string), // otherwise Scan fails. -func (nt *NullTime) Scan(value interface{}) (err error) { +func (nt *NullTime) Scan(value any) (err error) { if value == nil { nt.Time, nt.Valid = time.Time{}, false return diff --git a/nulltime_test.go b/nulltime_test.go index a14ec0607..4f1d9029e 100644 --- a/nulltime_test.go +++ b/nulltime_test.go @@ -23,7 +23,7 @@ var ( func TestScanNullTime(t *testing.T) { var scanTests = []struct { - in interface{} + in any error bool valid bool time time.Time diff --git a/statement.go b/statement.go index 31e7799c4..d8b921b8e 100644 --- a/statement.go +++ b/statement.go @@ -141,7 +141,7 @@ type converter struct{} // implementation does not. This function should be kept in sync with // database/sql/driver defaultConverter.ConvertValue() except for that // deliberate difference. -func (c converter) ConvertValue(v interface{}) (driver.Value, error) { +func (c converter) ConvertValue(v any) (driver.Value, error) { if driver.IsValue(v) { return v, nil } diff --git a/statement_test.go b/statement_test.go index 2563ece55..15f9d7c33 100644 --- a/statement_test.go +++ b/statement_test.go @@ -77,7 +77,7 @@ func TestConvertPointer(t *testing.T) { } func TestConvertSignedIntegers(t *testing.T) { - values := []interface{}{ + values := []any{ int8(-42), int16(-42), int32(-42), @@ -106,7 +106,7 @@ func (u myUint64) Value() (driver.Value, error) { } func TestConvertUnsignedIntegers(t *testing.T) { - values := []interface{}{ + values := []any{ uint8(42), uint16(42), uint32(42), From 1a6477358cbbc917d5370c53d3e35a13b45aed19 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 16 Mar 2024 23:24:21 +0900 Subject: [PATCH 072/123] add wrapper method to call mc.cfg.Logger (#1563) --- auth.go | 2 +- connection.go | 23 ++++++++++++++--------- packets.go | 24 ++++++++++++------------ statement.go | 4 ++-- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/auth.go b/auth.go index 658259b24..74e1bd03e 100644 --- a/auth.go +++ b/auth.go @@ -338,7 +338,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { return authEd25519(authData, mc.cfg.Passwd) default: - mc.cfg.Logger.Print("unknown auth plugin:", plugin) + mc.log("unknown auth plugin:", plugin) return nil, ErrUnknownPlugin } } diff --git a/connection.go b/connection.go index 55e42eb18..5061b69ca 100644 --- a/connection.go +++ b/connection.go @@ -45,6 +45,11 @@ type mysqlConn struct { closed atomic.Bool // set when conn is closed, before closech is closed } +// Helper function to call per-connection logger. +func (mc *mysqlConn) log(v ...any) { + mc.cfg.Logger.Print(v...) +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder @@ -110,7 +115,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } var q string @@ -152,7 +157,7 @@ func (mc *mysqlConn) cleanup() { return } if err := mc.netConn.Close(); err != nil { - mc.cfg.Logger.Print(err) + mc.log(err) } mc.clearResult() } @@ -169,14 +174,14 @@ func (mc *mysqlConn) error() error { func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. - mc.cfg.Logger.Print(err) + mc.log(err) return nil, driver.ErrBadConn } @@ -210,7 +215,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf, err := mc.buf.takeCompleteBuffer() if err != nil { // can not take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return "", ErrInvalidConn } buf = buf[:0] @@ -302,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -362,7 +367,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) handleOk := mc.clearResult() if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -457,7 +462,7 @@ func (mc *mysqlConn) finish() { // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { if mc.closed.Load() { - mc.cfg.Logger.Print(ErrInvalidConn) + mc.log(ErrInvalidConn) return driver.ErrBadConn } @@ -666,7 +671,7 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { err = connCheck(conn) } if err != nil { - mc.cfg.Logger.Print("closing bad idle connection: ", err) + mc.log("closing bad idle connection: ", err) return driver.ErrBadConn } } diff --git a/packets.go b/packets.go index 3d6e5308c..d727f00fe 100644 --- a/packets.go +++ b/packets.go @@ -34,7 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - mc.cfg.Logger.Print(err) + mc.log(err) mc.Close() return nil, ErrInvalidConn } @@ -57,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if pktLen == 0 { // there was no previous packet if prevData == nil { - mc.cfg.Logger.Print(ErrMalformPkt) + mc.log(ErrMalformPkt) mc.Close() return nil, ErrInvalidConn } @@ -71,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } - mc.cfg.Logger.Print(err) + mc.log(err) mc.Close() return nil, ErrInvalidConn } @@ -134,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) mc.cleanup() - mc.cfg.Logger.Print(ErrMalformPkt) + mc.log(ErrMalformPkt) } else { if cerr := mc.canceled.Value(); cerr != nil { return cerr @@ -144,7 +144,7 @@ func (mc *mysqlConn) writePacket(data []byte) error { return errBadConnNoWrite } mc.cleanup() - mc.cfg.Logger.Print(err) + mc.log(err) } return ErrInvalidConn } @@ -302,7 +302,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -392,7 +392,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { data, err := mc.buf.takeSmallBuffer(pktLen) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -412,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -431,7 +431,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -452,7 +452,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -994,7 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } if err != nil { // cannot take the buffer. Something must be wrong with the connection - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } @@ -1193,7 +1193,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) if err = mc.buf.store(data); err != nil { - mc.cfg.Logger.Print(err) + mc.log(err) return errBadConnNoWrite } } diff --git a/statement.go b/statement.go index d8b921b8e..0436f2240 100644 --- a/statement.go +++ b/statement.go @@ -51,7 +51,7 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.mc.closed.Load() { - stmt.mc.cfg.Logger.Print(ErrInvalidConn) + stmt.mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command @@ -95,7 +95,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if stmt.mc.closed.Load() { - stmt.mc.cfg.Logger.Print(ErrInvalidConn) + stmt.mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command From d86c4527bae98ccd4e5060f72887520ce30eda5e Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 17 Mar 2024 13:30:21 +0900 Subject: [PATCH 073/123] fix race condition when context is canceled (#1562) Fix #1559. --- connection.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/connection.go b/connection.go index 5061b69ca..f3656f0e6 100644 --- a/connection.go +++ b/connection.go @@ -138,7 +138,7 @@ func (mc *mysqlConn) Close() (err error) { } mc.cleanup() - + mc.clearResult() return } @@ -153,13 +153,16 @@ func (mc *mysqlConn) cleanup() { // Makes cleanup idempotent close(mc.closech) - if mc.netConn == nil { + nc := mc.netConn + if nc == nil { return } - if err := mc.netConn.Close(); err != nil { + if err := nc.Close(); err != nil { mc.log(err) } - mc.clearResult() + // This function can be called from multiple goroutines. + // So we can not mc.clearResult() here. + // Caller should do it if they are in safe goroutine. } func (mc *mysqlConn) error() error { From d7ddb8b9e324830b1ede89c5fea090c824497c51 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Sat, 23 Mar 2024 00:57:24 +0900 Subject: [PATCH 074/123] Fix issue 1567 (#1570) ### Description closes https://github.com/go-sql-driver/mysql/issues/1567 When TLS is enabled, `mc.netConn` is rewritten after the TLS handshak as detailed here: https://github.com/go-sql-driver/mysql/blob/d86c4527bae98ccd4e5060f72887520ce30eda5e/packets.go#L355 Therefore, `mc.netConn` should not be accessed within the watcher goroutine. Instead, `mc.rawConn` should be initialized prior to invoking `mc.startWatcher`, and `mc.rawConn` should be used in lieu of `mc.netConn`. ### Checklist - [x] Code compiles correctly - [x] Created tests which fail without the change (if possible) - [x] All tests passing - [x] Extended the README / documentation, if necessary - [x] Added myself / the copyright holder to the AUTHORS file ## Summary by CodeRabbit - **Refactor** - Improved variable naming for better code readability and maintenance. - Enhanced network connection handling logic. - **New Features** - Updated TCP connection handling to better support TCP Keepalives. - **Tests** - Added a new test to address and verify the fix for a specific issue related to TLS, connection pooling, and round trip time estimation. --- connection.go | 6 +++--- connector.go | 2 +- driver_test.go | 33 +++++++++++++++++++++++++++++++++ packets.go | 1 - 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/connection.go b/connection.go index f3656f0e6..7b8abeb00 100644 --- a/connection.go +++ b/connection.go @@ -153,11 +153,11 @@ func (mc *mysqlConn) cleanup() { // Makes cleanup idempotent close(mc.closech) - nc := mc.netConn - if nc == nil { + conn := mc.rawConn + if conn == nil { return } - if err := nc.Close(); err != nil { + if err := conn.Close(); err != nil { mc.log(err) } // This function can be called from multiple goroutines. diff --git a/connector.go b/connector.go index a0ee62839..b67077596 100644 --- a/connector.go +++ b/connector.go @@ -102,10 +102,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { nd := net.Dialer{Timeout: mc.cfg.Timeout} mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) } - if err != nil { return nil, err } + mc.rawConn = mc.netConn // Enable TCP Keepalives on TCP connections if tc, ok := mc.netConn.(*net.TCPConn); ok { diff --git a/driver_test.go b/driver_test.go index 6b52650c2..4fd196d4b 100644 --- a/driver_test.go +++ b/driver_test.go @@ -20,6 +20,7 @@ import ( "io" "log" "math" + mrand "math/rand" "net" "net/url" "os" @@ -3577,3 +3578,35 @@ func runCallCommand(dbt *DBTest, query, name string) { } } } + +func TestIssue1567(t *testing.T) { + // enable TLS. + runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) { + // disable connection pooling. + // data race happens when new connection is created. + dbt.db.SetMaxIdleConns(0) + + // estimate round trip time. + start := time.Now() + if err := dbt.db.PingContext(context.Background()); err != nil { + t.Fatal(err) + } + rtt := time.Since(start) + if rtt <= 0 { + // In some environments, rtt may become 0, so set it to at least 1ms. + rtt = time.Millisecond + } + + count := 1000 + if testing.Short() { + count = 10 + } + + for i := 0; i < count; i++ { + timeout := time.Duration(mrand.Int63n(int64(rtt))) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + dbt.db.PingContext(ctx) + cancel() + } + }) +} diff --git a/packets.go b/packets.go index d727f00fe..90a34728b 100644 --- a/packets.go +++ b/packets.go @@ -351,7 +351,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if err := tlsConn.Handshake(); err != nil { return err } - mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn } From 8d421d9c69403dbea52832f311b6d49cff004dbd Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Tue, 30 Apr 2024 11:54:28 +0900 Subject: [PATCH 075/123] update changelog for releasing v1.8.1 (#1576) (#1577) cherry pick of https://github.com/shogo82148/mysql/commit/476df92ad2293daaba19414bd1495c1b2b6c0bad ## Summary by CodeRabbit - **Bug Fixes** - Addressed race conditions when the context is canceled. - **New Features** - Enhanced database connection with charset and collation settings. - Improved path escaping in database names. - Dropped support for Go versions 1.13-17. - Implemented parsing numbers over text protocol. - Introduced new configuration options for advanced usage. - **Enhancements** - Made logger configurable per connection. - Fixed handling of `mediumint unsigned` in `ColumnType.DatabaseTypeName`. - Added connection attributes for more detailed connection information. --- CHANGELOG.md | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 213215c8d..0c9bd9b10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,45 @@ +## Version 1.8.1 (2024-03-26) + +Bugfixes: + +- fix race condition when context is canceled in [#1562](https://github.com/go-sql-driver/mysql/pull/1562) and [#1570](https://github.com/go-sql-driver/mysql/pull/1570) + +## Version 1.8.0 (2024-03-09) + +Major Changes: + +- Use `SET NAMES charset COLLATE collation`. by @methane in [#1437](https://github.com/go-sql-driver/mysql/pull/1437) + - Older go-mysql-driver used `collation_id` in the handshake packet. But it caused collation mismatch in some situation. + - If you don't specify charset nor collation, go-mysql-driver sends `SET NAMES utf8mb4` for new connection. This uses server's default collation for utf8mb4. + - If you specify charset, go-mysql-driver sends `SET NAMES `. This uses the server's default collation for ``. + - If you specify collation and/or charset, go-mysql-driver sends `SET NAMES charset COLLATE collation`. +- PathEscape dbname in DSN. by @methane in [#1432](https://github.com/go-sql-driver/mysql/pull/1432) + - This is backward incompatible in rare case. Check your DSN. +- Drop Go 1.13-17 support by @methane in [#1420](https://github.com/go-sql-driver/mysql/pull/1420) + - Use Go 1.18+ +- Parse numbers on text protocol too by @methane in [#1452](https://github.com/go-sql-driver/mysql/pull/1452) + - When text protocol is used, go-mysql-driver passed bare `[]byte` to database/sql for avoid unnecessary allocation and conversion. + - If user specified `*any` to `Scan()`, database/sql passed the `[]byte` into the target variable. + - This confused users because most user doesn't know when text/binary protocol used. + - go-mysql-driver 1.8 converts integer/float values into int64/double even in text protocol. This doesn't increase allocation compared to `[]byte` and conversion cost is negatable. +- New options start using the Functional Option Pattern to avoid increasing technical debt in the Config object. Future version may introduce Functional Option for existing options, but not for now. + - Make TimeTruncate functional option by @methane in [1552](https://github.com/go-sql-driver/mysql/pull/1552) + - Add BeforeConnect callback to configuration object by @ItalyPaleAle in [#1469](https://github.com/go-sql-driver/mysql/pull/1469) + + +Other changes: + +- Adding DeregisterDialContext to prevent memory leaks with dialers we don't need anymore by @jypelle in https://github.com/go-sql-driver/mysql/pull/1422 +- Make logger configurable per connection by @frozenbonito in https://github.com/go-sql-driver/mysql/pull/1408 +- Fix ColumnType.DatabaseTypeName for mediumint unsigned by @evanelias in https://github.com/go-sql-driver/mysql/pull/1428 +- Add connection attributes by @Daemonxiao in https://github.com/go-sql-driver/mysql/pull/1389 +- Stop `ColumnTypeScanType()` from returning `sql.RawBytes` by @methane in https://github.com/go-sql-driver/mysql/pull/1424 +- Exec() now provides access to status of multiple statements. by @mherr-google in https://github.com/go-sql-driver/mysql/pull/1309 +- Allow to change (or disable) the default driver name for registration by @dolmen in https://github.com/go-sql-driver/mysql/pull/1499 +- Add default connection attribute '_server_host' by @oblitorum in https://github.com/go-sql-driver/mysql/pull/1506 +- QueryUnescape DSN ConnectionAttribute value by @zhangyangyu in https://github.com/go-sql-driver/mysql/pull/1470 +- Add client_ed25519 authentication by @Gusted in https://github.com/go-sql-driver/mysql/pull/1518 + ## Version 1.7.1 (2023-04-25) Changes: From 7939f5923ddca00fbfcaba7ab72eca484d5f9060 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 30 Apr 2024 22:26:36 +0900 Subject: [PATCH 076/123] update URL for protocol docs (#1580) --- packets.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packets.go b/packets.go index 90a34728b..cf3412ff6 100644 --- a/packets.go +++ b/packets.go @@ -21,8 +21,9 @@ import ( "time" ) -// Packets documentation: -// http://dev.mysql.com/doc/internals/en/client-server-protocol.html +// MySQL client/server protocol documentations. +// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html +// https://mariadb.com/kb/en/clientserver-protocol/ // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { From af8d7931954ec21a96df9610a99c09c2887f2ee7 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 30 Apr 2024 22:27:06 +0900 Subject: [PATCH 077/123] unify short name for mysqlConn in connection_test (#1581) --- connection_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/connection_test.go b/connection_test.go index 98c985ae1..c59cb6176 100644 --- a/connection_test.go +++ b/connection_test.go @@ -117,8 +117,8 @@ func TestInterpolateParamsUint64(t *testing.T) { func TestCheckNamedValue(t *testing.T) { value := driver.NamedValue{Value: ^uint64(0)} - x := &mysqlConn{} - err := x.CheckNamedValue(&value) + mc := &mysqlConn{} + err := mc.CheckNamedValue(&value) if err != nil { t.Fatal("uint64 high-bit not convertible", err) @@ -159,13 +159,13 @@ func TestCleanCancel(t *testing.T) { func TestPingMarkBadConnection(t *testing.T) { nc := badConnection{err: errors.New("boom")} - ms := &mysqlConn{ + mc := &mysqlConn{ netConn: nc, buf: newBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, } - err := ms.Ping(context.Background()) + err := mc.Ping(context.Background()) if err != driver.ErrBadConn { t.Errorf("expected driver.ErrBadConn, got %#v", err) @@ -174,7 +174,7 @@ func TestPingMarkBadConnection(t *testing.T) { func TestPingErrInvalidConn(t *testing.T) { nc := badConnection{err: errors.New("failed to write"), n: 10} - ms := &mysqlConn{ + mc := &mysqlConn{ netConn: nc, buf: newBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, @@ -182,7 +182,7 @@ func TestPingErrInvalidConn(t *testing.T) { cfg: NewConfig(), } - err := ms.Ping(context.Background()) + err := mc.Ping(context.Background()) if err != ErrInvalidConn { t.Errorf("expected ErrInvalidConn, got %#v", err) From 2f7015e5c48d361a7dd188c01ae95379c7b9f6f9 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 4 Jun 2024 19:12:35 +0900 Subject: [PATCH 078/123] log: add "filename:line" prefix by ourself (#1589) go-sql-driver/mysql#1563 broke the filename:lineno prefix in the log message by introducing a helper function. This commit adds the "filename:line" prefix in the helper function instead of log.Lshortfile option to show correct filename:lineno. --- connection.go | 12 ++++++++++++ errors.go | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/connection.go b/connection.go index 7b8abeb00..a3dc09d2c 100644 --- a/connection.go +++ b/connection.go @@ -13,8 +13,10 @@ import ( "database/sql" "database/sql/driver" "encoding/json" + "fmt" "io" "net" + "runtime" "strconv" "strings" "sync/atomic" @@ -47,6 +49,16 @@ type mysqlConn struct { // Helper function to call per-connection logger. func (mc *mysqlConn) log(v ...any) { + _, filename, lineno, ok := runtime.Caller(1) + if ok { + pos := strings.LastIndexByte(filename, '/') + if pos != -1 { + filename = filename[pos+1:] + } + prefix := fmt.Sprintf("%s:%d ", filename, lineno) + v = append([]any{prefix}, v...) + } + mc.cfg.Logger.Print(v...) } diff --git a/errors.go b/errors.go index a7ef88909..238e480f3 100644 --- a/errors.go +++ b/errors.go @@ -37,7 +37,7 @@ var ( errBadConnNoWrite = errors.New("bad connection") ) -var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) +var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime)) // Logger is used to log critical error messages. type Logger interface { From 05325d8c2d8a3f5469086f2fd15552cc7960926c Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 11 Jun 2024 22:34:45 +0900 Subject: [PATCH 079/123] fix some write error handling (#1595) interpolateParams() returned ErrInvalidConn without closing the connection. Since database/sql doesn't understand ErrInvalidConn, there is a risk that database/sql reuse this connection and ErrInvalidConn is returned repeatedly. --- connection.go | 6 ++++-- packets.go | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/connection.go b/connection.go index a3dc09d2c..bf102cdf9 100644 --- a/connection.go +++ b/connection.go @@ -230,8 +230,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin buf, err := mc.buf.takeCompleteBuffer() if err != nil { // can not take the buffer. Something must be wrong with the connection - mc.log(err) - return "", ErrInvalidConn + mc.cleanup() + // interpolateParams would be called before sending any query. + // So its safe to retry. + return "", driver.ErrBadConn } buf = buf[:0] argPos := 0 diff --git a/packets.go b/packets.go index cf3412ff6..033ef201e 100644 --- a/packets.go +++ b/packets.go @@ -117,6 +117,8 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Write packet if mc.writeTimeout > 0 { if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + mc.cleanup() + mc.log(err) return err } } From 9b8d28eff68e1b0dec9d45e9868796e7f7a9af49 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 11 Jun 2024 22:49:22 +0900 Subject: [PATCH 080/123] fix missing skip test when no DB is available (#1594) Fix `go test` fails when no DB is set up. --- driver_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/driver_test.go b/driver_test.go index 4fd196d4b..24d73c34f 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3539,6 +3539,9 @@ func TestConnectionAttributes(t *testing.T) { } func TestErrorInMultiResult(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } // https://github.com/go-sql-driver/mysql/issues/1361 var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { From 87443b94dfd43b6cab62182a30c0e7d9759bc18d Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 14 Jun 2024 13:51:24 +0900 Subject: [PATCH 081/123] small code cleanup (#1598) * Go programmers familier with `if err != nil {}` than `if err == nil {}`. * Update some URLs about MySQL client/server protocol. --- connection.go | 43 +++++++++++++++++++++++-------------------- packets.go | 33 +++++++++++++++++---------------- 2 files changed, 40 insertions(+), 36 deletions(-) diff --git a/connection.go b/connection.go index bf102cdf9..462e7d134 100644 --- a/connection.go +++ b/connection.go @@ -400,31 +400,34 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) } // Send command err := mc.writeCommandPacketStr(comQuery, query) - if err == nil { - // Read Result - var resLen int - resLen, err = handleOk.readResultSetHeaderPacket() - if err == nil { - rows := new(textRows) - rows.mc = mc + if err != nil { + return nil, mc.markBadConn(err) + } - if resLen == 0 { - rows.rs.done = true + // Read Result + var resLen int + resLen, err = handleOk.readResultSetHeaderPacket() + if err != nil { + return nil, mc.markBadConn(err) + } - switch err := rows.NextResultSet(); err { - case nil, io.EOF: - return rows, nil - default: - return nil, err - } - } + rows := new(textRows) + rows.mc = mc - // Columns - rows.rs.columns, err = mc.readColumns(resLen) - return rows, err + if resLen == 0 { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err } } - return nil, mc.markBadConn(err) + + // Columns + rows.rs.columns, err = mc.readColumns(resLen) + return rows, err } // Gets the value of the given MySQL System Variable diff --git a/packets.go b/packets.go index 033ef201e..b90b14c5c 100644 --- a/packets.go +++ b/packets.go @@ -524,32 +524,33 @@ func (mc *okHandler) readResultOK() error { } // Result Set Header Packet -// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html func (mc *okHandler) readResultSetHeaderPacket() (int, error) { // handleOkPacket replaces both values; other cases leave the values unchanged. mc.result.affectedRows = append(mc.result.affectedRows, 0) mc.result.insertIds = append(mc.result.insertIds, 0) data, err := mc.conn().readPacket() - if err == nil { - switch data[0] { - - case iOK: - return 0, mc.handleOkPacket(data) + if err != nil { + return 0, err + } - case iERR: - return 0, mc.conn().handleErrorPacket(data) + switch data[0] { + case iOK: + return 0, mc.handleOkPacket(data) - case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) - } + case iERR: + return 0, mc.conn().handleErrorPacket(data) - // column count - num, _, _ := readLengthEncodedInteger(data) - // ignore remaining data in the packet. see #1478. - return int(num), nil + case iLocalInFile: + return 0, mc.handleInFileRequest(string(data[1:])) } - return 0, err + + // column count + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html + num, _, _ := readLengthEncodedInteger(data) + // ignore remaining data in the packet. see #1478. + return int(num), nil } // Error Packet From 2f69712cd480487ecb7e513b2fe1e0e7fe138767 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 16 Jun 2024 10:18:42 +0900 Subject: [PATCH 082/123] fix unnecesssary allocation in infile.go (#1600) --- infile.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/infile.go b/infile.go index 0c8af9f11..cf892beae 100644 --- a/infile.go +++ b/infile.go @@ -95,7 +95,6 @@ const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead a func (mc *okHandler) handleInFileRequest(name string) (err error) { var rdr io.Reader - var data []byte packetSize := defaultPacketSize if mc.maxWriteSize < packetSize { packetSize = mc.maxWriteSize @@ -147,9 +146,11 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { } // send content packets + var data []byte + // if packetSize == 0, the Reader contains no data if err == nil && packetSize > 0 { - data := make([]byte, 4+packetSize) + data = make([]byte, 4+packetSize) var n int for err == nil { n, err = rdr.Read(data[4:]) From 52c1917d99904701db2b0e4f14baffa948009cd7 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 16 Jun 2024 10:20:06 +0900 Subject: [PATCH 083/123] remove unnecessary logs (#1599) Logging ErrInvalidConn when the connection already closed doesn't provide any help to users. Additonally, database/sql now uses Validator() to check connection liveness before calling query methods. So stop using `mc.log(ErrInvalidConn)` idiom. This PR includes some cleanup and documentation relating to `mc.markBadConn()`. --- connection.go | 21 +++++++++------------ errors.go | 2 +- statement.go | 2 -- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/connection.go b/connection.go index 462e7d134..2b19c9272 100644 --- a/connection.go +++ b/connection.go @@ -111,14 +111,13 @@ func (mc *mysqlConn) handleParams() (err error) { return } +// markBadConn replaces errBadConnNoWrite with driver.ErrBadConn. +// This function is used to return driver.ErrBadConn only when safe to retry. func (mc *mysqlConn) markBadConn(err error) error { - if mc == nil { - return err - } - if err != errBadConnNoWrite { - return err + if err == errBadConnNoWrite { + return driver.ErrBadConn } - return driver.ErrBadConn + return err } func (mc *mysqlConn) Begin() (driver.Tx, error) { @@ -127,7 +126,6 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if mc.closed.Load() { - mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } var q string @@ -189,7 +187,6 @@ func (mc *mysqlConn) error() error { func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if mc.closed.Load() { - mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command @@ -324,7 +321,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.closed.Load() { - mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -384,7 +380,6 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) handleOk := mc.clearResult() if mc.closed.Load() { - mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -408,7 +403,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) var resLen int resLen, err = handleOk.readResultSetHeaderPacket() if err != nil { - return nil, mc.markBadConn(err) + return nil, err } rows := new(textRows) @@ -482,7 +477,6 @@ func (mc *mysqlConn) finish() { // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { if mc.closed.Load() { - mc.log(ErrInvalidConn) return driver.ErrBadConn } @@ -704,3 +698,6 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { func (mc *mysqlConn) IsValid() bool { return !mc.closed.Load() } + +var _ driver.SessionResetter = &mysqlConn{} +var _ driver.Validator = &mysqlConn{} diff --git a/errors.go b/errors.go index 238e480f3..584617b11 100644 --- a/errors.go +++ b/errors.go @@ -32,7 +32,7 @@ var ( // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn - // to trigger a resend. + // to trigger a resend. Use mc.markBadConn(err) to do this. // See https://github.com/go-sql-driver/mysql/pull/302 errBadConnNoWrite = errors.New("bad connection") ) diff --git a/statement.go b/statement.go index 0436f2240..35b02bbeb 100644 --- a/statement.go +++ b/statement.go @@ -51,7 +51,6 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.mc.closed.Load() { - stmt.mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command @@ -95,7 +94,6 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if stmt.mc.closed.Load() { - stmt.mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command From 3484db1f68a7b493faffc08c1897360fdd7a67f9 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 29 Jun 2024 08:36:17 +0900 Subject: [PATCH 084/123] improve error handling in writePacket (#1601) * handle error before success case. * return io.ErrShortWrite if not all bytes were written but err is nil. * return err instead of ErrInvalidConn. --- connection_test.go | 6 ++++-- packets.go | 34 +++++++++++++++++----------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/connection_test.go b/connection_test.go index c59cb6176..6f8d2a6d7 100644 --- a/connection_test.go +++ b/connection_test.go @@ -163,6 +163,8 @@ func TestPingMarkBadConnection(t *testing.T) { netConn: nc, buf: newBuffer(nc), maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + cfg: NewConfig(), } err := mc.Ping(context.Background()) @@ -184,8 +186,8 @@ func TestPingErrInvalidConn(t *testing.T) { err := mc.Ping(context.Background()) - if err != ErrInvalidConn { - t.Errorf("expected ErrInvalidConn, got %#v", err) + if err != nc.err { + t.Errorf("expected %#v, got %#v", nc.err, err) } } diff --git a/packets.go b/packets.go index b90b14c5c..df850fd41 100644 --- a/packets.go +++ b/packets.go @@ -124,32 +124,32 @@ func (mc *mysqlConn) writePacket(data []byte) error { } n, err := mc.netConn.Write(data[:4+size]) - if err == nil && n == 4+size { - mc.sequence++ - if size != maxPacketSize { - return nil - } - pktLen -= size - data = data[size:] - continue - } - - // Handle error - if err == nil { // n != len(data) - mc.cleanup() - mc.log(ErrMalformPkt) - } else { + if err != nil { if cerr := mc.canceled.Value(); cerr != nil { return cerr } + mc.cleanup() if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet + mc.log(err) return errBadConnNoWrite + } else { + return err } + } + if n != 4+size { + // io.Writer(b) must return a non-nil error if it cannot write len(b) bytes. + // The io.ErrShortWrite error is used to indicate that this rule has not been followed. mc.cleanup() - mc.log(err) + return io.ErrShortWrite + } + + mc.sequence++ + if size != maxPacketSize { + return nil } - return ErrInvalidConn + pktLen -= size + data = data[size:] } } From 9c20169374dba4e362a065b8d7183864ee076212 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Fri, 19 Jul 2024 06:02:06 +0200 Subject: [PATCH 085/123] Add support for new VECTOR type (#1609) MySQL 9.0.0 added support for the VECTOR type. This adds basic support so it can be handled at the protocol level. See also https://dev.mysql.com/doc/dev/mysql-server/latest/field__types_8h.html --- AUTHORS | 1 + const.go | 5 ++++- fields.go | 4 +++- packets.go | 3 ++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/AUTHORS b/AUTHORS index 4021b96cc..bab66a3b2 100644 --- a/AUTHORS +++ b/AUTHORS @@ -33,6 +33,7 @@ Daniel Montoya Daniel Nichter Daniël van Eeden Dave Protasowski +Dirkjan Bussink DisposaBoy Egor Smolyakov Erwan Martin diff --git a/const.go b/const.go index 22526e031..0cee9b2ee 100644 --- a/const.go +++ b/const.go @@ -125,7 +125,10 @@ const ( fieldTypeBit ) const ( - fieldTypeJSON fieldType = iota + 0xf5 + fieldTypeVector fieldType = iota + 0xf2 + fieldTypeInvalid + fieldTypeBool + fieldTypeJSON fieldTypeNewDecimal fieldTypeEnum fieldTypeSet diff --git a/fields.go b/fields.go index 286084247..be5cd809a 100644 --- a/fields.go +++ b/fields.go @@ -112,6 +112,8 @@ func (mf *mysqlField) typeDatabaseName() string { return "VARCHAR" case fieldTypeYear: return "YEAR" + case fieldTypeVector: + return "VECTOR" default: return "" } @@ -198,7 +200,7 @@ func (mf *mysqlField) scanType() reflect.Type { return scanTypeNullFloat case fieldTypeBit, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, - fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry: + fieldTypeBLOB, fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeVector: if mf.charSet == binaryCollationID { return scanTypeBytes } diff --git a/packets.go b/packets.go index df850fd41..ccdd532b3 100644 --- a/packets.go +++ b/packets.go @@ -1329,7 +1329,8 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, - fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, + fieldTypeVector: var isNull bool var n int dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) From f6a18cf1ac3e6bc282f72874a3742469a99e5762 Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Fri, 19 Jul 2024 13:20:36 +0900 Subject: [PATCH 086/123] MySQL 9.0 and MariaDB 11.4 are released (#1610) --- .github/workflows/test.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c5b2aa313..df37eab59 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -37,14 +37,16 @@ jobs: '1.20', ] mysql = [ + '9.0', + '8.4', # LTS '8.0', - '8.3', '5.7', - 'mariadb-11.3', + 'mariadb-11.4', # LTS + 'mariadb-11.2', 'mariadb-11.1', 'mariadb-10.11', # LTS 'mariadb-10.6', # LTS - 'mariadb-10.5', + 'mariadb-10.5', # LTS ] includes = [] From 44553d64bcde78a5b58cb133a5cc708281c333e0 Mon Sep 17 00:00:00 2001 From: Chris Kirkland Date: Tue, 23 Jul 2024 21:45:26 -0500 Subject: [PATCH 087/123] doc: clarify connection close behavior of context (#1606) Updates the README to make it clear that `go-sql-driver/mysql` closes the current connection if the `context.Context` provided to `ExecContext`, `SelectContext`, etc. is cancelled or times out prior to the query returning. --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 6c6abf9c4..c83f4f74f 100644 --- a/README.md +++ b/README.md @@ -519,6 +519,9 @@ This driver supports the [`ColumnType` interface](https://golang.org/pkg/databas Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. +> [!IMPORTANT] +> The `QueryContext`, `ExecContext`, etc. variants provided by `database/sql` will cause the connection to be closed if the provided context is cancelled or timed out before the result is received by the driver. + ### `LOAD DATA LOCAL INFILE` support For this feature you need direct access to the package. Therefore you must change the import path (no `_`): From c7276ee51ed3f9eeb720ab003e24f80303a7ce08 Mon Sep 17 00:00:00 2001 From: Nao Yokotsuka <32049413+yokonao@users.noreply.github.com> Date: Sun, 4 Aug 2024 16:52:29 +0900 Subject: [PATCH 088/123] Check mysqlConnector.canceled.Value when failed to TLS handshake (#1615) ### Description Check if the context is canceled when failed to TLS handshake. fix: #1614 ### Checklist - [x] Code compiles correctly - [x] Created tests which fail without the change (if possible) - [x] All tests passing - [x] Extended the README / documentation, if necessary - [x] Added myself / the copyright holder to the AUTHORS file ## Summary by CodeRabbit - **New Features** - Added Nao Yokotsuka to the contributors list for improved project documentation. - **Bug Fixes** - Enhanced error handling in the TLS handshake process to better manage cancellation requests, improving connection responsiveness. --- AUTHORS | 1 + packets.go | 3 +++ 2 files changed, 4 insertions(+) diff --git a/AUTHORS b/AUTHORS index bab66a3b2..287176fb4 100644 --- a/AUTHORS +++ b/AUTHORS @@ -81,6 +81,7 @@ Lunny Xiao Luke Scott Maciej Zimnoch Michael Woolnough +Nao Yokotsuka Nathanial Murphy Nicola Peduzzi Oliver Bone diff --git a/packets.go b/packets.go index ccdd532b3..5ca6491a8 100644 --- a/packets.go +++ b/packets.go @@ -352,6 +352,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // Switch to TLS tlsConn := tls.Client(mc.netConn, mc.cfg.TLS) if err := tlsConn.Handshake(); err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return cerr + } return err } mc.netConn = tlsConn From 2f1527670cb7207fd213f92c7120f9387fe256cf Mon Sep 17 00:00:00 2001 From: pengbanban Date: Mon, 5 Aug 2024 14:31:35 +0900 Subject: [PATCH 089/123] chore: fix comment (#1620) --- utils_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils_test.go b/utils_test.go index 80aebddff..42a88393c 100644 --- a/utils_test.go +++ b/utils_test.go @@ -339,7 +339,7 @@ func TestAppendDateTime(t *testing.T) { buf, err := appendDateTime(buf, v.t, v.timeTruncate) if err != nil { if !v.expectedErr { - t.Errorf("appendDateTime(%v) returned an errror: %v", v.t, err) + t.Errorf("appendDateTime(%v) returned an error: %v", v.t, err) } continue } From 00dc21a6243c02c1a84fc82d08a821c08fde4053 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 30 Aug 2024 14:38:05 +0900 Subject: [PATCH 090/123] allow unknown collation name (#1604) Fix #1603 --- collations.go | 2 +- connection.go | 43 +++++++++---------------------------------- connector.go | 19 +++++++++++++++++++ dsn.go | 11 ++++++++++- dsn_test.go | 6 +++--- packets.go | 20 +++++++++----------- 6 files changed, 51 insertions(+), 50 deletions(-) diff --git a/collations.go b/collations.go index 1cdf97b67..29b1aa43f 100644 --- a/collations.go +++ b/collations.go @@ -8,7 +8,7 @@ package mysql -const defaultCollation = "utf8mb4_general_ci" +const defaultCollationID = 45 // utf8mb4_general_ci const binaryCollationID = 63 // A list of available collations mapped to the internal ID. diff --git a/connection.go b/connection.go index 2b19c9272..ef6fc9e40 100644 --- a/connection.go +++ b/connection.go @@ -67,45 +67,20 @@ func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder for param, val := range mc.cfg.Params { - switch param { - // Charset: character_set_connection, character_set_client, character_set_results - case "charset": - charsets := strings.Split(val, ",") - for _, cs := range charsets { - // ignore errors here - a charset may not exist - if mc.cfg.Collation != "" { - err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation) - } else { - err = mc.exec("SET NAMES " + cs) - } - if err == nil { - break - } - } - if err != nil { - return - } - - // Other system vars accumulated in a single SET command - default: - if cmdSet.Len() == 0 { - // Heuristic: 29 chars for each other key=value to reduce reallocations - cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1)) - cmdSet.WriteString("SET ") - } else { - cmdSet.WriteString(", ") - } - cmdSet.WriteString(param) - cmdSet.WriteString(" = ") - cmdSet.WriteString(val) + if cmdSet.Len() == 0 { + // Heuristic: 29 chars for each other key=value to reduce reallocations + cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1)) + cmdSet.WriteString("SET ") + } else { + cmdSet.WriteString(", ") } + cmdSet.WriteString(param) + cmdSet.WriteString(" = ") + cmdSet.WriteString(val) } if cmdSet.Len() > 0 { err = mc.exec(cmdSet.String()) - if err != nil { - return - } } return diff --git a/connector.go b/connector.go index b67077596..62012dba3 100644 --- a/connector.go +++ b/connector.go @@ -180,6 +180,25 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.maxWriteSize = mc.maxAllowedPacket } + // Charset: character_set_connection, character_set_client, character_set_results + if len(mc.cfg.charsets) > 0 { + for _, cs := range mc.cfg.charsets { + // ignore errors here - a charset may not exist + if mc.cfg.Collation != "" { + err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation) + } else { + err = mc.exec("SET NAMES " + cs) + } + if err == nil { + break + } + } + if err != nil { + mc.Close() + return nil, err + } + } + // Handle DSN Params err = mc.handleParams() if err != nil { diff --git a/dsn.go b/dsn.go index 65f5a0242..3c7a6e215 100644 --- a/dsn.go +++ b/dsn.go @@ -44,7 +44,8 @@ type Config struct { DBName string // Database name Params map[string]string // Connection parameters ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs - Collation string // Connection collation + charsets []string // Connection charset. When set, this will be set in SET NAMES query + Collation string // Connection collation. When set, this will be set in SET NAMES COLLATE query Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed ServerPubKey string // Server public key name @@ -282,6 +283,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "clientFoundRows", "true") } + if charsets := cfg.charsets; len(charsets) > 0 { + writeDSNParam(&buf, &hasParam, "charset", strings.Join(charsets, ",")) + } + if col := cfg.Collation; col != "" { writeDSNParam(&buf, &hasParam, "collation", col) } @@ -501,6 +506,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { return errors.New("invalid bool value: " + value) } + // charset + case "charset": + cfg.charsets = strings.Split(value, ",") + // Collation case "collation": cfg.Collation = value diff --git a/dsn_test.go b/dsn_test.go index dd8cd935c..863d14824 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -31,13 +31,13 @@ var testDSNs = []struct { &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, }, { "user@unix(/path/to/socket)/dbname?charset=utf8", - &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", charsets: []string{"utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", charsets: []string{"utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, }, { "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", - &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", charsets: []string{"utf8mb4", "utf8"}, Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, }, { "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, Logger: defaultLogger, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, diff --git a/packets.go b/packets.go index 5ca6491a8..014a1deee 100644 --- a/packets.go +++ b/packets.go @@ -322,17 +322,15 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data[11] = 0x00 // Collation ID [1 byte] - cname := mc.cfg.Collation - if cname == "" { - cname = defaultCollation - } - var found bool - data[12], found = collations[cname] - if !found { - // Note possibility for false negatives: - // could be triggered although the collation is valid if the - // collations map does not contain entries the server supports. - return fmt.Errorf("unknown collation: %q", cname) + data[12] = defaultCollationID + if cname := mc.cfg.Collation; cname != "" { + colID, ok := collations[cname] + if ok { + data[12] = colID + } else if len(mc.cfg.charsets) > 0 { + // When cfg.charset is set, the collation is set by `SET NAMES COLLATE `. + return fmt.Errorf("unknown collation: %q", cname) + } } // Filler [23 bytes] (all 0x00) From 91ad4fb77b05cf5b4a413d2b4b67aa7dee6e9f60 Mon Sep 17 00:00:00 2001 From: Aaron Jheng Date: Sun, 10 Nov 2024 12:10:43 +0800 Subject: [PATCH 091/123] Specify a custom dial function per config (#1527) Specify a custom dial function per config instead of using RegisterDialContext. --- connector.go | 31 ++++++++++++++++++------------- dsn.go | 2 ++ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/connector.go b/connector.go index 62012dba3..769b3adc9 100644 --- a/connector.go +++ b/connector.go @@ -87,20 +87,25 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.parseTime = mc.cfg.ParseTime // Connect to Server - dialsLock.RLock() - dial, ok := dials[mc.cfg.Net] - dialsLock.RUnlock() - if ok { - dctx := ctx - if mc.cfg.Timeout > 0 { - var cancel context.CancelFunc - dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) - defer cancel() - } - mc.netConn, err = dial(dctx, mc.cfg.Addr) + dctx := ctx + if mc.cfg.Timeout > 0 { + var cancel context.CancelFunc + dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) + defer cancel() + } + + if c.cfg.DialFunc != nil { + mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr) } else { - nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + mc.netConn, err = dial(dctx, mc.cfg.Addr) + } else { + nd := net.Dialer{} + mc.netConn, err = nd.DialContext(dctx, mc.cfg.Net, mc.cfg.Addr) + } } if err != nil { return nil, err diff --git a/dsn.go b/dsn.go index 3c7a6e215..f391a8fc9 100644 --- a/dsn.go +++ b/dsn.go @@ -55,6 +55,8 @@ type Config struct { ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout Logger Logger // Logger + // DialFunc specifies the dial function for creating connections + DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // boolean fields From fc64d3f08fb84395f911a6a23a266db92ac8a7e1 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 11 Nov 2024 11:14:04 +0900 Subject: [PATCH 092/123] ci: update Go and staticcheck versions (#1639) - Add Go 1.23 support - Remove Go 1.20 support - Update staticcheck action --- .github/workflows/test.yml | 8 +++----- README.md | 2 +- go.mod | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index df37eab59..b1c1f2b34 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,9 +15,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: dominikh/staticcheck-action@v1.3.0 - with: - version: "2023.1.6" + - uses: dominikh/staticcheck-action@v1.3.1 list: runs-on: ubuntu-latest @@ -31,10 +29,10 @@ jobs: import os go = [ # Keep the most recent production release at the top - '1.22', + '1.23', # Older production releases + '1.22', '1.21', - '1.20', ] mysql = [ '9.0', diff --git a/README.md b/README.md index c83f4f74f..e9d9222ba 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac ## Requirements -* Go 1.20 or higher. We aim to support the 3 latest versions of Go. +* Go 1.21 or higher. We aim to support the 3 latest versions of Go. * MySQL (5.7+) and MariaDB (10.5+) are supported. * [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP. * Do not ask questions about TiDB in our issue tracker or forum. diff --git a/go.mod b/go.mod index 2eed53ebb..33c4dd5b1 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/go-sql-driver/mysql -go 1.20 +go 1.21 require filippo.io/edwards25519 v1.1.0 From f62f523d2458d82587f03e9357396a9c8a93fcba Mon Sep 17 00:00:00 2001 From: KratkyZobak Date: Mon, 11 Nov 2024 03:14:49 +0100 Subject: [PATCH 093/123] Fix auth errors when username/password are too long (#1482) (#1625) --- AUTHORS | 1 + packets.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 287176fb4..a98608504 100644 --- a/AUTHORS +++ b/AUTHORS @@ -51,6 +51,7 @@ ICHINOSE Shogo Ilia Cimpoes INADA Naoki Jacek Szwec +Jakub Adamus James Harr Janek Vedock Jason Ng diff --git a/packets.go b/packets.go index 014a1deee..eb4e0cefe 100644 --- a/packets.go +++ b/packets.go @@ -392,7 +392,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) - data, err := mc.buf.takeSmallBuffer(pktLen) + data, err := mc.buf.takeBuffer(pktLen) if err != nil { // cannot take the buffer. Something must be wrong with the connection mc.log(err) From 41a5fa29f2f73060c426547f762dc49b62e1f2a5 Mon Sep 17 00:00:00 2001 From: raffertyyu Date: Tue, 19 Nov 2024 12:09:49 +0800 Subject: [PATCH 094/123] Check if MySQL supports CLIENT_CONNECT_ATTRS before sending client attributes. (#1640) --- packets.go | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/packets.go b/packets.go index eb4e0cefe..a2e7ef95c 100644 --- a/packets.go +++ b/packets.go @@ -210,10 +210,13 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if len(data) > pos { // character set [1 byte] // status flags [2 bytes] + pos += 3 // capability flags (upper 2 bytes) [2 bytes] + mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + pos += 2 // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + pos += 11 // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -261,9 +264,11 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientLocalFiles | clientPluginAuth | clientMultiResults | - clientConnectAttrs | + mc.flags&clientConnectAttrs | mc.flags&clientLongFlag + sendConnectAttrs := mc.flags&clientConnectAttrs != 0 + if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } @@ -296,10 +301,13 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // encode length of the connection attributes - var connAttrsLEIBuf [9]byte - connAttrsLen := len(mc.connector.encodedAttributes) - connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) - pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) + var connAttrsLEI []byte + if sendConnectAttrs { + var connAttrsLEIBuf [9]byte + connAttrsLen := len(mc.connector.encodedAttributes) + connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) + pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) + } // Calculate packet length and get buffer with that size data, err := mc.buf.takeBuffer(pktLen + 4) @@ -382,8 +390,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pos++ // Connection Attributes - pos += copy(data[pos:], connAttrsLEI) - pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + if sendConnectAttrs { + pos += copy(data[pos:], connAttrsLEI) + pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + } // Send Auth packet return mc.writePacket(data[:pos]) From 9c8d6a5ddc5b4c2a658e77cb4d03583327901ca5 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 22 Nov 2024 00:48:25 +0900 Subject: [PATCH 095/123] Reduce "busy buffer" logs (#1641) Reduce the use of `errBadConnNoWrite` to improve maintainability. ResetSession() and IsValid() checks if the buffer is busy. This reduces the risk of busy buffer error during connection in use. In principle, the risk of this is zero. So I removed errBadConnNoWrite when checking the busy buffer. After this change, only `writePacke()` returns errBadConnNoWrite. Additionally, I do not send COM_QUIT when readPacket() encounter read error. It caused "busy buffer" error too and hide real errors. --- buffer.go | 5 +++++ connection.go | 10 +++++++--- packets.go | 44 ++++++++++++++------------------------------ 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/buffer.go b/buffer.go index 0774c5c8c..d3d009ccf 100644 --- a/buffer.go +++ b/buffer.go @@ -43,6 +43,11 @@ func newBuffer(nc net.Conn) buffer { } } +// busy returns true if the buffer contains some read data. +func (b *buffer) busy() bool { + return b.length > 0 +} + // flip replaces the active buffer with the background buffer // this is a delayed flip that simply increases the buffer counter; // the actual flip will be performed the next time we call `buffer.fill` diff --git a/connection.go b/connection.go index ef6fc9e40..c220a8360 100644 --- a/connection.go +++ b/connection.go @@ -121,10 +121,14 @@ func (mc *mysqlConn) Close() (err error) { if !mc.closed.Load() { err = mc.writeCommandPacket(comQuit) } + mc.close() + return +} +// close closes the network connection and clear results without sending COM_QUIT. +func (mc *mysqlConn) close() { mc.cleanup() mc.clearResult() - return } // Closes the network connection and unsets internal variables. Do not call this @@ -637,7 +641,7 @@ func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { // ResetSession implements driver.SessionResetter. // (From Go 1.10) func (mc *mysqlConn) ResetSession(ctx context.Context) error { - if mc.closed.Load() { + if mc.closed.Load() || mc.buf.busy() { return driver.ErrBadConn } @@ -671,7 +675,7 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { // IsValid implements driver.Validator interface // (From Go 1.15) func (mc *mysqlConn) IsValid() bool { - return !mc.closed.Load() + return !mc.closed.Load() && !mc.buf.busy() } var _ driver.SessionResetter = &mysqlConn{} diff --git a/packets.go b/packets.go index a2e7ef95c..4695fb81a 100644 --- a/packets.go +++ b/packets.go @@ -32,11 +32,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet header data, err := mc.buf.readNext(4) if err != nil { + mc.close() if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } mc.log(err) - mc.Close() return nil, ErrInvalidConn } @@ -45,7 +45,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // check packet sync [8 bit] if data[3] != mc.sequence { - mc.Close() + mc.close() if data[3] > mc.sequence { return nil, ErrPktSyncMul } @@ -59,7 +59,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // there was no previous packet if prevData == nil { mc.log(ErrMalformPkt) - mc.Close() + mc.close() return nil, ErrInvalidConn } @@ -69,11 +69,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { + mc.close() if cerr := mc.canceled.Value(); cerr != nil { return nil, cerr } mc.log(err) - mc.Close() return nil, ErrInvalidConn } @@ -125,10 +125,10 @@ func (mc *mysqlConn) writePacket(data []byte) error { n, err := mc.netConn.Write(data[:4+size]) if err != nil { + mc.cleanup() if cerr := mc.canceled.Value(); cerr != nil { return cerr } - mc.cleanup() if n == 0 && pktLen == len(data)-4 { // only for the first loop iteration when nothing was written yet mc.log(err) @@ -162,11 +162,6 @@ func (mc *mysqlConn) writePacket(data []byte) error { func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { data, err = mc.readPacket() if err != nil { - // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since - // in connection initialization we don't risk retrying non-idempotent actions. - if err == ErrInvalidConn { - return nil, "", driver.ErrBadConn - } return } @@ -312,9 +307,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // Calculate packet length and get buffer with that size data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // ClientFlags [32 bit] @@ -404,9 +398,8 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) data, err := mc.buf.takeBuffer(pktLen) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + mc.cleanup() + return err } // Add the auth data [EOF] @@ -424,9 +417,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -443,9 +434,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -464,9 +453,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // Add command byte @@ -1007,9 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In this case the len(data) == cap(data) which is used to optimise the flow below. } if err != nil { - // cannot take the buffer. Something must be wrong with the connection - mc.log(err) - return errBadConnNoWrite + return err } // command [1 byte] @@ -1207,8 +1192,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) if err = mc.buf.store(data); err != nil { - mc.log(err) - return errBadConnNoWrite + return err } } From 2df7a26b03e5f9a55bc31544bc9240ac5705e235 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Wed, 27 Nov 2024 12:41:28 +0900 Subject: [PATCH 096/123] stmt.Close() returns nil when double close (#1642) ErrBadConn needs special care to ensure it is safe to retry. To improve maintenance, I don't want to use the error where I don't have to. Additionally, update the old comment about Go's bug that had been fixed long time ago. --- statement.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index 35b02bbeb..35df85457 100644 --- a/statement.go +++ b/statement.go @@ -24,11 +24,12 @@ type mysqlStmt struct { func (stmt *mysqlStmt) Close() error { if stmt.mc == nil || stmt.mc.closed.Load() { - // driver.Stmt.Close can be called more than once, thus this function - // has to be idempotent. - // See also Issue #450 and golang/go#16019. - //errLog.Print(ErrInvalidConn) - return driver.ErrBadConn + // driver.Stmt.Close could be called more than once, thus this function + // had to be idempotent. See also Issue #450 and golang/go#16019. + // This bug has been fixed in Go 1.8. + // https://github.com/golang/go/commit/90b8a0ca2d0b565c7c7199ffcf77b15ea6b6db3a + // But we keep this function idempotent because it is safer. + return nil } err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) From 575e1b288d624fb14bf56532689f3ec1c1989149 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sun, 1 Dec 2024 10:08:42 +0900 Subject: [PATCH 097/123] stop double-buffering (#1643) Since we dropped Go 1.20 support, we do not need double buffering. This pull request stop double buffering and simplify buffer implementation a lot. Fix #1435 --- buffer.go | 118 ++++++++++++++++++++--------------------------------- packets.go | 4 +- rows.go | 7 ---- 3 files changed, 45 insertions(+), 84 deletions(-) diff --git a/buffer.go b/buffer.go index d3d009ccf..dd82c9313 100644 --- a/buffer.go +++ b/buffer.go @@ -22,47 +22,30 @@ const maxCachedBufSize = 256 * 1024 // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. -// This buffer is backed by two byte slices in a double-buffering scheme type buffer struct { - buf []byte // buf is a byte buffer who's length and capacity are equal. - nc net.Conn - idx int - length int - timeout time.Duration - dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer - flipcnt uint // flipccnt is the current buffer counter for double-buffering + buf []byte // read buffer. + cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. + nc net.Conn + timeout time.Duration } // newBuffer allocates and returns a new buffer. func newBuffer(nc net.Conn) buffer { - fg := make([]byte, defaultBufSize) return buffer{ - buf: fg, - nc: nc, - dbuf: [2][]byte{fg, nil}, + cachedBuf: make([]byte, defaultBufSize), + nc: nc, } } -// busy returns true if the buffer contains some read data. +// busy returns true if the read buffer is not empty. func (b *buffer) busy() bool { - return b.length > 0 + return len(b.buf) > 0 } -// flip replaces the active buffer with the background buffer -// this is a delayed flip that simply increases the buffer counter; -// the actual flip will be performed the next time we call `buffer.fill` -func (b *buffer) flip() { - b.flipcnt += 1 -} - -// fill reads into the buffer until at least _need_ bytes are in it +// fill reads into the read buffer until at least _need_ bytes are in it. func (b *buffer) fill(need int) error { - n := b.length - // fill data into its double-buffering target: if we've called - // flip on this buffer, we'll be copying to the background buffer, - // and then filling it with network data; otherwise we'll just move - // the contents of the current buffer to the front before filling it - dest := b.dbuf[b.flipcnt&1] + // we'll move the contents of the current buffer to dest before filling it. + dest := b.cachedBuf // grow buffer if necessary to fit the whole packet. if need > len(dest) { @@ -72,18 +55,13 @@ func (b *buffer) fill(need int) error { // if the allocated buffer is not too large, move it to backing storage // to prevent extra allocations on applications that perform large reads if len(dest) <= maxCachedBufSize { - b.dbuf[b.flipcnt&1] = dest + b.cachedBuf = dest } } - // if we're filling the fg buffer, move the existing data to the start of it. - // if we're filling the bg buffer, copy over the data - if n > 0 { - copy(dest[:n], b.buf[b.idx:]) - } - - b.buf = dest - b.idx = 0 + // move the existing data to the start of the buffer. + n := len(b.buf) + copy(dest[:n], b.buf) for { if b.timeout > 0 { @@ -92,44 +70,39 @@ func (b *buffer) fill(need int) error { } } - nn, err := b.nc.Read(b.buf[n:]) + nn, err := b.nc.Read(dest[n:]) n += nn - switch err { - case nil: - if n < need { - continue - } - b.length = n - return nil + if err == nil && n < need { + continue + } - case io.EOF: - if n >= need { - b.length = n - return nil - } - return io.ErrUnexpectedEOF + b.buf = dest[:n] - default: - return err + if err == io.EOF { + if n < need { + err = io.ErrUnexpectedEOF + } else { + err = nil + } } + return err } } // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read func (b *buffer) readNext(need int) ([]byte, error) { - if b.length < need { + if len(b.buf) < need { // refill if err := b.fill(need); err != nil { return nil, err } } - offset := b.idx - b.idx += need - b.length -= need - return b.buf[offset:b.idx], nil + data := b.buf[:need] + b.buf = b.buf[need:] + return data, nil } // takeBuffer returns a buffer with the requested size. @@ -137,18 +110,18 @@ func (b *buffer) readNext(need int) ([]byte, error) { // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. func (b *buffer) takeBuffer(length int) ([]byte, error) { - if b.length > 0 { + if b.busy() { return nil, ErrBusyBuffer } // test (cheap) general case first - if length <= cap(b.buf) { - return b.buf[:length], nil + if length <= len(b.cachedBuf) { + return b.cachedBuf[:length], nil } - if length < maxPacketSize { - b.buf = make([]byte, length) - return b.buf, nil + if length < maxCachedBufSize { + b.cachedBuf = make([]byte, length) + return b.cachedBuf, nil } // buffer is larger than we want to store. @@ -159,10 +132,10 @@ func (b *buffer) takeBuffer(length int) ([]byte, error) { // known to be smaller than defaultBufSize. // Only one buffer (total) can be used at a time. func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { - if b.length > 0 { + if b.busy() { return nil, ErrBusyBuffer } - return b.buf[:length], nil + return b.cachedBuf[:length], nil } // takeCompleteBuffer returns the complete existing buffer. @@ -170,18 +143,15 @@ func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { // cap and len of the returned buffer will be equal. // Only one buffer (total) can be used at a time. func (b *buffer) takeCompleteBuffer() ([]byte, error) { - if b.length > 0 { + if b.busy() { return nil, ErrBusyBuffer } - return b.buf, nil + return b.cachedBuf, nil } // store stores buf, an updated buffer, if its suitable to do so. -func (b *buffer) store(buf []byte) error { - if b.length > 0 { - return ErrBusyBuffer - } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { - b.buf = buf[:cap(buf)] +func (b *buffer) store(buf []byte) { + if cap(buf) <= maxCachedBufSize && cap(buf) > cap(b.cachedBuf) { + b.cachedBuf = buf[:cap(buf)] } - return nil } diff --git a/packets.go b/packets.go index 4695fb81a..736e4418c 100644 --- a/packets.go +++ b/packets.go @@ -1191,9 +1191,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - if err = mc.buf.store(data); err != nil { - return err - } + mc.buf.store(data) // allow this buffer to be reused } pos += len(paramValues) diff --git a/rows.go b/rows.go index 81fa6062c..df98417b8 100644 --- a/rows.go +++ b/rows.go @@ -111,13 +111,6 @@ func (rows *mysqlRows) Close() (err error) { return err } - // flip the buffer for this connection if we need to drain it. - // note that for a successful query (i.e. one where rows.next() - // has been called until it returns false), `rows.mc` will be nil - // by the time the user calls `(*Rows).Close`, so we won't reach this - // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 - mc.buf.flip() - // Remove unread packets from stream if !rows.rs.done { err = mc.readUntilEOF() From c9f41c074062d5ab9aeb5e44adeac3a7d85fbc4e Mon Sep 17 00:00:00 2001 From: Minh Quang Date: Sun, 15 Dec 2024 10:37:13 +0700 Subject: [PATCH 098/123] fix typo in comment (#1647) Fix #1646 --- AUTHORS | 1 + connection.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index a98608504..361c6b647 100644 --- a/AUTHORS +++ b/AUTHORS @@ -92,6 +92,7 @@ Paul Bonser Paulius Lozys Peter Schultz Phil Porada +Minh Quang Rebecca Chin Reed Allman Richard Wilkes diff --git a/connection.go b/connection.go index c220a8360..acc627086 100644 --- a/connection.go +++ b/connection.go @@ -435,7 +435,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { return nil, err } -// finish is called when the query has canceled. +// cancel is called when the query has canceled. func (mc *mysqlConn) cancel(err error) { mc.canceled.Set(err) mc.cleanup() From 3348e573da4c1d7186ae7d6eabd4d7333bd486a0 Mon Sep 17 00:00:00 2001 From: Joe Mann Date: Thu, 19 Dec 2024 04:14:14 +0100 Subject: [PATCH 099/123] Implement zlib compression (#1487) Implemented the SQL compression protocol. This new feature is enabled by: * Adding `compress=true` in DSN. * `cfg.Apply(Compress(True))` Co-authored-by: Brigitte Lamarche Co-authored-by: Julien Schmidt Co-authored-by: Jeffrey Charles Co-authored-by: Jeff Hodges Co-authored-by: Daniel Montoya Co-authored-by: Justin Li Co-authored-by: Dave Stubbs Co-authored-by: Linh Tran Tuan Co-authored-by: Robert R. Russell Co-authored-by: INADA Naoki Co-authored-by: Kieron Woodhouse Co-authored-by: Alexey Palazhchenko Co-authored-by: Reed Allman Co-authored-by: Joe Mann --- .github/workflows/test.yml | 2 +- AUTHORS | 2 + README.md | 11 ++ benchmark_test.go | 28 +++-- buffer.go | 26 ++--- compress.go | 214 +++++++++++++++++++++++++++++++++++++ compress_test.go | 119 +++++++++++++++++++++ connection.go | 43 +++++++- connection_test.go | 14 +-- connector.go | 10 +- const.go | 2 + driver_test.go | 45 ++++++-- dsn.go | 24 ++++- infile.go | 1 + packets.go | 113 ++++++++++++-------- packets_test.go | 24 +++-- utils.go | 12 +++ 17 files changed, 581 insertions(+), 109 deletions(-) create mode 100644 compress.go create mode 100644 compress_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b1c1f2b34..2e07fea91 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -83,7 +83,7 @@ jobs: my-cnf: | innodb_log_file_size=256MB innodb_buffer_pool_size=512MB - max_allowed_packet=16MB + max_allowed_packet=48MB ; TestConcurrent fails if max_connections is too large max_connections=50 local_infile=1 diff --git a/AUTHORS b/AUTHORS index 361c6b647..cbcc90f51 100644 --- a/AUTHORS +++ b/AUTHORS @@ -21,6 +21,7 @@ Animesh Ray Arne Hormann Ariel Mashraki Asta Xie +B Lamarche Brian Hendriks Bulat Gaifullin Caine Jette @@ -62,6 +63,7 @@ Jennifer Purevsuren Jerome Meyer Jiajia Zhong Jian Zhen +Joe Mann Joshua Prunier Julien Lefevre Julien Schmidt diff --git a/README.md b/README.md index e9d9222ba..da4593ccf 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Secure `LOAD DATA LOCAL INFILE` support with file allowlisting and `io.Reader` support * Optional `time.Time` parsing * Optional placeholder interpolation + * Supports zlib compression. ## Requirements @@ -267,6 +268,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `compress` + +``` +Type: bool +Valid Values: true, false +Default: false +``` + +Toggles zlib compression. false by default. + ##### `interpolateParams` ``` diff --git a/benchmark_test.go b/benchmark_test.go index a4ecc0a63..5c9a046b5 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -46,9 +46,13 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { return stmt } -func initDB(b *testing.B, queries ...string) *sql.DB { +func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { tb := (*TB)(b) - db := tb.checkDB(sql.Open(driverNameTest, dsn)) + comprStr := "" + if useCompression { + comprStr = "&compress=1" + } + db := tb.checkDB(sql.Open(driverNameTest, dsn+comprStr)) for _, query := range queries { if _, err := db.Exec(query); err != nil { b.Fatalf("error on %q: %v", query, err) @@ -60,10 +64,18 @@ func initDB(b *testing.B, queries ...string) *sql.DB { const concurrencyLevel = 10 func BenchmarkQuery(b *testing.B) { + benchmarkQueryHelper(b, false) +} + +func BenchmarkQueryCompression(b *testing.B) { + benchmarkQueryHelper(b, true) +} + +func benchmarkQueryHelper(b *testing.B, compr bool) { tb := (*TB)(b) b.StopTimer() b.ReportAllocs() - db := initDB(b, + db := initDB(b, compr, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -222,7 +234,7 @@ func BenchmarkInterpolation(b *testing.B) { }, maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, - buf: newBuffer(nil), + buf: newBuffer(), } args := []driver.Value{ @@ -269,7 +281,7 @@ func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkQueryContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -305,7 +317,7 @@ func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { } func BenchmarkExecContext(b *testing.B) { - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", `INSERT INTO foo VALUES (1, "one")`, @@ -323,7 +335,7 @@ func BenchmarkExecContext(b *testing.B) { // "size=" means size of each blobs. func BenchmarkQueryRawBytes(b *testing.B) { var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS bench_rawbytes", "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", ) @@ -376,7 +388,7 @@ func BenchmarkQueryRawBytes(b *testing.B) { // BenchmarkReceiveMassiveRows measures performance of receiving large number of rows. func BenchmarkReceiveMassiveRows(b *testing.B) { // Setup -- prepare 10000 rows. - db := initDB(b, + db := initDB(b, false, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val TEXT)") defer db.Close() diff --git a/buffer.go b/buffer.go index dd82c9313..a65324315 100644 --- a/buffer.go +++ b/buffer.go @@ -10,13 +10,16 @@ package mysql import ( "io" - "net" - "time" ) const defaultBufSize = 4096 const maxCachedBufSize = 256 * 1024 +// readerFunc is a function that compatible with io.Reader. +// We use this function type instead of io.Reader because we want to +// just pass mc.readWithTimeout. +type readerFunc func([]byte) (int, error) + // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. @@ -25,15 +28,12 @@ const maxCachedBufSize = 256 * 1024 type buffer struct { buf []byte // read buffer. cachedBuf []byte // buffer that will be reused. len(cachedBuf) <= maxCachedBufSize. - nc net.Conn - timeout time.Duration } // newBuffer allocates and returns a new buffer. -func newBuffer(nc net.Conn) buffer { +func newBuffer() buffer { return buffer{ cachedBuf: make([]byte, defaultBufSize), - nc: nc, } } @@ -43,7 +43,7 @@ func (b *buffer) busy() bool { } // fill reads into the read buffer until at least _need_ bytes are in it. -func (b *buffer) fill(need int) error { +func (b *buffer) fill(need int, r readerFunc) error { // we'll move the contents of the current buffer to dest before filling it. dest := b.cachedBuf @@ -64,13 +64,7 @@ func (b *buffer) fill(need int) error { copy(dest[:n], b.buf) for { - if b.timeout > 0 { - if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { - return err - } - } - - nn, err := b.nc.Read(dest[n:]) + nn, err := r(dest[n:]) n += nn if err == nil && n < need { @@ -92,10 +86,10 @@ func (b *buffer) fill(need int) error { // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int) ([]byte, error) { +func (b *buffer) readNext(need int, r readerFunc) ([]byte, error) { if len(b.buf) < need { // refill - if err := b.fill(need); err != nil { + if err := b.fill(need, r); err != nil { return nil, err } } diff --git a/compress.go b/compress.go new file mode 100644 index 000000000..fa42772ac --- /dev/null +++ b/compress.go @@ -0,0 +1,214 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sync" +) + +var ( + zrPool *sync.Pool // Do not use directly. Use zDecompress() instead. + zwPool *sync.Pool // Do not use directly. Use zCompress() instead. +) + +func init() { + zrPool = &sync.Pool{ + New: func() any { return nil }, + } + zwPool = &sync.Pool{ + New: func() any { + zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2) + if err != nil { + panic(err) // compress/zlib return non-nil error only if level is invalid + } + return zw + }, + } +} + +func zDecompress(src []byte, dst *bytes.Buffer) (int, error) { + br := bytes.NewReader(src) + var zr io.ReadCloser + var err error + + if a := zrPool.Get(); a == nil { + if zr, err = zlib.NewReader(br); err != nil { + return 0, err + } + } else { + zr = a.(io.ReadCloser) + if err := zr.(zlib.Resetter).Reset(br, nil); err != nil { + return 0, err + } + } + + n, _ := dst.ReadFrom(zr) // ignore err because zr.Close() will return it again. + err = zr.Close() // zr.Close() may return chuecksum error. + zrPool.Put(zr) + return int(n), err +} + +func zCompress(src []byte, dst io.Writer) error { + zw := zwPool.Get().(*zlib.Writer) + zw.Reset(dst) + if _, err := zw.Write(src); err != nil { + return err + } + err := zw.Close() + zwPool.Put(zw) + return err +} + +type compIO struct { + mc *mysqlConn + buff bytes.Buffer +} + +func newCompIO(mc *mysqlConn) *compIO { + return &compIO{ + mc: mc, + } +} + +func (c *compIO) reset() { + c.buff.Reset() +} + +func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) { + for c.buff.Len() < need { + if err := c.readCompressedPacket(r); err != nil { + return nil, err + } + } + data := c.buff.Next(need) + return data[:need:need], nil // prevent caller writes into c.buff +} + +func (c *compIO) readCompressedPacket(r readerFunc) error { + header, err := c.mc.buf.readNext(7, r) // size of compressed header + if err != nil { + return err + } + _ = header[6] // bounds check hint to compiler; guaranteed by readNext + + // compressed header structure + comprLength := getUint24(header[0:3]) + compressionSequence := uint8(header[3]) + uncompressedLength := getUint24(header[4:7]) + if debug { + fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", + comprLength, uncompressedLength, compressionSequence, c.mc.sequence) + } + // Do not return ErrPktSync here. + // Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes) + // before receiving all packets from client. In this case, seqnr is younger than expected. + // NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it. + if debug && compressionSequence != c.mc.sequence { + fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v", + c.mc.sequence, compressionSequence) + } + c.mc.sequence = compressionSequence + 1 + c.mc.compressSequence = c.mc.sequence + + comprData, err := c.mc.buf.readNext(comprLength, r) + if err != nil { + return err + } + + // if payload is uncompressed, its length will be specified as zero, and its + // true length is contained in comprLength + if uncompressedLength == 0 { + c.buff.Write(comprData) + return nil + } + + // use existing capacity in bytesBuf if possible + c.buff.Grow(uncompressedLength) + nread, err := zDecompress(comprData, &c.buff) + if err != nil { + return err + } + if nread != uncompressedLength { + return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d", + uncompressedLength, nread) + } + return nil +} + +const minCompressLength = 150 +const maxPayloadLen = maxPacketSize - 4 + +// writePackets sends one or some packets with compression. +// Use this instead of mc.netConn.Write() when mc.compress is true. +func (c *compIO) writePackets(packets []byte) (int, error) { + totalBytes := len(packets) + blankHeader := make([]byte, 7) + buf := &c.buff + + for len(packets) > 0 { + payloadLen := min(maxPayloadLen, len(packets)) + payload := packets[:payloadLen] + uncompressedLen := payloadLen + + buf.Reset() + buf.Write(blankHeader) // Buffer.Write() never returns error + + // If payload is less than minCompressLength, don't compress. + if uncompressedLen < minCompressLength { + buf.Write(payload) + uncompressedLen = 0 + } else { + err := zCompress(payload, buf) + if debug && err != nil { + fmt.Printf("zCompress error: %v", err) + } + // do not compress if compressed data is larger than uncompressed data + // I intentionally miss 7 byte header in the buf; zCompress must compress more than 7 bytes. + if err != nil || buf.Len() >= uncompressedLen { + buf.Reset() + buf.Write(blankHeader) + buf.Write(payload) + uncompressedLen = 0 + } + } + + if n, err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil { + // To allow returning ErrBadConn when sending really 0 bytes, we sum + // up compressed bytes that is returned by underlying Write(). + return totalBytes - len(packets) + n, err + } + packets = packets[payloadLen:] + } + + return totalBytes, nil +} + +// writeCompressedPacket writes a compressed packet with header. +// data should start with 7 size space for header followed by payload. +func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, error) { + mc := c.mc + comprLength := len(data) - 7 + if debug { + fmt.Printf( + "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v", + comprLength, uncompressedLen, mc.compressSequence) + } + + // compression header + putUint24(data[0:3], comprLength) + data[3] = mc.compressSequence + putUint24(data[4:7], uncompressedLen) + + mc.compressSequence++ + return mc.writeWithTimeout(data) +} diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 000000000..030deaefa --- /dev/null +++ b/compress_test.go @@ -0,0 +1,119 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2024 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/rand" + "io" + "testing" +) + +func makeRandByteSlice(size int) []byte { + randBytes := make([]byte, size) + rand.Read(randBytes) + return randBytes +} + +// compressHelper compresses uncompressedPacket and checks state variables +func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte { + conn := new(mockConn) + mc.netConn = conn + + err := mc.writePacket(append(make([]byte, 4), uncompressedPacket...)) + if err != nil { + t.Fatal(err) + } + + return conn.written +} + +// uncompressHelper uncompresses compressedPacket and checks state variables +func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte) []byte { + // mocking out buf variable + conn := new(mockConn) + conn.data = compressedPacket + mc.netConn = conn + + uncompressedPacket, err := mc.readPacket() + if err != nil { + if err != io.EOF { + t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error()) + } + } + return uncompressedPacket +} + +// roundtripHelper compresses then uncompresses uncompressedPacket and checks state variables +func roundtripHelper(t *testing.T, cSend *mysqlConn, cReceive *mysqlConn, uncompressedPacket []byte) []byte { + compressed := compressHelper(t, cSend, uncompressedPacket) + return uncompressHelper(t, cReceive, compressed) +} + +// TestRoundtrip tests two connections, where one is reading and the other is writing +func TestRoundtrip(t *testing.T) { + tests := []struct { + uncompressed []byte + desc string + }{ + {uncompressed: []byte("a"), + desc: "a"}, + {uncompressed: []byte("hello world"), + desc: "hello world"}, + {uncompressed: make([]byte, 100), + desc: "100 bytes"}, + {uncompressed: make([]byte, 32768), + desc: "32768 bytes"}, + {uncompressed: make([]byte, 330000), + desc: "33000 bytes"}, + {uncompressed: makeRandByteSlice(10), + desc: "10 rand bytes", + }, + {uncompressed: makeRandByteSlice(100), + desc: "100 rand bytes", + }, + {uncompressed: makeRandByteSlice(32768), + desc: "32768 rand bytes", + }, + {uncompressed: bytes.Repeat(makeRandByteSlice(100), 10000), + desc: "100 rand * 10000 repeat bytes", + }, + } + + _, cSend := newRWMockConn(0) + cSend.compress = true + cSend.compIO = newCompIO(cSend) + _, cReceive := newRWMockConn(0) + cReceive.compress = true + cReceive.compIO = newCompIO(cReceive) + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + cSend.resetSequence() + cReceive.resetSequence() + + uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed) + if len(uncompressed) != len(test.uncompressed) { + t.Errorf("uncompressed size is unexpected. expected %d but got %d", + len(test.uncompressed), len(uncompressed)) + } + if !bytes.Equal(uncompressed, test.uncompressed) { + t.Errorf("roundtrip failed") + } + if cSend.sequence != cReceive.sequence { + t.Errorf("inconsistent sequence number: send=%v recv=%v", + cSend.sequence, cReceive.sequence) + } + if cSend.compressSequence != cReceive.compressSequence { + t.Errorf("inconsistent compress sequence number: send=%v recv=%v", + cSend.compressSequence, cReceive.compressSequence) + } + }) + } +} diff --git a/connection.go b/connection.go index acc627086..3e455a3ff 100644 --- a/connection.go +++ b/connection.go @@ -28,15 +28,17 @@ type mysqlConn struct { netConn net.Conn rawConn net.Conn // underlying connection when netConn is TLS connection. result mysqlResult // managed by clearResult() and handleOkPacket(). + compIO *compIO cfg *Config connector *connector maxAllowedPacket int maxWriteSize int - writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 + compressSequence uint8 parseTime bool + compress bool // for context support (Go 1.8+) watching bool @@ -62,6 +64,43 @@ func (mc *mysqlConn) log(v ...any) { mc.cfg.Logger.Print(v...) } +func (mc *mysqlConn) readWithTimeout(b []byte) (int, error) { + to := mc.cfg.ReadTimeout + if to > 0 { + if err := mc.netConn.SetReadDeadline(time.Now().Add(to)); err != nil { + return 0, err + } + } + return mc.netConn.Read(b) +} + +func (mc *mysqlConn) writeWithTimeout(b []byte) (int, error) { + to := mc.cfg.WriteTimeout + if to > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(to)); err != nil { + return 0, err + } + } + return mc.netConn.Write(b) +} + +func (mc *mysqlConn) resetSequence() { + mc.sequence = 0 + mc.compressSequence = 0 +} + +// syncSequence must be called when finished writing some packet and before start reading. +func (mc *mysqlConn) syncSequence() { + // Syncs compressionSequence to sequence. + // This is not documented but done in `net_flush()` in MySQL and MariaDB. + // https://github.com/mariadb-corporation/mariadb-connector-c/blob/8228164f850b12353da24df1b93a1e53cc5e85e9/libmariadb/ma_net.c#L170-L171 + // https://github.com/mysql/mysql-server/blob/824e2b4064053f7daf17d7f3f84b7a3ed92e5fb4/sql-common/net_serv.cc#L293 + if mc.compress { + mc.sequence = mc.compressSequence + mc.compIO.reset() + } +} + // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder @@ -147,7 +186,7 @@ func (mc *mysqlConn) cleanup() { return } if err := conn.Close(); err != nil { - mc.log(err) + mc.log("closing connection:", err) } // This function can be called from multiple goroutines. // So we can not mc.clearResult() here. diff --git a/connection_test.go b/connection_test.go index 6f8d2a6d7..f7740898e 100644 --- a/connection_test.go +++ b/connection_test.go @@ -19,7 +19,7 @@ import ( func TestInterpolateParams(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -39,7 +39,7 @@ func TestInterpolateParams(t *testing.T) { func TestInterpolateParamsJSONRawMessage(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -66,7 +66,7 @@ func TestInterpolateParamsJSONRawMessage(t *testing.T) { func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -83,7 +83,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { // https://github.com/go-sql-driver/mysql/pull/490 func TestInterpolateParamsPlaceholderInString(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -99,7 +99,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) { func TestInterpolateParamsUint64(t *testing.T) { mc := &mysqlConn{ - buf: newBuffer(nil), + buf: newBuffer(), maxAllowedPacket: maxPacketSize, cfg: &Config{ InterpolateParams: true, @@ -161,7 +161,7 @@ func TestPingMarkBadConnection(t *testing.T) { nc := badConnection{err: errors.New("boom")} mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: newBuffer(), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), @@ -178,7 +178,7 @@ func TestPingErrInvalidConn(t *testing.T) { nc := badConnection{err: errors.New("failed to write"), n: 10} mc := &mysqlConn{ netConn: nc, - buf: newBuffer(nc), + buf: newBuffer(), maxAllowedPacket: defaultMaxAllowedPacket, closech: make(chan struct{}), cfg: NewConfig(), diff --git a/connector.go b/connector.go index 769b3adc9..a4f3655ef 100644 --- a/connector.go +++ b/connector.go @@ -127,11 +127,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } defer mc.finish() - mc.buf = newBuffer(mc.netConn) - - // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout + mc.buf = newBuffer() // Reading Handshake Initialization Packet authData, plugin, err := mc.readHandshakePacket() @@ -170,6 +166,10 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + mc.compress = true + mc.compIO = newCompIO(mc) + } if mc.cfg.MaxAllowedPacket > 0 { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { diff --git a/const.go b/const.go index 0cee9b2ee..4aadcd642 100644 --- a/const.go +++ b/const.go @@ -11,6 +11,8 @@ package mysql import "runtime" const ( + debug = false // for debugging. Set true only in development. + defaultAuthPlugin = "mysql_native_password" defaultMaxAllowedPacket = 64 << 20 // 64 MiB. See https://github.com/go-sql-driver/mysql/issues/1355 minProtocolVersion = 10 diff --git a/driver_test.go b/driver_test.go index 24d73c34f..58b3cb38d 100644 --- a/driver_test.go +++ b/driver_test.go @@ -147,12 +147,11 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { db, err := sql.Open(driverNameTest, dsn) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn, err) } defer db.Close() - - cleanup := func() { - db.Exec("DROP TABLE IF EXISTS test") + if err = db.Ping(); err != nil { + t.Fatalf("connecting %q: %s", dsn, err) } dsn2 := dsn + "&interpolateParams=true" @@ -160,25 +159,46 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { db2, err = sql.Open(driverNameTest, dsn2) if err != nil { - t.Fatalf("error connecting: %s", err.Error()) + t.Fatalf("connecting %q: %s", dsn2, err) } defer db2.Close() } + dsn3 := dsn + "&compress=true" + var db3 *sql.DB + db3, err = sql.Open(driverNameTest, dsn3) + if err != nil { + t.Fatalf("connecting %q: %s", dsn3, err) + } + defer db3.Close() + + cleanupSql := "DROP TABLE IF EXISTS test" + for _, test := range tests { test := test t.Run("default", func(t *testing.T) { dbt := &DBTest{t, db} - t.Cleanup(cleanup) + t.Cleanup(func() { + db.Exec(cleanupSql) + }) test(dbt) }) if db2 != nil { t.Run("interpolateParams", func(t *testing.T) { dbt2 := &DBTest{t, db2} - t.Cleanup(cleanup) + t.Cleanup(func() { + db2.Exec(cleanupSql) + }) test(dbt2) }) } + t.Run("compress", func(t *testing.T) { + dbt3 := &DBTest{t, db3} + t.Cleanup(func() { + db3.Exec(cleanupSql) + }) + test(dbt3) + }) } } @@ -958,12 +978,16 @@ func TestDateTime(t *testing.T) { var err error rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) if err == nil { - rows.Scan(µsecsSupported) + if rows.Next() { + rows.Scan(µsecsSupported) + } rows.Close() } rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) if err == nil { - rows.Scan(&zeroDateSupported) + if rows.Next() { + rows.Scan(&zeroDateSupported) + } rows.Close() } for _, setups := range testcases { @@ -1265,8 +1289,7 @@ func TestLongData(t *testing.T) { var rows *sql.Rows // Long text data - const nonDataQueryLen = 28 // length query w/o value - inS := in[:maxAllowedPacketSize-nonDataQueryLen] + inS := in[:maxAllowedPacketSize-100] dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") rows = dbt.mustQuery("SELECT value FROM test") defer rows.Close() diff --git a/dsn.go b/dsn.go index f391a8fc9..9b560b735 100644 --- a/dsn.go +++ b/dsn.go @@ -73,7 +73,10 @@ type Config struct { ParseTime bool // Parse time values to time.Time RejectReadOnly bool // Reject read-only connections - // unexported fields. new options should be come here + // unexported fields. new options should be come here. + // boolean first. alphabetical order. + + compress bool // Enable zlib compression beforeConnect func(context.Context, *Config) error // Invoked before a connection is established pubKey *rsa.PublicKey // Server public key @@ -93,7 +96,6 @@ func NewConfig() *Config { AllowNativePasswords: true, CheckConnLiveness: true, } - return cfg } @@ -125,6 +127,14 @@ func BeforeConnect(fn func(context.Context, *Config) error) Option { } } +// EnableCompress sets the compression mode. +func EnableCompression(yes bool) Option { + return func(cfg *Config) error { + cfg.compress = yes + return nil + } +} + func (cfg *Config) Clone() *Config { cp := *cfg if cp.TLS != nil { @@ -297,6 +307,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") } + if cfg.compress { + writeDSNParam(&buf, &hasParam, "compress", "true") + } + if cfg.InterpolateParams { writeDSNParam(&buf, &hasParam, "interpolateParams", "true") } @@ -525,7 +539,11 @@ func parseDSNParams(cfg *Config, params string) (err error) { // Compression case "compress": - return errors.New("compression not implemented yet") + var isBool bool + cfg.compress, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } // Enable client side placeholder substitution case "interpolateParams": diff --git a/infile.go b/infile.go index cf892beae..555ef71ad 100644 --- a/infile.go +++ b/infile.go @@ -172,6 +172,7 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { if ioErr := mc.conn().writePacket(data[:4]); ioErr != nil { return ioErr } + mc.conn().syncSequence() // read OK packet if err == nil { diff --git a/packets.go b/packets.go index 736e4418c..e4d2820ed 100644 --- a/packets.go +++ b/packets.go @@ -28,9 +28,16 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte + invalidSequence := false + + readNext := mc.buf.readNext + if mc.compress { + readNext = mc.compIO.readNext + } + for { // read packet header - data, err := mc.buf.readNext(4) + data, err := readNext(4, mc.readWithTimeout) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -41,17 +48,29 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // packet length [24 bit] - pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - - // check packet sync [8 bit] - if data[3] != mc.sequence { - mc.close() - if data[3] > mc.sequence { - return nil, ErrPktSyncMul + pktLen := getUint24(data[:3]) + seq := data[3] + + if mc.compress { + // MySQL and MariaDB doesn't check packet nr in compressed packet. + if debug && seq != mc.compressSequence { + fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v", + mc.compressSequence, seq) + } + mc.compressSequence = seq + 1 + } else { + // check packet sync [8 bit] + if seq != mc.sequence { + mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq)) + // For large packets, we stop reading as soon as sync error. + if len(prevData) > 0 { + mc.close() + return nil, ErrPktSyncMul + } + invalidSequence = true } - return nil, ErrPktSync + mc.sequence++ } - mc.sequence++ // packets with length 0 terminate a previous packet which is a // multiple of (2^24)-1 bytes long @@ -62,12 +81,11 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.close() return nil, ErrInvalidConn } - return prevData, nil } // read packet body [pktLen bytes] - data, err = mc.buf.readNext(pktLen) + data, err = readNext(pktLen, mc.readWithTimeout) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -80,11 +98,18 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // return data if this was the last packet if pktLen < maxPacketSize { // zero allocations for non-split packets - if prevData == nil { - return data, nil + if prevData != nil { + data = append(prevData, data...) } - - return append(prevData, data...), nil + if invalidSequence { + mc.close() + // return sync error only for regular packet. + // error packets may have wrong sequence number. + if data[0] != iERR { + return nil, ErrPktSync + } + } + return data, nil } prevData = append(prevData, data...) @@ -94,36 +119,26 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // Write packet buffer 'data' func (mc *mysqlConn) writePacket(data []byte) error { pktLen := len(data) - 4 - if pktLen > mc.maxAllowedPacket { return ErrPktTooLarge } + writeFunc := mc.writeWithTimeout + if mc.compress { + writeFunc = mc.compIO.writePackets + } + for { - var size int - if pktLen >= maxPacketSize { - data[0] = 0xff - data[1] = 0xff - data[2] = 0xff - size = maxPacketSize - } else { - data[0] = byte(pktLen) - data[1] = byte(pktLen >> 8) - data[2] = byte(pktLen >> 16) - size = pktLen - } + size := min(maxPacketSize, pktLen) + putUint24(data[:3], size) data[3] = mc.sequence // Write packet - if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { - mc.cleanup() - mc.log(err) - return err - } + if debug { + fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence) } - n, err := mc.netConn.Write(data[:4+size]) + n, err := writeFunc(data[:4+size]) if err != nil { mc.cleanup() if cerr := mc.canceled.Value(); cerr != nil { @@ -267,7 +282,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string if mc.cfg.ClientFoundRows { clientFlags |= clientFoundRows } - + if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + clientFlags |= clientCompress + } // To enable TLS / SSL if mc.cfg.TLS != nil { clientFlags |= clientSSL @@ -358,7 +375,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return err } mc.netConn = tlsConn - mc.buf.nc = tlsConn } // User [null terminated string] @@ -413,7 +429,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequence() data, err := mc.buf.takeSmallBuffer(4 + 1) if err != nil { @@ -429,7 +445,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequence() pktLen := 1 + len(arg) data, err := mc.buf.takeBuffer(pktLen + 4) @@ -444,12 +460,14 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequence() + return err } func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence - mc.sequence = 0 + mc.resetSequence() data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) if err != nil { @@ -932,7 +950,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { pktLen = dataOffset + argLen } - stmt.mc.sequence = 0 + stmt.mc.resetSequence() // Add command byte [1 byte] data[4] = comStmtSendLongData @@ -953,11 +971,10 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { continue } return err - } // Reset Packet Sequence - stmt.mc.sequence = 0 + stmt.mc.resetSequence() return nil } @@ -982,7 +999,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } // Reset packet-sequence - mc.sequence = 0 + mc.resetSequence() var data []byte var err error @@ -1198,7 +1215,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data = data[:pos] } - return mc.writePacket(data) + err = mc.writePacket(data) + mc.syncSequence() + return err } // For each remaining resultset in the stream, discards its rows and updates diff --git a/packets_test.go b/packets_test.go index fa4683eab..694b0564c 100644 --- a/packets_test.go +++ b/packets_test.go @@ -98,7 +98,7 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) connector := newConnector(NewConfig()) mc := &mysqlConn{ - buf: newBuffer(conn), + buf: newBuffer(), cfg: connector.cfg, connector: connector, netConn: conn, @@ -112,7 +112,9 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { func TestReadPacketSingleByte(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), + cfg: NewConfig(), } conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} @@ -143,12 +145,12 @@ func TestReadPacketWrongSequenceID(t *testing.T) { { ClientSequenceID: 0, ServerSequenceID: 0x42, - ExpectedErr: ErrPktSyncMul, + ExpectedErr: ErrPktSync, }, } { conn, mc := newRWMockConn(testCase.ClientSequenceID) - conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0xff} + conn.data = []byte{0x01, 0x00, 0x00, testCase.ServerSequenceID, 0x22} _, err := mc.readPacket() if err != testCase.ExpectedErr { t.Errorf("expected %v, got %v", testCase.ExpectedErr, err) @@ -164,7 +166,9 @@ func TestReadPacketWrongSequenceID(t *testing.T) { func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), + cfg: NewConfig(), } data := make([]byte, maxPacketSize*2+4*3) @@ -269,7 +273,8 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), closech: make(chan struct{}), cfg: NewConfig(), } @@ -285,7 +290,7 @@ func TestReadPacketFail(t *testing.T) { // reset conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + mc.buf = newBuffer() // fail to read header conn.closed = true @@ -298,7 +303,7 @@ func TestReadPacketFail(t *testing.T) { conn.closed = false conn.reads = 0 mc.sequence = 0 - mc.buf = newBuffer(conn) + mc.buf = newBuffer() // fail to read body conn.maxReads = 1 @@ -313,7 +318,8 @@ func TestReadPacketFail(t *testing.T) { func TestRegression801(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + netConn: conn, + buf: newBuffer(), cfg: new(Config), sequence: 42, closech: make(chan struct{}), diff --git a/utils.go b/utils.go index cda24fe74..d902f3b60 100644 --- a/utils.go +++ b/utils.go @@ -490,6 +490,18 @@ func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { * Convert from and to bytes * ******************************************************************************/ +// 24bit integer: used for packet headers. + +func putUint24(data []byte, n int) { + data[2] = byte(n >> 16) + data[1] = byte(n >> 8) + data[0] = byte(n) +} + +func getUint24(data []byte) int { + return int(data[2])<<16 | int(data[1])<<8 | int(data[0]) +} + func uint64ToBytes(n uint64) []byte { return []byte{ byte(n), From b335ed33d6a10803949fb71bbd7e0974c5be38b2 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 20 Dec 2024 14:13:32 +0900 Subject: [PATCH 100/123] use binary.LittleEndian (#1651) Recent Go does inlinine functions well. Using `LittleEndian.Put*` would better for readability and minimize bound check. Additionally, Go 1.19 introduced `LittleEndian.Append*`. It reduce more code. --- packets.go | 72 ++++++++---------------------------------------------- utils.go | 29 ++++++---------------- 2 files changed, 17 insertions(+), 84 deletions(-) diff --git a/packets.go b/packets.go index e4d2820ed..f3860c5f8 100644 --- a/packets.go +++ b/packets.go @@ -329,16 +329,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // ClientFlags [32 bit] - data[4] = byte(clientFlags) - data[5] = byte(clientFlags >> 8) - data[6] = byte(clientFlags >> 16) - data[7] = byte(clientFlags >> 24) + binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags)) // MaxPacketSize [32 bit] (none) - data[8] = 0x00 - data[9] = 0x00 - data[10] = 0x00 - data[11] = 0x00 + binary.LittleEndian.PutUint32(data[8:], 0) // Collation ID [1 byte] data[12] = defaultCollationID @@ -478,10 +472,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data[4] = command // Add arg [32 bit] - data[5] = byte(arg) - data[6] = byte(arg >> 8) - data[7] = byte(arg >> 16) - data[8] = byte(arg >> 24) + binary.LittleEndian.PutUint32(data[5:], arg) // Send CMD packet return mc.writePacket(data) @@ -955,14 +946,10 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { data[4] = comStmtSendLongData // Add stmtID [32 bit] - data[5] = byte(stmt.id) - data[6] = byte(stmt.id >> 8) - data[7] = byte(stmt.id >> 16) - data[8] = byte(stmt.id >> 24) + binary.LittleEndian.PutUint32(data[5:], stmt.id) // Add paramID [16 bit] - data[9] = byte(paramID) - data[10] = byte(paramID >> 8) + binary.LittleEndian.PutUint16(data[9:], uint16(paramID)) // Send CMD packet err := stmt.mc.writePacket(data[:4+pktLen]) @@ -1018,19 +1005,13 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data[4] = comStmtExecute // statement_id [4 bytes] - data[5] = byte(stmt.id) - data[6] = byte(stmt.id >> 8) - data[7] = byte(stmt.id >> 16) - data[8] = byte(stmt.id >> 24) + binary.LittleEndian.PutUint32(data[5:], stmt.id) // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] data[9] = 0x00 // iteration_count (uint32(1)) [4 bytes] - data[10] = 0x01 - data[11] = 0x00 - data[12] = 0x00 - data[13] = 0x00 + binary.LittleEndian.PutUint32(data[10:], 1) if len(args) > 0 { pos := minPktLen @@ -1084,50 +1065,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case int64: paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 - - if cap(paramValues)-len(paramValues)-8 >= 0 { - paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64( - paramValues[len(paramValues)-8:], - uint64(v), - ) - } else { - paramValues = append(paramValues, - uint64ToBytes(uint64(v))..., - ) - } + paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v)) case uint64: paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x80 // type is unsigned - - if cap(paramValues)-len(paramValues)-8 >= 0 { - paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64( - paramValues[len(paramValues)-8:], - uint64(v), - ) - } else { - paramValues = append(paramValues, - uint64ToBytes(uint64(v))..., - ) - } + paramValues = binary.LittleEndian.AppendUint64(paramValues, uint64(v)) case float64: paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 - - if cap(paramValues)-len(paramValues)-8 >= 0 { - paramValues = paramValues[:len(paramValues)+8] - binary.LittleEndian.PutUint64( - paramValues[len(paramValues)-8:], - math.Float64bits(v), - ) - } else { - paramValues = append(paramValues, - uint64ToBytes(math.Float64bits(v))..., - ) - } + paramValues = binary.LittleEndian.AppendUint64(paramValues, math.Float64bits(v)) case bool: paramTypes[i+i] = byte(fieldTypeTiny) diff --git a/utils.go b/utils.go index d902f3b60..44f43ef7b 100644 --- a/utils.go +++ b/utils.go @@ -502,19 +502,6 @@ func getUint24(data []byte) int { return int(data[2])<<16 | int(data[1])<<8 | int(data[0]) } -func uint64ToBytes(n uint64) []byte { - return []byte{ - byte(n), - byte(n >> 8), - byte(n >> 16), - byte(n >> 24), - byte(n >> 32), - byte(n >> 40), - byte(n >> 48), - byte(n >> 56), - } -} - func uint64ToString(n uint64) []byte { var a [20]byte i := 20 @@ -598,18 +585,15 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { // 252: value of following 2 case 0xfc: - return uint64(b[1]) | uint64(b[2])<<8, false, 3 + return uint64(binary.LittleEndian.Uint16(b[1:])), false, 3 // 253: value of following 3 case 0xfd: - return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 + return uint64(getUint24(b[1:])), false, 4 // 254: value of following 8 case 0xfe: - return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | - uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | - uint64(b[7])<<48 | uint64(b[8])<<56, - false, 9 + return uint64(binary.LittleEndian.Uint64(b[1:])), false, 9 } // 0-250: value of first byte @@ -623,13 +607,14 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte { return append(b, byte(n)) case n <= 0xffff: - return append(b, 0xfc, byte(n), byte(n>>8)) + b = append(b, 0xfc) + return binary.LittleEndian.AppendUint16(b, uint16(n)) case n <= 0xffffff: return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) } - return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), - byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) + b = append(b, 0xfe) + return binary.LittleEndian.AppendUint64(b, n) } func appendLengthEncodedString(b []byte, s string) []byte { From 7403860363ca112af503b4612568c3096fecb466 Mon Sep 17 00:00:00 2001 From: Artur Melanchyk Date: Tue, 24 Dec 2024 05:10:11 +0100 Subject: [PATCH 101/123] Make fileRegister a set (#1653) --- AUTHORS | 1 + infile.go | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/AUTHORS b/AUTHORS index cbcc90f51..a38395797 100644 --- a/AUTHORS +++ b/AUTHORS @@ -20,6 +20,7 @@ Andrew Reid Animesh Ray Arne Hormann Ariel Mashraki +Artur Melanchyk Asta Xie B Lamarche Brian Hendriks diff --git a/infile.go b/infile.go index 555ef71ad..453ae091e 100644 --- a/infile.go +++ b/infile.go @@ -17,7 +17,7 @@ import ( ) var ( - fileRegister map[string]bool + fileRegister map[string]struct{} fileRegisterLock sync.RWMutex readerRegister map[string]func() io.Reader readerRegisterLock sync.RWMutex @@ -37,10 +37,10 @@ func RegisterLocalFile(filePath string) { fileRegisterLock.Lock() // lazy map init if fileRegister == nil { - fileRegister = make(map[string]bool) + fileRegister = make(map[string]struct{}) } - fileRegister[strings.Trim(filePath, `"`)] = true + fileRegister[strings.Trim(filePath, `"`)] = struct{}{} fileRegisterLock.Unlock() } @@ -123,9 +123,9 @@ func (mc *okHandler) handleInFileRequest(name string) (err error) { } else { // File name = strings.Trim(name, `"`) fileRegisterLock.RLock() - fr := fileRegister[name] + _, exists := fileRegister[name] fileRegisterLock.RUnlock() - if mc.cfg.AllowAllFiles || fr { + if mc.cfg.AllowAllFiles || exists { var file *os.File var fi os.FileInfo From 255d1ad98f1d3be99661d2a8c0a7a91418acbc8d Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Wed, 22 Jan 2025 14:59:24 +0900 Subject: [PATCH 102/123] better max_allowed_packet parsing (#1661) Remove `stringToInt()` and use `strconv.Atoi` instead. --- connector.go | 8 +++++++- utils.go | 10 ---------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/connector.go b/connector.go index a4f3655ef..bc1d46afc 100644 --- a/connector.go +++ b/connector.go @@ -11,6 +11,7 @@ package mysql import ( "context" "database/sql/driver" + "fmt" "net" "os" "strconv" @@ -179,7 +180,12 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.Close() return nil, err } - mc.maxAllowedPacket = stringToInt(maxap) - 1 + n, err := strconv.Atoi(string(maxap)) + if err != nil { + mc.Close() + return nil, fmt.Errorf("invalid max_allowed_packet value (%q): %w", maxap, err) + } + mc.maxAllowedPacket = n - 1 } if mc.maxAllowedPacket < maxPacketSize { mc.maxWriteSize = mc.maxAllowedPacket diff --git a/utils.go b/utils.go index 44f43ef7b..8716c26c5 100644 --- a/utils.go +++ b/utils.go @@ -524,16 +524,6 @@ func uint64ToString(n uint64) []byte { return a[i:] } -// treats string value as unsigned integer representation -func stringToInt(b []byte) int { - val := 0 - for i := range b { - val *= 10 - val += int(b[i] - 0x30) - } - return val -} - // returns the string read as a bytes slice, whether the value is NULL, // the number of bytes read and an error, in case the string is longer than // the input slice From 85c6311943c82f1300077b2d0e94687106ab61e7 Mon Sep 17 00:00:00 2001 From: Kolbe Kegel Date: Thu, 23 Jan 2025 01:32:31 -0800 Subject: [PATCH 103/123] Add error 1290/ER_READ_ONLY_MODE to rejectReadOnly handling (#1660) --- packets.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packets.go b/packets.go index f3860c5f8..9951bdf80 100644 --- a/packets.go +++ b/packets.go @@ -574,7 +574,8 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) - if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { + // 1836: ER_READ_ONLY_MODE + if (errno == 1792 || errno == 1290 || errno == 1836) && mc.cfg.RejectReadOnly { // Oops; we are connected to a read-only connection, and won't be able // to issue any write statements. Since RejectReadOnly is configured, // we throw away this connection hoping this one would have write From 341a5a5246835b2ac4b8d36bb12a9dfad70663f4 Mon Sep 17 00:00:00 2001 From: Bes Dollma <143414965+bdollma-te@users.noreply.github.com> Date: Wed, 29 Jan 2025 07:59:01 +0200 Subject: [PATCH 104/123] Fix auth_switch_request packet handling auth_data contains last NUL. Fix #1666 Signed-off-by: Bes Dollma (bdollma) --- AUTHORS | 2 ++ auth_test.go | 24 ++++++++++++------------ packets.go | 3 +++ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/AUTHORS b/AUTHORS index a38395797..123b5dc50 100644 --- a/AUTHORS +++ b/AUTHORS @@ -23,6 +23,7 @@ Ariel Mashraki Artur Melanchyk Asta Xie B Lamarche +Bes Dollma Brian Hendriks Bulat Gaifullin Caine Jette @@ -146,4 +147,5 @@ PingCAP Inc. Pivotal Inc. Shattered Silicon Ltd. Stripe Inc. +ThousandEyes Zendesk Inc. diff --git a/auth_test.go b/auth_test.go index 8caed1fff..46e1e3b4e 100644 --- a/auth_test.go +++ b/auth_test.go @@ -734,9 +734,9 @@ func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { expectedReply := []byte{ // 1. Packet: Hash - 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, - 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, - 153, 9, 130, + 32, 0, 0, 3, 219, 72, 64, 97, 56, 197, 167, 203, 64, 236, 168, 80, 223, + 56, 103, 217, 196, 176, 124, 60, 253, 41, 195, 10, 205, 190, 177, 206, 63, + 118, 211, 69, } if !bytes.Equal(conn.written, expectedReply) { t.Errorf("got unexpected data: %v", conn.written) @@ -803,9 +803,9 @@ func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { expectedReplyPrefix := []byte{ // 1. Packet: Hash - 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, - 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, - 153, 9, 130, + 32, 0, 0, 3, 219, 72, 64, 97, 56, 197, 167, 203, 64, 236, 168, 80, 223, + 56, 103, 217, 196, 176, 124, 60, 253, 41, 195, 10, 205, 190, 177, 206, 63, + 118, 211, 69, // 2. Packet: Pub Key Request 1, 0, 0, 5, 2, @@ -848,9 +848,9 @@ func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { expectedReplyPrefix := []byte{ // 1. Packet: Hash - 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, - 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, - 153, 9, 130, + 32, 0, 0, 3, 219, 72, 64, 97, 56, 197, 167, 203, 64, 236, 168, 80, 223, + 56, 103, 217, 196, 176, 124, 60, 253, 41, 195, 10, 205, 190, 177, 206, 63, + 118, 211, 69, // 2. Packet: Encrypted Password 0, 1, 0, 5, // [changing bytes] @@ -891,9 +891,9 @@ func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { expectedReply := []byte{ // 1. Packet: Hash - 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, - 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, - 153, 9, 130, + 32, 0, 0, 3, 219, 72, 64, 97, 56, 197, 167, 203, 64, 236, 168, 80, 223, + 56, 103, 217, 196, 176, 124, 60, 253, 41, 195, 10, 205, 190, 177, 206, 63, + 118, 211, 69, // 2. Packet: Cleartext password 7, 0, 0, 5, 115, 101, 99, 114, 101, 116, 0, diff --git a/packets.go b/packets.go index 9951bdf80..4b8362160 100644 --- a/packets.go +++ b/packets.go @@ -510,6 +510,9 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { } plugin := string(data[1:pluginEndIndex]) authData := data[pluginEndIndex+1:] + if len(authData) > 0 && authData[len(authData)-1] == 0 { + authData = authData[:len(authData)-1] + } return authData, plugin, nil default: // Error otherwise From 5d1bb8a9cf03422554dd52abf5eba89b8ca11307 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 18 Feb 2025 12:05:50 +0900 Subject: [PATCH 105/123] fix flaky test. (#1663) TestIssue1567 fails by max_connections error. This makes our CI unhappy. https://github.com/go-sql-driver/mysql/actions/runs/12904961433/job/35984402310 --- driver_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/driver_test.go b/driver_test.go index 58b3cb38d..00e828657 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3608,6 +3608,12 @@ func runCallCommand(dbt *DBTest, query, name string) { func TestIssue1567(t *testing.T) { // enable TLS. runTests(t, dsn+"&tls=skip-verify", func(dbt *DBTest) { + var max int + err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + // disable connection pooling. // data race happens when new connection is created. dbt.db.SetMaxIdleConns(0) @@ -3627,6 +3633,9 @@ func TestIssue1567(t *testing.T) { if testing.Short() { count = 10 } + if count > max { + count = max + } for i := 0; i < count; i++ { timeout := time.Duration(mrand.Int63n(int64(rtt))) From 58941dd8a7888cf3d593d7bb182120e42168eac9 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 18 Feb 2025 12:37:58 +0900 Subject: [PATCH 106/123] release v1.9.0 (#1662) --- CHANGELOG.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c9bd9b10..d8c3aac1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,28 @@ +# Changelog + +## v1.9.0 (2025-02-18) + +### Major Changes + +- Implement zlib compression. (#1487) +- Supported Go version is updated to Go 1.21+. (#1639) +- Add support for VECTOR type introduced in MySQL 9.0. (#1609) +- Config object can have custom dial function. (#1527) + +### Bugfixes + +- Fix auth errors when username/password are too long. (#1625) +- Check if MySQL supports CLIENT_CONNECT_ATTRS before sending client attributes. (#1640) +- Fix auth switch request handling. (#1666) + +### Other changes + +- Add "filename:line" prefix to log in go-mysql. Custom loggers now show it. (#1589) +- Improve error handling. It reduces the "busy buffer" errors. (#1595, #1601, #1641) +- Use `strconv.Atoi` to parse max_allowed_packet. (#1661) +- `rejectReadOnly` option now handles ER_READ_ONLY_MODE (1290) error too. (#1660) + + ## Version 1.8.1 (2024-03-26) Bugfixes: From c87981610c07572d94be59d39550be1e3b1b5bb3 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 10 Mar 2025 11:33:49 +0900 Subject: [PATCH 107/123] add Charset() option (#1679) Fix #1664. --- dsn.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/dsn.go b/dsn.go index 9b560b735..9bafab441 100644 --- a/dsn.go +++ b/dsn.go @@ -44,7 +44,6 @@ type Config struct { DBName string // Database name Params map[string]string // Connection parameters ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs - charsets []string // Connection charset. When set, this will be set in SET NAMES query Collation string // Connection collation. When set, this will be set in SET NAMES COLLATE query Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed @@ -81,6 +80,7 @@ type Config struct { beforeConnect func(context.Context, *Config) error // Invoked before a connection is established pubKey *rsa.PublicKey // Server public key timeTruncate time.Duration // Truncate time.Time values to the specified duration + charsets []string // Connection charset. When set, this will be set in SET NAMES query } // Functional Options Pattern @@ -135,6 +135,21 @@ func EnableCompression(yes bool) Option { } } +// Charset sets the connection charset and collation. +// +// charset is the connection charset. +// collation is the connection collation. It can be null or empty string. +// +// When collation is not specified, `SET NAMES ` command is sent when the connection is established. +// When collation is specified, `SET NAMES COLLATE ` command is sent when the connection is established. +func Charset(charset, collation string) Option { + return func(cfg *Config) error { + cfg.charsets = []string{charset} + cfg.Collation = collation + return nil + } +} + func (cfg *Config) Clone() *Config { cp := *cfg if cp.TLS != nil { From 88ff88b5915d34bde2b2c59991c586abb8ea9eca Mon Sep 17 00:00:00 2001 From: Bogdan Constantinescu Date: Mon, 10 Mar 2025 04:48:22 +0200 Subject: [PATCH 108/123] Fix FormatDSN missing ConnectionAttributes (#1619) Fix #1618 --- AUTHORS | 1 + dsn.go | 4 ++++ dsn_test.go | 6 +++++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 123b5dc50..510b869b7 100644 --- a/AUTHORS +++ b/AUTHORS @@ -24,6 +24,7 @@ Artur Melanchyk Asta Xie B Lamarche Bes Dollma +Bogdan Constantinescu Brian Hendriks Bulat Gaifullin Caine Jette diff --git a/dsn.go b/dsn.go index 9bafab441..ecf62567a 100644 --- a/dsn.go +++ b/dsn.go @@ -322,6 +322,10 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") } + if cfg.ConnectionAttributes != "" { + writeDSNParam(&buf, &hasParam, "connectionAttributes", url.QueryEscape(cfg.ConnectionAttributes)) + } + if cfg.compress { writeDSNParam(&buf, &hasParam, "compress", "true") } diff --git a/dsn_test.go b/dsn_test.go index 863d14824..436f77992 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -77,6 +77,9 @@ var testDSNs = []struct { }, { "user:password@/dbname?loc=UTC&timeout=30s&parseTime=true&timeTruncate=1h", &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, timeTruncate: time.Hour}, +}, { + "foo:bar@tcp(192.168.1.50:3307)/baz?timeout=10s&connectionAttributes=program_name:MySQLGoDriver%2FTest,program_version:1.2.3", + &Config{User: "foo", Passwd: "bar", Net: "tcp", Addr: "192.168.1.50:3307", DBName: "baz", Loc: time.UTC, Timeout: 10 * time.Second, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, ConnectionAttributes: "program_name:MySQLGoDriver/Test,program_version:1.2.3"}, }, } @@ -109,7 +112,8 @@ func TestDSNParserInvalid(t *testing.T) { "User:pass@tcp(1.2.3.4:3306)", // no trailing slash "net()/", // unknown default addr "user:pass@tcp(127.0.0.1:3306)/db/name", // invalid dbname - "user:password@/dbname?allowFallbackToPlaintext=PREFERRED", // wrong bool flag + "user:password@/dbname?allowFallbackToPlaintext=PREFERRED", // wrong bool flag + "user:password@/dbname?connectionAttributes=attr1:/unescaped/value", // unescaped //"/dbname?arg=/some/unescaped/path", } From b84ac5af9c77b13e4c6417e484cbed087cd672f3 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 14 Mar 2025 22:05:00 +0900 Subject: [PATCH 109/123] go.mod: fix go version format (#1682) As of Go 1.21, toolchain versions must use the 1.N.P syntax. https://go.dev/doc/toolchain#version --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 33c4dd5b1..187aff179 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/go-sql-driver/mysql -go 1.21 +go 1.21.0 require filippo.io/edwards25519 v1.1.0 From 1fbafa8082dab81e2c2e8caeb55d569dfeafcf94 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 21 Mar 2025 11:04:43 +0900 Subject: [PATCH 110/123] release v1.9.1 (#1683) --- CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d8c3aac1e..a1b23c66b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## v1.9.1 (2025-03-21) + +### Major Changes + +* Add Charset() option. (#1679) + +### Bugfixes + +* go.mod: fix go version format (#1682) +* Fix FormatDSN missing ConnectionAttributes (#1619) + ## v1.9.0 (2025-02-18) ### Major Changes From 21ef4c694538530b2e4b43d1a197402ed22e9749 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 7 Apr 2025 20:18:01 +0900 Subject: [PATCH 111/123] release v1.9.2 (#1693) --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a1b23c66b..66189edaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## v1.9.2 (2025-04-07) + +v1.9.2 is a re-release of v1.9.1 due to a release process issue; no changes were made to the content. + + ## v1.9.1 (2025-03-21) ### Major Changes From c84f49d1dbeb3f3acc40748569cf08e0066a42ad Mon Sep 17 00:00:00 2001 From: ICHINOSE Shogo Date: Tue, 8 Apr 2025 12:45:31 +0900 Subject: [PATCH 112/123] add Go 1.24 to the test matrix (#1681) --- .github/workflows/test.yml | 4 ++-- README.md | 2 +- go.mod | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2e07fea91..0dc207445 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,10 +29,10 @@ jobs: import os go = [ # Keep the most recent production release at the top - '1.23', + '1.24', # Older production releases + '1.23', '1.22', - '1.21', ] mysql = [ '9.0', diff --git a/README.md b/README.md index da4593ccf..65dd898d8 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac ## Requirements -* Go 1.21 or higher. We aim to support the 3 latest versions of Go. +* Go 1.22 or higher. We aim to support the 3 latest versions of Go. * MySQL (5.7+) and MariaDB (10.5+) are supported. * [TiDB](https://github.com/pingcap/tidb) is supported by PingCAP. * Do not ask questions about TiDB in our issue tracker or forum. diff --git a/go.mod b/go.mod index 187aff179..f17666dc8 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/go-sql-driver/mysql -go 1.21.0 +go 1.22.0 require filippo.io/edwards25519 v1.1.0 From 879eb117f443f98e8ea7289d423a3448211dcffe Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Fri, 11 Apr 2025 19:58:34 +0900 Subject: [PATCH 113/123] modernize for Go 1.22 (#1695) $ go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix ./... --- benchmark_test.go | 8 ++++---- connection_test.go | 2 +- driver_test.go | 21 +++++++++++---------- dsn.go | 2 +- infile.go | 5 +---- packets.go | 5 +---- result.go | 6 ++++-- utils.go | 8 ++++---- 8 files changed, 27 insertions(+), 30 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 5c9a046b5..912e54140 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -93,7 +93,7 @@ func benchmarkQueryHelper(b *testing.B, compr bool) { defer wg.Wait() b.StartTimer() - for i := 0; i < concurrencyLevel; i++ { + for range concurrencyLevel { go func() { for { if atomic.AddInt64(&remain, -1) < 0 { @@ -130,7 +130,7 @@ func BenchmarkExec(b *testing.B) { defer wg.Wait() b.StartTimer() - for i := 0; i < concurrencyLevel; i++ { + for range concurrencyLevel { go func() { for { if atomic.AddInt64(&remain, -1) < 0 { @@ -345,7 +345,7 @@ func BenchmarkQueryRawBytes(b *testing.B) { for i := range blob { blob[i] = 42 } - for i := 0; i < 100; i++ { + for i := range 100 { _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob) if err != nil { b.Fatal(err) @@ -401,7 +401,7 @@ func BenchmarkReceiveMassiveRows(b *testing.B) { } for i := 0; i < 10000; i += 100 { args := make([]any, 200) - for j := 0; j < 100; j++ { + for j := range 100 { args[j*2] = i + j args[j*2+1] = sval } diff --git a/connection_test.go b/connection_test.go index f7740898e..440ecbff7 100644 --- a/connection_test.go +++ b/connection_test.go @@ -141,7 +141,7 @@ func TestCleanCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - for i := 0; i < 3; i++ { // Repeat same behavior + for range 3 { // Repeat same behavior err := mc.Ping(ctx) if err != context.Canceled { t.Errorf("expected context.Canceled, got %#v", err) diff --git a/driver_test.go b/driver_test.go index 00e828657..477770c67 100644 --- a/driver_test.go +++ b/driver_test.go @@ -26,6 +26,7 @@ import ( "os" "reflect" "runtime" + "slices" "strconv" "strings" "sync" @@ -1926,7 +1927,7 @@ func TestPreparedManyCols(t *testing.T) { rows.Close() // Create 0byte string which we can't send via STMT_LONG_DATA. - for i := 0; i < numParams; i++ { + for i := range numParams { params[i] = "" } rows, err = stmt.Query(params...) @@ -1971,7 +1972,7 @@ func TestConcurrent(t *testing.T) { }) } - for i := 0; i < max; i++ { + for i := range max { go func(id int) { defer wg.Done() @@ -2355,7 +2356,7 @@ func TestPing(t *testing.T) { q.Close() // Verify that Ping() clears both fields. - for i := 0; i < 2; i++ { + for range 2 { if err := c.Ping(ctx); err != nil { dbt.fail("Pinger", "Ping", err) } @@ -2558,7 +2559,7 @@ func TestMultiResultSet(t *testing.T) { } defer stmt.Close() - for j := 0; j < 2; j++ { + for j := range 2 { rows, err := stmt.Query() if err != nil { dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j) @@ -2665,7 +2666,7 @@ func TestQueryMultipleResults(t *testing.T) { c := conn.(*mysqlConn) // Demonstrate that repeated queries reset the affectedRows - for i := 0; i < 2; i++ { + for range 2 { _, err := qr.Query(` INSERT INTO test (value) VALUES ('a'), ('b'); INSERT INTO test (value) VALUES ('c'), ('d'), ('e'); @@ -3293,11 +3294,11 @@ func TestRawBytesAreNotModified(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") - for i := 0; i < insertRows; i++ { + for i := range insertRows { dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1]) } - for i := 0; i < contextRaceIterations; i++ { + for i := range contextRaceIterations { func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -3552,8 +3553,8 @@ func TestConnectionAttributes(t *testing.T) { rowsMap[attrName] = attrValue } - connAttrs := append(append([]string{}, defaultAttrs...), customAttrs...) - expectedAttrValues := append(append([]string{}, defaultAttrValues...), customAttrValues...) + connAttrs := slices.Concat(defaultAttrs, customAttrs) + expectedAttrValues := slices.Concat(defaultAttrValues, customAttrValues) for i := range connAttrs { if gotValue := rowsMap[connAttrs[i]]; gotValue != expectedAttrValues[i] { dbt.Errorf("expected %q, got %q", expectedAttrValues[i], gotValue) @@ -3637,7 +3638,7 @@ func TestIssue1567(t *testing.T) { count = max } - for i := 0; i < count; i++ { + for range count { timeout := time.Duration(mrand.Int63n(int64(rtt))) ctx, cancel := context.WithTimeout(context.Background(), timeout) dbt.db.PingContext(ctx) diff --git a/dsn.go b/dsn.go index ecf62567a..89556bfba 100644 --- a/dsn.go +++ b/dsn.go @@ -414,7 +414,7 @@ func ParseDSN(dsn string) (cfg *Config, err error) { if dsn[j] == '@' { // username[:password] // Find the first ':' in dsn[:j] - for k = 0; k < j; k++ { + for k = 0; k < j; k++ { // We cannot use k = range j here, because we use dsn[:k] below if dsn[k] == ':' { cfg.Passwd = dsn[k+1 : j] break diff --git a/infile.go b/infile.go index 453ae091e..597b5e7f6 100644 --- a/infile.go +++ b/infile.go @@ -95,10 +95,7 @@ const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead a func (mc *okHandler) handleInFileRequest(name string) (err error) { var rdr io.Reader - packetSize := defaultPacketSize - if mc.maxWriteSize < packetSize { - packetSize = mc.maxWriteSize - } + packetSize := min(mc.maxWriteSize, defaultPacketSize) if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader // The server might return an an absolute path. See issue #355. diff --git a/packets.go b/packets.go index 4b8362160..a497a50a7 100644 --- a/packets.go +++ b/packets.go @@ -984,10 +984,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { mc := stmt.mc // Determine threshold dynamically to avoid packet size shortage. - longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) - if longDataSize < 64 { - longDataSize = 64 - } + longDataSize := max(mc.maxAllowedPacket/(stmt.paramCount+1), 64) // Reset packet-sequence mc.resetSequence() diff --git a/result.go b/result.go index d51631468..82dc0f9b6 100644 --- a/result.go +++ b/result.go @@ -8,6 +8,8 @@ package mysql +import "slices" + import "database/sql/driver" // Result exposes data not available through *connection.Result. @@ -42,9 +44,9 @@ func (res *mysqlResult) RowsAffected() (int64, error) { } func (res *mysqlResult) AllLastInsertIds() []int64 { - return append([]int64{}, res.insertIds...) // defensive copy + return slices.Clone(res.insertIds) // defensive copy } func (res *mysqlResult) AllRowsAffected() []int64 { - return append([]int64{}, res.affectedRows...) // defensive copy + return slices.Clone(res.affectedRows) // defensive copy } diff --git a/utils.go b/utils.go index 8716c26c5..b041804df 100644 --- a/utils.go +++ b/utils.go @@ -182,7 +182,7 @@ func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { func parseByteYear(b []byte) (int, error) { year, n := 0, 1000 - for i := 0; i < 4; i++ { + for i := range 4 { v, err := bToi(b[i]) if err != nil { return 0, err @@ -207,7 +207,7 @@ func parseByte2Digits(b1, b2 byte) (int, error) { func parseByteNanoSec(b []byte) (int, error) { ns, digit := 0, 100000 // max is 6-digits - for i := 0; i < len(b); i++ { + for i := range b { v, err := bToi(b[i]) if err != nil { return 0, err @@ -678,7 +678,7 @@ func escapeStringBackslash(buf []byte, v string) []byte { pos := len(buf) buf = reserveBuffer(buf, len(v)*2) - for i := 0; i < len(v); i++ { + for i := range len(v) { c := v[i] switch c { case '\x00': @@ -746,7 +746,7 @@ func escapeStringQuotes(buf []byte, v string) []byte { pos := len(buf) buf = reserveBuffer(buf, len(v)*2) - for i := 0; i < len(v); i++ { + for i := range len(v) { c := v[i] if c == '\'' { buf[pos+1] = '\'' From f433f1f9c1c0680fa967648bf16cbbe98a3ce8f6 Mon Sep 17 00:00:00 2001 From: Diego Dupin Date: Mon, 31 Mar 2025 18:04:08 +0200 Subject: [PATCH 114/123] test stability improvement. * ensuring performance schema is enabled when testing some performance schema results * Added logic to check if the default collation is overridden by the server character_set_collations * ensure using IANA timezone in test, since tzinfo depending on system won't have deprecated tz like "US/Central" and "US/Pacific" --- driver_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/driver_test.go b/driver_test.go index 477770c67..bb8aee7e2 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1631,13 +1631,46 @@ func TestCollation(t *testing.T) { } runTests(t, tdsn, func(dbt *DBTest) { + // see https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation + // when character_set_collations is set for the charset, it overrides the default collation + // so we need to check if the default collation is overridden + forceExpected := expected + var defaultCollations string + err := dbt.db.QueryRow("SELECT @@character_set_collations").Scan(&defaultCollations) + if err == nil { + // Query succeeded, need to check if we should override expected collation + collationMap := make(map[string]string) + pairs := strings.Split(defaultCollations, ",") + for _, pair := range pairs { + parts := strings.Split(pair, "=") + if len(parts) == 2 { + collationMap[parts[0]] = parts[1] + } + } + + // Get charset prefix from expected collation + parts := strings.Split(expected, "_") + if len(parts) > 0 { + charset := parts[0] + if newCollation, ok := collationMap[charset]; ok { + forceExpected = newCollation + } + } + } + var got string if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { dbt.Fatal(err) } if got != expected { - dbt.Fatalf("expected connection collation %s but got %s", expected, got) + if forceExpected != expected { + if got != forceExpected { + dbt.Fatalf("expected forced connection collation %s but got %s", forceExpected, got) + } + } else { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } } }) } @@ -1686,7 +1719,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) { } func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + zones := []string{"UTC", "America/New_York", "Asia/Hong_Kong", "Local"} // Regression test for timezone handling tzTest := func(dbt *DBTest) { @@ -1694,8 +1727,8 @@ func TestTimezoneConversion(t *testing.T) { dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") // Insert local time into database (should be converted) - usCentral, _ := time.LoadLocation("US/Central") - reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + newYorkTz, _ := time.LoadLocation("America/New_York") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(newYorkTz) dbt.mustExec("INSERT INTO test VALUE (?)", reftime) // Retrieve time from DB @@ -1714,7 +1747,7 @@ func TestTimezoneConversion(t *testing.T) { // Check that dates match if reftime.Unix() != dbTime.Unix() { dbt.Errorf("times do not match.\n") - dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(%v)=%v\n", newYorkTz, reftime) dbt.Errorf(" Now(UTC)=%v\n", dbTime) } } @@ -3542,6 +3575,15 @@ func TestConnectionAttributes(t *testing.T) { dbt := &DBTest{t, db} + var varName string + var varValue string + err := dbt.db.QueryRow("SHOW VARIABLES LIKE 'performance_schema'").Scan(&varName, &varValue) + if err != nil { + t.Fatalf("error: %s", err.Error()) + } + if varValue != "ON" { + t.Skipf("Performance schema is not enabled. skipping") + } queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" rows := dbt.mustQuery(queryString) defer rows.Close() From 8a2f8734b358651b93303ed4f5db4a0773e94eb4 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 21 Apr 2025 14:29:01 +0900 Subject: [PATCH 115/123] simplify collation tests --- driver_test.go | 64 +++++++++++--------------------------------------- 1 file changed, 14 insertions(+), 50 deletions(-) diff --git a/driver_test.go b/driver_test.go index bb8aee7e2..ec0f28772 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1610,10 +1610,12 @@ func TestCollation(t *testing.T) { t.Skipf("MySQL server not running on %s", netAddr) } - defaultCollation := "utf8mb4_general_ci" + // MariaDB may override collation specified by handshake with `character_set_collations` variable. + // https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation + // https://mariadb.com/kb/en/server-system-variables/#character_set_collations + // utf8mb4_general_ci, utf8mb3_general_ci will be overridden by default MariaDB. + // Collations other than charasets default are not overridden. So utf8mb4_unicode_ci is safe. testCollations := []string{ - "", // do not set - defaultCollation, // driver default "latin1_general_ci", "binary", "utf8mb4_unicode_ci", @@ -1621,57 +1623,19 @@ func TestCollation(t *testing.T) { } for _, collation := range testCollations { - var expected, tdsn string - if collation != "" { - tdsn = dsn + "&collation=" + collation - expected = collation - } else { - tdsn = dsn - expected = defaultCollation - } - - runTests(t, tdsn, func(dbt *DBTest) { - // see https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation - // when character_set_collations is set for the charset, it overrides the default collation - // so we need to check if the default collation is overridden - forceExpected := expected - var defaultCollations string - err := dbt.db.QueryRow("SELECT @@character_set_collations").Scan(&defaultCollations) - if err == nil { - // Query succeeded, need to check if we should override expected collation - collationMap := make(map[string]string) - pairs := strings.Split(defaultCollations, ",") - for _, pair := range pairs { - parts := strings.Split(pair, "=") - if len(parts) == 2 { - collationMap[parts[0]] = parts[1] - } - } + t.Run(collation, func(t *testing.T) { + tdsn := dsn + "&collation=" + collation + expected := collation - // Get charset prefix from expected collation - parts := strings.Split(expected, "_") - if len(parts) > 0 { - charset := parts[0] - if newCollation, ok := collationMap[charset]; ok { - forceExpected = newCollation - } + runTests(t, tdsn, func(dbt *DBTest) { + var got string + if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { + dbt.Fatal(err) } - } - - var got string - if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { - dbt.Fatal(err) - } - - if got != expected { - if forceExpected != expected { - if got != forceExpected { - dbt.Fatalf("expected forced connection collation %s but got %s", forceExpected, got) - } - } else { + if got != expected { dbt.Fatalf("expected connection collation %s but got %s", expected, got) } - } + }) }) } } From c786d41ac467d545d7c767b896034cacfe33765d Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Mon, 21 Apr 2025 14:32:09 +0900 Subject: [PATCH 116/123] add Diego Dupin to AUTHORS --- AUTHORS | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS b/AUTHORS index 510b869b7..a261819f2 100644 --- a/AUTHORS +++ b/AUTHORS @@ -37,6 +37,7 @@ Daniel Montoya Daniel Nichter Daniël van Eeden Dave Protasowski +Diego Dupin Dirkjan Bussink DisposaBoy Egor Smolyakov From e02b809d44edf544dae1e9c243725573256417c2 Mon Sep 17 00:00:00 2001 From: elonnzhang <49381087+elonnzhang@users.noreply.github.com> Date: Mon, 21 Apr 2025 15:25:06 +0800 Subject: [PATCH 117/123] ColumnTypeScanType() returns sql.Null[uint64] for bigint unsigned (#1612) Co-authored-by: elonnzhang Co-authored-by: Inada Naoki --- AUTHORS | 7 ++++--- fields.go | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/AUTHORS b/AUTHORS index a261819f2..4f64dca12 100644 --- a/AUTHORS +++ b/AUTHORS @@ -18,8 +18,8 @@ Alex Snast Alexey Palazhchenko Andrew Reid Animesh Ray -Arne Hormann Ariel Mashraki +Arne Hormann Artur Melanchyk Asta Xie B Lamarche @@ -65,6 +65,7 @@ Jeff Hodges Jeffrey Charles Jennifer Purevsuren Jerome Meyer +Jiabin Zhang Jiajia Zhong Jian Zhen Joe Mann @@ -84,10 +85,11 @@ Linh Tran Tuan Lion Yang Luca Looz Lucas Liu -Lunny Xiao Luke Scott +Lunny Xiao Maciej Zimnoch Michael Woolnough +Minh Quang Nao Yokotsuka Nathanial Murphy Nicola Peduzzi @@ -98,7 +100,6 @@ Paul Bonser Paulius Lozys Peter Schultz Phil Porada -Minh Quang Rebecca Chin Reed Allman Richard Wilkes diff --git a/fields.go b/fields.go index be5cd809a..25a166283 100644 --- a/fields.go +++ b/fields.go @@ -128,6 +128,7 @@ var ( scanTypeInt64 = reflect.TypeOf(int64(0)) scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullUint = reflect.TypeOf(sql.Null[uint64]{}) scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) scanTypeUint8 = reflect.TypeOf(uint8(0)) scanTypeUint16 = reflect.TypeOf(uint16(0)) @@ -185,6 +186,9 @@ func (mf *mysqlField) scanType() reflect.Type { } return scanTypeInt64 } + if mf.flags&flagUnsigned != 0 { + return scanTypeNullUint + } return scanTypeNullInt case fieldTypeFloat: From 7da50ff71ba333796e05c76b8fd134a7e6240d06 Mon Sep 17 00:00:00 2001 From: brad-defined <77982333+brad-defined@users.noreply.github.com> Date: Mon, 21 Apr 2025 06:42:34 -0400 Subject: [PATCH 118/123] Transaction Commit/Rollback returns conn's cached error, if present (#1691) If a transaction connection has a cached error, return it instead of ErrInvalidConn during Commit/Rollback operations. Fix #1690 --- AUTHORS | 2 ++ transaction.go | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index 4f64dca12..53841ef51 100644 --- a/AUTHORS +++ b/AUTHORS @@ -25,6 +25,7 @@ Asta Xie B Lamarche Bes Dollma Bogdan Constantinescu +Brad Higgins Brian Hendriks Bulat Gaifullin Caine Jette @@ -135,6 +136,7 @@ Ziheng Lyu Barracuda Networks, Inc. Counting Ltd. +Defined Networking Inc. DigitalOcean Inc. Dolthub Inc. dyves labs AG diff --git a/transaction.go b/transaction.go index 4a4b61001..8c502f49e 100644 --- a/transaction.go +++ b/transaction.go @@ -13,18 +13,32 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.closed.Load() { + if tx.mc == nil { return ErrInvalidConn } + if tx.mc.closed.Load() { + err = tx.mc.error() + if err == nil { + err = ErrInvalidConn + } + return + } err = tx.mc.exec("COMMIT") tx.mc = nil return } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.closed.Load() { + if tx.mc == nil { return ErrInvalidConn } + if tx.mc.closed.Load() { + err = tx.mc.error() + if err == nil { + err = ErrInvalidConn + } + return + } err = tx.mc.exec("ROLLBACK") tx.mc = nil return From f7d94ecd2d71490d39bb5715186c8bbbdd512e7b Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Wed, 23 Apr 2025 17:28:06 +0900 Subject: [PATCH 119/123] add BenchmarkReceive10kRowsCompress (#1704) * Rename BenchmarkReceiveMassiveRows to BenchmarkReceive10kRows * Add BenchmarkReceive10kRowsCompress that run BenchmarkReceiveMassiveRows with compression * Other tiny benchmark improvements. --- .github/workflows/test.yml | 4 ++ benchmark_test.go | 81 ++++++++++++++++++++++---------------- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0dc207445..d035c12a8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,6 +96,10 @@ jobs: run: | go test -v '-race' '-covermode=atomic' '-coverprofile=coverage.out' -parallel 10 + - name: benchmark + run: | + go test -run '^$' -bench . + - name: Send coverage uses: shogo82148/actions-goveralls@v1 with: diff --git a/benchmark_test.go b/benchmark_test.go index 912e54140..1c3f64d32 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -46,10 +46,10 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { return stmt } -func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { +func initDB(b *testing.B, compress bool, queries ...string) *sql.DB { tb := (*TB)(b) comprStr := "" - if useCompression { + if compress { comprStr = "&compress=1" } db := tb.checkDB(sql.Open(driverNameTest, dsn+comprStr)) @@ -64,16 +64,15 @@ func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { const concurrencyLevel = 10 func BenchmarkQuery(b *testing.B) { - benchmarkQueryHelper(b, false) + benchmarkQuery(b, false) } -func BenchmarkQueryCompression(b *testing.B) { - benchmarkQueryHelper(b, true) +func BenchmarkQueryCompressed(b *testing.B) { + benchmarkQuery(b, true) } -func benchmarkQueryHelper(b *testing.B, compr bool) { +func benchmarkQuery(b *testing.B, compr bool) { tb := (*TB)(b) - b.StopTimer() b.ReportAllocs() db := initDB(b, compr, "DROP TABLE IF EXISTS foo", @@ -93,7 +92,7 @@ func benchmarkQueryHelper(b *testing.B, compr bool) { defer wg.Wait() b.StartTimer() - for range concurrencyLevel { + for i := 0; i < concurrencyLevel; i++ { go func() { for { if atomic.AddInt64(&remain, -1) < 0 { @@ -115,8 +114,6 @@ func benchmarkQueryHelper(b *testing.B, compr bool) { func BenchmarkExec(b *testing.B) { tb := (*TB)(b) - b.StopTimer() - b.ReportAllocs() db := tb.checkDB(sql.Open(driverNameTest, dsn)) db.SetMaxIdleConns(concurrencyLevel) defer db.Close() @@ -128,9 +125,11 @@ func BenchmarkExec(b *testing.B) { var wg sync.WaitGroup wg.Add(concurrencyLevel) defer wg.Wait() - b.StartTimer() - for range concurrencyLevel { + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < concurrencyLevel; i++ { go func() { for { if atomic.AddInt64(&remain, -1) < 0 { @@ -158,14 +157,15 @@ func initRoundtripBenchmarks() ([]byte, int, int) { } func BenchmarkRoundtripTxt(b *testing.B) { - b.StopTimer() sample, min, max := initRoundtripBenchmarks() sampleString := string(sample) - b.ReportAllocs() tb := (*TB)(b) db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() - b.StartTimer() + + b.ReportAllocs() + b.ResetTimer() + var result string for i := 0; i < b.N; i++ { length := min + i @@ -192,15 +192,15 @@ func BenchmarkRoundtripTxt(b *testing.B) { } func BenchmarkRoundtripBin(b *testing.B) { - b.StopTimer() sample, min, max := initRoundtripBenchmarks() - b.ReportAllocs() tb := (*TB)(b) db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() stmt := tb.checkStmt(db.Prepare("SELECT ?")) defer stmt.Close() - b.StartTimer() + + b.ReportAllocs() + b.ResetTimer() var result sql.RawBytes for i := 0; i < b.N; i++ { length := min + i @@ -345,7 +345,7 @@ func BenchmarkQueryRawBytes(b *testing.B) { for i := range blob { blob[i] = 42 } - for i := range 100 { + for i := 0; i < 100; i++ { _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob) if err != nil { b.Fatal(err) @@ -385,10 +385,9 @@ func BenchmarkQueryRawBytes(b *testing.B) { } } -// BenchmarkReceiveMassiveRows measures performance of receiving large number of rows. -func BenchmarkReceiveMassiveRows(b *testing.B) { +func benchmark10kRows(b *testing.B, compress bool) { // Setup -- prepare 10000 rows. - db := initDB(b, false, + db := initDB(b, compress, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val TEXT)") defer db.Close() @@ -399,11 +398,14 @@ func BenchmarkReceiveMassiveRows(b *testing.B) { b.Errorf("failed to prepare query: %v", err) return } + + args := make([]any, 200) + for i := 1; i < 200; i+=2 { + args[i] = sval + } for i := 0; i < 10000; i += 100 { - args := make([]any, 200) - for j := range 100 { + for j := 0; j < 100; j++ { args[j*2] = i + j - args[j*2+1] = sval } _, err := stmt.Exec(args...) if err != nil { @@ -413,30 +415,43 @@ func BenchmarkReceiveMassiveRows(b *testing.B) { } stmt.Close() - // Use b.Run() to skip expensive setup. + // benchmark function called several times with different b.N. + // it means heavy setup is called multiple times. + // Use b.Run() to run expensive setup only once. + // Go 1.24 introduced b.Loop() for this purpose. But we keep this + // benchmark compatible with Go 1.20. b.Run("query", func(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { rows, err := db.Query(`SELECT id, val FROM foo`) if err != nil { b.Errorf("failed to select: %v", err) return } + // rows.Scan() escapes arguments. So these variables must be defined + // before loop. + var i int + var s sql.RawBytes for rows.Next() { - var i int - var s sql.RawBytes - err = rows.Scan(&i, &s) - if err != nil { + if err := rows.Scan(&i, &s); err != nil { b.Errorf("failed to scan: %v", err) - _ = rows.Close() + rows.Close() return } } if err = rows.Err(); err != nil { b.Errorf("failed to read rows: %v", err) } - _ = rows.Close() + rows.Close() } }) } + +// BenchmarkReceive10kRows measures performance of receiving large number of rows. +func BenchmarkReceive10kRows(b *testing.B) { + benchmark10kRows(b, false) +} + +func BenchmarkReceive10kRowsCompressed(b *testing.B) { + benchmark10kRows(b, true) +} From 0fd55eb45dae058584cc7a2c80f721dadc991117 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 24 Apr 2025 18:43:48 +0900 Subject: [PATCH 120/123] optimize readPacket (#1705) Avoid unnecessary allocation. --- buffer.go | 18 ++++++++---------- compress.go | 12 ++++++------ packets.go | 17 ++++++++++++++--- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/buffer.go b/buffer.go index a65324315..f895e87b3 100644 --- a/buffer.go +++ b/buffer.go @@ -42,6 +42,11 @@ func (b *buffer) busy() bool { return len(b.buf) > 0 } +// len returns how many bytes in the read buffer. +func (b *buffer) len() int { + return len(b.buf) +} + // fill reads into the read buffer until at least _need_ bytes are in it. func (b *buffer) fill(need int, r readerFunc) error { // we'll move the contents of the current buffer to dest before filling it. @@ -86,17 +91,10 @@ func (b *buffer) fill(need int, r readerFunc) error { // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int, r readerFunc) ([]byte, error) { - if len(b.buf) < need { - // refill - if err := b.fill(need, r); err != nil { - return nil, err - } - } - - data := b.buf[:need] +func (b *buffer) readNext(need int) []byte { + data := b.buf[:need:need] b.buf = b.buf[need:] - return data, nil + return data } // takeBuffer returns a buffer with the requested size. diff --git a/compress.go b/compress.go index fa42772ac..e247a65ac 100644 --- a/compress.go +++ b/compress.go @@ -84,9 +84,9 @@ func (c *compIO) reset() { c.buff.Reset() } -func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) { +func (c *compIO) readNext(need int) ([]byte, error) { for c.buff.Len() < need { - if err := c.readCompressedPacket(r); err != nil { + if err := c.readCompressedPacket(); err != nil { return nil, err } } @@ -94,8 +94,8 @@ func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) { return data[:need:need], nil // prevent caller writes into c.buff } -func (c *compIO) readCompressedPacket(r readerFunc) error { - header, err := c.mc.buf.readNext(7, r) // size of compressed header +func (c *compIO) readCompressedPacket() error { + header, err := c.mc.readNext(7) if err != nil { return err } @@ -103,7 +103,7 @@ func (c *compIO) readCompressedPacket(r readerFunc) error { // compressed header structure comprLength := getUint24(header[0:3]) - compressionSequence := uint8(header[3]) + compressionSequence := header[3] uncompressedLength := getUint24(header[4:7]) if debug { fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", @@ -120,7 +120,7 @@ func (c *compIO) readCompressedPacket(r readerFunc) error { c.mc.sequence = compressionSequence + 1 c.mc.compressSequence = c.mc.sequence - comprData, err := c.mc.buf.readNext(comprLength, r) + comprData, err := c.mc.readNext(comprLength) if err != nil { return err } diff --git a/packets.go b/packets.go index a497a50a7..e6e1704b3 100644 --- a/packets.go +++ b/packets.go @@ -25,19 +25,30 @@ import ( // https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html // https://mariadb.com/kb/en/clientserver-protocol/ +// read n bytes from mc.buf +func (mc *mysqlConn) readNext(n int) ([]byte, error) { + if mc.buf.len() < n { + err := mc.buf.fill(n, mc.readWithTimeout) + if err != nil { + return nil, err + } + } + return mc.buf.readNext(n), nil +} + // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte invalidSequence := false - readNext := mc.buf.readNext + readNext := mc.readNext if mc.compress { readNext = mc.compIO.readNext } for { // read packet header - data, err := readNext(4, mc.readWithTimeout) + data, err := readNext(4) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -85,7 +96,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = readNext(pktLen, mc.readWithTimeout) + data, err = readNext(pktLen) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { From 6e944e126d91bd7fa99d2e0159134e8cbac42f12 Mon Sep 17 00:00:00 2001 From: diego dupin Date: Sat, 26 Apr 2025 02:08:52 +0200 Subject: [PATCH 121/123] MariaDB Metadata skipping and DEPRECATE_EOF (#1708) [MariaDB metadata skipping](https://mariadb.com/kb/en/mariadb-protocol-differences-with-mysql/#prepare-statement-skipping-metadata). With this change, MariaDB server won't send metadata when they have not changed, saving client parsing metadata and network. This feature rely on these changes: * extended capabilities support * EOF packet deprecation makes current implementation to be revised A benchmark BenchmarkReceiveMetadata has been added to show the difference. --- benchmark_test.go | 59 ++++++++- connection.go | 29 +++-- connector.go | 6 +- const.go | 19 ++- packets.go | 297 +++++++++++++++++++++++++++------------------- packets_test.go | 10 +- rows.go | 6 +- statement.go | 30 ++++- 8 files changed, 309 insertions(+), 147 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index 1c3f64d32..b246f4ac3 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -129,7 +129,7 @@ func BenchmarkExec(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < concurrencyLevel; i++ { + for i := 0; i < concurrencyLevel; i++ { go func() { for { if atomic.AddInt64(&remain, -1) < 0 { @@ -400,7 +400,7 @@ func benchmark10kRows(b *testing.B, compress bool) { } args := make([]any, 200) - for i := 1; i < 200; i+=2 { + for i := 1; i < 200; i += 2 { args[i] = sval } for i := 0; i < 10000; i += 100 { @@ -455,3 +455,58 @@ func BenchmarkReceive10kRows(b *testing.B) { func BenchmarkReceive10kRowsCompressed(b *testing.B) { benchmark10kRows(b, true) } + +// BenchmarkReceiveMetadata measures performance of receiving lots of metadata compare to data in rows +func BenchmarkReceiveMetadata(b *testing.B) { + tb := (*TB)(b) + + // Create a table with 1000 integer fields + createTableQuery := "CREATE TABLE large_integer_table (" + for i := 0; i < 1000; i++ { + createTableQuery += fmt.Sprintf("col_%d INT", i) + if i < 999 { + createTableQuery += ", " + } + } + createTableQuery += ")" + + // Initialize database + db := initDB(b, false, + "DROP TABLE IF EXISTS large_integer_table", + createTableQuery, + "INSERT INTO large_integer_table VALUES ("+ + strings.Repeat("0,", 999)+"0)", // Insert a row of zeros + ) + defer db.Close() + + b.Run("query", func(b *testing.B) { + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) + + // Create a slice to scan all columns + values := make([]any, 1000) + valuePtrs := make([]any, 1000) + for j := range values { + valuePtrs[j] = &values[j] + } + + b.ReportAllocs() + b.ResetTimer() + + // Prepare a SELECT query to retrieve metadata + stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1")) + defer stmt.Close() + + // Benchmark metadata retrieval + for range b.N { + rows := tb.checkRows(stmt.Query()) + + rows.Next() + // Scan the row + err := rows.Scan(valuePtrs...) + tb.check(err) + + rows.Close() + } + }) +} diff --git a/connection.go b/connection.go index 3e455a3ff..58c763fad 100644 --- a/connection.go +++ b/connection.go @@ -33,7 +33,8 @@ type mysqlConn struct { connector *connector maxAllowedPacket int maxWriteSize int - flags clientFlag + capabilities capabilityFlag + extCapabilities extendedCapabilityFlag status statusFlag sequence uint8 compressSequence uint8 @@ -223,13 +224,21 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { + if err = mc.skipColumns(stmt.paramCount); err != nil { return nil, err } } if columnCount > 0 { - err = mc.readUntilEOF() + if mc.extCapabilities&clientCacheMetadata != 0 { + if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil { + return nil, err + } + } else { + if err = mc.skipColumns(int(columnCount)); err != nil { + return nil, err + } + } } } @@ -370,19 +379,19 @@ func (mc *mysqlConn) exec(query string) error { } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipColumns(resLen); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipRows(); err != nil { return err } } @@ -419,7 +428,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) // Read Result var resLen int - resLen, err = handleOk.readResultSetHeaderPacket() + resLen, _, err = handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -453,7 +462,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, _, err := handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc @@ -461,14 +470,14 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if resLen > 0 { // Columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.skipColumns(resLen); err != nil { return nil, err } } dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { - return dest[0].([]byte), mc.readUntilEOF() + return dest[0].([]byte), mc.skipRows() } } return nil, err diff --git a/connector.go b/connector.go index bc1d46afc..dca473fa7 100644 --- a/connector.go +++ b/connector.go @@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.buf = newBuffer() // Reading Handshake Initialization Packet - authData, plugin, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtCapabilities, plugin, err := mc.readHandshakePacket() if err != nil { mc.cleanup() return nil, err @@ -153,6 +153,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } } + mc.initCapabilities(serverCapabilities, serverExtCapabilities, mc.cfg) if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { mc.cleanup() return nil, err @@ -167,7 +168,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return nil, err } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { + // compression is enabled after auth, not right after sending handshake response. + if mc.capabilities&clientCompress > 0 { mc.compress = true mc.compIO = newCompIO(mc) } diff --git a/const.go b/const.go index 4aadcd642..311e92eaf 100644 --- a/const.go +++ b/const.go @@ -42,11 +42,12 @@ const ( iERR byte = 0xff ) -// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags -type clientFlag uint32 +// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html +// https://mariadb.com/kb/en/connection/#capabilities +type capabilityFlag uint32 const ( - clientLongPassword clientFlag = 1 << iota + clientMySQL capabilityFlag = 1 << iota clientFoundRows clientLongFlag clientConnectWithDB @@ -73,6 +74,18 @@ const ( clientDeprecateEOF ) +// https://mariadb.com/kb/en/connection/#capabilities +type extendedCapabilityFlag uint32 + +const ( + progressIndicator extendedCapabilityFlag = 1 << iota + clientComMulti + clientStmtBulkOperations + clientExtendedMetadata + clientCacheMetadata + clientUnitBulkResult +) + const ( comQuit byte = iota + 1 comInitDB diff --git a/packets.go b/packets.go index e6e1704b3..1319f9e64 100644 --- a/packets.go +++ b/packets.go @@ -184,20 +184,22 @@ func (mc *mysqlConn) writePacket(data []byte) error { ******************************************************************************/ // Handshake Initialization Packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html +// https://mariadb.com/kb/en/connection/#initial-handshake-packet +func (mc *mysqlConn) readHandshakePacket() (data []byte, capabilities capabilityFlag, extendedCapabilities extendedCapabilityFlag, plugin string, err error) { data, err = mc.readPacket() if err != nil { return } if data[0] == iERR { - return nil, "", mc.handleErrorPacket(data) + err = mc.handleErrorPacket(data) + return } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, "", fmt.Errorf( + return nil, 0, 0, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -215,15 +217,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro pos += 8 + 1 // capability flags (lower 2 bytes) [2 bytes] - mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) - if mc.flags&clientProtocol41 == 0 { - return nil, "", ErrOldProtocol + capabilities = capabilityFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if capabilities&clientProtocol41 == 0 { + return nil, capabilities, 0, "", ErrOldProtocol } - if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil { + if capabilities&clientSSL == 0 && mc.cfg.TLS != nil { if mc.cfg.AllowFallbackToPlaintext { mc.cfg.TLS = nil } else { - return nil, "", ErrNoTLS + return nil, capabilities, 0, "", ErrNoTLS } } pos += 2 @@ -233,11 +235,16 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // status flags [2 bytes] pos += 3 // capability flags (upper 2 bytes) [2 bytes] - mc.flags |= clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + capabilities |= capabilityFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 pos += 2 // length of auth-plugin-data [1 byte] - // reserved (all [00]) [10 bytes] - pos += 11 + // reserved (all [00]) [6 bytes] + pos += 7 + if capabilities&clientMySQL == 0 { + // MariaDB server extended flag + extendedCapabilities = extendedCapabilityFlag(binary.LittleEndian.Uint32(data[pos : pos+4])) + } + pos += 4 // second part of the password cipher [minimum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -265,82 +272,72 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // make a memory safe copy of the cipher slice var b [20]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], capabilities, extendedCapabilities, plugin, nil } // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], authData) - return b[:], plugin, nil + return b[:], capabilities, 0, plugin, nil } -// Client Authentication Packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { - // Adjust client flags based on server support - clientFlags := clientProtocol41 | - clientSecureConn | - clientLongPassword | - clientTransactions | - clientLocalFiles | - clientPluginAuth | - clientMultiResults | - mc.flags&clientConnectAttrs | - mc.flags&clientLongFlag - - sendConnectAttrs := mc.flags&clientConnectAttrs != 0 - - if mc.cfg.ClientFoundRows { - clientFlags |= clientFoundRows +// initCapabilities initializes the capabilities based on server support and configuration +func (mc *mysqlConn) initCapabilities(serverCapabilities capabilityFlag, serverExtCapabilities extendedCapabilityFlag, cfg *Config) { + clientCapabilities := + clientMySQL | + clientLongFlag | + clientProtocol41 | + clientSecureConn | + clientTransactions | + clientPluginAuthLenEncClientData | + clientLocalFiles | + clientPluginAuth | + clientMultiResults | + clientConnectAttrs | + clientDeprecateEOF + + if cfg.ClientFoundRows { + clientCapabilities |= clientFoundRows } - if mc.cfg.compress && mc.flags&clientCompress == clientCompress { - clientFlags |= clientCompress + if cfg.compress { + clientCapabilities |= clientCompress } // To enable TLS / SSL if mc.cfg.TLS != nil { - clientFlags |= clientSSL + clientCapabilities |= clientSSL } if mc.cfg.MultiStatements { - clientFlags |= clientMultiStatements + clientCapabilities |= clientMultiStatements } - - // encode length of the auth plugin data - var authRespLEIBuf [9]byte - authRespLen := len(authResp) - authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) - if len(authRespLEI) > 1 { - // if the length can not be written in 1 byte, it must be written as a - // length encoded integer - clientFlags |= clientPluginAuthLenEncClientData + if n := len(cfg.DBName); n > 0 { + clientCapabilities |= clientConnectWithDB } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 - - // To specify a db name - if n := len(mc.cfg.DBName); n > 0 { - clientFlags |= clientConnectWithDB - pktLen += n + 1 - } + // only keep client capabilities that server have + mc.capabilities = clientCapabilities & serverCapabilities - // encode length of the connection attributes - var connAttrsLEI []byte - if sendConnectAttrs { - var connAttrsLEIBuf [9]byte - connAttrsLen := len(mc.connector.encodedAttributes) - connAttrsLEI = appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) - pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) - } + // set MariaDB extended clientCacheMetadata capability if server support it + mc.extCapabilities = clientCacheMetadata & serverExtCapabilities +} - // Calculate packet length and get buffer with that size - data, err := mc.buf.takeBuffer(pktLen + 4) +// Client Authentication Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { + // packet header 4 + // capabilities 4 + // maxPacketSize 4 + // collation id 1 + // filler 23 + data, err := mc.buf.takeSmallBuffer(4*3 + 24) if err != nil { mc.cleanup() return err } + _ = data[4*3+23] // boundery check - // ClientFlags [32 bit] - binary.LittleEndian.PutUint32(data[4:], uint32(clientFlags)) + // clientCapabilities [32 bit] + binary.LittleEndian.PutUint32(data[4:], uint32(mc.capabilities)) // MaxPacketSize [32 bit] (none) binary.LittleEndian.PutUint32(data[8:], 0) @@ -358,16 +355,26 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } // Filler [23 bytes] (all 0x00) + // or filler 19bytes + mariadb extCapabilities pos := 13 - for ; pos < 13+23; pos++ { - data[pos] = 0 + if mc.capabilities&clientMySQL == 0 { + for ; pos < 13+19; pos++ { + data[pos] = 0 + } + // MariaDB Extended Capabilities + binary.LittleEndian.PutUint32(data[13+19:], uint32(mc.extCapabilities)) + } else { + for ; pos < 13+23; pos++ { + data[pos] = 0 + } } // SSL Connection Request Packet - // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_ssl_request.html + // https://mariadb.com/kb/en/connection/#sslrequest-packet if mc.cfg.TLS != nil { // Send TLS / SSL request packet - if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + if err := mc.writePacket(data); err != nil { return err } @@ -384,34 +391,32 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string // User [null terminated string] if len(mc.cfg.User) > 0 { - pos += copy(data[pos:], mc.cfg.User) + data = append(data, mc.cfg.User...) } - data[pos] = 0x00 - pos++ + data = append(data, 0) // Auth Data [length encoded integer] - pos += copy(data[pos:], authRespLEI) - pos += copy(data[pos:], authResp) + data = appendLengthEncodedInteger(data, uint64(len(authResp))) + data = append(data, authResp...) - // Databasename [null terminated string] - if len(mc.cfg.DBName) > 0 { - pos += copy(data[pos:], mc.cfg.DBName) - data[pos] = 0x00 - pos++ + // Database name [null terminated string] + if mc.capabilities&clientConnectWithDB != 0 { + data = append(data, mc.cfg.DBName...) + data = append(data, 0) } - pos += copy(data[pos:], plugin) - data[pos] = 0x00 - pos++ + data = append(data, plugin...) + data = append(data, 0) // Connection Attributes - if sendConnectAttrs { - pos += copy(data[pos:], connAttrsLEI) - pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) + if mc.capabilities&clientConnectAttrs != 0 { + connAttrsLen := len(mc.connector.encodedAttributes) + data = appendLengthEncodedInteger(data, uint64(connAttrsLen)) + data = append(data, mc.connector.encodedAttributes...) } // Send Auth packet - return mc.writePacket(data[:pos]) + return mc.writePacket(data) } // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse @@ -546,32 +551,37 @@ func (mc *okHandler) readResultOK() error { // Result Set Header Packet // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response.html -func (mc *okHandler) readResultSetHeaderPacket() (int, error) { +func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) { // handleOkPacket replaces both values; other cases leave the values unchanged. mc.result.affectedRows = append(mc.result.affectedRows, 0) mc.result.insertIds = append(mc.result.insertIds, 0) data, err := mc.conn().readPacket() if err != nil { - return 0, err + return 0, false, err } switch data[0] { case iOK: - return 0, mc.handleOkPacket(data) + return 0, false, mc.handleOkPacket(data) case iERR: - return 0, mc.conn().handleErrorPacket(data) + return 0, false, mc.conn().handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) + return 0, false, mc.handleInFileRequest(string(data[1:])) } // column count // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html - num, _, _ := readLengthEncodedInteger(data) + // https://mariadb.com/kb/en/result-set-packets/#column-count-packet + num, _, len := readLengthEncodedInteger(data) + + if mc.extCapabilities&clientCacheMetadata != 0 { + return int(num), data[len] == 0x01, nil + } // ignore remaining data in the packet. see #1478. - return int(num), nil + return int(num), true, nil } // Error Packet @@ -695,20 +705,12 @@ func (mc *okHandler) handleOkPacket(data []byte) error { func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; ; i++ { + for i := range count { data, err := mc.readPacket() if err != nil { return nil, err } - // EOF Packet - if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { - if i == count { - return columns, nil - } - return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) - } - // Catalog pos, err := skipLengthEncodedString(data) if err != nil { @@ -781,13 +783,13 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Decimals [uint8] columns[i].decimals = data[pos] - //pos++ + } - // Default value [len coded binary] - //if pos < len(data) { - // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) - //} + // skip EOF packet if client does not support deprecateEOF + if err := mc.skipEof(); err != nil { + return nil, err } + return columns, nil } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -805,9 +807,20 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) + // text row packets may starts with LengthEncodedString. + // In such case, 0xFE can mean string larger than 0xffffff. + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le + if data[0] == iEOF && len(data) <= 0xffffff { + if mc.capabilities&clientDeprecateEOF == 0 { + // Deprecated EOF packet + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_eof_packet.html + mc.status = readStatus(data[3:]) + } else { + // Ok Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) // affected_rows + _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id + mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil @@ -881,8 +894,34 @@ func (rows *textRows) readRow(dest []driver.Value) error { return nil } -// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() error { +func (mc *mysqlConn) skipPackets(n int) error { + for i := 0; i < n; i++ { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +// skips EOF packet after n * ColumnDefinition packets when clientDeprecateEOF is not set +func (mc *mysqlConn) skipEof() error { + if mc.capabilities&clientDeprecateEOF == 0 { + if _, err := mc.readPacket(); err != nil { + return err + } + } + return nil +} + +func (mc *mysqlConn) skipColumns(n int) error { + if err := mc.skipPackets(n); err != nil { + return err + } + return mc.skipEof() +} + +// Reads Packets until EOF-Packet or an Error appears. +func (mc *mysqlConn) skipRows() error { for { data, err := mc.readPacket() if err != nil { @@ -893,10 +932,20 @@ func (mc *mysqlConn) readUntilEOF() error { case iERR: return mc.handleErrorPacket(data) case iEOF: - if len(data) == 5 { - mc.status = readStatus(data[3:]) + // text row packets may starts with LengthEncodedString. + // In such case, 0xFE can mean string larger than 0xffffff. + if len(data) <= 0xffffff { + if mc.capabilities&clientDeprecateEOF == 0 { + // EOF packet + mc.status = readStatus(data[3:]) + } else { + // OK packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) // affected_rows + _, _, m := readLengthEncodedInteger(data[1+n:]) // last_insert_id + mc.status = readStatus(data[1+n+m:]) + } + return nil } - return nil } } } @@ -1184,17 +1233,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // mc.affectedRows and mc.insertIds. func (mc *okHandler) discardResults() error { for mc.status&statusMoreResultsExists != 0 { - resLen, err := mc.readResultSetHeaderPacket() + resLen, _, err := mc.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipColumns(resLen); err != nil { return err } // rows - if err := mc.conn().readUntilEOF(); err != nil { + if err := mc.conn().skipRows(); err != nil { return err } } @@ -1211,9 +1260,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - // EOF Packet - if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) + // EOF/OK Packet + if data[0] == iEOF { + if rows.mc.capabilities&clientDeprecateEOF == 0 { + // EOF packet + rows.mc.status = readStatus(data[3:]) + } else { + // OK Packet with an 0xFE header + _, _, n := readLengthEncodedInteger(data[1:]) + _, _, m := readLengthEncodedInteger(data[1+n:]) + rows.mc.status = readStatus(data[1+n+m:]) + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil diff --git a/packets_test.go b/packets_test.go index 694b0564c..b487051e2 100644 --- a/packets_test.go +++ b/packets_test.go @@ -332,11 +332,19 @@ func TestRegression801(t *testing.T) { 112, 97, 115, 115, 119, 111, 114, 100} conn.maxReads = 1 - authData, pluginName, err := mc.readHandshakePacket() + authData, serverCapabilities, serverExtendedCapabilities, pluginName, err := mc.readHandshakePacket() if err != nil { t.Fatalf("got error: %v", err) } + if serverCapabilities != 2148530143 { + t.Fatalf("expected serverCapabilities to be 2148530143, got %v", serverCapabilities) + } + + if serverExtendedCapabilities != 0 { + t.Fatalf("expected serverExtendedCapabilities to be 0, got %v", serverExtendedCapabilities) + } + if pluginName != "mysql_native_password" { t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName) } diff --git a/rows.go b/rows.go index df98417b8..e41fda6f4 100644 --- a/rows.go +++ b/rows.go @@ -113,7 +113,7 @@ func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.rs.done { - err = mc.readUntilEOF() + err = mc.skipRows() } if err == nil { handleOk := mc.clearResult() @@ -143,7 +143,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { // Remove unread packets from stream if !rows.rs.done { - if err := rows.mc.readUntilEOF(); err != nil { + if err := rows.mc.skipRows(); err != nil { return 0, err } rows.rs.done = true @@ -156,7 +156,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { rows.rs = resultSet{} // rows.mc.affectedRows and rows.mc.insertIds accumulate on each call to // nextResultSet. - resLen, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() + resLen, _, err := rows.mc.resultUnchanged().readResultSetHeaderPacket() if err != nil { // Clean up about multi-results flag rows.rs.done = true diff --git a/statement.go b/statement.go index 35df85457..0f6c65a37 100644 --- a/statement.go +++ b/statement.go @@ -20,6 +20,7 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int + columns []mysqlField } func (stmt *mysqlStmt) Close() error { @@ -64,19 +65,26 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { handleOk := stmt.mc.clearResult() // Read Result - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } if resLen > 0 { // Columns - if err = mc.readUntilEOF(); err != nil { - return nil, err + if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 { + // we can not skip column metadata because next stmt.Query() may use it. + if stmt.columns, err = mc.readColumns(resLen); err != nil { + return nil, err + } + } else { + if err = mc.skipColumns(resLen); err != nil { + return nil, err + } } // Rows - if err := mc.readUntilEOF(); err != nil { + if err = mc.skipRows(); err != nil { return nil, err } } @@ -107,7 +115,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { // Read Result handleOk := stmt.mc.clearResult() - resLen, err := handleOk.readResultSetHeaderPacket() + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() if err != nil { return nil, err } @@ -116,7 +124,17 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if resLen > 0 { rows.mc = mc - rows.rs.columns, err = mc.readColumns(resLen) + if metadataFollows { + if rows.rs.columns, err = mc.readColumns(resLen); err != nil { + return nil, err + } + stmt.columns = rows.rs.columns + } else { + if err = mc.skipEof(); err != nil { + return nil, err + } + rows.rs.columns = stmt.columns + } } else { rows.rs.done = true From 2356566e5123327d6fe49bc46e90e4efcd4a93b1 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Tue, 29 Apr 2025 11:43:41 +0900 Subject: [PATCH 122/123] Optimization: statements reuse previous column name (#1711) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #1708 added `[]mysqlField` cache to stmt. It was used only for MariaDB cached metadata. This commit allows MySQL to also benefit from the metadata cache. If the column names are the same as the cached metadata, it reuses them instead of allocating new strings. goos: darwin goarch: arm64 pkg: github.com/go-sql-driver/mysql cpu: Apple M1 Pro │ master.txt │ reuse.txt │ │ sec/op │ sec/op vs base │ ReceiveMetadata-8 1.273m ± 2% 1.269m ± 2% ~ (p=1.000 n=10) │ master.txt │ reuse.txt │ │ B/op │ B/op vs base │ ReceiveMetadata-8 88.17Ki ± 0% 80.39Ki ± 0% -8.82% (p=0.000 n=10) │ master.txt │ reuse.txt │ │ allocs/op │ allocs/op vs base │ ReceiveMetadata-8 1015.00 ± 0% 16.00 ± 0% -98.42% (p=0.000 n=10) --- benchmark_test.go | 5 ++--- connection.go | 4 ++-- packets.go | 19 ++++++++++++++++--- rows.go | 4 ++-- statement.go | 4 ++-- 5 files changed, 24 insertions(+), 12 deletions(-) diff --git a/benchmark_test.go b/benchmark_test.go index b246f4ac3..e735776d3 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -490,14 +490,13 @@ func BenchmarkReceiveMetadata(b *testing.B) { valuePtrs[j] = &values[j] } - b.ReportAllocs() - b.ResetTimer() - // Prepare a SELECT query to retrieve metadata stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1")) defer stmt.Close() // Benchmark metadata retrieval + b.ReportAllocs() + b.ResetTimer() for range b.N { rows := tb.checkRows(stmt.Query()) diff --git a/connection.go b/connection.go index 58c763fad..5648e47d8 100644 --- a/connection.go +++ b/connection.go @@ -231,7 +231,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if columnCount > 0 { if mc.extCapabilities&clientCacheMetadata != 0 { - if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil { + if stmt.columns, err = mc.readColumns(int(columnCount), nil); err != nil { return nil, err } } else { @@ -448,7 +448,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) } // Columns - rows.rs.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(resLen, nil) return rows, err } diff --git a/packets.go b/packets.go index 1319f9e64..b8f061263 100644 --- a/packets.go +++ b/packets.go @@ -702,8 +702,11 @@ func (mc *okHandler) handleOkPacket(data []byte) error { // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 -func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { +func (mc *mysqlConn) readColumns(count int, old []mysqlField) ([]mysqlField, error) { columns := make([]mysqlField, count) + if len(old) != count { + old = nil + } for i := range count { data, err := mc.readPacket() @@ -731,7 +734,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { return nil, err } pos += n - columns[i].tableName = string(tableName) + if old != nil && old[i].tableName == string(tableName) { + // avoid allocating new string + columns[i].tableName = old[i].tableName + } else { + columns[i].tableName = string(tableName) + } } else { n, err = skipLengthEncodedString(data[pos:]) if err != nil { @@ -752,7 +760,12 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if err != nil { return nil, err } - columns[i].name = string(name) + if old != nil && old[i].name == string(name) { + // avoid allocating new string + columns[i].name = old[i].name + } else { + columns[i].name = string(name) + } pos += n // Original name [len coded string] diff --git a/rows.go b/rows.go index e41fda6f4..190e75f9b 100644 --- a/rows.go +++ b/rows.go @@ -186,7 +186,7 @@ func (rows *binaryRows) NextResultSet() error { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(resLen, nil) return err } @@ -208,7 +208,7 @@ func (rows *textRows) NextResultSet() (err error) { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(resLen, nil) return err } diff --git a/statement.go b/statement.go index 0f6c65a37..2db8960e5 100644 --- a/statement.go +++ b/statement.go @@ -74,7 +74,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Columns if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 { // we can not skip column metadata because next stmt.Query() may use it. - if stmt.columns, err = mc.readColumns(resLen); err != nil { + if stmt.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { return nil, err } } else { @@ -125,7 +125,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if resLen > 0 { rows.mc = mc if metadataFollows { - if rows.rs.columns, err = mc.readColumns(resLen); err != nil { + if rows.rs.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { return nil, err } stmt.columns = rows.rs.columns From af56fba75c83ccdef7925a9aeaa01729c4f47e52 Mon Sep 17 00:00:00 2001 From: demouth <1133178+demouth@users.noreply.github.com> Date: Thu, 8 May 2025 15:38:58 +0900 Subject: [PATCH 123/123] update outdated MySQL internals documentation links (#1714) --- AUTHORS | 1 + auth.go | 2 +- connector.go | 2 +- const.go | 2 +- packets.go | 22 +++++++++++----------- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/AUTHORS b/AUTHORS index 53841ef51..05e71df48 100644 --- a/AUTHORS +++ b/AUTHORS @@ -38,6 +38,7 @@ Daniel Montoya Daniel Nichter Daniël van Eeden Dave Protasowski +Demouth Diego Dupin Dirkjan Bussink DisposaBoy diff --git a/auth.go b/auth.go index 74e1bd03e..610044fc1 100644 --- a/auth.go +++ b/auth.go @@ -305,7 +305,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { if !mc.cfg.AllowNativePasswords { return nil, ErrNativePassword } - // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + // https://dev.mysql.com/doc/dev/mysql-server/8.4.5/page_protocol_connection_phase_authentication_methods_native_password_authentication.html // Native password authentication only need and will need 20-byte challenge. authResp := scramblePassword(authData[:20], mc.cfg.Passwd) return authResp, nil diff --git a/connector.go b/connector.go index dca473fa7..db2bd7cf9 100644 --- a/connector.go +++ b/connector.go @@ -162,7 +162,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { // Handle response to auth packet, switch methods if possible if err = mc.handleAuthResult(authData, plugin); err != nil { // Authentication failed and MySQL has already closed the connection - // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // (https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html#sect_protocol_connection_phase_fast_path_fails). // Do not send COM_QUIT, just cleanup and return the error. mc.cleanup() return nil, err diff --git a/const.go b/const.go index 311e92eaf..6f0cdf303 100644 --- a/const.go +++ b/const.go @@ -32,7 +32,7 @@ const ( ) // MySQL constants documentation: -// http://dev.mysql.com/doc/internals/en/client-server-protocol.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html const ( iOK byte = 0x00 diff --git a/packets.go b/packets.go index b8f061263..11835b1b1 100644 --- a/packets.go +++ b/packets.go @@ -322,7 +322,7 @@ func (mc *mysqlConn) initCapabilities(serverCapabilities capabilityFlag, serverE } // Client Authentication Packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_response.html func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { // packet header 4 // capabilities 4 @@ -419,7 +419,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string return mc.writePacket(data) } -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_response.html func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) data, err := mc.buf.takeBuffer(pktLen) @@ -517,7 +517,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { case iEOF: if len(data) == 1 { - // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_old_auth_switch_request.html return nil, "mysql_old_password", nil } pluginEndIndex := bytes.IndexByte(data, 0x00) @@ -585,7 +585,7 @@ func (mc *okHandler) readResultSetHeaderPacket() (int, bool, error) { } // Error Packet -// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_err_packet.html func (mc *mysqlConn) handleErrorPacket(data []byte) error { if data[0] != iERR { return ErrMalformPkt @@ -667,7 +667,7 @@ func (mc *mysqlConn) clearResult() *okHandler { } // Ok Packet -// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_ok_packet.html func (mc *okHandler) handleOkPacket(data []byte) error { var n, m int var affectedRows, insertId uint64 @@ -701,7 +701,7 @@ func (mc *okHandler) handleOkPacket(data []byte) error { } // Read Packets as Field Packets until EOF-Packet or an Error appears -// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_column_definition.html#sect_protocol_com_query_response_text_resultset_column_definition_41 func (mc *mysqlConn) readColumns(count int, old []mysqlField) ([]mysqlField, error) { columns := make([]mysqlField, count) if len(old) != count { @@ -806,7 +806,7 @@ func (mc *mysqlConn) readColumns(count int, old []mysqlField) ([]mysqlField, err } // Read Packets as Field Packets until EOF-Packet or an Error appears -// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset_row.html func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc @@ -968,7 +968,7 @@ func (mc *mysqlConn) skipRows() error { ******************************************************************************/ // Prepare Result Packets -// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { data, err := stmt.mc.readPacket() if err == nil { @@ -995,7 +995,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { return 0, err } -// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_send_long_data.html func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen @@ -1043,7 +1043,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { } // Execute Prepared Statement -// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( @@ -1264,7 +1264,7 @@ func (mc *okHandler) discardResults() error { return nil } -// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row func (rows *binaryRows) readRow(dest []driver.Value) error { data, err := rows.mc.readPacket() if err != nil {