diff --git a/cluster.go b/cluster.go index 223d2522c..43cbf6434 100644 --- a/cluster.go +++ b/cluster.go @@ -704,7 +704,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error { } err = c.pipelineProcessCmds(cn, cmds, failedCmds) - node.Client.putConn(cn, err, false) + node.Client.putConn(cn, err) } if len(failedCmds) == 0 { @@ -840,7 +840,7 @@ func (c *ClusterClient) txPipelineExec(cmds []Cmder) error { } err = c.txPipelineProcessCmds(node, cn, cmds, failedCmds) - node.Client.putConn(cn, err, false) + node.Client.putConn(cn, err) } if len(failedCmds) == 0 { diff --git a/pubsub.go b/pubsub.go index e47978c26..3680323a0 100644 --- a/pubsub.go +++ b/pubsub.go @@ -20,87 +20,112 @@ type PubSub struct { cn *pool.Conn closed bool - cmd *Cmd - subMu sync.Mutex channels []string patterns []string + + cmd *Cmd } -func (c *PubSub) conn() (*pool.Conn, error) { - cn, isNew, err := c._conn() +func (c *PubSub) conn() (*pool.Conn, bool, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return nil, false, pool.ErrClosed + } + + if c.cn != nil { + return c.cn, false, nil + } + + cn, err := c.base.connPool.NewConn() if err != nil { - return nil, err + return nil, false, err } - if isNew { - if err := c.resubscribe(); err != nil { - internal.Logf("resubscribe failed: %s", err) + if !cn.Inited { + if err := c.base.initConn(cn); err != nil { + _ = c.base.connPool.Remove(cn) + return nil, false, err } } - return cn, nil + if err := c.resubscribe(cn); err != nil { + return nil, false, err + } + + c.cn = cn + return cn, true, nil } -func (c *PubSub) resubscribe() error { +func (c *PubSub) resubscribe(cn *pool.Conn) error { c.subMu.Lock() - channels := c.channels - patterns := c.patterns - c.subMu.Unlock() + defer c.subMu.Unlock() var firstErr error - if len(channels) > 0 { - if err := c.subscribe("subscribe", channels...); err != nil && firstErr == nil { + if len(c.channels) > 0 { + if err := c._subscribe(cn, "subscribe", c.channels...); err != nil && firstErr == nil { firstErr = err } } - if len(patterns) > 0 { - if err := c.subscribe("psubscribe", patterns...); err != nil && firstErr == nil { + if len(c.patterns) > 0 { + if err := c._subscribe(cn, "psubscribe", c.patterns...); err != nil && firstErr == nil { firstErr = err } } return firstErr } -func (c *PubSub) _conn() (*pool.Conn, bool, error) { +func (c *PubSub) putConn(cn *pool.Conn, err error) { + if !internal.IsBadConn(err, true) { + return + } + + c.mu.Lock() + if c.cn == cn { + _ = c.closeConn() + } + c.mu.Unlock() +} + +func (c *PubSub) closeConn() error { + err := c.base.connPool.CloseConn(c.cn) + c.cn = nil + return err +} + +func (c *PubSub) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.closed { - return nil, false, pool.ErrClosed + return pool.ErrClosed } + c.closed = true if c.cn != nil { - return c.cn, false, nil + return c.closeConn() } + return nil +} - cn, err := c.base.connPool.NewConn() +func (c *PubSub) subscribe(redisCmd string, channels ...string) error { + cn, isNew, err := c.conn() if err != nil { - return nil, false, err + return err } - if !cn.Inited { - if err := c.base.initConn(cn); err != nil { - _ = c.base.connPool.Remove(cn) - return nil, false, err - } + if isNew { + return nil } - c.cn = cn - return cn, true, nil -} - -func (c *PubSub) putConn(cn *pool.Conn, err error) { - if internal.IsBadConn(err, true) { - c.mu.Lock() - if c.cn == cn { - _ = c.closeConn() - } - c.mu.Unlock() - } + err = c._subscribe(cn, redisCmd, channels...) + c.putConn(cn, err) + return err } -func (c *PubSub) subscribe(redisCmd string, channels ...string) error { +func (c *PubSub) _subscribe(cn *pool.Conn, redisCmd string, channels ...string) error { args := make([]interface{}, 1+len(channels)) args[0] = redisCmd for i, channel := range channels { @@ -108,19 +133,8 @@ func (c *PubSub) subscribe(redisCmd string, channels ...string) error { } cmd := NewSliceCmd(args...) - cn, isNew, err := c._conn() - if err != nil { - return err - } - - if isNew { - return c.resubscribe() - } - cn.SetWriteTimeout(c.base.opt.WriteTimeout) - err = writeCmd(cn, cmd) - c.putConn(cn, err) - return err + return writeCmd(cn, cmd) } // Subscribes the client to the specified channels. @@ -157,28 +171,6 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error { return c.subscribe("punsubscribe", patterns...) } -func (c *PubSub) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.closed { - return pool.ErrClosed - } - c.closed = true - - if c.cn != nil { - _ = c.closeConn() - } - - return nil -} - -func (c *PubSub) closeConn() error { - err := c.base.connPool.CloseConn(c.cn) - c.cn = nil - return err -} - func (c *PubSub) Ping(payload ...string) error { args := []interface{}{"ping"} if len(payload) == 1 { @@ -186,7 +178,7 @@ func (c *PubSub) Ping(payload ...string) error { } cmd := NewCmd(args...) - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { return err } @@ -279,7 +271,7 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) { c.cmd = NewCmd() } - cn, err := c.conn() + cn, _, err := c.conn() if err != nil { return nil, err } diff --git a/redis.go b/redis.go index ecf1fc0af..b71b9fc60 100644 --- a/redis.go +++ b/redis.go @@ -42,8 +42,8 @@ func (c *baseClient) conn() (*pool.Conn, bool, error) { return cn, isNew, nil } -func (c *baseClient) putConn(cn *pool.Conn, err error, allowTimeout bool) bool { - if internal.IsBadConn(err, allowTimeout) { +func (c *baseClient) putConn(cn *pool.Conn, err error) bool { + if internal.IsBadConn(err, false) { _ = c.connPool.Remove(cn) return false } @@ -104,7 +104,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn.SetWriteTimeout(c.opt.WriteTimeout) if err := writeCmd(cn, cmd); err != nil { - c.putConn(cn, err, false) + c.putConn(cn, err) cmd.setErr(err) if err != nil && internal.IsRetryableError(err) { continue @@ -114,7 +114,7 @@ func (c *baseClient) defaultProcess(cmd Cmder) error { cn.SetReadTimeout(c.cmdTimeout(cmd)) err = cmd.readReply(cn) - c.putConn(cn, err, false) + c.putConn(cn, err) if err != nil && internal.IsRetryableError(err) { continue } @@ -167,7 +167,7 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer { } canRetry, err := p(cn, cmds) - c.putConn(cn, err, false) + c.putConn(cn, err) if err == nil { return nil } diff --git a/ring.go b/ring.go index 69715b6b7..d13a33b70 100644 --- a/ring.go +++ b/ring.go @@ -428,7 +428,7 @@ func (c *Ring) pipelineExec(cmds []Cmder) (firstErr error) { } canRetry, err := shard.Client.pipelineProcessCmds(cn, cmds) - shard.Client.putConn(cn, err, false) + shard.Client.putConn(cn, err) if err == nil { continue }