diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2e07fea9..207a2453 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,6 +96,10 @@ jobs: run: | go test -v '-race' '-covermode=atomic' '-coverprofile=coverage.out' -parallel 10 + - name: benchmark + run: | + go test -run '^$' -bench . + - name: Send coverage uses: shogo82148/actions-goveralls@v1 with: diff --git a/AUTHORS b/AUTHORS index 510b869b..ec346e20 100644 --- a/AUTHORS +++ b/AUTHORS @@ -25,6 +25,7 @@ Asta Xie B Lamarche Bes Dollma Bogdan Constantinescu +Brad Higgins Brian Hendriks Bulat Gaifullin Caine Jette @@ -37,6 +38,7 @@ Daniel Montoya Daniel Nichter Daniƫl van Eeden Dave Protasowski +Diego Dupin Dirkjan Bussink DisposaBoy Egor Smolyakov @@ -133,6 +135,7 @@ Ziheng Lyu Barracuda Networks, Inc. Counting Ltd. +Defined Networking Inc. DigitalOcean Inc. Dolthub Inc. dyves labs AG diff --git a/benchmark_test.go b/benchmark_test.go index 5c9a046b..1c3f64d3 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -46,10 +46,10 @@ func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { return stmt } -func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { +func initDB(b *testing.B, compress bool, queries ...string) *sql.DB { tb := (*TB)(b) comprStr := "" - if useCompression { + if compress { comprStr = "&compress=1" } db := tb.checkDB(sql.Open(driverNameTest, dsn+comprStr)) @@ -64,16 +64,15 @@ func initDB(b *testing.B, useCompression bool, queries ...string) *sql.DB { const concurrencyLevel = 10 func BenchmarkQuery(b *testing.B) { - benchmarkQueryHelper(b, false) + benchmarkQuery(b, false) } -func BenchmarkQueryCompression(b *testing.B) { - benchmarkQueryHelper(b, true) +func BenchmarkQueryCompressed(b *testing.B) { + benchmarkQuery(b, true) } -func benchmarkQueryHelper(b *testing.B, compr bool) { +func benchmarkQuery(b *testing.B, compr bool) { tb := (*TB)(b) - b.StopTimer() b.ReportAllocs() db := initDB(b, compr, "DROP TABLE IF EXISTS foo", @@ -115,8 +114,6 @@ func benchmarkQueryHelper(b *testing.B, compr bool) { func BenchmarkExec(b *testing.B) { tb := (*TB)(b) - b.StopTimer() - b.ReportAllocs() db := tb.checkDB(sql.Open(driverNameTest, dsn)) db.SetMaxIdleConns(concurrencyLevel) defer db.Close() @@ -128,9 +125,11 @@ func BenchmarkExec(b *testing.B) { var wg sync.WaitGroup wg.Add(concurrencyLevel) defer wg.Wait() - b.StartTimer() - for i := 0; i < concurrencyLevel; i++ { + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < concurrencyLevel; i++ { go func() { for { if atomic.AddInt64(&remain, -1) < 0 { @@ -158,14 +157,15 @@ func initRoundtripBenchmarks() ([]byte, int, int) { } func BenchmarkRoundtripTxt(b *testing.B) { - b.StopTimer() sample, min, max := initRoundtripBenchmarks() sampleString := string(sample) - b.ReportAllocs() tb := (*TB)(b) db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() - b.StartTimer() + + b.ReportAllocs() + b.ResetTimer() + var result string for i := 0; i < b.N; i++ { length := min + i @@ -192,15 +192,15 @@ func BenchmarkRoundtripTxt(b *testing.B) { } func BenchmarkRoundtripBin(b *testing.B) { - b.StopTimer() sample, min, max := initRoundtripBenchmarks() - b.ReportAllocs() tb := (*TB)(b) db := tb.checkDB(sql.Open(driverNameTest, dsn)) defer db.Close() stmt := tb.checkStmt(db.Prepare("SELECT ?")) defer stmt.Close() - b.StartTimer() + + b.ReportAllocs() + b.ResetTimer() var result sql.RawBytes for i := 0; i < b.N; i++ { length := min + i @@ -385,10 +385,9 @@ func BenchmarkQueryRawBytes(b *testing.B) { } } -// BenchmarkReceiveMassiveRows measures performance of receiving large number of rows. -func BenchmarkReceiveMassiveRows(b *testing.B) { +func benchmark10kRows(b *testing.B, compress bool) { // Setup -- prepare 10000 rows. - db := initDB(b, false, + db := initDB(b, compress, "DROP TABLE IF EXISTS foo", "CREATE TABLE foo (id INT PRIMARY KEY, val TEXT)") defer db.Close() @@ -399,11 +398,14 @@ func BenchmarkReceiveMassiveRows(b *testing.B) { b.Errorf("failed to prepare query: %v", err) return } + + args := make([]any, 200) + for i := 1; i < 200; i+=2 { + args[i] = sval + } for i := 0; i < 10000; i += 100 { - args := make([]any, 200) for j := 0; j < 100; j++ { args[j*2] = i + j - args[j*2+1] = sval } _, err := stmt.Exec(args...) if err != nil { @@ -413,30 +415,43 @@ func BenchmarkReceiveMassiveRows(b *testing.B) { } stmt.Close() - // Use b.Run() to skip expensive setup. + // benchmark function called several times with different b.N. + // it means heavy setup is called multiple times. + // Use b.Run() to run expensive setup only once. + // Go 1.24 introduced b.Loop() for this purpose. But we keep this + // benchmark compatible with Go 1.20. b.Run("query", func(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { rows, err := db.Query(`SELECT id, val FROM foo`) if err != nil { b.Errorf("failed to select: %v", err) return } + // rows.Scan() escapes arguments. So these variables must be defined + // before loop. + var i int + var s sql.RawBytes for rows.Next() { - var i int - var s sql.RawBytes - err = rows.Scan(&i, &s) - if err != nil { + if err := rows.Scan(&i, &s); err != nil { b.Errorf("failed to scan: %v", err) - _ = rows.Close() + rows.Close() return } } if err = rows.Err(); err != nil { b.Errorf("failed to read rows: %v", err) } - _ = rows.Close() + rows.Close() } }) } + +// BenchmarkReceive10kRows measures performance of receiving large number of rows. +func BenchmarkReceive10kRows(b *testing.B) { + benchmark10kRows(b, false) +} + +func BenchmarkReceive10kRowsCompressed(b *testing.B) { + benchmark10kRows(b, true) +} diff --git a/buffer.go b/buffer.go index a6532431..f895e87b 100644 --- a/buffer.go +++ b/buffer.go @@ -42,6 +42,11 @@ func (b *buffer) busy() bool { return len(b.buf) > 0 } +// len returns how many bytes in the read buffer. +func (b *buffer) len() int { + return len(b.buf) +} + // fill reads into the read buffer until at least _need_ bytes are in it. func (b *buffer) fill(need int, r readerFunc) error { // we'll move the contents of the current buffer to dest before filling it. @@ -86,17 +91,10 @@ func (b *buffer) fill(need int, r readerFunc) error { // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int, r readerFunc) ([]byte, error) { - if len(b.buf) < need { - // refill - if err := b.fill(need, r); err != nil { - return nil, err - } - } - - data := b.buf[:need] +func (b *buffer) readNext(need int) []byte { + data := b.buf[:need:need] b.buf = b.buf[need:] - return data, nil + return data } // takeBuffer returns a buffer with the requested size. diff --git a/compress.go b/compress.go index fa42772a..e247a65a 100644 --- a/compress.go +++ b/compress.go @@ -84,9 +84,9 @@ func (c *compIO) reset() { c.buff.Reset() } -func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) { +func (c *compIO) readNext(need int) ([]byte, error) { for c.buff.Len() < need { - if err := c.readCompressedPacket(r); err != nil { + if err := c.readCompressedPacket(); err != nil { return nil, err } } @@ -94,8 +94,8 @@ func (c *compIO) readNext(need int, r readerFunc) ([]byte, error) { return data[:need:need], nil // prevent caller writes into c.buff } -func (c *compIO) readCompressedPacket(r readerFunc) error { - header, err := c.mc.buf.readNext(7, r) // size of compressed header +func (c *compIO) readCompressedPacket() error { + header, err := c.mc.readNext(7) if err != nil { return err } @@ -103,7 +103,7 @@ func (c *compIO) readCompressedPacket(r readerFunc) error { // compressed header structure comprLength := getUint24(header[0:3]) - compressionSequence := uint8(header[3]) + compressionSequence := header[3] uncompressedLength := getUint24(header[4:7]) if debug { fmt.Printf("uncompress cmplen=%v uncomplen=%v pkt_cmp_seq=%v expected_cmp_seq=%v\n", @@ -120,7 +120,7 @@ func (c *compIO) readCompressedPacket(r readerFunc) error { c.mc.sequence = compressionSequence + 1 c.mc.compressSequence = c.mc.sequence - comprData, err := c.mc.buf.readNext(comprLength, r) + comprData, err := c.mc.readNext(comprLength) if err != nil { return err } diff --git a/driver_test.go b/driver_test.go index 00e82865..46caa0e2 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1609,10 +1609,12 @@ func TestCollation(t *testing.T) { t.Skipf("MySQL server not running on %s", netAddr) } - defaultCollation := "utf8mb4_general_ci" + // MariaDB may override collation specified by handshake with `character_set_collations` variable. + // https://mariadb.com/kb/en/setting-character-sets-and-collations/#changing-default-collation + // https://mariadb.com/kb/en/server-system-variables/#character_set_collations + // utf8mb4_general_ci, utf8mb3_general_ci will be overridden by default MariaDB. + // Collations other than charasets default are not overridden. So utf8mb4_unicode_ci is safe. testCollations := []string{ - "", // do not set - defaultCollation, // driver default "latin1_general_ci", "binary", "utf8mb4_unicode_ci", @@ -1620,24 +1622,19 @@ func TestCollation(t *testing.T) { } for _, collation := range testCollations { - var expected, tdsn string - if collation != "" { - tdsn = dsn + "&collation=" + collation - expected = collation - } else { - tdsn = dsn - expected = defaultCollation - } - - runTests(t, tdsn, func(dbt *DBTest) { - var got string - if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { - dbt.Fatal(err) - } + t.Run(collation, func(t *testing.T) { + tdsn := dsn + "&collation=" + collation + expected := collation - if got != expected { - dbt.Fatalf("expected connection collation %s but got %s", expected, got) - } + runTests(t, tdsn, func(dbt *DBTest) { + var got string + if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { + dbt.Fatal(err) + } + if got != expected { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } + }) }) } } @@ -1685,7 +1682,7 @@ func TestRawBytesResultExceedsBuffer(t *testing.T) { } func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + zones := []string{"UTC", "America/New_York", "Asia/Hong_Kong", "Local"} // Regression test for timezone handling tzTest := func(dbt *DBTest) { @@ -1693,8 +1690,8 @@ func TestTimezoneConversion(t *testing.T) { dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") // Insert local time into database (should be converted) - usCentral, _ := time.LoadLocation("US/Central") - reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + newYorkTz, _ := time.LoadLocation("America/New_York") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(newYorkTz) dbt.mustExec("INSERT INTO test VALUE (?)", reftime) // Retrieve time from DB @@ -1713,7 +1710,7 @@ func TestTimezoneConversion(t *testing.T) { // Check that dates match if reftime.Unix() != dbTime.Unix() { dbt.Errorf("times do not match.\n") - dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(%v)=%v\n", newYorkTz, reftime) dbt.Errorf(" Now(UTC)=%v\n", dbTime) } } @@ -3541,6 +3538,15 @@ func TestConnectionAttributes(t *testing.T) { dbt := &DBTest{t, db} + var varName string + var varValue string + err := dbt.db.QueryRow("SHOW VARIABLES LIKE 'performance_schema'").Scan(&varName, &varValue) + if err != nil { + t.Fatalf("error: %s", err.Error()) + } + if varValue != "ON" { + t.Skipf("Performance schema is not enabled. skipping") + } queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" rows := dbt.mustQuery(queryString) defer rows.Close() diff --git a/packets.go b/packets.go index 4b836216..15b000d6 100644 --- a/packets.go +++ b/packets.go @@ -25,19 +25,30 @@ import ( // https://dev.mysql.com/doc/dev/mysql-server/latest/PAGE_PROTOCOL.html // https://mariadb.com/kb/en/clientserver-protocol/ +// read n bytes from mc.buf +func (mc *mysqlConn) readNext(n int) ([]byte, error) { + if mc.buf.len() < n { + err := mc.buf.fill(n, mc.readWithTimeout) + if err != nil { + return nil, err + } + } + return mc.buf.readNext(n), nil +} + // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { var prevData []byte invalidSequence := false - readNext := mc.buf.readNext + readNext := mc.readNext if mc.compress { readNext = mc.compIO.readNext } for { // read packet header - data, err := readNext(4, mc.readWithTimeout) + data, err := readNext(4) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { @@ -85,7 +96,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = readNext(pktLen, mc.readWithTimeout) + data, err = readNext(pktLen) if err != nil { mc.close() if cerr := mc.canceled.Value(); cerr != nil { diff --git a/transaction.go b/transaction.go index 4a4b6100..8c502f49 100644 --- a/transaction.go +++ b/transaction.go @@ -13,18 +13,32 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.closed.Load() { + if tx.mc == nil { return ErrInvalidConn } + if tx.mc.closed.Load() { + err = tx.mc.error() + if err == nil { + err = ErrInvalidConn + } + return + } err = tx.mc.exec("COMMIT") tx.mc = nil return } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.closed.Load() { + if tx.mc == nil { return ErrInvalidConn } + if tx.mc.closed.Load() { + err = tx.mc.error() + if err == nil { + err = ErrInvalidConn + } + return + } err = tx.mc.exec("ROLLBACK") tx.mc = nil return