diff --git a/cluster.go b/cluster.go index 29b0cfbd8..bf160ee70 100644 --- a/cluster.go +++ b/cluster.go @@ -60,7 +60,7 @@ func (c *ClusterClient) getClients() map[string]*Client { // Watch creates new transaction and marks the keys to be watched // for conditional execution of a transaction. -func (c *ClusterClient) Watch(keys ...string) (*Multi, error) { +func (c *ClusterClient) Watch(keys ...string) (*Tx, error) { addr := c.slotMasterAddr(hashtag.Slot(keys[0])) client, err := c.getClient(addr) if err != nil { diff --git a/commands_test.go b/commands_test.go index eb0d52372..dfa22ca07 100644 --- a/commands_test.go +++ b/commands_test.go @@ -4,9 +4,6 @@ import ( "encoding/json" "fmt" "reflect" - "strconv" - "sync" - "testing" "time" . "github.com/onsi/ginkgo" @@ -2551,63 +2548,6 @@ var _ = Describe("Commands", func() { }) - Describe("watch/unwatch", func() { - - It("should WatchUnwatch", func() { - var C, N = 10, 1000 - if testing.Short() { - N = 100 - } - - err := client.Set("key", "0", 0).Err() - Expect(err).NotTo(HaveOccurred()) - - wg := &sync.WaitGroup{} - for i := 0; i < C; i++ { - wg.Add(1) - - go func() { - defer GinkgoRecover() - defer wg.Done() - - multi := client.Multi() - defer multi.Close() - - for j := 0; j < N; j++ { - val, err := multi.Watch("key").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(val).To(Equal("OK")) - - val, err = multi.Get("key").Result() - Expect(err).NotTo(HaveOccurred()) - Expect(val).NotTo(Equal(redis.Nil)) - - num, err := strconv.ParseInt(val, 10, 64) - Expect(err).NotTo(HaveOccurred()) - - cmds, err := multi.Exec(func() error { - multi.Set("key", strconv.FormatInt(num+1, 10), 0) - return nil - }) - if err == redis.TxFailedErr { - j-- - continue - } - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(1)) - Expect(cmds[0].Err()).NotTo(HaveOccurred()) - } - }() - } - wg.Wait() - - val, err := client.Get("key").Int64() - Expect(err).NotTo(HaveOccurred()) - Expect(val).To(Equal(int64(C * N))) - }) - - }) - Describe("Geo add and radius search", func() { BeforeEach(func() { geoAdd := client.GeoAdd( diff --git a/pool_test.go b/pool_test.go index 793ff5031..942b9b979 100644 --- a/pool_test.go +++ b/pool_test.go @@ -37,16 +37,18 @@ var _ = Describe("pool", func() { perform(1000, func(id int) { var ping *redis.StatusCmd - multi := client.Multi() - cmds, err := multi.Exec(func() error { - ping = multi.Ping() + tx, err := client.Watch() + Expect(err).NotTo(HaveOccurred()) + + cmds, err := tx.Exec(func() error { + ping = tx.Ping() return nil }) Expect(err).NotTo(HaveOccurred()) Expect(cmds).To(HaveLen(1)) Expect(ping.Err()).NotTo(HaveOccurred()) Expect(ping.Val()).To(Equal("PONG")) - Expect(multi.Close()).NotTo(HaveOccurred()) + Expect(tx.Close()).NotTo(HaveOccurred()) }) pool := client.Pool() diff --git a/race_test.go b/race_test.go index 1ce8430a8..968f76e45 100644 --- a/race_test.go +++ b/race_test.go @@ -208,4 +208,43 @@ var _ = Describe("races", func() { Expect(err).NotTo(HaveOccurred()) }) }) + + It("should Watch/Unwatch", func() { + err := client.Set("key", "0", 0).Err() + Expect(err).NotTo(HaveOccurred()) + + perform(C, func(id int) { + for i := 0; i < N; i++ { + tx, err := client.Watch("key") + Expect(err).NotTo(HaveOccurred()) + + val, err := tx.Get("key").Result() + Expect(err).NotTo(HaveOccurred()) + Expect(val).NotTo(Equal(redis.Nil)) + + num, err := strconv.ParseInt(val, 10, 64) + Expect(err).NotTo(HaveOccurred()) + + cmds, err := tx.Exec(func() error { + tx.Set("key", strconv.FormatInt(num+1, 10), 0) + return nil + }) + if err == redis.TxFailedErr { + i-- + continue + } + Expect(err).NotTo(HaveOccurred()) + Expect(cmds).To(HaveLen(1)) + Expect(cmds[0].Err()).NotTo(HaveOccurred()) + + err = tx.Close() + Expect(err).NotTo(HaveOccurred()) + } + }) + + val, err := client.Get("key").Int64() + Expect(err).NotTo(HaveOccurred()) + Expect(val).To(Equal(int64(C * N))) + }) + }) diff --git a/redis_test.go b/redis_test.go index cb8e3ac69..fc0d702df 100644 --- a/redis_test.go +++ b/redis_test.go @@ -66,11 +66,12 @@ var _ = Describe("Client", func() { }) It("should close multi without closing the client", func() { - multi := client.Multi() - Expect(multi.Close()).NotTo(HaveOccurred()) + tx, err := client.Watch() + Expect(err).NotTo(HaveOccurred()) + Expect(tx.Close()).NotTo(HaveOccurred()) - _, err := multi.Exec(func() error { - multi.Ping() + _, err = tx.Exec(func() error { + tx.Ping() return nil }) Expect(err).To(MatchError("redis: client is closed")) @@ -96,9 +97,10 @@ var _ = Describe("Client", func() { }) It("should close multi when client is closed", func() { - multi := client.Multi() + tx, err := client.Watch() + Expect(err).NotTo(HaveOccurred()) Expect(client.Close()).NotTo(HaveOccurred()) - Expect(multi.Close()).NotTo(HaveOccurred()) + Expect(tx.Close()).NotTo(HaveOccurred()) }) It("should close pipeline when client is closed", func() { diff --git a/multi.go b/tx.go similarity index 70% rename from multi.go rename to tx.go index 79b7cb6df..51c875742 100644 --- a/multi.go +++ b/tx.go @@ -9,13 +9,11 @@ import ( var errDiscard = errors.New("redis: Discard can be used only inside Exec") -// Multi implements Redis transactions as described in +// Tx implements Redis transactions as described in // http://redis.io/topics/transactions. It's NOT safe for concurrent use // by multiple goroutines, because Exec resets list of watched keys. // If you don't need WATCH it is better to use Pipeline. -// -// TODO(vmihailenco): rename to Tx and rework API -type Multi struct { +type Tx struct { commandable base *baseClient @@ -24,77 +22,78 @@ type Multi struct { closed bool } -// Watch creates new transaction and marks the keys to be watched -// for conditional execution of a transaction. -func (c *Client) Watch(keys ...string) (*Multi, error) { - tx := c.Multi() - if err := tx.Watch(keys...).Err(); err != nil { - tx.Close() - return nil, err - } - return tx, nil -} - -// Deprecated. Use Watch instead. -func (c *Client) Multi() *Multi { - multi := &Multi{ +func (c *Client) newTx() *Tx { + tx := &Tx{ base: &baseClient{ opt: c.opt, connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true), }, } - multi.commandable.process = multi.process - return multi + tx.commandable.process = tx.process + return tx +} + +// Watch creates new transaction and marks the keys to be watched +// for conditional execution of a transaction. +func (c *Client) Watch(keys ...string) (*Tx, error) { + tx := c.newTx() + if len(keys) > 0 { + if err := tx.Watch(keys...).Err(); err != nil { + tx.Close() + return nil, err + } + } + return tx, nil } -func (c *Multi) process(cmd Cmder) { - if c.cmds == nil { - c.base.process(cmd) +func (tx *Tx) process(cmd Cmder) { + if tx.cmds == nil { + tx.base.process(cmd) } else { - c.cmds = append(c.cmds, cmd) + tx.cmds = append(tx.cmds, cmd) } } -// Close closes the client, releasing any open resources. -func (c *Multi) Close() error { - c.closed = true - if err := c.Unwatch().Err(); err != nil { +// Close closes the transaction, releasing any open resources. +func (tx *Tx) Close() error { + tx.closed = true + if err := tx.Unwatch().Err(); err != nil { Logger.Printf("Unwatch failed: %s", err) } - return c.base.Close() + return tx.base.Close() } // Watch marks the keys to be watched for conditional execution // of a transaction. -func (c *Multi) Watch(keys ...string) *StatusCmd { +func (tx *Tx) Watch(keys ...string) *StatusCmd { args := make([]interface{}, 1+len(keys)) args[0] = "WATCH" for i, key := range keys { args[1+i] = key } cmd := NewStatusCmd(args...) - c.Process(cmd) + tx.Process(cmd) return cmd } // Unwatch flushes all the previously watched keys for a transaction. -func (c *Multi) Unwatch(keys ...string) *StatusCmd { +func (tx *Tx) Unwatch(keys ...string) *StatusCmd { args := make([]interface{}, 1+len(keys)) args[0] = "UNWATCH" for i, key := range keys { args[1+i] = key } cmd := NewStatusCmd(args...) - c.Process(cmd) + tx.Process(cmd) return cmd } // Discard discards queued commands. -func (c *Multi) Discard() error { - if c.cmds == nil { +func (tx *Tx) Discard() error { + if tx.cmds == nil { return errDiscard } - c.cmds = c.cmds[:1] + tx.cmds = tx.cmds[:1] return nil } @@ -107,19 +106,19 @@ func (c *Multi) Discard() error { // Exec always returns list of commands. If transaction fails // TxFailedErr is returned. Otherwise Exec returns error of the first // failed command or nil. -func (c *Multi) Exec(f func() error) ([]Cmder, error) { - if c.closed { +func (tx *Tx) Exec(f func() error) ([]Cmder, error) { + if tx.closed { return nil, pool.ErrClosed } - c.cmds = []Cmder{NewStatusCmd("MULTI")} + tx.cmds = []Cmder{NewStatusCmd("MULTI")} if err := f(); err != nil { return nil, err } - c.cmds = append(c.cmds, NewSliceCmd("EXEC")) + tx.cmds = append(tx.cmds, NewSliceCmd("EXEC")) - cmds := c.cmds - c.cmds = nil + cmds := tx.cmds + tx.cmds = nil if len(cmds) == 2 { return []Cmder{}, nil @@ -128,18 +127,18 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) { // Strip MULTI and EXEC commands. retCmds := cmds[1 : len(cmds)-1] - cn, err := c.base.conn() + cn, err := tx.base.conn() if err != nil { setCmdsErr(retCmds, err) return retCmds, err } - err = c.execCmds(cn, cmds) - c.base.putConn(cn, err, false) + err = tx.execCmds(cn, cmds) + tx.base.putConn(cn, err, false) return retCmds, err } -func (c *Multi) execCmds(cn *pool.Conn, cmds []Cmder) error { +func (tx *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error { err := writeCmd(cn, cmds...) if err != nil { setCmdsErr(cmds[1:len(cmds)-1], err) diff --git a/multi_test.go b/tx_test.go similarity index 77% rename from multi_test.go rename to tx_test.go index e76c2b33a..66ef6b084 100644 --- a/multi_test.go +++ b/tx_test.go @@ -10,7 +10,7 @@ import ( "gopkg.in/redis.v3" ) -var _ = Describe("Multi", func() { +var _ = Describe("Tx", func() { var client *redis.Client BeforeEach(func() { @@ -67,15 +67,16 @@ var _ = Describe("Multi", func() { }) It("should discard", func() { - multi := client.Multi() + tx, err := client.Watch("key1", "key2") + Expect(err).NotTo(HaveOccurred()) defer func() { - Expect(multi.Close()).NotTo(HaveOccurred()) + Expect(tx.Close()).NotTo(HaveOccurred()) }() - cmds, err := multi.Exec(func() error { - multi.Set("key1", "hello1", 0) - multi.Discard() - multi.Set("key2", "hello2", 0) + cmds, err := tx.Exec(func() error { + tx.Set("key1", "hello1", 0) + tx.Discard() + tx.Set("key2", "hello2", 0) return nil }) Expect(err).NotTo(HaveOccurred()) @@ -91,40 +92,31 @@ var _ = Describe("Multi", func() { }) It("should exec empty", func() { - multi := client.Multi() + tx, err := client.Watch() + Expect(err).NotTo(HaveOccurred()) defer func() { - Expect(multi.Close()).NotTo(HaveOccurred()) + Expect(tx.Close()).NotTo(HaveOccurred()) }() - cmds, err := multi.Exec(func() error { return nil }) + cmds, err := tx.Exec(func() error { return nil }) Expect(err).NotTo(HaveOccurred()) Expect(cmds).To(HaveLen(0)) - ping := multi.Ping() + ping := tx.Ping() Expect(ping.Err()).NotTo(HaveOccurred()) Expect(ping.Val()).To(Equal("PONG")) }) - It("should exec empty queue", func() { - multi := client.Multi() - defer func() { - Expect(multi.Close()).NotTo(HaveOccurred()) - }() - - cmds, err := multi.Exec(func() error { return nil }) - Expect(err).NotTo(HaveOccurred()) - Expect(cmds).To(HaveLen(0)) - }) - It("should exec bulks", func() { - multi := client.Multi() + tx, err := client.Watch() + Expect(err).NotTo(HaveOccurred()) defer func() { - Expect(multi.Close()).NotTo(HaveOccurred()) + Expect(tx.Close()).NotTo(HaveOccurred()) }() - cmds, err := multi.Exec(func() error { + cmds, err := tx.Exec(func() error { for i := int64(0); i < 20000; i++ { - multi.Incr("key") + tx.Incr("key") } return nil }) @@ -148,19 +140,20 @@ var _ = Describe("Multi", func() { err = client.Pool().Put(cn) Expect(err).NotTo(HaveOccurred()) - multi := client.Multi() + tx, err := client.Watch() + Expect(err).NotTo(HaveOccurred()) defer func() { - Expect(multi.Close()).NotTo(HaveOccurred()) + Expect(tx.Close()).NotTo(HaveOccurred()) }() - _, err = multi.Exec(func() error { - multi.Ping() + _, err = tx.Exec(func() error { + tx.Ping() return nil }) Expect(err).To(MatchError("bad connection")) - _, err = multi.Exec(func() error { - multi.Ping() + _, err = tx.Exec(func() error { + tx.Ping() return nil }) Expect(err).NotTo(HaveOccurred())