Skip to content

Commit 72e0ac3

Browse files
Add atomic wrappers for bool and error (go-sql-driver#612)
* Add atomic wrappers for bool and error Improves go-sql-driver#608 * Drop Go 1.2 and Go 1.3 support * "test" noCopy.Lock()
1 parent 5622634 commit 72e0ac3

9 files changed

+169
-47
lines changed

.travis.yml

-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
sudo: false
22
language: go
33
go:
4-
- 1.2
5-
- 1.3
64
- 1.4
75
- 1.5
86
- 1.6

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
3939
* Optional placeholder interpolation
4040

4141
## Requirements
42-
* Go 1.2 or higher
42+
* Go 1.4 or higher
4343
* MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
4444

4545
---------------------------------------
@@ -279,7 +279,7 @@ Default: false
279279

280280
`rejectreadOnly=true` causes the driver to reject read-only connections. This
281281
is for a possible race condition during an automatic failover, where the mysql
282-
client gets connected to a read-only replica after the failover.
282+
client gets connected to a read-only replica after the failover.
283283

284284
Note that this should be a fairly rare case, as an automatic failover normally
285285
happens when the primary is down, and the race condition shouldn't happen

connection.go

+14-34
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,15 @@ import (
1414
"net"
1515
"strconv"
1616
"strings"
17-
"sync"
18-
"sync/atomic"
1917
"time"
2018
)
2119

22-
// a copy of context.Context for Go 1.7 and later.
20+
// a copy of context.Context for Go 1.7 and earlier
2321
type mysqlContext interface {
2422
Done() <-chan struct{}
2523
Err() error
2624

27-
// They are defined in context.Context, but go-mysql-driver does not use them.
25+
// defined in context.Context, but not used in this driver:
2826
// Deadline() (deadline time.Time, ok bool)
2927
// Value(key interface{}) interface{}
3028
}
@@ -44,18 +42,13 @@ type mysqlConn struct {
4442
parseTime bool
4543
strict bool
4644

47-
// for context support (From Go 1.8)
45+
// for context support (Go 1.8+)
4846
watching bool
4947
watcher chan<- mysqlContext
5048
closech chan struct{}
5149
finished chan<- struct{}
52-
53-
// set non-zero when conn is closed, before closech is closed.
54-
// accessed atomically.
55-
closed int32
56-
57-
mu sync.Mutex // guards following fields
58-
canceledErr error // set non-nil if conn is canceled
50+
canceled atomicError // set non-nil if conn is canceled
51+
closed atomicBool // set when conn is closed, before closech is closed
5952
}
6053

6154
// Handles parameters set in DSN after the connection is established
@@ -89,7 +82,7 @@ func (mc *mysqlConn) handleParams() (err error) {
8982
}
9083

9184
func (mc *mysqlConn) Begin() (driver.Tx, error) {
92-
if mc.isBroken() {
85+
if mc.closed.IsSet() {
9386
errLog.Print(ErrInvalidConn)
9487
return nil, driver.ErrBadConn
9588
}
@@ -103,7 +96,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
10396

10497
func (mc *mysqlConn) Close() (err error) {
10598
// Makes Close idempotent
106-
if !mc.isBroken() {
99+
if !mc.closed.IsSet() {
107100
err = mc.writeCommandPacket(comQuit)
108101
}
109102

@@ -117,7 +110,7 @@ func (mc *mysqlConn) Close() (err error) {
117110
// is called before auth or on auth failure because MySQL will have already
118111
// closed the network connection.
119112
func (mc *mysqlConn) cleanup() {
120-
if atomic.SwapInt32(&mc.closed, 1) != 0 {
113+
if !mc.closed.TrySet(true) {
121114
return
122115
}
123116

@@ -131,13 +124,9 @@ func (mc *mysqlConn) cleanup() {
131124
}
132125
}
133126

134-
func (mc *mysqlConn) isBroken() bool {
135-
return atomic.LoadInt32(&mc.closed) != 0
136-
}
137-
138127
func (mc *mysqlConn) error() error {
139-
if mc.isBroken() {
140-
if err := mc.canceled(); err != nil {
128+
if mc.closed.IsSet() {
129+
if err := mc.canceled.Value(); err != nil {
141130
return err
142131
}
143132
return ErrInvalidConn
@@ -146,7 +135,7 @@ func (mc *mysqlConn) error() error {
146135
}
147136

148137
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
149-
if mc.isBroken() {
138+
if mc.closed.IsSet() {
150139
errLog.Print(ErrInvalidConn)
151140
return nil, driver.ErrBadConn
152141
}
@@ -300,7 +289,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
300289
}
301290

302291
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
303-
if mc.isBroken() {
292+
if mc.closed.IsSet() {
304293
errLog.Print(ErrInvalidConn)
305294
return nil, driver.ErrBadConn
306295
}
@@ -361,7 +350,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
361350
}
362351

363352
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
364-
if mc.isBroken() {
353+
if mc.closed.IsSet() {
365354
errLog.Print(ErrInvalidConn)
366355
return nil, driver.ErrBadConn
367356
}
@@ -436,19 +425,10 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
436425

437426
// finish is called when the query has canceled.
438427
func (mc *mysqlConn) cancel(err error) {
439-
mc.mu.Lock()
440-
mc.canceledErr = err
441-
mc.mu.Unlock()
428+
mc.canceled.Set(err)
442429
mc.cleanup()
443430
}
444431

445-
// canceled returns non-nil if the connection was closed due to context cancelation.
446-
func (mc *mysqlConn) canceled() error {
447-
mc.mu.Lock()
448-
defer mc.mu.Unlock()
449-
return mc.canceledErr
450-
}
451-
452432
// finish is called when the query has succeeded.
453433
func (mc *mysqlConn) finish() {
454434
if !mc.watching || mc.finished == nil {

connection_go18.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919

2020
// Ping implements driver.Pinger interface
2121
func (mc *mysqlConn) Ping(ctx context.Context) error {
22-
if mc.isBroken() {
22+
if mc.closed.IsSet() {
2323
errLog.Print(ErrInvalidConn)
2424
return driver.ErrBadConn
2525
}

packets.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
3030
// read packet header
3131
data, err := mc.buf.readNext(4)
3232
if err != nil {
33-
if cerr := mc.canceled(); cerr != nil {
33+
if cerr := mc.canceled.Value(); cerr != nil {
3434
return nil, cerr
3535
}
3636
errLog.Print(err)
@@ -66,7 +66,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
6666
// read packet body [pktLen bytes]
6767
data, err = mc.buf.readNext(pktLen)
6868
if err != nil {
69-
if cerr := mc.canceled(); cerr != nil {
69+
if cerr := mc.canceled.Value(); cerr != nil {
7070
return nil, cerr
7171
}
7272
errLog.Print(err)
@@ -134,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
134134
mc.cleanup()
135135
errLog.Print(ErrMalformPkt)
136136
} else {
137-
if cerr := mc.canceled(); cerr != nil {
137+
if cerr := mc.canceled.Value(); cerr != nil {
138138
return cerr
139139
}
140140
mc.cleanup()

statement.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type mysqlStmt struct {
2323
}
2424

2525
func (stmt *mysqlStmt) Close() error {
26-
if stmt.mc == nil || stmt.mc.isBroken() {
26+
if stmt.mc == nil || stmt.mc.closed.IsSet() {
2727
// driver.Stmt.Close can be called more than once, thus this function
2828
// has to be idempotent.
2929
// See also Issue #450 and golang/go#16019.
@@ -45,7 +45,7 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
4545
}
4646

4747
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
48-
if stmt.mc.isBroken() {
48+
if stmt.mc.closed.IsSet() {
4949
errLog.Print(ErrInvalidConn)
5050
return nil, driver.ErrBadConn
5151
}
@@ -93,7 +93,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
9393
}
9494

9595
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
96-
if stmt.mc.isBroken() {
96+
if stmt.mc.closed.IsSet() {
9797
errLog.Print(ErrInvalidConn)
9898
return nil, driver.ErrBadConn
9999
}

transaction.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type mysqlTx struct {
1313
}
1414

1515
func (tx *mysqlTx) Commit() (err error) {
16-
if tx.mc == nil || tx.mc.isBroken() {
16+
if tx.mc == nil || tx.mc.closed.IsSet() {
1717
return ErrInvalidConn
1818
}
1919
err = tx.mc.exec("COMMIT")
@@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
2222
}
2323

2424
func (tx *mysqlTx) Rollback() (err error) {
25-
if tx.mc == nil || tx.mc.isBroken() {
25+
if tx.mc == nil || tx.mc.closed.IsSet() {
2626
return ErrInvalidConn
2727
}
2828
err = tx.mc.exec("ROLLBACK")

utils.go

+64
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"fmt"
1717
"io"
1818
"strings"
19+
"sync/atomic"
1920
"time"
2021
)
2122

@@ -740,3 +741,66 @@ func escapeStringQuotes(buf []byte, v string) []byte {
740741

741742
return buf[:pos]
742743
}
744+
745+
/******************************************************************************
746+
* Sync utils *
747+
******************************************************************************/
748+
// noCopy may be embedded into structs which must not be copied
749+
// after the first use.
750+
//
751+
// See https://github.com/golang/go/issues/8005#issuecomment-190753527
752+
// for details.
753+
type noCopy struct{}
754+
755+
// Lock is a no-op used by -copylocks checker from `go vet`.
756+
func (*noCopy) Lock() {}
757+
758+
// atomicBool is a wrapper around uint32 for usage as a boolean value with
759+
// atomic access.
760+
type atomicBool struct {
761+
_noCopy noCopy
762+
value uint32
763+
}
764+
765+
// IsSet returns wether the current boolean value is true
766+
func (ab *atomicBool) IsSet() bool {
767+
return atomic.LoadUint32(&ab.value) > 0
768+
}
769+
770+
// Set sets the value of the bool regardless of the previous value
771+
func (ab *atomicBool) Set(value bool) {
772+
if value {
773+
atomic.StoreUint32(&ab.value, 1)
774+
} else {
775+
atomic.StoreUint32(&ab.value, 0)
776+
}
777+
}
778+
779+
// TrySet sets the value of the bool and returns wether the value changed
780+
func (ab *atomicBool) TrySet(value bool) bool {
781+
if value {
782+
return atomic.SwapUint32(&ab.value, 1) == 0
783+
}
784+
return atomic.SwapUint32(&ab.value, 0) > 0
785+
}
786+
787+
// atomicBool is a wrapper for atomically accessed error values
788+
type atomicError struct {
789+
_noCopy noCopy
790+
value atomic.Value
791+
}
792+
793+
// Set sets the error value regardless of the previous value.
794+
// The value must not be nil
795+
func (ae *atomicError) Set(value error) {
796+
ae.value.Store(value)
797+
}
798+
799+
// Value returns the current error value
800+
func (ae *atomicError) Value() error {
801+
if v := ae.value.Load(); v != nil {
802+
// this will panic if the value doesn't implement the error interface
803+
return v.(error)
804+
}
805+
return nil
806+
}

utils_test.go

+80
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,83 @@ func TestEscapeQuotes(t *testing.T) {
195195
expect("foo''bar", "foo'bar") // affected
196196
expect("foo\"bar", "foo\"bar") // not affected
197197
}
198+
199+
func TestAtomicBool(t *testing.T) {
200+
var ab atomicBool
201+
if ab.IsSet() {
202+
t.Fatal("Expected value to be false")
203+
}
204+
205+
ab.Set(true)
206+
if ab.value != 1 {
207+
t.Fatal("Set(true) did not set value to 1")
208+
}
209+
if !ab.IsSet() {
210+
t.Fatal("Expected value to be true")
211+
}
212+
213+
ab.Set(true)
214+
if !ab.IsSet() {
215+
t.Fatal("Expected value to be true")
216+
}
217+
218+
ab.Set(false)
219+
if ab.value != 0 {
220+
t.Fatal("Set(false) did not set value to 0")
221+
}
222+
if ab.IsSet() {
223+
t.Fatal("Expected value to be false")
224+
}
225+
226+
ab.Set(false)
227+
if ab.IsSet() {
228+
t.Fatal("Expected value to be false")
229+
}
230+
if ab.TrySet(false) {
231+
t.Fatal("Expected TrySet(false) to fail")
232+
}
233+
if !ab.TrySet(true) {
234+
t.Fatal("Expected TrySet(true) to succeed")
235+
}
236+
if !ab.IsSet() {
237+
t.Fatal("Expected value to be true")
238+
}
239+
240+
ab.Set(true)
241+
if !ab.IsSet() {
242+
t.Fatal("Expected value to be true")
243+
}
244+
if ab.TrySet(true) {
245+
t.Fatal("Expected TrySet(true) to fail")
246+
}
247+
if !ab.TrySet(false) {
248+
t.Fatal("Expected TrySet(false) to succeed")
249+
}
250+
if ab.IsSet() {
251+
t.Fatal("Expected value to be false")
252+
}
253+
254+
ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯
255+
}
256+
257+
func TestAtomicError(t *testing.T) {
258+
var ae atomicError
259+
if ae.Value() != nil {
260+
t.Fatal("Expected value to be nil")
261+
}
262+
263+
ae.Set(ErrMalformPkt)
264+
if v := ae.Value(); v != ErrMalformPkt {
265+
if v == nil {
266+
t.Fatal("Value is still nil")
267+
}
268+
t.Fatal("Error did not match")
269+
}
270+
ae.Set(ErrPktSync)
271+
if ae.Value() == ErrMalformPkt {
272+
t.Fatal("Error still matches old error")
273+
}
274+
if v := ae.Value(); v != ErrPktSync {
275+
t.Fatal("Error did not match")
276+
}
277+
}

0 commit comments

Comments
 (0)