From 6e2f60e97e02cc93f8a2bd92c8bd6d4e492b469a Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Mon, 1 Apr 2019 16:40:01 +0200 Subject: [PATCH 1/6] test: Implement TestContextCancelQueryWhileScan --- driver_test.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/driver_test.go b/driver_test.go index c35588a09..9c3d286ce 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2938,3 +2938,58 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() }) } + +// TestRawBytesAreNotModified checks for a race condition that arises when a query context +// is canceled while a user is calling rows.Scan. This is a more stringent test than the one +// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using +// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit +// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers. +func TestRawBytesAreNotModified(t *testing.T) { + const blob = "abcdefghijklmnop" + const contextRaceIterations = 20 + const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row. + const insertRows = 4 + + var sqlBlobs = [2]string{ + strings.Repeat(blob, blobSize/len(blob)), + strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)), + } + + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + for i := 0; i < insertRows; i++ { + dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1]) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) + if err != nil { + t.Fatal(err) + } + + var b int + var raw sql.RawBytes + for rows.Next() { + if err := rows.Scan(&b, &raw); err != nil { + t.Fatal(err) + } + + before := string(raw) + // Ensure cancelling the query does not corrupt the contents of `raw` + cancel() + time.Sleep(time.Microsecond * 100) + after := string(raw) + + if before != after { + t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) + } + } + rows.Close() + }() + } + }) +} From 505cdc7eb3f304d8b045f8afef4a08b84d57cc7b Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Mon, 1 Apr 2019 17:12:34 +0200 Subject: [PATCH 2/6] buffer: Use a double-buffering scheme to prevent data races --- buffer.go | 48 +++++++++++++++++++++++++++++++++++------------- rows.go | 7 +++++++ 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/buffer.go b/buffer.go index 19486bd6f..1b8f3b235 100644 --- a/buffer.go +++ b/buffer.go @@ -15,47 +15,69 @@ import ( ) const defaultBufSize = 4096 +const maxCachedBufSize = 16 * 1024 // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. +// This buffer is backed by two byte slices in a double-buffering scheme type buffer struct { buf []byte // buf is a byte buffer who's length and capacity are equal. nc net.Conn idx int length int timeout time.Duration + dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer + flipcnt uint // flipccnt is the current buffer counter for double-buffering } // newBuffer allocates and returns a new buffer. func newBuffer(nc net.Conn) buffer { + fg := make([]byte, defaultBufSize) return buffer{ - buf: make([]byte, defaultBufSize), - nc: nc, + buf: fg, + nc: nc, + dbuf: [2][]byte{fg, nil}, } } +// flip replaces the active buffer with the background buffer +// this is a delayed flip that simply increases the buffer counter; +// the actual flip will be performed the next time we call `buffer.fill` +func (b *buffer) flip() { + b.flipcnt += 1 +} + // fill reads into the buffer until at least _need_ bytes are in it func (b *buffer) fill(need int) error { n := b.length + // fill data into its double-buffering target: if we've called + // flip on this buffer, we'll be copying to the background buffer, + // and then filling it with network data; otherwise we'll just move + // the contents of the current buffer to the front before filling it + dest := b.dbuf[b.flipcnt&1] + + // grow buffer if necessary to fit the whole packet. + if need > len(dest) { + // Round up to the next multiple of the default size + dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) - // move existing data to the beginning - if n > 0 && b.idx > 0 { - copy(b.buf[0:n], b.buf[b.idx:]) + // if the allocated buffer is not too large, move it to backing storage + // to prevent extra allocations on applications that perform large reads + if len(dest) <= maxCachedBufSize { + b.dbuf[b.flipcnt&1] = dest + } } - // grow buffer if necessary - // TODO: let the buffer shrink again at some point - // Maybe keep the org buf slice and swap back? - if need > len(b.buf) { - // Round up to the next multiple of the default size - newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) - copy(newBuf, b.buf) - b.buf = newBuf + // if we're filling the fg buffer, move the existing data to the start of it. + // if we're filling the bg buffer, copy over the data + if n > 0 { + copy(dest[0:n], b.buf[b.idx:]) } + b.buf = dest b.idx = 0 for { diff --git a/rows.go b/rows.go index d3b1e2822..888bdb5f0 100644 --- a/rows.go +++ b/rows.go @@ -111,6 +111,13 @@ func (rows *mysqlRows) Close() (err error) { return err } + // flip the buffer for this connection if we need to drain it. + // note that for a successful query (i.e. one where rows.next() + // has been called until it returns false), `rows.mc` will be nil + // by the time the user calls `(*Rows).Close`, so we won't reach this + // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 + mc.buf.flip() + // Remove unread packets from stream if !rows.rs.done { err = mc.readUntilEOF() From e4d8c6382f86247f9db140e6b90cbe3ca08fb333 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Wed, 3 Apr 2019 21:33:13 +0900 Subject: [PATCH 3/6] add benchmark for sql.RawBytes --- benchmark_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/benchmark_test.go b/benchmark_test.go index 5828d40f9..869784c7a 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -317,3 +317,52 @@ func BenchmarkExecContext(b *testing.B) { }) } } + +// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes. +// "size=" means size of each blobs. +func BenchmarkQueryRawBytes(b *testing.B) { + var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} + db := initDB(b, + "DROP TABLE IF EXISTS bench_rawbytes", + "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", + ) + defer db.Close() + + blob := make([]byte, sizes[len(sizes)-1]) + for i := range blob { + blob[i] = 42 + } + for i := 0; i < 100; i++ { + _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob) + if err != nil { + b.Fatal(err) + } + } + for _, s := range sizes { + b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) { + b.ReportAllocs() + for j := 0; j < b.N; j++ { + rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s) + if err != nil { + b.Fatal(err) + } + nrows := 0 + for rows.Next() { + var buf sql.RawBytes + err := rows.Scan(&buf) + if err != nil { + b.Fatal(err) + } + if len(buf) != s { + b.Fatalf("size mismatch: expected %v, got %v", s, len(buf)) + } + nrows++ + } + rows.Close() + if nrows != 100 { + b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows) + } + } + }) + } +} From e1c9555ca99994c278f4e4598a5f61f92cfc06bd Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 4 Apr 2019 11:49:51 +0900 Subject: [PATCH 4/6] don't reuse connection over subbenchs --- benchmark_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/benchmark_test.go b/benchmark_test.go index 869784c7a..3e25a3bf2 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -338,9 +338,14 @@ func BenchmarkQueryRawBytes(b *testing.B) { b.Fatal(err) } } + for _, s := range sizes { b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) { + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) b.ReportAllocs() + b.ResetTimer() + for j := 0; j < b.N; j++ { rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s) if err != nil { From 84ff6e3256f8d874045d9f76fa3b4d42965f10ab Mon Sep 17 00:00:00 2001 From: Vicent Marti Date: Thu, 4 Apr 2019 09:43:08 +0200 Subject: [PATCH 5/6] buffer: Increase maxCachedBufferSize --- buffer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/buffer.go b/buffer.go index 1b8f3b235..fd22f22c4 100644 --- a/buffer.go +++ b/buffer.go @@ -15,7 +15,7 @@ import ( ) const defaultBufSize = 4096 -const maxCachedBufSize = 16 * 1024 +const maxCachedBufSize = 256 * 1024 // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. From 8af699773e8a41b3a9baeb48b0239796c9ccd165 Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Thu, 4 Apr 2019 12:01:10 +0200 Subject: [PATCH 6/6] buffer: do not use superfluous 0 Co-Authored-By: vmg --- buffer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/buffer.go b/buffer.go index fd22f22c4..0774c5c8c 100644 --- a/buffer.go +++ b/buffer.go @@ -74,7 +74,7 @@ func (b *buffer) fill(need int) error { // if we're filling the fg buffer, move the existing data to the start of it. // if we're filling the bg buffer, copy over the data if n > 0 { - copy(dest[0:n], b.buf[b.idx:]) + copy(dest[:n], b.buf[b.idx:]) } b.buf = dest