Skip to content

Commit

Permalink
Add ClusterPipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Apr 13, 2015
1 parent 5c951b3 commit 99fe911
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 104 deletions.
75 changes: 41 additions & 34 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ func (c *ClusterClient) Close() error {
// ------------------------------------------------------------------------

// getClient returns a Client for a given address.
func (c *ClusterClient) getClient(addr string) *Client {
func (c *ClusterClient) getClient(addr string) (*Client, error) {
if addr == "" {
return c.randomClient()
}

c.clientsMx.RLock()
client, ok := c.clients[addr]
if ok {
c.clientsMx.RUnlock()
return client
return client, nil
}
c.clientsMx.RUnlock()

Expand All @@ -66,14 +70,24 @@ func (c *ClusterClient) getClient(addr string) *Client {
}
c.clientsMx.Unlock()

return client
return client, nil
}

func (c *ClusterClient) slotAddrs(slot int) []string {
c.slotsMx.RLock()
addrs := c.slots[slot]
c.slotsMx.RUnlock()
return addrs
}

// randomClient returns a Client for the first live node.
func (c *ClusterClient) randomClient() (client *Client, err error) {
for i := 0; i < 10; i++ {
n := rand.Intn(len(c.addrs))
client = c.getClient(c.addrs[n])
client, err = c.getClient(c.addrs[n])
if err != nil {
continue
}
err = client.Ping().Err()
if err == nil {
return client, nil
Expand All @@ -82,27 +96,22 @@ func (c *ClusterClient) randomClient() (client *Client, err error) {
return nil, err
}

// Process a command
func (c *ClusterClient) process(cmd Cmder) {
var client *Client
var ask bool

c.reloadIfDue()

slot := hashSlot(cmd.clusterKey())
c.slotsMx.RLock()
addrs := c.slots[slot]
c.slotsMx.RUnlock()

if len(addrs) > 0 {
client = c.getClient(addrs[0]) // First address is master.
} else {
var err error
client, err = c.randomClient()
if err != nil {
cmd.setErr(err)
return
}
var addr string
if addrs := c.slotAddrs(slot); len(addrs) > 0 {
addr = addrs[0] // First address is master.
}

client, err := c.getClient(addr)
if err != nil {
cmd.setErr(err)
return
}

for attempt := 0; attempt <= c.opt.getMaxRedirects(); attempt++ {
Expand Down Expand Up @@ -132,24 +141,22 @@ func (c *ClusterClient) process(cmd Cmder) {
continue
}

// Check the error message, return if unexpected
parts := strings.SplitN(err.Error(), " ", 3)
if len(parts) != 3 {
return
var moved bool
var addr string
moved, ask, addr = isMovedError(err)
if moved || ask {
if moved {
c.scheduleReload()
}
client, err = c.getClient(addr)
if err != nil {
return
}
cmd.reset()
continue
}

// Handle MOVE and ASK redirections, return on any other error
switch parts[0] {
case "MOVED":
c.scheduleReload()
client = c.getClient(parts[2])
case "ASK":
ask = true
client = c.getClient(parts[2])
default:
return
}
cmd.reset()
break
}
}

Expand Down
17 changes: 17 additions & 0 deletions cluster_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@ import (
. "github.com/onsi/gomega"
)

// GetSlot returns the cached slot addresses
func (c *ClusterClient) GetSlot(pos int) []string {
c.slotsMx.RLock()
defer c.slotsMx.RUnlock()

return c.slots[pos]
}

// SwapSlot swaps a slot's master/slave address
// for testing MOVED redirects
func (c *ClusterClient) SwapSlot(pos int) []string {
c.slotsMx.Lock()
defer c.slotsMx.Unlock()
c.slots[pos][0], c.slots[pos][1] = c.slots[pos][1], c.slots[pos][0]
return c.slots[pos]
}

var _ = Describe("ClusterClient", func() {
var subject *ClusterClient

Expand Down
128 changes: 128 additions & 0 deletions cluster_pipeline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package redis

// ClusterPipeline is not thread-safe.
type ClusterPipeline struct {
commandable

cmds []Cmder
cluster *ClusterClient
closed bool
}

// Pipeline creates a new pipeline which is able to execute commands
// against multiple shards.
func (c *ClusterClient) Pipeline() *ClusterPipeline {
pipe := &ClusterPipeline{
cluster: c,
cmds: make([]Cmder, 0, 10),
}
pipe.commandable.process = pipe.process
return pipe
}

func (c *ClusterPipeline) process(cmd Cmder) {
c.cmds = append(c.cmds, cmd)
}

// Close marks the pipeline as closed
func (c *ClusterPipeline) Close() error {
c.closed = true
return nil
}

// Discard resets the pipeline and discards queued commands
func (c *ClusterPipeline) Discard() error {
if c.closed {
return errClosed
}
c.cmds = c.cmds[:0]
return nil
}

func (c *ClusterPipeline) Exec() (cmds []Cmder, retErr error) {
if c.closed {
return nil, errClosed
}
if len(c.cmds) == 0 {
return []Cmder{}, nil
}

cmds = c.cmds
c.cmds = make([]Cmder, 0, 10)

cmdsMap := make(map[string][]Cmder)
for _, cmd := range cmds {
slot := hashSlot(cmd.clusterKey())
addrs := c.cluster.slotAddrs(slot)

var addr string
if len(addrs) > 0 {
addr = addrs[0] // First address is master.
}

cmdsMap[addr] = append(cmdsMap[addr], cmd)
}

for attempt := 0; attempt <= c.cluster.opt.getMaxRedirects(); attempt++ {
failedCmds := make(map[string][]Cmder)

for addr, cmds := range cmdsMap {
client, err := c.cluster.getClient(addr)
if err != nil {
setCmdsErr(cmds, err)
retErr = err
continue
}

cn, err := client.conn()
if err != nil {
setCmdsErr(cmds, err)
retErr = err
continue
}

failedCmds, err = c.execClusterCmds(cn, cmds, failedCmds)
if err != nil {
retErr = err
}
client.freeConn(cn, err)
}

cmdsMap = failedCmds
}

return cmds, retErr
}

func (c *ClusterPipeline) execClusterCmds(
cn *conn, cmds []Cmder, failedCmds map[string][]Cmder,
) (map[string][]Cmder, error) {
if err := cn.writeCmds(cmds...); err != nil {
setCmdsErr(cmds, err)
return failedCmds, err
}

var firstCmdErr error
for i, cmd := range cmds {
err := cmd.parseReply(cn.rd)
if err == nil {
continue
}
if isNetworkError(err) {
cmd.reset()
failedCmds[""] = append(failedCmds[""], cmds[i:]...)
break
} else if moved, ask, addr := isMovedError(err); moved {
c.cluster.scheduleReload()
cmd.reset()
failedCmds[addr] = append(failedCmds[addr], cmd)
} else if ask {
cmd.reset()
failedCmds[addr] = append(failedCmds[addr], NewCmd("ASKING"), cmd)
} else if firstCmdErr == nil {
firstCmdErr = err
}
}

return failedCmds, firstCmdErr
}
52 changes: 41 additions & 11 deletions cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package redis_test

import (
"math/rand"
"time"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

"gopkg.in/redis.v2"
)

Expand Down Expand Up @@ -181,22 +183,50 @@ var _ = Describe("Cluster", func() {
It("should follow redirects", func() {
Expect(client.Set("A", "VALUE", 0).Err()).NotTo(HaveOccurred())
Expect(redis.HashSlot("A")).To(Equal(6373))

// Slot 6373 is stored on the second node
defer func() {
scenario.masters()[1].ClusterFailover()
}()

slave := scenario.slaves()[1]
Expect(slave.ClusterFailover().Err()).NotTo(HaveOccurred())
Eventually(func() string {
return slave.Info().Val()
}, "10s", "200ms").Should(ContainSubstring("role:master"))
Expect(client.SwapSlot(6373)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"}))

val, err := client.Get("A").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("VALUE"))
Expect(client.GetSlot(6373)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"}))

val, err = client.Get("A").Result()
Expect(err).NotTo(HaveOccurred())
Expect(val).To(Equal("VALUE"))
Expect(client.GetSlot(6373)).To(Equal([]string{"127.0.0.1:8221", "127.0.0.1:8224"}))
})

It("should perform multi-pipelines", func() {
// Dummy command to load slots info.
Expect(client.Ping().Err()).NotTo(HaveOccurred())

slot := redis.HashSlot("A")
Expect(client.SwapSlot(slot)).To(Equal([]string{"127.0.0.1:8224", "127.0.0.1:8221"}))

pipe := client.Pipeline()
defer pipe.Close()

keys := []string{"A", "B", "C", "D", "E", "F", "G"}
for i, key := range keys {
pipe.Set(key, key+"_value", 0)
pipe.Expire(key, time.Duration(i+1)*time.Hour)
}
for _, key := range keys {
pipe.Get(key)
pipe.TTL(key)
}

cmds, err := pipe.Exec()
Expect(err).NotTo(HaveOccurred())
Expect(cmds).To(HaveLen(28))
Expect(cmds[14].(*redis.StringCmd).Val()).To(Equal("A_value"))
Expect(cmds[15].(*redis.DurationCmd).Val()).To(BeNumerically("~", 1*time.Hour, time.Second))
Expect(cmds[20].(*redis.StringCmd).Val()).To(Equal("D_value"))
Expect(cmds[21].(*redis.DurationCmd).Val()).To(BeNumerically("~", 4*time.Hour, time.Second))
Expect(cmds[26].(*redis.StringCmd).Val()).To(Equal("G_value"))
Expect(cmds[27].(*redis.DurationCmd).Val()).To(BeNumerically("~", 7*time.Hour, time.Second))
})

})
})

Expand Down
Loading

0 comments on commit 99fe911

Please sign in to comment.