Skip to content

Commit 26471af

Browse files
authored
fewer driver.ErrBadConn to prevent repeated queries (go-sql-driver#302)
According to the database/sql/driver documentation, ErrBadConn should only be used when the database was not affected. The driver restarts the same query on a different connection, then. The mysql driver did not follow this advice, so queries were repeated if ErrBadConn is returned but a query succeeded. This is fixed by changing most ErrBadConn errors to ErrInvalidConn. The only valid returns of ErrBadConn are at the beginning of a database interaction when no data was sent to the database yet. Those valid cases are located the following funcs before attempting to write to the network or if 0 bytes were written: * Begin * BeginTx * Exec * ExecContext * Prepare * PrepareContext * Query * QueryContext Commit and Rollback could arguably also be on that list, but are left out as some engines like MyISAM are not supporting transactions. Tests in b/packets_test.go were changed because they simulate a read not preceded by a write to the db. This cannot happen as the client has to send the query first.
1 parent 21d7e97 commit 26471af

File tree

5 files changed

+46
-28
lines changed

5 files changed

+46
-28
lines changed

connection.go

+16-7
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ func (mc *mysqlConn) handleParams() (err error) {
8181
return
8282
}
8383

84+
func (mc *mysqlConn) markBadConn(err error) error {
85+
if mc == nil {
86+
return err
87+
}
88+
if err != errBadConnNoWrite {
89+
return err
90+
}
91+
return driver.ErrBadConn
92+
}
93+
8494
func (mc *mysqlConn) Begin() (driver.Tx, error) {
8595
if mc.closed.IsSet() {
8696
errLog.Print(ErrInvalidConn)
@@ -90,8 +100,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
90100
if err == nil {
91101
return &mysqlTx{mc}, err
92102
}
93-
94-
return nil, err
103+
return nil, mc.markBadConn(err)
95104
}
96105

97106
func (mc *mysqlConn) Close() (err error) {
@@ -142,7 +151,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
142151
// Send command
143152
err := mc.writeCommandPacketStr(comStmtPrepare, query)
144153
if err != nil {
145-
return nil, err
154+
return nil, mc.markBadConn(err)
146155
}
147156

148157
stmt := &mysqlStmt{
@@ -176,7 +185,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
176185
if buf == nil {
177186
// can not take the buffer. Something must be wrong with the connection
178187
errLog.Print(ErrBusyBuffer)
179-
return "", driver.ErrBadConn
188+
return "", ErrInvalidConn
180189
}
181190
buf = buf[:0]
182191
argPos := 0
@@ -314,14 +323,14 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
314323
insertId: int64(mc.insertId),
315324
}, err
316325
}
317-
return nil, err
326+
return nil, mc.markBadConn(err)
318327
}
319328

320329
// Internal function to execute commands
321330
func (mc *mysqlConn) exec(query string) error {
322331
// Send command
323332
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
324-
return err
333+
return mc.markBadConn(err)
325334
}
326335

327336
// Read Result
@@ -390,7 +399,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
390399
return rows, err
391400
}
392401
}
393-
return nil, err
402+
return nil, mc.markBadConn(err)
394403
}
395404

396405
// Gets the value of the given MySQL System Variable

errors.go

+6
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ var (
3131
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
3232
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
3333
ErrBusyBuffer = errors.New("busy buffer")
34+
35+
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
36+
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
37+
// to trigger a resend.
38+
// See https://github.com/go-sql-driver/mysql/pull/302
39+
errBadConnNoWrite = errors.New("bad connection")
3440
)
3541

3642
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))

packets.go

+16-12
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
3535
}
3636
errLog.Print(err)
3737
mc.Close()
38-
return nil, driver.ErrBadConn
38+
return nil, ErrInvalidConn
3939
}
4040

4141
// packet length [24 bit]
@@ -57,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
5757
if prevData == nil {
5858
errLog.Print(ErrMalformPkt)
5959
mc.Close()
60-
return nil, driver.ErrBadConn
60+
return nil, ErrInvalidConn
6161
}
6262

6363
return prevData, nil
@@ -71,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
7171
}
7272
errLog.Print(err)
7373
mc.Close()
74-
return nil, driver.ErrBadConn
74+
return nil, ErrInvalidConn
7575
}
7676

7777
// return data if this was the last packet
@@ -137,10 +137,14 @@ func (mc *mysqlConn) writePacket(data []byte) error {
137137
if cerr := mc.canceled.Value(); cerr != nil {
138138
return cerr
139139
}
140+
if n == 0 && pktLen == len(data)-4 {
141+
// only for the first loop iteration when nothing was written yet
142+
return errBadConnNoWrite
143+
}
140144
mc.cleanup()
141145
errLog.Print(err)
142146
}
143-
return driver.ErrBadConn
147+
return ErrInvalidConn
144148
}
145149
}
146150

@@ -274,7 +278,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
274278
if data == nil {
275279
// can not take the buffer. Something must be wrong with the connection
276280
errLog.Print(ErrBusyBuffer)
277-
return driver.ErrBadConn
281+
return errBadConnNoWrite
278282
}
279283

280284
// ClientFlags [32 bit]
@@ -362,7 +366,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
362366
if data == nil {
363367
// can not take the buffer. Something must be wrong with the connection
364368
errLog.Print(ErrBusyBuffer)
365-
return driver.ErrBadConn
369+
return errBadConnNoWrite
366370
}
367371

368372
// Add the scrambled password [null terminated string]
@@ -381,7 +385,7 @@ func (mc *mysqlConn) writeClearAuthPacket() error {
381385
if data == nil {
382386
// can not take the buffer. Something must be wrong with the connection
383387
errLog.Print(ErrBusyBuffer)
384-
return driver.ErrBadConn
388+
return errBadConnNoWrite
385389
}
386390

387391
// Add the clear password [null terminated string]
@@ -404,7 +408,7 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
404408
if data == nil {
405409
// can not take the buffer. Something must be wrong with the connection
406410
errLog.Print(ErrBusyBuffer)
407-
return driver.ErrBadConn
411+
return errBadConnNoWrite
408412
}
409413

410414
// Add the scramble
@@ -425,7 +429,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
425429
if data == nil {
426430
// can not take the buffer. Something must be wrong with the connection
427431
errLog.Print(ErrBusyBuffer)
428-
return driver.ErrBadConn
432+
return errBadConnNoWrite
429433
}
430434

431435
// Add command byte
@@ -444,7 +448,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
444448
if data == nil {
445449
// can not take the buffer. Something must be wrong with the connection
446450
errLog.Print(ErrBusyBuffer)
447-
return driver.ErrBadConn
451+
return errBadConnNoWrite
448452
}
449453

450454
// Add command byte
@@ -465,7 +469,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
465469
if data == nil {
466470
// can not take the buffer. Something must be wrong with the connection
467471
errLog.Print(ErrBusyBuffer)
468-
return driver.ErrBadConn
472+
return errBadConnNoWrite
469473
}
470474

471475
// Add command byte
@@ -931,7 +935,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
931935
if data == nil {
932936
// can not take the buffer. Something must be wrong with the connection
933937
errLog.Print(ErrBusyBuffer)
934-
return driver.ErrBadConn
938+
return errBadConnNoWrite
935939
}
936940

937941
// command [1 byte]

packets_test.go

+6-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
package mysql
1010

1111
import (
12-
"database/sql/driver"
1312
"errors"
1413
"net"
1514
"testing"
@@ -252,8 +251,8 @@ func TestReadPacketFail(t *testing.T) {
252251
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
253252
conn.maxReads = 1
254253
_, err := mc.readPacket()
255-
if err != driver.ErrBadConn {
256-
t.Errorf("expected ErrBadConn, got %v", err)
254+
if err != ErrInvalidConn {
255+
t.Errorf("expected ErrInvalidConn, got %v", err)
257256
}
258257

259258
// reset
@@ -264,8 +263,8 @@ func TestReadPacketFail(t *testing.T) {
264263
// fail to read header
265264
conn.closed = true
266265
_, err = mc.readPacket()
267-
if err != driver.ErrBadConn {
268-
t.Errorf("expected ErrBadConn, got %v", err)
266+
if err != ErrInvalidConn {
267+
t.Errorf("expected ErrInvalidConn, got %v", err)
269268
}
270269

271270
// reset
@@ -277,7 +276,7 @@ func TestReadPacketFail(t *testing.T) {
277276
// fail to read body
278277
conn.maxReads = 1
279278
_, err = mc.readPacket()
280-
if err != driver.ErrBadConn {
281-
t.Errorf("expected ErrBadConn, got %v", err)
279+
if err != ErrInvalidConn {
280+
t.Errorf("expected ErrInvalidConn, got %v", err)
282281
}
283282
}

statement.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
5252
// Send command
5353
err := stmt.writeExecutePacket(args)
5454
if err != nil {
55-
return nil, err
55+
return nil, stmt.mc.markBadConn(err)
5656
}
5757

5858
mc := stmt.mc
@@ -100,7 +100,7 @@ func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
100100
// Send command
101101
err := stmt.writeExecutePacket(args)
102102
if err != nil {
103-
return nil, err
103+
return nil, stmt.mc.markBadConn(err)
104104
}
105105

106106
mc := stmt.mc

0 commit comments

Comments
 (0)