Skip to content

Commit

Permalink
Tweak transaction API.
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed May 2, 2016
1 parent 033a4de commit 092698e
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 163 deletions.
8 changes: 3 additions & 5 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ func (c *ClusterClient) getClients() map[string]*Client {
return clients
}

// Watch creates new transaction and marks the keys to be watched
// for conditional execution of a transaction.
func (c *ClusterClient) Watch(keys ...string) (*Tx, error) {
func (c *ClusterClient) Watch(fn func(*Tx) error, keys ...string) error {
addr := c.slotMasterAddr(hashtag.Slot(keys[0]))
client, err := c.getClient(addr)
if err != nil {
return nil, err
return err
}
return client.Watch(keys...)
return client.Watch(fn, keys...)
}

// PoolStats returns accumulated connection pool stats.
Expand Down
25 changes: 11 additions & 14 deletions cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,21 +383,18 @@ var _ = Describe("Cluster", func() {

// Transactionally increments key using GET and SET commands.
incr = func(key string) error {
tx, err := client.Watch(key)
if err != nil {
err := client.Watch(func(tx *redis.Tx) error {
n, err := tx.Get(key).Int64()
if err != nil && err != redis.Nil {
return err
}

_, err = tx.MultiExec(func() error {
tx.Set(key, strconv.FormatInt(n+1, 10), 0)
return nil
})
return err
}
defer tx.Close()

n, err := tx.Get(key).Int64()
if err != nil && err != redis.Nil {
return err
}

_, err = tx.Exec(func() error {
tx.Set(key, strconv.FormatInt(n+1, 10), 0)
return nil
})
}, key)
if err == redis.TxFailedErr {
return incr(key)
}
Expand Down
23 changes: 10 additions & 13 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,21 +184,18 @@ func ExampleClient_Watch() {

// Transactionally increments key using GET and SET commands.
incr = func(key string) error {
tx, err := client.Watch(key)
if err != nil {
return err
}
defer tx.Close()
err := client.Watch(func(tx *redis.Tx) error {
n, err := tx.Get(key).Int64()
if err != nil && err != redis.Nil {
return err
}

n, err := tx.Get(key).Int64()
if err != nil && err != redis.Nil {
_, err = tx.MultiExec(func() error {
tx.Set(key, strconv.FormatInt(n+1, 10), 0)
return nil
})
return err
}

_, err = tx.Exec(func() error {
tx.Set(key, strconv.FormatInt(n+1, 10), 0)
return nil
})
}, key)
if err == redis.TxFailedErr {
return incr(key)
}
Expand Down
17 changes: 9 additions & 8 deletions pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,19 @@ var _ = Describe("pool", func() {
perform(1000, func(id int) {
var ping *redis.StatusCmd

tx, err := client.Watch()
Expect(err).NotTo(HaveOccurred())

cmds, err := tx.Exec(func() error {
ping = tx.Ping()
return nil
err := client.Watch(func(tx *redis.Tx) error {
cmds, err := tx.MultiExec(func() error {
ping = tx.Ping()
return nil
})
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1))
return err
})
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1))

Expect(ping.Err()).NotTo(HaveOccurred())
Expect(ping.Val()).To(Equal("PONG"))
Expect(tx.Close()).NotTo(HaveOccurred())
})

pool := client.Pool()
Expand Down
30 changes: 13 additions & 17 deletions race_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,30 +215,26 @@ var _ = Describe("races", func() {

perform(C, func(id int) {
for i := 0; i < N; i++ {
tx, err := client.Watch("key")
Expect(err).NotTo(HaveOccurred())

val, err := tx.Get("key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).NotTo(Equal(redis.Nil))
err := client.Watch(func(tx *redis.Tx) error {
val, err := tx.Get("key").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).NotTo(Equal(redis.Nil))

num, err := strconv.ParseInt(val, 10, 64)
Expect(err).NotTo(HaveOccurred())
num, err := strconv.ParseInt(val, 10, 64)
Expect(err).NotTo(HaveOccurred())

cmds, err := tx.Exec(func() error {
tx.Set("key", strconv.FormatInt(num+1, 10), 0)
return nil
})
cmds, err := tx.MultiExec(func() error {
tx.Set("key", strconv.FormatInt(num+1, 10), 0)
return nil
})
Expect(cmds).To(HaveLen(1))
return err
}, "key")
if err == redis.TxFailedErr {
i--
continue
}
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].Err()).NotTo(HaveOccurred())

err = tx.Close()
Expect(err).NotTo(HaveOccurred())
}
})

Expand Down
24 changes: 8 additions & 16 deletions redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,15 @@ var _ = Describe("Client", func() {
Expect(client.Ping().Err()).NotTo(HaveOccurred())
})

It("should close multi without closing the client", func() {
tx, err := client.Watch()
Expect(err).NotTo(HaveOccurred())
Expect(tx.Close()).NotTo(HaveOccurred())

_, err = tx.Exec(func() error {
tx.Ping()
return nil
It("should close Tx without closing the client", func() {
err := client.Watch(func(tx *redis.Tx) error {
_, err := tx.MultiExec(func() error {
tx.Ping()
return nil
})
return err
})
Expect(err).To(MatchError("redis: client is closed"))
Expect(err).NotTo(HaveOccurred())

Expect(client.Ping().Err()).NotTo(HaveOccurred())
})
Expand All @@ -96,13 +95,6 @@ var _ = Describe("Client", func() {
Expect(pubsub.Close()).NotTo(HaveOccurred())
})

It("should close multi when client is closed", func() {
tx, err := client.Watch()
Expect(err).NotTo(HaveOccurred())
Expect(client.Close()).NotTo(HaveOccurred())
Expect(tx.Close()).NotTo(HaveOccurred())
})

It("should close pipeline when client is closed", func() {
pipeline := client.Pipeline()
Expect(client.Close()).NotTo(HaveOccurred())
Expand Down
27 changes: 16 additions & 11 deletions tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,19 @@ func (c *Client) newTx() *Tx {
return tx
}

// Watch creates new transaction and marks the keys to be watched
// for conditional execution of a transaction.
func (c *Client) Watch(keys ...string) (*Tx, error) {
func (c *Client) Watch(fn func(*Tx) error, keys ...string) error {
tx := c.newTx()
if len(keys) > 0 {
if err := tx.Watch(keys...).Err(); err != nil {
tx.Close()
return nil, err
tx.close()
return err
}
}
return tx, nil
retErr := fn(tx)
if err := tx.close(); err != nil && retErr == nil {
retErr = err
}
return retErr
}

func (tx *Tx) process(cmd Cmder) {
Expand All @@ -55,8 +57,11 @@ func (tx *Tx) process(cmd Cmder) {
}
}

// Close closes the transaction, releasing any open resources.
func (tx *Tx) Close() error {
// close closes the transaction, releasing any open resources.
func (tx *Tx) close() error {
if tx.closed {
return nil
}
tx.closed = true
if err := tx.Unwatch().Err(); err != nil {
internal.Logf("Unwatch failed: %s", err)
Expand Down Expand Up @@ -98,7 +103,7 @@ func (tx *Tx) Discard() error {
return nil
}

// Exec executes all previously queued commands in a transaction
// MultiExec executes all previously queued commands in a transaction
// and restores the connection state to normal.
//
// When using WATCH, EXEC will execute commands only if the watched keys
Expand All @@ -107,13 +112,13 @@ func (tx *Tx) Discard() error {
// Exec always returns list of commands. If transaction fails
// TxFailedErr is returned. Otherwise Exec returns error of the first
// failed command or nil.
func (tx *Tx) Exec(f func() error) ([]Cmder, error) {
func (tx *Tx) MultiExec(fn func() error) ([]Cmder, error) {
if tx.closed {
return nil, pool.ErrClosed
}

tx.cmds = []Cmder{NewStatusCmd("MULTI")}
if err := f(); err != nil {
if err := fn(); err != nil {
return nil, err
}
tx.cmds = append(tx.cmds, NewSliceCmd("EXEC"))
Expand Down
Loading

0 comments on commit 092698e

Please sign in to comment.