Skip to content

Commit

Permalink
Replace Wrap* with hooks that support context
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed May 31, 2019
1 parent b902746 commit 8476dfe
Show file tree
Hide file tree
Showing 10 changed files with 423 additions and 349 deletions.
373 changes: 181 additions & 192 deletions cluster.go

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,14 @@ type baseCmd struct {

var _ Cmder = (*Cmd)(nil)

func (cmd *baseCmd) Err() error {
return cmd.err
func (cmd *baseCmd) Name() string {
if len(cmd._args) > 0 {
// Cmd name must be lower cased.
s := internal.ToLower(cmd.stringArg(0))
cmd._args[0] = s
return s
}
return ""
}

func (cmd *baseCmd) Args() []interface{} {
Expand All @@ -116,14 +122,8 @@ func (cmd *baseCmd) stringArg(pos int) string {
return s
}

func (cmd *baseCmd) Name() string {
if len(cmd._args) > 0 {
// Cmd name must be lower cased.
s := internal.ToLower(cmd.stringArg(0))
cmd._args[0] = s
return s
}
return ""
func (cmd *baseCmd) Err() error {
return cmd.err
}

func (cmd *baseCmd) readTimeout() *time.Duration {
Expand Down
52 changes: 31 additions & 21 deletions example_instrumentation_test.go
Original file line number Diff line number Diff line change
@@ -1,44 +1,54 @@
package redis_test

import (
"context"
"fmt"

"github.com/go-redis/redis"
)

type redisHook struct{}

var _ redis.Hook = redisHook{}

func (redisHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
fmt.Printf("starting processing: <%s>\n", cmd)
return ctx, nil
}

func (redisHook) AfterProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
fmt.Printf("finished processing: <%s>\n", cmd)
return ctx, nil
}

func (redisHook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
fmt.Printf("pipeline starting processing: %v\n", cmds)
return ctx, nil
}

func (redisHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
fmt.Printf("pipeline finished processing: %v\n", cmds)
return ctx, nil
}

func Example_instrumentation() {
redisdb := redis.NewClient(&redis.Options{
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
})
redisdb.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
return func(cmd redis.Cmder) error {
fmt.Printf("starting processing: <%s>\n", cmd)
err := old(cmd)
fmt.Printf("finished processing: <%s>\n", cmd)
return err
}
})
rdb.AddHook(redisHook{})

redisdb.Ping()
rdb.Ping()
// Output: starting processing: <ping: >
// finished processing: <ping: PONG>
}

func ExamplePipeline_instrumentation() {
redisdb := redis.NewClient(&redis.Options{
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
})
rdb.AddHook(redisHook{})

redisdb.WrapProcessPipeline(func(old func([]redis.Cmder) error) func([]redis.Cmder) error {
return func(cmds []redis.Cmder) error {
fmt.Printf("pipeline starting processing: %v\n", cmds)
err := old(cmds)
fmt.Printf("pipeline finished processing: %v\n", cmds)
return err
}
})

redisdb.Pipelined(func(pipe redis.Pipeliner) error {
rdb.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Ping()
pipe.Ping()
return nil
Expand Down
162 changes: 125 additions & 37 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,114 @@ func SetLogger(logger *log.Logger) {
internal.Logger = logger
}

//------------------------------------------------------------------------------

type Hook interface {
BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
AfterProcess(ctx context.Context, cmd Cmder) (context.Context, error)

BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
AfterProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
}

type hooks struct {
hooks []Hook
}

func (hs *hooks) AddHook(hook Hook) {
hs.hooks = append(hs.hooks, hook)
}

func (hs *hooks) copy() {
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
}

func (hs hooks) process(ctx context.Context, cmd Cmder, fn func(Cmder) error) error {
ctx, err := hs.beforeProcess(ctx, cmd)
if err != nil {
return err
}

cmdErr := fn(cmd)

_, err = hs.afterProcess(ctx, cmd)
if err != nil {
return err
}

return cmdErr
}

func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcess(ctx, cmd)
if err != nil {
return nil, err
}
}
return ctx, nil
}

func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.AfterProcess(ctx, cmd)
if err != nil {
return nil, err
}
}
return ctx, nil
}

func (hs hooks) processPipeline(ctx context.Context, cmds []Cmder, fn func([]Cmder) error) error {
ctx, err := hs.beforeProcessPipeline(ctx, cmds)
if err != nil {
return err
}

cmdsErr := fn(cmds)

_, err = hs.afterProcessPipeline(ctx, cmds)
if err != nil {
return err
}

return cmdsErr
}

func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.BeforeProcessPipeline(ctx, cmds)
if err != nil {
return nil, err
}
}
return ctx, nil
}

func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
for _, h := range hs.hooks {
var err error
ctx, err = h.AfterProcessPipeline(ctx, cmds)
if err != nil {
return nil, err
}
}
return ctx, nil
}

//------------------------------------------------------------------------------

type baseClient struct {
opt *Options
connPool pool.Pooler
limiter Limiter

process func(Cmder) error
processPipeline func([]Cmder) error
processTxPipeline func([]Cmder) error

onClose func() error // hook called when client is closed
}

func (c *baseClient) init() {
c.process = c.defaultProcess
c.processPipeline = c.defaultProcessPipeline
c.processTxPipeline = c.defaultProcessTxPipeline
}

func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
}
Expand Down Expand Up @@ -159,22 +249,11 @@ func (c *baseClient) initConn(cn *pool.Conn) error {
// Do creates a Cmd from the args and processes the cmd.
func (c *baseClient) Do(args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.Process(cmd)
_ = c.process(cmd)
return cmd
}

// WrapProcess wraps function that processes Redis commands.
func (c *baseClient) WrapProcess(
fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error,
) {
c.process = fn(c.process)
}

func (c *baseClient) Process(cmd Cmder) error {
return c.process(cmd)
}

func (c *baseClient) defaultProcess(cmd Cmder) error {
func (c *baseClient) process(cmd Cmder) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 {
time.Sleep(c.retryBackoff(attempt))
Expand Down Expand Up @@ -249,18 +328,11 @@ func (c *baseClient) getAddr() string {
return c.opt.Addr
}

func (c *baseClient) WrapProcessPipeline(
fn func(oldProcess func([]Cmder) error) func([]Cmder) error,
) {
c.processPipeline = fn(c.processPipeline)
c.processTxPipeline = fn(c.processTxPipeline)
}

func (c *baseClient) defaultProcessPipeline(cmds []Cmder) error {
func (c *baseClient) processPipeline(cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.pipelineProcessCmds)
}

func (c *baseClient) defaultProcessTxPipeline(cmds []Cmder) error {
func (c *baseClient) processTxPipeline(cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds)
}

Expand Down Expand Up @@ -388,6 +460,7 @@ type Client struct {
cmdable

ctx context.Context
hooks
}

// NewClient returns a client to the Redis Server specified by Options.
Expand All @@ -400,7 +473,6 @@ func NewClient(opt *Options) *Client {
connPool: newConnPool(opt),
},
}
c.baseClient.init()
c.init()

return &c
Expand All @@ -427,9 +499,22 @@ func (c *Client) WithContext(ctx context.Context) *Client {
}

func (c *Client) clone() *Client {
cp := *c
cp.init()
return &cp
clone := *c
clone.hooks.copy()
clone.init()
return &clone
}

func (c *Client) Process(cmd Cmder) error {
return c.hooks.process(c.ctx, cmd, c.baseClient.process)
}

func (c *Client) processPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processPipeline)
}

func (c *Client) processTxPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processTxPipeline)
}

// Options returns read-only Options that were used to create the client.
Expand Down Expand Up @@ -547,11 +632,14 @@ func newConn(opt *Options, cn *pool.Conn) *Conn {
connPool: pool.NewSingleConnPool(cn),
},
}
c.baseClient.init()
c.statefulCmdable.setProcessor(c.Process)
return &c
}

func (c *Conn) Process(cmd Cmder) error {
return c.baseClient.process(cmd)
}

func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn)
}
Expand Down
37 changes: 0 additions & 37 deletions redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,43 +224,6 @@ var _ = Describe("Client", func() {
Expect(err).NotTo(HaveOccurred())
Expect(got).To(Equal(bigVal))
})

It("should call WrapProcess", func() {
var fnCalled bool

client.WrapProcess(func(old func(redis.Cmder) error) func(redis.Cmder) error {
return func(cmd redis.Cmder) error {
fnCalled = true
return old(cmd)
}
})

Expect(client.Ping().Err()).NotTo(HaveOccurred())
Expect(fnCalled).To(BeTrue())
})

It("should call WrapProcess after WithContext", func() {
var fn1Called, fn2Called bool

client.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
return func(cmd redis.Cmder) error {
fn1Called = true
return old(cmd)
}
})

client2 := client.WithContext(client.Context())
client2.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
return func(cmd redis.Cmder) error {
fn2Called = true
return old(cmd)
}
})

Expect(client2.Ping().Err()).NotTo(HaveOccurred())
Expect(fn2Called).To(BeTrue())
Expect(fn1Called).To(BeTrue())
})
})

var _ = Describe("Client timeout", func() {
Expand Down
Loading

0 comments on commit 8476dfe

Please sign in to comment.