Skip to content

Commit

Permalink
Allow passing context where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jun 4, 2019
1 parent 3da4357 commit 09eb108
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 51 deletions.
28 changes: 18 additions & 10 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -724,16 +724,24 @@ func (c *ClusterClient) Close() error {

// Do creates a Cmd from the args and processes the cmd.
func (c *ClusterClient) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}

func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
c.Process(cmd)
c.ProcessContext(ctx, cmd)
return cmd
}

func (c *ClusterClient) Process(cmd Cmder) error {
return c.hooks.process(c.ctx, cmd, c.process)
return c.ProcessContext(c.ctx, cmd)
}

func (c *ClusterClient) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process)
}

func (c *ClusterClient) process(cmd Cmder) error {
func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error {
var node *clusterNode
var ask bool
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
Expand All @@ -755,11 +763,11 @@ func (c *ClusterClient) process(cmd Cmder) error {
pipe := node.Client.Pipeline()
_ = pipe.Process(NewCmd("ASKING"))
_ = pipe.Process(cmd)
_, err = pipe.Exec()
_, err = pipe.ExecContext(ctx)
_ = pipe.Close()
ask = false
} else {
err = node.Client.Process(cmd)
err = node.Client.ProcessContext(ctx, cmd)
}

// If there is no error - we are done.
Expand Down Expand Up @@ -1022,11 +1030,11 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().Pipelined(fn)
}

func (c *ClusterClient) processPipeline(cmds []Cmder) error {
func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
}

func (c *ClusterClient) _processPipeline(cmds []Cmder) error {
func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {
cmdsMap := newCmdsMap()
err := c.mapCmdsByNode(cmds, cmdsMap)
if err != nil {
Expand Down Expand Up @@ -1216,11 +1224,11 @@ func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().Pipelined(fn)
}

func (c *ClusterClient) processTxPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processTxPipeline)
func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c._processTxPipeline)
}

func (c *ClusterClient) _processTxPipeline(cmds []Cmder) error {
func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error {
state, err := c.state.Get()
if err != nil {
return err
Expand Down
4 changes: 3 additions & 1 deletion iterator.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package redis

import "sync"
import (
"sync"
)

// ScanIterator is used to incrementally iterate over a collection of elements.
// It's safe for concurrent use by multiple goroutines.
Expand Down
6 changes: 3 additions & 3 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ import (

// Limiter is the interface of a rate limiter or a circuit breaker.
type Limiter interface {
// Allow returns a nil if operation is allowed or an error otherwise.
// If operation is allowed client must report the result of operation
// whether is a success or a failure.
// Allow returns nil if operation is allowed or an error otherwise.
// If operation is allowed client must ReportResult of the operation
// whether it is a success or a failure.
Allow() error
// ReportResult reports the result of previously allowed operation.
// nil indicates a success, non-nil error indicates a failure.
Expand Down
18 changes: 10 additions & 8 deletions pipeline.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package redis

import (
"context"
"sync"

"github.com/go-redis/redis/internal/pool"
)

type pipelineExecer func([]Cmder) error
type pipelineExecer func(context.Context, []Cmder) error

// Pipeliner is an mechanism to realise Redis Pipeline technique.
//
Expand All @@ -28,6 +29,7 @@ type Pipeliner interface {
Close() error
Discard() error
Exec() ([]Cmder, error)
ExecContext(ctx context.Context) ([]Cmder, error)
}

var _ Pipeliner = (*Pipeline)(nil)
Expand Down Expand Up @@ -96,6 +98,10 @@ func (c *Pipeline) discard() error {
// Exec always returns list of commands and error of the first failed
// command if any.
func (c *Pipeline) Exec() ([]Cmder, error) {
return c.ExecContext(nil)
}

func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {
c.mu.Lock()
defer c.mu.Unlock()

Expand All @@ -110,10 +116,10 @@ func (c *Pipeline) Exec() ([]Cmder, error) {
cmds := c.cmds
c.cmds = nil

return cmds, c.exec(cmds)
return cmds, c.exec(ctx, cmds)
}

func (c *Pipeline) pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
if err := fn(c); err != nil {
return nil, err
}
Expand All @@ -122,16 +128,12 @@ func (c *Pipeline) pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return cmds, err
}

func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.pipelined(fn)
}

func (c *Pipeline) Pipeline() Pipeliner {
return c
}

func (c *Pipeline) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.pipelined(fn)
return c.Pipelined(fn)
}

func (c *Pipeline) TxPipeline() Pipeliner {
Expand Down
56 changes: 36 additions & 20 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ func (hs *hooks) AddHook(hook Hook) {
hs.hooks = append(hs.hooks, hook)
}

func (hs hooks) process(ctx context.Context, cmd Cmder, fn func(Cmder) error) error {
func (hs hooks) process(
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
) error {
ctx, err := hs.beforeProcess(ctx, cmd)
if err != nil {
return err
}

cmdErr := fn(cmd)
cmdErr := fn(ctx, cmd)

_, err = hs.afterProcess(ctx, cmd)
if err != nil {
Expand Down Expand Up @@ -83,13 +85,15 @@ func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) (context.Context, e
return ctx, nil
}

func (hs hooks) processPipeline(ctx context.Context, cmds []Cmder, fn func([]Cmder) error) error {
func (hs hooks) processPipeline(
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
) error {
ctx, err := hs.beforeProcessPipeline(ctx, cmds)
if err != nil {
return err
}

cmdsErr := fn(cmds)
cmdsErr := fn(ctx, cmds)

_, err = hs.afterProcessPipeline(ctx, cmds)
if err != nil {
Expand Down Expand Up @@ -246,14 +250,7 @@ func (c *baseClient) initConn(cn *pool.Conn) error {
return nil
}

// Do creates a Cmd from the args and processes the cmd.
func (c *baseClient) Do(args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.process(cmd)
return cmd
}

func (c *baseClient) process(cmd Cmder) error {
func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 {
time.Sleep(c.retryBackoff(attempt))
Expand Down Expand Up @@ -328,11 +325,11 @@ func (c *baseClient) getAddr() string {
return c.opt.Addr
}

func (c *baseClient) processPipeline(cmds []Cmder) error {
func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.pipelineProcessCmds)
}

func (c *baseClient) processTxPipeline(cmds []Cmder) error {
func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds)
}

Expand Down Expand Up @@ -503,16 +500,31 @@ func (c *Client) WithContext(ctx context.Context) *Client {
return &clone
}

// Do creates a Cmd from the args and processes the cmd.
func (c *Client) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}

func (c *Client) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
_ = c.ProcessContext(ctx, cmd)
return cmd
}

func (c *Client) Process(cmd Cmder) error {
return c.hooks.process(c.ctx, cmd, c.baseClient.process)
return c.ProcessContext(c.ctx, cmd)
}

func (c *Client) processPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processPipeline)
func (c *Client) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.baseClient.process)
}

func (c *Client) processTxPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processTxPipeline)
func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
}

func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline)
}

// Options returns read-only Options that were used to create the client.
Expand Down Expand Up @@ -637,7 +649,11 @@ func newConn(opt *Options, cn *pool.Conn) *Conn {
}

func (c *Conn) Process(cmd Cmder) error {
return c.baseClient.process(cmd)
return c.ProcessContext(context.TODO(), cmd)
}

func (c *Conn) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd)
}

func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
Expand Down
22 changes: 15 additions & 7 deletions ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,21 @@ func (c *Ring) WithContext(ctx context.Context) *Ring {

// Do creates a Cmd from the args and processes the cmd.
func (c *Ring) Do(args ...interface{}) *Cmd {
return c.DoContext(c.ctx, args...)
}

func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd {
cmd := NewCmd(args...)
c.Process(cmd)
c.ProcessContext(ctx, cmd)
return cmd
}

func (c *Ring) Process(cmd Cmder) error {
return c.hooks.process(c.ctx, cmd, c.process)
return c.ProcessContext(c.ctx, cmd)
}

func (c *Ring) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.hooks.process(ctx, cmd, c.process)
}

// Options returns read-only Options that were used to create the client.
Expand Down Expand Up @@ -532,7 +540,7 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
return c.shards.GetByKey(firstKey)
}

func (c *Ring) process(cmd Cmder) error {
func (c *Ring) process(ctx context.Context, cmd Cmder) error {
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
if attempt > 0 {
time.Sleep(c.retryBackoff(attempt))
Expand All @@ -544,7 +552,7 @@ func (c *Ring) process(cmd Cmder) error {
return err
}

err = shard.Client.Process(cmd)
err = shard.Client.ProcessContext(ctx, cmd)
if err == nil {
return nil
}
Expand All @@ -567,11 +575,11 @@ func (c *Ring) Pipeline() Pipeliner {
return &pipe
}

func (c *Ring) processPipeline(cmds []Cmder) error {
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error {
return c.hooks.processPipeline(ctx, cmds, c._processPipeline)
}

func (c *Ring) _processPipeline(cmds []Cmder) error {
func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error {
cmdsMap := make(map[string][]Cmder)
for _, cmd := range cmds {
cmdInfo := c.cmdInfo(cmd.Name())
Expand Down
6 changes: 5 additions & 1 deletion sentinel.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
}

func (c *SentinelClient) Process(cmd Cmder) error {
return c.baseClient.process(cmd)
return c.ProcessContext(c.ctx, cmd)
}

func (c *SentinelClient) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd)
}

func (c *SentinelClient) pubSub() *PubSub {
Expand Down
6 changes: 5 additions & 1 deletion tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ func (c *Tx) WithContext(ctx context.Context) *Tx {
}

func (c *Tx) Process(cmd Cmder) error {
return c.baseClient.process(cmd)
return c.ProcessContext(c.ctx, cmd)
}

func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error {
return c.baseClient.process(ctx, cmd)
}

// Watch prepares a transaction and marks the keys to be watched
Expand Down
3 changes: 3 additions & 0 deletions universal.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ type UniversalClient interface {
Context() context.Context
AddHook(Hook)
Watch(fn func(*Tx) error, keys ...string) error
Do(args ...interface{}) *Cmd
DoContext(ctx context.Context, args ...interface{}) *Cmd
Process(cmd Cmder) error
ProcessContext(ctx context.Context, cmd Cmder) error
Subscribe(channels ...string) *PubSub
PSubscribe(channels ...string) *PubSub
Close() error
Expand Down

0 comments on commit 09eb108

Please sign in to comment.