diff --git a/ring.go b/ring.go index a582aefc8..5956b71a5 100644 --- a/ring.go +++ b/ring.go @@ -361,7 +361,8 @@ func NewRing(opt *RingOptions) *Ring { ring.process = ring.defaultProcess ring.processPipeline = ring.defaultProcessPipeline - ring.cmdable.setProcessor(ring.Process) + + ring.init() for name, addr := range opt.Addrs { clopt := opt.clientOptions() @@ -374,6 +375,10 @@ func NewRing(opt *RingOptions) *Ring { return ring } +func (c *Ring) init() { + c.cmdable.setProcessor(c.Process) +} + func (c *Ring) Context() context.Context { if c.ctx != nil { return c.ctx @@ -392,6 +397,8 @@ func (c *Ring) WithContext(ctx context.Context) *Ring { func (c *Ring) clone() *Ring { cp := *c + cp.init() + return &cp } diff --git a/ring_test.go b/ring_test.go index 4ff089860..d498e0349 100644 --- a/ring_test.go +++ b/ring_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "crypto/rand" "fmt" "net" @@ -104,6 +105,27 @@ var _ = Describe("Redis Ring", func() { Expect(ringShard2.Info("keyspace").Val()).To(ContainSubstring("keys=100")) }) + It("propagates process for WithContext", func() { + var fromWrap []string + wrapper := func(oldProcess func(cmd redis.Cmder) error) func(cmd redis.Cmder) error { + return func(cmd redis.Cmder) error { + fromWrap = append(fromWrap, cmd.Name()) + + return oldProcess(cmd) + } + } + + ctx := context.Background() + ring = ring.WithContext(ctx) + ring.WrapProcess(wrapper) + + ring.Ping() + Expect(fromWrap).To(Equal([]string{"ping"})) + + ring.Ping() + Expect(fromWrap).To(Equal([]string{"ping", "ping"})) + }) + Describe("pipeline", func() { It("distributes keys", func() { pipe := ring.Pipeline()