Skip to content

Commit 3484db1

Browse files
authoredJun 28, 2024
improve error handling in writePacket (go-sql-driver#1601)
* handle error before success case. * return io.ErrShortWrite if not all bytes were written but err is nil. * return err instead of ErrInvalidConn.
1 parent 52c1917 commit 3484db1

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed
 

‎connection_test.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ func TestPingMarkBadConnection(t *testing.T) {
163163
netConn: nc,
164164
buf: newBuffer(nc),
165165
maxAllowedPacket: defaultMaxAllowedPacket,
166+
closech: make(chan struct{}),
167+
cfg: NewConfig(),
166168
}
167169

168170
err := mc.Ping(context.Background())
@@ -184,8 +186,8 @@ func TestPingErrInvalidConn(t *testing.T) {
184186

185187
err := mc.Ping(context.Background())
186188

187-
if err != ErrInvalidConn {
188-
t.Errorf("expected ErrInvalidConn, got %#v", err)
189+
if err != nc.err {
190+
t.Errorf("expected %#v, got %#v", nc.err, err)
189191
}
190192
}
191193

‎packets.go

+17-17
Original file line numberDiff line numberDiff line change
@@ -124,32 +124,32 @@ func (mc *mysqlConn) writePacket(data []byte) error {
124124
}
125125

126126
n, err := mc.netConn.Write(data[:4+size])
127-
if err == nil && n == 4+size {
128-
mc.sequence++
129-
if size != maxPacketSize {
130-
return nil
131-
}
132-
pktLen -= size
133-
data = data[size:]
134-
continue
135-
}
136-
137-
// Handle error
138-
if err == nil { // n != len(data)
139-
mc.cleanup()
140-
mc.log(ErrMalformPkt)
141-
} else {
127+
if err != nil {
142128
if cerr := mc.canceled.Value(); cerr != nil {
143129
return cerr
144130
}
131+
mc.cleanup()
145132
if n == 0 && pktLen == len(data)-4 {
146133
// only for the first loop iteration when nothing was written yet
134+
mc.log(err)
147135
return errBadConnNoWrite
136+
} else {
137+
return err
148138
}
139+
}
140+
if n != 4+size {
141+
// io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
142+
// The io.ErrShortWrite error is used to indicate that this rule has not been followed.
149143
mc.cleanup()
150-
mc.log(err)
144+
return io.ErrShortWrite
145+
}
146+
147+
mc.sequence++
148+
if size != maxPacketSize {
149+
return nil
151150
}
152-
return ErrInvalidConn
151+
pktLen -= size
152+
data = data[size:]
153153
}
154154
}
155155

0 commit comments

Comments
 (0)