Skip to content

Commit df597a2

Browse files
Vicent Martímethane
Vicent Martí
authored andcommitted
buffer: Use a double-buffering scheme to prevent data races (go-sql-driver#943)
Fixes go-sql-driver#903 Co-Authored-By: vmg <[email protected]>
1 parent c0f6b44 commit df597a2

File tree

4 files changed

+151
-13
lines changed

4 files changed

+151
-13
lines changed

benchmark_test.go

+54
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,57 @@ func BenchmarkExecContext(b *testing.B) {
317317
})
318318
}
319319
}
320+
321+
// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes.
322+
// "size=" means size of each blobs.
323+
func BenchmarkQueryRawBytes(b *testing.B) {
324+
var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000}
325+
db := initDB(b,
326+
"DROP TABLE IF EXISTS bench_rawbytes",
327+
"CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)",
328+
)
329+
defer db.Close()
330+
331+
blob := make([]byte, sizes[len(sizes)-1])
332+
for i := range blob {
333+
blob[i] = 42
334+
}
335+
for i := 0; i < 100; i++ {
336+
_, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob)
337+
if err != nil {
338+
b.Fatal(err)
339+
}
340+
}
341+
342+
for _, s := range sizes {
343+
b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) {
344+
db.SetMaxIdleConns(0)
345+
db.SetMaxIdleConns(1)
346+
b.ReportAllocs()
347+
b.ResetTimer()
348+
349+
for j := 0; j < b.N; j++ {
350+
rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s)
351+
if err != nil {
352+
b.Fatal(err)
353+
}
354+
nrows := 0
355+
for rows.Next() {
356+
var buf sql.RawBytes
357+
err := rows.Scan(&buf)
358+
if err != nil {
359+
b.Fatal(err)
360+
}
361+
if len(buf) != s {
362+
b.Fatalf("size mismatch: expected %v, got %v", s, len(buf))
363+
}
364+
nrows++
365+
}
366+
rows.Close()
367+
if nrows != 100 {
368+
b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows)
369+
}
370+
}
371+
})
372+
}
373+
}

buffer.go

+35-13
Original file line numberDiff line numberDiff line change
@@ -15,47 +15,69 @@ import (
1515
)
1616

1717
const defaultBufSize = 4096
18+
const maxCachedBufSize = 256 * 1024
1819

1920
// A buffer which is used for both reading and writing.
2021
// This is possible since communication on each connection is synchronous.
2122
// In other words, we can't write and read simultaneously on the same connection.
2223
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
2324
// Also highly optimized for this particular use case.
25+
// This buffer is backed by two byte slices in a double-buffering scheme
2426
type buffer struct {
2527
buf []byte // buf is a byte buffer who's length and capacity are equal.
2628
nc net.Conn
2729
idx int
2830
length int
2931
timeout time.Duration
32+
dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
33+
flipcnt uint // flipccnt is the current buffer counter for double-buffering
3034
}
3135

3236
// newBuffer allocates and returns a new buffer.
3337
func newBuffer(nc net.Conn) buffer {
38+
fg := make([]byte, defaultBufSize)
3439
return buffer{
35-
buf: make([]byte, defaultBufSize),
36-
nc: nc,
40+
buf: fg,
41+
nc: nc,
42+
dbuf: [2][]byte{fg, nil},
3743
}
3844
}
3945

46+
// flip replaces the active buffer with the background buffer
47+
// this is a delayed flip that simply increases the buffer counter;
48+
// the actual flip will be performed the next time we call `buffer.fill`
49+
func (b *buffer) flip() {
50+
b.flipcnt += 1
51+
}
52+
4053
// fill reads into the buffer until at least _need_ bytes are in it
4154
func (b *buffer) fill(need int) error {
4255
n := b.length
56+
// fill data into its double-buffering target: if we've called
57+
// flip on this buffer, we'll be copying to the background buffer,
58+
// and then filling it with network data; otherwise we'll just move
59+
// the contents of the current buffer to the front before filling it
60+
dest := b.dbuf[b.flipcnt&1]
61+
62+
// grow buffer if necessary to fit the whole packet.
63+
if need > len(dest) {
64+
// Round up to the next multiple of the default size
65+
dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
4366

44-
// move existing data to the beginning
45-
if n > 0 && b.idx > 0 {
46-
copy(b.buf[0:n], b.buf[b.idx:])
67+
// if the allocated buffer is not too large, move it to backing storage
68+
// to prevent extra allocations on applications that perform large reads
69+
if len(dest) <= maxCachedBufSize {
70+
b.dbuf[b.flipcnt&1] = dest
71+
}
4772
}
4873

49-
// grow buffer if necessary
50-
// TODO: let the buffer shrink again at some point
51-
// Maybe keep the org buf slice and swap back?
52-
if need > len(b.buf) {
53-
// Round up to the next multiple of the default size
54-
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
55-
copy(newBuf, b.buf)
56-
b.buf = newBuf
74+
// if we're filling the fg buffer, move the existing data to the start of it.
75+
// if we're filling the bg buffer, copy over the data
76+
if n > 0 {
77+
copy(dest[:n], b.buf[b.idx:])
5778
}
5879

80+
b.buf = dest
5981
b.idx = 0
6082

6183
for {

driver_test.go

+55
Original file line numberDiff line numberDiff line change
@@ -2938,3 +2938,58 @@ func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
29382938
// This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
29392939
})
29402940
}
2941+
2942+
// TestRawBytesAreNotModified checks for a race condition that arises when a query context
2943+
// is canceled while a user is calling rows.Scan. This is a more stringent test than the one
2944+
// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using
2945+
// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit
2946+
// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers.
2947+
func TestRawBytesAreNotModified(t *testing.T) {
2948+
const blob = "abcdefghijklmnop"
2949+
const contextRaceIterations = 20
2950+
const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row.
2951+
const insertRows = 4
2952+
2953+
var sqlBlobs = [2]string{
2954+
strings.Repeat(blob, blobSize/len(blob)),
2955+
strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)),
2956+
}
2957+
2958+
runTests(t, dsn, func(dbt *DBTest) {
2959+
dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8")
2960+
for i := 0; i < insertRows; i++ {
2961+
dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1])
2962+
}
2963+
2964+
for i := 0; i < contextRaceIterations; i++ {
2965+
func() {
2966+
ctx, cancel := context.WithCancel(context.Background())
2967+
defer cancel()
2968+
2969+
rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`)
2970+
if err != nil {
2971+
t.Fatal(err)
2972+
}
2973+
2974+
var b int
2975+
var raw sql.RawBytes
2976+
for rows.Next() {
2977+
if err := rows.Scan(&b, &raw); err != nil {
2978+
t.Fatal(err)
2979+
}
2980+
2981+
before := string(raw)
2982+
// Ensure cancelling the query does not corrupt the contents of `raw`
2983+
cancel()
2984+
time.Sleep(time.Microsecond * 100)
2985+
after := string(raw)
2986+
2987+
if before != after {
2988+
t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i)
2989+
}
2990+
}
2991+
rows.Close()
2992+
}()
2993+
}
2994+
})
2995+
}

rows.go

+7
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,13 @@ func (rows *mysqlRows) Close() (err error) {
111111
return err
112112
}
113113

114+
// flip the buffer for this connection if we need to drain it.
115+
// note that for a successful query (i.e. one where rows.next()
116+
// has been called until it returns false), `rows.mc` will be nil
117+
// by the time the user calls `(*Rows).Close`, so we won't reach this
118+
// see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
119+
mc.buf.flip()
120+
114121
// Remove unread packets from stream
115122
if !rows.rs.done {
116123
err = mc.readUntilEOF()

0 commit comments

Comments
 (0)