Skip to content

Commit 6be42e0

Browse files
stevenhmethane
authored andcommittedNov 16, 2018
Improve buffer handling (go-sql-driver#890)
* Eliminate redundant size test in takeBuffer. * Change buffer takeXXX functions to return an error to make it explicit that they can fail. * Add missing error check in handleAuthResult. * Add buffer.store(..) method which can be used by external buffer consumers to update the raw buffer. * Fix some typos and unnecessary UTF-8 characters in comments. * Improve buffer function docs. * Add comments to explain some non-obvious behavior around buffer handling.
1 parent 369b5d6 commit 6be42e0

File tree

6 files changed

+72
-49
lines changed

6 files changed

+72
-49
lines changed
 

‎AUTHORS

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ Shuode Li <elemount at qq.com>
7373
Soroush Pour <me at soroushjp.com>
7474
Stan Putrya <root.vagner at gmail.com>
7575
Stanley Gunawan <gunawan.stanley at gmail.com>
76+
Steven Hartland <steven.hartland at multiplay.co.uk>
7677
Thomas Wodarek <wodarekwebpage at gmail.com>
7778
Tom Jenkinson <tom at tjenkinson.me>
7879
Xiangyu Hu <xiangyu.hu at outlook.com>
@@ -90,3 +91,4 @@ Keybase Inc.
9091
Percona LLC
9192
Pivotal Inc.
9293
Stripe Inc.
94+
Multiplay Ltd.

‎auth.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
360360
pubKey := mc.cfg.pubKey
361361
if pubKey == nil {
362362
// request public key from server
363-
data := mc.buf.takeSmallBuffer(4 + 1)
363+
data, err := mc.buf.takeSmallBuffer(4 + 1)
364+
if err != nil {
365+
return err
366+
}
364367
data[4] = cachingSha2PasswordRequestPublicKey
365368
mc.writePacket(data)
366369

367370
// parse public key
368-
data, err := mc.readPacket()
369-
if err != nil {
371+
if data, err = mc.readPacket(); err != nil {
370372
return err
371373
}
372374

‎buffer.go

+31-18
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@ const defaultBufSize = 4096
2222
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
2323
// Also highly optimized for this particular use case.
2424
type buffer struct {
25-
buf []byte
25+
buf []byte // buf is a byte buffer who's length and capacity are equal.
2626
nc net.Conn
2727
idx int
2828
length int
2929
timeout time.Duration
3030
}
3131

32+
// newBuffer allocates and returns a new buffer.
3233
func newBuffer(nc net.Conn) buffer {
33-
var b [defaultBufSize]byte
3434
return buffer{
35-
buf: b[:],
35+
buf: make([]byte, defaultBufSize),
3636
nc: nc,
3737
}
3838
}
@@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) {
105105
return b.buf[offset:b.idx], nil
106106
}
107107

108-
// returns a buffer with the requested size.
108+
// takeBuffer returns a buffer with the requested size.
109109
// If possible, a slice from the existing buffer is returned.
110110
// Otherwise a bigger buffer is made.
111111
// Only one buffer (total) can be used at a time.
112-
func (b *buffer) takeBuffer(length int) []byte {
112+
func (b *buffer) takeBuffer(length int) ([]byte, error) {
113113
if b.length > 0 {
114-
return nil
114+
return nil, ErrBusyBuffer
115115
}
116116

117117
// test (cheap) general case first
118-
if length <= defaultBufSize || length <= cap(b.buf) {
119-
return b.buf[:length]
118+
if length <= cap(b.buf) {
119+
return b.buf[:length], nil
120120
}
121121

122122
if length < maxPacketSize {
123123
b.buf = make([]byte, length)
124-
return b.buf
124+
return b.buf, nil
125125
}
126-
return make([]byte, length)
126+
127+
// buffer is larger than we want to store.
128+
return make([]byte, length), nil
127129
}
128130

129-
// shortcut which can be used if the requested buffer is guaranteed to be
130-
// smaller than defaultBufSize
131+
// takeSmallBuffer is shortcut which can be used if length is
132+
// known to be smaller than defaultBufSize.
131133
// Only one buffer (total) can be used at a time.
132-
func (b *buffer) takeSmallBuffer(length int) []byte {
134+
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
133135
if b.length > 0 {
134-
return nil
136+
return nil, ErrBusyBuffer
135137
}
136-
return b.buf[:length]
138+
return b.buf[:length], nil
137139
}
138140

139141
// takeCompleteBuffer returns the complete existing buffer.
140142
// This can be used if the necessary buffer size is unknown.
143+
// cap and len of the returned buffer will be equal.
141144
// Only one buffer (total) can be used at a time.
142-
func (b *buffer) takeCompleteBuffer() []byte {
145+
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
146+
if b.length > 0 {
147+
return nil, ErrBusyBuffer
148+
}
149+
return b.buf, nil
150+
}
151+
152+
// store stores buf, an updated buffer, if its suitable to do so.
153+
func (b *buffer) store(buf []byte) error {
143154
if b.length > 0 {
144-
return nil
155+
return ErrBusyBuffer
156+
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
157+
b.buf = buf[:cap(buf)]
145158
}
146-
return b.buf
159+
return nil
147160
}

‎connection.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
182182
return "", driver.ErrSkip
183183
}
184184

185-
buf := mc.buf.takeCompleteBuffer()
186-
if buf == nil {
185+
buf, err := mc.buf.takeCompleteBuffer()
186+
if err != nil {
187187
// can not take the buffer. Something must be wrong with the connection
188-
errLog.Print(ErrBusyBuffer)
188+
errLog.Print(err)
189189
return "", ErrInvalidConn
190190
}
191191
buf = buf[:0]

‎driver.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {
5050

5151
// Open new Connection.
5252
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
53-
// the DSN string is formated
53+
// the DSN string is formatted
5454
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
5555
var err error
5656

‎packets.go

+30-24
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
5151
mc.sequence++
5252

5353
// packets with length 0 terminate a previous packet which is a
54-
// multiple of (2^24)1 bytes long
54+
// multiple of (2^24)-1 bytes long
5555
if pktLen == 0 {
5656
// there was no previous packet
5757
if prevData == nil {
@@ -286,10 +286,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
286286
}
287287

288288
// Calculate packet length and get buffer with that size
289-
data := mc.buf.takeSmallBuffer(pktLen + 4)
290-
if data == nil {
289+
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
290+
if err != nil {
291291
// cannot take the buffer. Something must be wrong with the connection
292-
errLog.Print(ErrBusyBuffer)
292+
errLog.Print(err)
293293
return errBadConnNoWrite
294294
}
295295

@@ -367,10 +367,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
367367
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
368368
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
369369
pktLen := 4 + len(authData)
370-
data := mc.buf.takeSmallBuffer(pktLen)
371-
if data == nil {
370+
data, err := mc.buf.takeSmallBuffer(pktLen)
371+
if err != nil {
372372
// cannot take the buffer. Something must be wrong with the connection
373-
errLog.Print(ErrBusyBuffer)
373+
errLog.Print(err)
374374
return errBadConnNoWrite
375375
}
376376

@@ -387,10 +387,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
387387
// Reset Packet Sequence
388388
mc.sequence = 0
389389

390-
data := mc.buf.takeSmallBuffer(4 + 1)
391-
if data == nil {
390+
data, err := mc.buf.takeSmallBuffer(4 + 1)
391+
if err != nil {
392392
// cannot take the buffer. Something must be wrong with the connection
393-
errLog.Print(ErrBusyBuffer)
393+
errLog.Print(err)
394394
return errBadConnNoWrite
395395
}
396396

@@ -406,10 +406,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
406406
mc.sequence = 0
407407

408408
pktLen := 1 + len(arg)
409-
data := mc.buf.takeBuffer(pktLen + 4)
410-
if data == nil {
409+
data, err := mc.buf.takeBuffer(pktLen + 4)
410+
if err != nil {
411411
// cannot take the buffer. Something must be wrong with the connection
412-
errLog.Print(ErrBusyBuffer)
412+
errLog.Print(err)
413413
return errBadConnNoWrite
414414
}
415415

@@ -427,10 +427,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
427427
// Reset Packet Sequence
428428
mc.sequence = 0
429429

430-
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
431-
if data == nil {
430+
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
431+
if err != nil {
432432
// cannot take the buffer. Something must be wrong with the connection
433-
errLog.Print(ErrBusyBuffer)
433+
errLog.Print(err)
434434
return errBadConnNoWrite
435435
}
436436

@@ -883,7 +883,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
883883
const minPktLen = 4 + 1 + 4 + 1 + 4
884884
mc := stmt.mc
885885

886-
// Determine threshould dynamically to avoid packet size shortage.
886+
// Determine threshold dynamically to avoid packet size shortage.
887887
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
888888
if longDataSize < 64 {
889889
longDataSize = 64
@@ -893,15 +893,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
893893
mc.sequence = 0
894894

895895
var data []byte
896+
var err error
896897

897898
if len(args) == 0 {
898-
data = mc.buf.takeBuffer(minPktLen)
899+
data, err = mc.buf.takeBuffer(minPktLen)
899900
} else {
900-
data = mc.buf.takeCompleteBuffer()
901+
data, err = mc.buf.takeCompleteBuffer()
902+
// In this case the len(data) == cap(data) which is used to optimise the flow below.
901903
}
902-
if data == nil {
904+
if err != nil {
903905
// cannot take the buffer. Something must be wrong with the connection
904-
errLog.Print(ErrBusyBuffer)
906+
errLog.Print(err)
905907
return errBadConnNoWrite
906908
}
907909

@@ -927,7 +929,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
927929
pos := minPktLen
928930

929931
var nullMask []byte
930-
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
932+
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
931933
// buffer has to be extended but we don't know by how much so
932934
// we depend on append after all data with known sizes fit.
933935
// We stop at that because we deal with a lot of columns here
@@ -936,10 +938,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
936938
copy(tmp[:pos], data[:pos])
937939
data = tmp
938940
nullMask = data[pos : pos+maskLen]
941+
// No need to clean nullMask as make ensures that.
939942
pos += maskLen
940943
} else {
941944
nullMask = data[pos : pos+maskLen]
942-
for i := 0; i < maskLen; i++ {
945+
for i := range nullMask {
943946
nullMask[i] = 0
944947
}
945948
pos += maskLen
@@ -1076,7 +1079,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
10761079
// In that case we must build the data packet with the new values buffer
10771080
if valuesCap != cap(paramValues) {
10781081
data = append(data[:pos], paramValues...)
1079-
mc.buf.buf = data
1082+
if err = mc.buf.store(data); err != nil {
1083+
errLog.Print(err)
1084+
return errBadConnNoWrite
1085+
}
10801086
}
10811087

10821088
pos += len(paramValues)

0 commit comments

Comments
 (0)