Skip to content

Commit

Permalink
Fix WithContext and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jul 4, 2019
1 parent 73d3c18 commit 2cbb519
Show file tree
Hide file tree
Showing 14 changed files with 114 additions and 90 deletions.
11 changes: 5 additions & 6 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
opt: opt,
nodes: newClusterNodes(opt),
},
ctx: context.Background(),
}
c.state = newClusterStateHolder(c.loadState)
c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)
Expand All @@ -690,10 +691,7 @@ func (c *ClusterClient) init() {
}

func (c *ClusterClient) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
return c.ctx
}

func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
Expand All @@ -702,6 +700,7 @@ func (c *ClusterClient) WithContext(ctx context.Context) *ClusterClient {
}
clone := *c
clone.ctx = ctx
clone.init()
return &clone
}

Expand Down Expand Up @@ -732,7 +731,7 @@ func (c *ClusterClient) Do(args ...interface{}) *Cmd {

func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
c.ProcessContext(ctx, cmd)
_ = c.ProcessContext(ctx, cmd)
return cmd
}

Expand Down Expand Up @@ -1035,7 +1034,7 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
}

func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
return c.hooks.processPipeline(ctx, cmds, c._processPipeline)
}

func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {
Expand Down
9 changes: 9 additions & 0 deletions cluster_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package redis_test

import (
"context"
"fmt"
"net"
"strconv"
Expand Down Expand Up @@ -241,6 +242,14 @@ var _ = Describe("ClusterClient", func() {
var client *redis.ClusterClient

assertClusterClient := func() {
It("supports WithContext", func() {
c, cancel := context.WithCancel(context.Background())
cancel()

err := client.WithContext(c).Ping().Err()
Expect(err).To(MatchError("context canceled"))
})

It("should GET/SET/DEL", func() {
err := client.Get("A").Err()
Expect(err).To(Equal(redis.Nil))
Expand Down
5 changes: 3 additions & 2 deletions internal/pool/bench_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pool_test

import (
"context"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -39,7 +40,7 @@ func BenchmarkPoolGetPut(b *testing.B) {

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(context.Background())
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -81,7 +82,7 @@ func BenchmarkPoolGetRemove(b *testing.B) {

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(context.Background())
if err != nil {
b.Fatal(err)
}
Expand Down
48 changes: 24 additions & 24 deletions internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,38 +250,38 @@ func (p *ConnPool) getTurn() {
}

func (p *ConnPool) waitTurn(ctx context.Context) error {
var done <-chan struct{}
if ctx != nil {
done = ctx.Done()
select {
case <-ctx.Done():
return ctx.Err()
default:
}

select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}:
return nil
default:
timer := timers.Get().(*time.Timer)
timer.Reset(p.opt.PoolTimeout)
}

select {
case <-done:
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return ctx.Err()
case p.queue <- struct{}{}:
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return nil
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return ErrPoolTimeout
timer := timers.Get().(*time.Timer)
timer.Reset(p.opt.PoolTimeout)

select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return ctx.Err()
case p.queue <- struct{}{}:
if !timer.Stop() {
<-timer.C
}
timers.Put(timer)
return nil
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return ErrPoolTimeout
}
}

Expand Down
30 changes: 18 additions & 12 deletions internal/pool/pool_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pool_test

import (
"context"
"sync"
"testing"
"time"
Expand All @@ -12,6 +13,7 @@ import (
)

var _ = Describe("ConnPool", func() {
c := context.Background()
var connPool *pool.ConnPool

BeforeEach(func() {
Expand All @@ -30,13 +32,13 @@ var _ = Describe("ConnPool", func() {

It("should unblock client when conn is removed", func() {
// Reserve one connection.
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())

// Reserve all other connections.
var cns []*pool.Conn
for i := 0; i < 9; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn)
}
Expand All @@ -47,7 +49,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover()

started <- true
_, err := connPool.Get(nil)
_, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
done <- true

Expand Down Expand Up @@ -80,6 +82,7 @@ var _ = Describe("ConnPool", func() {
})

var _ = Describe("MinIdleConns", func() {
c := context.Background()
const poolSize = 100
var minIdleConns int
var connPool *pool.ConnPool
Expand Down Expand Up @@ -110,7 +113,7 @@ var _ = Describe("MinIdleConns", func() {

BeforeEach(func() {
var err error
cn, err = connPool.Get(nil)
cn, err = connPool.Get(c)
Expect(err).NotTo(HaveOccurred())

Eventually(func() int {
Expand Down Expand Up @@ -145,7 +148,7 @@ var _ = Describe("MinIdleConns", func() {
perform(poolSize, func(_ int) {
defer GinkgoRecover()

cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
mu.Lock()
cns = append(cns, cn)
Expand All @@ -160,7 +163,7 @@ var _ = Describe("MinIdleConns", func() {
It("Get is blocked", func() {
done := make(chan struct{})
go func() {
connPool.Get(nil)
connPool.Get(c)
close(done)
}()

Expand Down Expand Up @@ -247,6 +250,8 @@ var _ = Describe("MinIdleConns", func() {
})

var _ = Describe("conns reaper", func() {
c := context.Background()

const idleTimeout = time.Minute
const maxAge = time.Hour

Expand Down Expand Up @@ -274,7 +279,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections
staleConns = nil
for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
switch typ {
case "idle":
Expand All @@ -288,7 +293,7 @@ var _ = Describe("conns reaper", func() {

// add fresh connections
for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn)
}
Expand Down Expand Up @@ -333,7 +338,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ {
var freeCns []*pool.Conn
for i := 0; i < 3; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn)
Expand All @@ -342,7 +347,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3))
Expect(connPool.IdleLen()).To(Equal(0))

cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
conns = append(conns, cn)
Expand Down Expand Up @@ -370,6 +375,7 @@ var _ = Describe("conns reaper", func() {
})

var _ = Describe("race", func() {
c := context.Background()
var connPool *pool.ConnPool
var C, N int

Expand All @@ -396,15 +402,15 @@ var _ = Describe("race", func() {

perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Put(cn)
}
}
}, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get(nil)
cn, err := connPool.Get(c)
Expect(err).NotTo(HaveOccurred())
if err == nil {
connPool.Remove(cn)
Expand Down
2 changes: 1 addition & 1 deletion pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (c *Pipeline) discard() error {
// Exec always returns list of commands and error of the first failed
// command if any.
func (c *Pipeline) Exec() ([]Cmder, error) {
return c.ExecContext(nil)
return c.ExecContext(context.Background())
}

func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {
Expand Down
3 changes: 2 additions & 1 deletion pool_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package redis_test

import (
"context"
"time"

"github.com/go-redis/redis"
Expand Down Expand Up @@ -81,7 +82,7 @@ var _ = Describe("pool", func() {
})

It("removes broken connections", func() {
cn, err := client.Pool().Get(nil)
cn, err := client.Pool().Get(context.Background())
Expect(err).NotTo(HaveOccurred())
cn.SetNetConn(&badConn{})
client.Pool().Put(cn)
Expand Down
11 changes: 3 additions & 8 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ type hooks struct {
hooks []Hook
}

func (hs *hooks) lazyCopy() {
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
}

func (hs *hooks) AddHook(hook Hook) {
hs.hooks = append(hs.hooks, hook)
}
Expand Down Expand Up @@ -475,6 +471,7 @@ func NewClient(opt *Options) *Client {
connPool: newConnPool(opt),
},
},
ctx: context.Background(),
}
c.init()

Expand All @@ -486,10 +483,7 @@ func (c *Client) init() {
}

func (c *Client) Context() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.Background()
return c.ctx
}

func (c *Client) WithContext(ctx context.Context) *Client {
Expand All @@ -498,6 +492,7 @@ func (c *Client) WithContext(ctx context.Context) *Client {
}
clone := *c
clone.ctx = ctx
clone.init()
return &clone
}

Expand Down
Loading

0 comments on commit 2cbb519

Please sign in to comment.