Skip to content

Commit

Permalink
Use Context.Deadline to set net.Conn deadline
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jun 8, 2019
1 parent 4fe609d commit 5460bc1
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 46 deletions.
16 changes: 8 additions & 8 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) erro
return
}

err = c.pipelineProcessCmds(node, cn, cmds, failedCmds)
err = c.pipelineProcessCmds(ctx, node, cn, cmds, failedCmds)
node.Client.releaseConnStrict(cn, err)
}(node, cmds)
}
Expand Down Expand Up @@ -1129,9 +1129,9 @@ func (c *ClusterClient) cmdsAreReadOnly(cmds []Cmder) bool {
}

func (c *ClusterClient) pipelineProcessCmds(
node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmds...)
})
if err != nil {
Expand All @@ -1142,7 +1142,7 @@ func (c *ClusterClient) pipelineProcessCmds(
return err
}

return cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(node, rd, cmds, failedCmds)
})
}
Expand Down Expand Up @@ -1266,7 +1266,7 @@ func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) er
return
}

err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds)
err = c.txPipelineProcessCmds(ctx, node, cn, cmds, failedCmds)
node.Client.releaseConnStrict(cn, err)
}(node, cmds)
}
Expand All @@ -1292,9 +1292,9 @@ func (c *ClusterClient) mapCmdsBySlot(cmds []Cmder) map[int][]Cmder {
}

func (c *ClusterClient) txPipelineProcessCmds(
node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds)
})
if err != nil {
Expand All @@ -1305,7 +1305,7 @@ func (c *ClusterClient) txPipelineProcessCmds(
return err
}

err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := c.txPipelineReadQueued(rd, cmds, failedCmds)
if err != nil {
setCmdsErr(cmds, err)
Expand Down
50 changes: 28 additions & 22 deletions internal/pool/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pool

import (
"context"
"net"
"sync/atomic"
"time"
Expand Down Expand Up @@ -48,24 +49,6 @@ func (cn *Conn) SetNetConn(netConn net.Conn) {
cn.wr.Reset(netConn)
}

func (cn *Conn) setReadTimeout(timeout time.Duration) error {
now := time.Now()
cn.SetUsedAt(now)
if timeout > 0 {
return cn.netConn.SetReadDeadline(now.Add(timeout))
}
return cn.netConn.SetReadDeadline(noDeadline)
}

func (cn *Conn) setWriteTimeout(timeout time.Duration) error {
now := time.Now()
cn.SetUsedAt(now)
if timeout > 0 {
return cn.netConn.SetWriteDeadline(now.Add(timeout))
}
return cn.netConn.SetWriteDeadline(noDeadline)
}

func (cn *Conn) Write(b []byte) (int, error) {
return cn.netConn.Write(b)
}
Expand All @@ -74,13 +57,17 @@ func (cn *Conn) RemoteAddr() net.Addr {
return cn.netConn.RemoteAddr()
}

func (cn *Conn) WithReader(timeout time.Duration, fn func(rd *proto.Reader) error) error {
_ = cn.setReadTimeout(timeout)
func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error {
tm := cn.deadline(ctx, timeout)
_ = cn.netConn.SetReadDeadline(tm)
return fn(cn.rd)
}

func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) error) error {
_ = cn.setWriteTimeout(timeout)
func (cn *Conn) WithWriter(
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
) error {
tm := cn.deadline(ctx, timeout)
_ = cn.netConn.SetWriteDeadline(tm)

firstErr := fn(cn.wr)
err := cn.wr.Flush()
Expand All @@ -93,3 +80,22 @@ func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) erro
func (cn *Conn) Close() error {
return cn.netConn.Close()
}

func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
if ctx != nil {
tm, ok := ctx.Deadline()
if ok {
cn.SetUsedAt(tm)
return tm
}
}

now := time.Now()
if timeout > 0 {
cn.SetUsedAt(now)
return now.Add(timeout)
}

cn.SetUsedAt(now)
return noDeadline
}
11 changes: 6 additions & 5 deletions pubsub.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package redis

import (
"context"
"errors"
"fmt"
"strings"
Expand Down Expand Up @@ -83,8 +84,8 @@ func (c *PubSub) _conn(newChannels []string) (*pool.Conn, error) {
return cn, nil
}

func (c *PubSub) writeCmd(cn *pool.Conn, cmd Cmder) error {
return cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
})
}
Expand Down Expand Up @@ -128,7 +129,7 @@ func (c *PubSub) _subscribe(
args = append(args, channel)
}
cmd := NewSliceCmd(args...)
return c.writeCmd(cn, cmd)
return c.writeCmd(context.TODO(), cn, cmd)
}

func (c *PubSub) releaseConn(cn *pool.Conn, err error, allowTimeout bool) {
Expand Down Expand Up @@ -258,7 +259,7 @@ func (c *PubSub) Ping(payload ...string) error {
return err
}

err = c.writeCmd(cn, cmd)
err = c.writeCmd(context.TODO(), cn, cmd)
c.releaseConn(cn, err, false)
return err
}
Expand Down Expand Up @@ -350,7 +351,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
return nil, err
}

err = cn.WithReader(timeout, func(rd *proto.Reader) error {
err = cn.WithReader(context.TODO(), timeout, func(rd *proto.Reader) error {
return c.cmd.readReply(rd)
})

Expand Down
24 changes: 14 additions & 10 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
return err
}

err = cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
err = cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
})
if err != nil {
Expand All @@ -277,7 +277,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
return err
}

err = cn.WithReader(c.cmdTimeout(cmd), cmd.readReply)
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
c.releaseConn(cn, err)
if err != nil && internal.IsRetryableError(err, cmd.readTimeout() == nil) {
continue
Expand Down Expand Up @@ -333,7 +333,7 @@ func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error
return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds)
}

type pipelineProcessor func(*pool.Conn, []Cmder) (bool, error)
type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error)

func (c *baseClient) generalProcessPipeline(
ctx context.Context, cmds []Cmder, p pipelineProcessor,
Expand All @@ -349,7 +349,7 @@ func (c *baseClient) generalProcessPipeline(
return err
}

canRetry, err := p(cn, cmds)
canRetry, err := p(ctx, cn, cmds)
c.releaseConnStrict(cn, err)

if !canRetry || !internal.IsRetryableError(err, true) {
Expand All @@ -359,16 +359,18 @@ func (c *baseClient) generalProcessPipeline(
return cmdsFirstErr(cmds)
}

func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
func (c *baseClient) pipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmds...)
})
if err != nil {
setCmdsErr(cmds, err)
return true, err
}

err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return pipelineReadCmds(rd, cmds)
})
return true, err
Expand All @@ -384,16 +386,18 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
return nil
}

func (c *baseClient) txPipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
err := cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
func (c *baseClient) txPipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return txPipelineWriteMulti(wr, cmds)
})
if err != nil {
setCmdsErr(cmds, err)
return true, err
}

err = cn.WithReader(c.opt.ReadTimeout, func(rd *proto.Reader) error {
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
err := txPipelineReadQueued(rd, cmds)
if err != nil {
setCmdsErr(cmds, err)
Expand Down
2 changes: 1 addition & 1 deletion ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error {
return
}

canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds)
canRetry, err := shard.Client.pipelineProcessCmds(ctx, cn, cmds)
shard.Client.releaseConnStrict(cn, err)

if canRetry && internal.IsRetryableError(err, true) {
Expand Down

0 comments on commit 5460bc1

Please sign in to comment.