diff --git a/cluster.go b/cluster.go index 8709f9031..b5463f434 100644 --- a/cluster.go +++ b/cluster.go @@ -673,6 +673,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { opt: opt, nodes: newClusterNodes(opt), }, + ctx: context.Background(), } c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) @@ -690,10 +691,7 @@ func (c *ClusterClient) init() { } func (c *ClusterClient) Context() context.Context { - if c.ctx != nil { - return c.ctx - } - return context.Background() + return c.ctx } func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient { @@ -702,6 +700,7 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient { } clone := *c clone.ctx = ctx + clone.init() return &clone } @@ -732,7 +731,7 @@ func (c *ClusterClient) Do(args ...interface{}) *Cmd { func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd { cmd := NewCmd(args...) - c.ProcessContext(ctx, cmd) + _ = c.ProcessContext(ctx, cmd) return cmd } @@ -1035,7 +1034,7 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) { } func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { - return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline) + return c.hooks.processPipeline(ctx, cmds, c._processPipeline) } func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error { diff --git a/cluster_test.go b/cluster_test.go index ad3a6be1c..c8554c9b9 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "fmt" "net" "strconv" @@ -241,6 +242,14 @@ var _ = Describe("ClusterClient", func() { var client *redis.ClusterClient assertClusterClient := func() { + It("supports WithContext", func() { + c, cancel := context.WithCancel(context.Background()) + cancel() + + err := client.WithContext(c).Ping().Err() + Expect(err).To(MatchError("context canceled")) + }) + It("should GET/SET/DEL", func() { err := client.Get("A").Err() Expect(err).To(Equal(redis.Nil)) diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index 82e2d5d95..a461c72f7 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -1,6 +1,7 @@ package pool_test import ( + "context" "fmt" "testing" "time" @@ -39,7 +40,7 @@ func BenchmarkPoolGetPut(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - cn, err := connPool.Get(nil) + cn, err := connPool.Get(context.Background()) if err != nil { b.Fatal(err) } @@ -81,7 +82,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - cn, err := connPool.Get(nil) + cn, err := connPool.Get(context.Background()) if err != nil { b.Fatal(err) } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 0a32e0311..0f730ebd0 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -250,38 +250,38 @@ func (p *ConnPool) getTurn() { } func (p *ConnPool) waitTurn(ctx context.Context) error { - var done <-chan struct{} - if ctx != nil { - done = ctx.Done() + select { + case <-ctx.Done(): + return ctx.Err() + default: } select { - case <-done: - return ctx.Err() case p.queue <- struct{}{}: return nil default: - timer := timers.Get().(*time.Timer) - timer.Reset(p.opt.PoolTimeout) + } - select { - case <-done: - if !timer.Stop() { - <-timer.C - } - timers.Put(timer) - return ctx.Err() - case p.queue <- struct{}{}: - if !timer.Stop() { - <-timer.C - } - timers.Put(timer) - return nil - case <-timer.C: - timers.Put(timer) - atomic.AddUint32(&p.stats.Timeouts, 1) - return ErrPoolTimeout + timer := timers.Get().(*time.Timer) + timer.Reset(p.opt.PoolTimeout) + + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + timers.Put(timer) + return ctx.Err() + case p.queue <- struct{}{}: + if !timer.Stop() { + <-timer.C } + timers.Put(timer) + return nil + case <-timer.C: + timers.Put(timer) + atomic.AddUint32(&p.stats.Timeouts, 1) + return ErrPoolTimeout } } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 18d027802..e023348c8 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -1,6 +1,7 @@ package pool_test import ( + "context" "sync" "testing" "time" @@ -12,6 +13,7 @@ import ( ) var _ = Describe("ConnPool", func() { + c := context.Background() var connPool *pool.ConnPool BeforeEach(func() { @@ -30,13 +32,13 @@ var _ = Describe("ConnPool", func() { It("should unblock client when conn is removed", func() { // Reserve one connection. - cn, err := connPool.Get(nil) + cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) // Reserve all other connections. var cns []*pool.Conn for i := 0; i < 9; i++ { - cn, err := connPool.Get(nil) + cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) cns = append(cns, cn) } @@ -47,7 +49,7 @@ var _ = Describe("ConnPool", func() { defer GinkgoRecover() started <- true - _, err := connPool.Get(nil) + _, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) done <- true @@ -80,6 +82,7 @@ var _ = Describe("ConnPool", func() { }) var _ = Describe("MinIdleConns", func() { + c := context.Background() const poolSize = 100 var minIdleConns int var connPool *pool.ConnPool @@ -110,7 +113,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { var err error - cn, err = connPool.Get(nil) + cn, err = connPool.Get(c) Expect(err).NotTo(HaveOccurred()) Eventually(func() int { @@ -145,7 +148,7 @@ var _ = Describe("MinIdleConns", func() { perform(poolSize, func(_ int) { defer GinkgoRecover() - cn, err := connPool.Get(nil) + cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) mu.Lock() cns = append(cns, cn) @@ -160,7 +163,7 @@ var _ = Describe("MinIdleConns", func() { It("Get is blocked", func() { done := make(chan struct{}) go func() { - connPool.Get(nil) + connPool.Get(c) close(done) }() @@ -247,6 +250,8 @@ var _ = Describe("MinIdleConns", func() { }) var _ = Describe("conns reaper", func() { + c := context.Background() + const idleTimeout = time.Minute const maxAge = time.Hour @@ -274,7 +279,7 @@ var _ = Describe("conns reaper", func() { // add stale connections staleConns = nil for i := 0; i < 3; i++ { - cn, err := connPool.Get(nil) + cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) switch typ { case "idle": @@ -288,7 +293,7 @@ var _ = Describe("conns reaper", func() { // add fresh connections for i := 0; i < 3; i++ { - cn, err := connPool.Get(nil) + cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) conns = append(conns, cn) } @@ -333,7 +338,7 @@ var _ = Describe("conns reaper", func() { for j := 0; j < 3; j++ { var freeCns []*pool.Conn for i := 0; i < 3; i++ { - cn, err := connPool.Get(nil) + cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) freeCns = append(freeCns, cn) @@ -342,7 +347,7 @@ var _ = Describe("conns reaper", func() { Expect(connPool.Len()).To(Equal(3)) Expect(connPool.IdleLen()).To(Equal(0)) - cn, err := connPool.Get(nil) + cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) conns = append(conns, cn) @@ -370,6 +375,7 @@ var _ = Describe("conns reaper", func() { }) var _ = Describe("race", func() { + c := context.Background() var connPool *pool.ConnPool var C, N int @@ -396,7 +402,7 @@ var _ = Describe("race", func() { perform(C, func(id int) { for i := 0; i < N; i++ { - cn, err := connPool.Get(nil) + cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) if err == nil { connPool.Put(cn) @@ -404,7 +410,7 @@ var _ = Describe("race", func() { } }, func(id int) { for i := 0; i < N; i++ { - cn, err := connPool.Get(nil) + cn, err := connPool.Get(c) Expect(err).NotTo(HaveOccurred()) if err == nil { connPool.Remove(cn) diff --git a/pipeline.go b/pipeline.go index 51333a736..8b9147ea8 100644 --- a/pipeline.go +++ b/pipeline.go @@ -98,7 +98,7 @@ func (c *Pipeline) discard() error { // Exec always returns list of commands and error of the first failed // command if any. func (c *Pipeline) Exec() ([]Cmder, error) { - return c.ExecContext(nil) + return c.ExecContext(context.Background()) } func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) { diff --git a/pool_test.go b/pool_test.go index 3adcebbb3..2f3b10e1e 100644 --- a/pool_test.go +++ b/pool_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "time" "github.com/go-redis/redis" @@ -81,7 +82,7 @@ var _ = Describe("pool", func() { }) It("removes broken connections", func() { - cn, err := client.Pool().Get(nil) + cn, err := client.Pool().Get(context.Background()) Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) client.Pool().Put(cn) diff --git a/redis.go b/redis.go index d9db7e108..33c580a50 100644 --- a/redis.go +++ b/redis.go @@ -32,10 +32,6 @@ type hooks struct { hooks []Hook } -func (hs *hooks) lazyCopy() { - hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] -} - func (hs *hooks) AddHook(hook Hook) { hs.hooks = append(hs.hooks, hook) } @@ -475,6 +471,7 @@ func NewClient(opt *Options) *Client { connPool: newConnPool(opt), }, }, + ctx: context.Background(), } c.init() @@ -486,10 +483,7 @@ func (c *Client) init() { } func (c *Client) Context() context.Context { - if c.ctx != nil { - return c.ctx - } - return context.Background() + return c.ctx } func (c *Client) WithContext(ctx context.Context) *Client { @@ -498,6 +492,7 @@ func (c *Client) WithContext(ctx context.Context) *Client { } clone := *c clone.ctx = ctx + clone.init() return &clone } diff --git a/redis_test.go b/redis_test.go index 801753508..eb1bc981e 100644 --- a/redis_test.go +++ b/redis_test.go @@ -24,6 +24,14 @@ var _ = Describe("Client", func() { client.Close() }) + It("supports WithContext", func() { + c, cancel := context.WithCancel(context.Background()) + cancel() + + err := client.WithContext(c).Ping().Err() + Expect(err).To(MatchError("context canceled")) + }) + It("should Stringer", func() { Expect(client.String()).To(Equal("Redis<:6380 db:15>")) }) @@ -129,7 +137,7 @@ var _ = Describe("Client", func() { It("processes custom commands", func() { cmd := redis.NewCmd("PING") - client.Process(cmd) + _ = client.Process(cmd) // Flush buffers. Expect(client.Echo("hello").Err()).NotTo(HaveOccurred()) @@ -147,7 +155,7 @@ var _ = Describe("Client", func() { }) // Put bad connection in the pool. - cn, err := client.Pool().Get(nil) + cn, err := client.Pool().Get(context.Background()) Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{}) @@ -185,7 +193,7 @@ var _ = Describe("Client", func() { }) It("should update conn.UsedAt on read/write", func() { - cn, err := client.Pool().Get(nil) + cn, err := client.Pool().Get(context.Background()) Expect(err).NotTo(HaveOccurred()) Expect(cn.UsedAt).NotTo(BeZero()) createdAt := cn.UsedAt() @@ -198,7 +206,7 @@ var _ = Describe("Client", func() { err = client.Ping().Err() Expect(err).NotTo(HaveOccurred()) - cn, err = client.Pool().Get(nil) + cn, err = client.Pool().Get(context.Background()) Expect(err).NotTo(HaveOccurred()) Expect(cn).NotTo(BeNil()) Expect(cn.UsedAt().After(createdAt)).To(BeTrue()) diff --git a/ring.go b/ring.go index 93066ffa4..b9dc3e820 100644 --- a/ring.go +++ b/ring.go @@ -358,6 +358,7 @@ func NewRing(opt *RingOptions) *Ring { opt: opt, shards: newRingShards(opt), }, + ctx: context.Background(), } ring.init() @@ -379,10 +380,7 @@ func (c *Ring) init() { } func (c *Ring) Context() context.Context { - if c.ctx != nil { - return c.ctx - } - return context.Background() + return c.ctx } func (c *Ring) WithContext(ctx context.Context) *Ring { @@ -391,6 +389,7 @@ func (c *Ring) WithContext(ctx context.Context) *Ring { } clone := *c clone.ctx = ctx + clone.init() return &clone } @@ -401,7 +400,7 @@ func (c *Ring) Do(args ...interface{}) *Cmd { func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd { cmd := NewCmd(args...) - c.ProcessContext(ctx, cmd) + _ = c.ProcessContext(ctx, cmd) return cmd } diff --git a/ring_test.go b/ring_test.go index e44a40466..3c94dea02 100644 --- a/ring_test.go +++ b/ring_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "crypto/rand" "fmt" "net" @@ -41,6 +42,14 @@ var _ = Describe("Redis Ring", func() { Expect(ring.Close()).NotTo(HaveOccurred()) }) + It("supports WithContext", func() { + c, cancel := context.WithCancel(context.Background()) + cancel() + + err := ring.WithContext(c).Ping().Err() + Expect(err).To(MatchError("context canceled")) + }) + It("distributes keys", func() { setRingKeys() diff --git a/sentinel.go b/sentinel.go index 4c75de90f..2fb32e377 100644 --- a/sentinel.go +++ b/sentinel.go @@ -97,8 +97,9 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { onClose: failover.Close, }, }, + ctx: context.Background(), } - c.cmdable = c.Process + c.init() return &c } @@ -117,15 +118,13 @@ func NewSentinelClient(opt *Options) *SentinelClient { opt: opt, connPool: newConnPool(opt), }, + ctx: context.Background(), } return c } func (c *SentinelClient) Context() context.Context { - if c.ctx != nil { - return c.ctx - } - return context.Background() + return c.ctx } func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient { @@ -162,7 +161,7 @@ func (c *SentinelClient) pubSub() *PubSub { // measure latency. func (c *SentinelClient) Ping() *StringCmd { cmd := NewStringCmd("ping") - c.Process(cmd) + _ = c.Process(cmd) return cmd } @@ -188,13 +187,13 @@ func (c *SentinelClient) PSubscribe(channels ...string) *PubSub { func (c *SentinelClient) GetMasterAddrByName(name string) *StringSliceCmd { cmd := NewStringSliceCmd("sentinel", "get-master-addr-by-name", name) - c.Process(cmd) + _ = c.Process(cmd) return cmd } func (c *SentinelClient) Sentinels(name string) *SliceCmd { cmd := NewSliceCmd("sentinel", "sentinels", name) - c.Process(cmd) + _ = c.Process(cmd) return cmd } @@ -202,7 +201,7 @@ func (c *SentinelClient) Sentinels(name string) *SliceCmd { // asking for agreement to other Sentinels. func (c *SentinelClient) Failover(name string) *StatusCmd { cmd := NewStatusCmd("sentinel", "failover", name) - c.Process(cmd) + _ = c.Process(cmd) return cmd } @@ -212,7 +211,7 @@ func (c *SentinelClient) Failover(name string) *StatusCmd { // already discovered and associated with the master. func (c *SentinelClient) Reset(pattern string) *IntCmd { cmd := NewIntCmd("sentinel", "reset", pattern) - c.Process(cmd) + _ = c.Process(cmd) return cmd } @@ -220,28 +219,28 @@ func (c *SentinelClient) Reset(pattern string) *IntCmd { // the current Sentinel state. func (c *SentinelClient) FlushConfig() *StatusCmd { cmd := NewStatusCmd("sentinel", "flushconfig") - c.Process(cmd) + _ = c.Process(cmd) return cmd } // Master shows the state and info of the specified master. func (c *SentinelClient) Master(name string) *StringStringMapCmd { cmd := NewStringStringMapCmd("sentinel", "master", name) - c.Process(cmd) + _ = c.Process(cmd) return cmd } // Masters shows a list of monitored masters and their state. func (c *SentinelClient) Masters() *SliceCmd { cmd := NewSliceCmd("sentinel", "masters") - c.Process(cmd) + _ = c.Process(cmd) return cmd } // Slaves shows a list of slaves for the specified master and their state. func (c *SentinelClient) Slaves(name string) *SliceCmd { cmd := NewSliceCmd("sentinel", "slaves", name) - c.Process(cmd) + _ = c.Process(cmd) return cmd } @@ -251,7 +250,7 @@ func (c *SentinelClient) Slaves(name string) *SliceCmd { // Sentinel deployment is ok. func (c *SentinelClient) CkQuorum(name string) *StringCmd { cmd := NewStringCmd("sentinel", "ckquorum", name) - c.Process(cmd) + _ = c.Process(cmd) return cmd } @@ -259,14 +258,14 @@ func (c *SentinelClient) CkQuorum(name string) *StringCmd { // name, ip, port, and quorum. func (c *SentinelClient) Monitor(name, ip, port, quorum string) *StringCmd { cmd := NewStringCmd("sentinel", "monitor", name, ip, port, quorum) - c.Process(cmd) + _ = c.Process(cmd) return cmd } // Set is used in order to change configuration parameters of a specific master. func (c *SentinelClient) Set(name, option, value string) *StringCmd { cmd := NewStringCmd("sentinel", "set", name, option, value) - c.Process(cmd) + _ = c.Process(cmd) return cmd } @@ -275,7 +274,7 @@ func (c *SentinelClient) Set(name, option, value string) *StringCmd { // the Sentinel. func (c *SentinelClient) Remove(name string) *StringCmd { cmd := NewStringCmd("sentinel", "remove", name) - c.Process(cmd) + _ = c.Process(cmd) return cmd } @@ -313,7 +312,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool { return c.pool } -func (c *sentinelFailover) dial(ctx context.Context, network, addr string) (net.Conn, error) { +func (c *sentinelFailover) dial(ctx context.Context, network, _ string) (net.Conn, error) { addr, err := c.MasterAddr() if err != nil { return nil, err @@ -396,7 +395,7 @@ func (c *sentinelFailover) getMasterAddr() string { c.masterName, err) c.mu.Lock() if c.sentinel == sentinel { - c.closeSentinel() + _ = c.closeSentinel() } c.mu.Unlock() return "" @@ -436,13 +435,13 @@ func (c *sentinelFailover) closeSentinel() error { var firstErr error err := c.pubsub.Close() - if err != nil && firstErr == err { + if err != nil && firstErr == nil { firstErr = err } c.pubsub = nil err = c.sentinel.Close() - if err != nil && firstErr == err { + if err != nil && firstErr == nil { firstErr = err } c.sentinel = nil diff --git a/tx.go b/tx.go index f6b8bbda9..bcaebde46 100644 --- a/tx.go +++ b/tx.go @@ -40,10 +40,7 @@ func (c *Tx) init() { } func (c *Tx) Context() context.Context { - if c.ctx != nil { - return c.ctx - } - return context.Background() + return c.ctx } func (c *Tx) WithContext(ctx context.Context) *Tx { diff --git a/tx_test.go b/tx_test.go index c5cb4b3c4..dc3139c62 100644 --- a/tx_test.go +++ b/tx_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "strconv" "sync" @@ -124,7 +125,7 @@ var _ = Describe("Tx", func() { It("should recover from bad connection", func() { // Put bad connection in the pool. - cn, err := client.Pool().Get(nil) + cn, err := client.Pool().Get(context.Background()) Expect(err).NotTo(HaveOccurred()) cn.SetNetConn(&badConn{})