Skip to content

Commit

Permalink
Add test for ring and cluster hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Feb 14, 2020
1 parent 2e3402d commit 49a0c8c
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 44 deletions.
88 changes: 55 additions & 33 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ func (c *ClusterClient) _process(ctx context.Context, cmd Cmder) error {
_ = pipe.Close()
ask = false
} else {
lastErr = node.Client._process(ctx, cmd)
lastErr = node.Client.ProcessContext(ctx, cmd)
}

// If there is no error - we are done.
Expand Down Expand Up @@ -840,6 +840,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {

var wg sync.WaitGroup
errCh := make(chan error, 1)

for _, master := range state.Masters {
wg.Add(1)
go func(node *clusterNode) {
Expand All @@ -853,6 +854,7 @@ func (c *ClusterClient) ForEachMaster(fn func(client *Client) error) error {
}
}(master)
}

wg.Wait()

select {
Expand All @@ -873,6 +875,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {

var wg sync.WaitGroup
errCh := make(chan error, 1)

for _, slave := range state.Slaves {
wg.Add(1)
go func(node *clusterNode) {
Expand All @@ -886,6 +889,7 @@ func (c *ClusterClient) ForEachSlave(fn func(client *Client) error) error {
}
}(slave)
}

wg.Wait()

select {
Expand All @@ -906,6 +910,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {

var wg sync.WaitGroup
errCh := make(chan error, 1)

worker := func(node *clusterNode) {
defer wg.Done()
err := fn(node.Client)
Expand All @@ -927,6 +932,7 @@ func (c *ClusterClient) ForEachNode(fn func(client *Client) error) error {
}

wg.Wait()

select {
case err := <-errCh:
return err
Expand Down Expand Up @@ -1068,18 +1074,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
go func(node *clusterNode, cmds []Cmder) {
defer wg.Done()

err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
})
if err != nil {
return err
}

return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
})
})
err := c._processPipelineNode(ctx, node, cmds, failedCmds)
if err == nil {
return
}
Expand Down Expand Up @@ -1142,6 +1137,25 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool {
return true
}

func (c *ClusterClient) _processPipelineNode(
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
) error {
return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
})
if err != nil {
return err
}

return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
})
})
})
}

func (c *ClusterClient) pipelineReadCmds(
node *clusterNode, rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap,
) error {
Expand Down Expand Up @@ -1243,26 +1257,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
go func(node *clusterNode, cmds []Cmder) {
defer wg.Done()

err := node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds)
})
if err != nil {
return err
}

return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := c.txPipelineReadQueued(rd, cmds, failedCmds)
if err != nil {
moved, ask, addr := isMovedError(err)
if moved || ask {
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
}
return err
}
return pipelineReadCmds(rd, cmds)
})
})
err := c._processTxPipelineNode(ctx, node, cmds, failedCmds)
if err == nil {
return
}
Expand Down Expand Up @@ -1296,6 +1291,33 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder {
return cmdsMap
}

func (c *ClusterClient) _processTxPipelineNode(
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
) error {
return node.Client.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
return node.Client.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds)
})
if err != nil {
return err
}

return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := c.txPipelineReadQueued(rd, cmds, failedCmds)
if err != nil {
moved, ask, addr := isMovedError(err)
if moved || ask {
return c.cmdsMoved(cmds, moved, ask, addr, failedCmds)
}
return err
}
return pipelineReadCmds(rd, cmds)
})
})
})
}

func (c *ClusterClient) txPipelineReadQueued(
rd *proto.Reader, cmds []Cmder, failedCmds *cmdsMap,
) error {
Expand Down
178 changes: 178 additions & 0 deletions cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,184 @@ var _ = Describe("ClusterClient", func() {
err := pubsub.Ping()
Expect(err).NotTo(HaveOccurred())
})

It("supports Process hook", func() {
var masters []*redis.Client

err := client.Ping().Err()
Expect(err).NotTo(HaveOccurred())

err = client.ForEachMaster(func(master *redis.Client) error {
masters = append(masters, master)
return master.Ping().Err()
})
Expect(err).NotTo(HaveOccurred())

var stack []string

clusterHook := &hook{
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
Expect(cmd.String()).To(Equal("ping: "))
stack = append(stack, "cluster.BeforeProcess")
return ctx, nil
},
afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
Expect(cmd.String()).To(Equal("ping: PONG"))
stack = append(stack, "cluster.AfterProcess")
return nil
},
}
client.AddHook(clusterHook)

masterHook := &hook{
beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
Expect(cmd.String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcess")
return ctx, nil
},
afterProcess: func(ctx context.Context, cmd redis.Cmder) error {
Expect(cmd.String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcess")
return nil
},
}

for _, master := range masters {
master.AddHook(masterHook)
}

err = client.Ping().Err()
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"cluster.BeforeProcess",
"shard.BeforeProcess",
"shard.AfterProcess",
"cluster.AfterProcess",
}))

clusterHook.beforeProcess = nil
clusterHook.afterProcess = nil
masterHook.beforeProcess = nil
masterHook.afterProcess = nil
})

It("supports Pipeline hook", func() {
var masters []*redis.Client

err := client.Ping().Err()
Expect(err).NotTo(HaveOccurred())

err = client.ForEachMaster(func(master *redis.Client) error {
masters = append(masters, master)
return master.Ping().Err()
})
Expect(err).NotTo(HaveOccurred())

var stack []string

client.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "cluster.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "cluster.AfterProcessPipeline")
return nil
},
})

for _, master := range masters {
master.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline")
return nil
},
})
}

_, err = client.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
return nil
})
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"cluster.BeforeProcessPipeline",
"shard.BeforeProcessPipeline",
"shard.AfterProcessPipeline",
"cluster.AfterProcessPipeline",
}))
})

It("supports TxPipeline hook", func() {
var masters []*redis.Client

err := client.Ping().Err()
Expect(err).NotTo(HaveOccurred())

err = client.ForEachMaster(func(master *redis.Client) error {
masters = append(masters, master)
return master.Ping().Err()
})
Expect(err).NotTo(HaveOccurred())

var stack []string

client.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "cluster.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "cluster.AfterProcessPipeline")
return nil
},
})

for _, master := range masters {
master.AddHook(&hook{
beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: "))
stack = append(stack, "shard.BeforeProcessPipeline")
return ctx, nil
},
afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error {
Expect(cmds).To(HaveLen(1))
Expect(cmds[0].String()).To(Equal("ping: PONG"))
stack = append(stack, "shard.AfterProcessPipeline")
return nil
},
})
}

_, err = client.TxPipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
return nil
})
Expect(err).NotTo(HaveOccurred())
Expect(stack).To(Equal([]string{
"cluster.BeforeProcessPipeline",
"shard.BeforeProcessPipeline",
"shard.AfterProcessPipeline",
"cluster.AfterProcessPipeline",
}))
})
}

Describe("ClusterClient", func() {
Expand Down
7 changes: 6 additions & 1 deletion command.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
type Cmder interface {
Name() string
Args() []interface{}
String() string
stringArg(int) string

readTimeout() *time.Duration
Expand Down Expand Up @@ -152,6 +153,10 @@ func NewCmd(args ...interface{}) *Cmd {
}
}

func (cmd *Cmd) String() string {
return cmdString(cmd, cmd.val)
}

func (cmd *Cmd) Val() interface{} {
return cmd.val
}
Expand All @@ -160,7 +165,7 @@ func (cmd *Cmd) Result() (interface{}, error) {
return cmd.val, cmd.err
}

func (cmd *Cmd) String() (string, error) {
func (cmd *Cmd) Text() (string, error) {
if cmd.err != nil {
return "", cmd.err
}
Expand Down
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ func Example_customCommand() {
}

func Example_customCommand2() {
v, err := rdb.Do("get", "key_does_not_exist").String()
v, err := rdb.Do("get", "key_does_not_exist").Text()
fmt.Printf("%q %s", v, err)
// Output: "" redis: nil
}
Expand Down
Loading

0 comments on commit 49a0c8c

Please sign in to comment.