Skip to content

Commit 04572b4

Browse files
committed
Merge pull request go-sql-driver#122 from go-sql-driver/badconn_close
Explicitly close connection on ErrBadConn
2 parents 61e7183 + 36cccb2 commit 04572b4

File tree

5 files changed

+30
-7
lines changed

5 files changed

+30
-7
lines changed

connection.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,16 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
108108
}
109109

110110
func (mc *mysqlConn) Close() (err error) {
111-
mc.writeCommandPacket(comQuit)
111+
// Makes Close idempotent
112+
if mc.netConn != nil {
113+
mc.writeCommandPacket(comQuit)
114+
mc.netConn.Close()
115+
mc.netConn = nil
116+
}
117+
112118
mc.cfg = nil
113119
mc.buf = nil
114-
mc.netConn.Close()
115-
mc.netConn = nil
120+
116121
return
117122
}
118123

packets.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
2828
data, err = mc.buf.readNext(4)
2929
if err != nil {
3030
errLog.Print(err.Error())
31+
mc.Close()
3132
return nil, driver.ErrBadConn
3233
}
3334

@@ -36,6 +37,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
3637

3738
if pktLen < 1 {
3839
errLog.Print(errMalformPkt.Error())
40+
mc.Close()
3941
return nil, driver.ErrBadConn
4042
}
4143

@@ -50,8 +52,7 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
5052
mc.sequence++
5153

5254
// Read packet body [pktLen bytes]
53-
data, err = mc.buf.readNext(pktLen)
54-
if err == nil {
55+
if data, err = mc.buf.readNext(pktLen); err == nil {
5556
if pktLen < maxPacketSize {
5657
return data, nil
5758
}
@@ -65,6 +66,9 @@ func (mc *mysqlConn) readPacket() (data []byte, err error) {
6566
return append(buf, data...), nil
6667
}
6768
}
69+
70+
// err case
71+
mc.Close()
6872
errLog.Print(err.Error())
6973
return nil, driver.ErrBadConn
7074
}

rows.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@ func (rows *mysqlRows) Columns() (columns []string) {
3737
func (rows *mysqlRows) Close() (err error) {
3838
// Remove unread packets from stream
3939
if !rows.eof {
40-
if rows.mc == nil {
40+
if rows.mc == nil || rows.mc.netConn == nil {
4141
return errInvalidConn
4242
}
4343

4444
err = rows.mc.readUntilEOF()
45+
46+
// explicitly set because readUntilEOF might return early in case of an
47+
// error
48+
rows.eof = true
4549
}
4650

4751
rows.mc = nil
@@ -54,7 +58,7 @@ func (rows *mysqlRows) Next(dest []driver.Value) (err error) {
5458
return io.EOF
5559
}
5660

57-
if rows.mc == nil {
61+
if rows.mc == nil || rows.mc.netConn == nil {
5862
return errInvalidConn
5963
}
6064

statement.go

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ type mysqlStmt struct {
2020
}
2121

2222
func (stmt *mysqlStmt) Close() (err error) {
23+
if stmt.mc == nil || stmt.mc.netConn == nil {
24+
return errInvalidConn
25+
}
26+
2327
err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
2428
stmt.mc = nil
2529
return

transaction.go

+6
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@ type mysqlTx struct {
1313
}
1414

1515
func (tx *mysqlTx) Commit() (err error) {
16+
if tx.mc == nil || tx.mc.netConn == nil {
17+
return errInvalidConn
18+
}
1619
err = tx.mc.exec("COMMIT")
1720
tx.mc = nil
1821
return
1922
}
2023

2124
func (tx *mysqlTx) Rollback() (err error) {
25+
if tx.mc == nil || tx.mc.netConn == nil {
26+
return errInvalidConn
27+
}
2228
err = tx.mc.exec("ROLLBACK")
2329
tx.mc = nil
2430
return

0 commit comments

Comments
 (0)