Skip to content

Commit

Permalink
Simplify connection management with sticky connection pool. Fixes red…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Mar 1, 2016
1 parent 0382d1e commit 110e93a
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 89 deletions.
10 changes: 5 additions & 5 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ func isNetworkError(err error) bool {
return ok
}

func isBadConn(cn *conn, ei error) bool {
if cn.rd.Buffered() > 0 {
return true
func isBadConn(err error) bool {
if err == nil {
return false
}
if ei == nil {
if _, ok := err.(redisError); ok {
return false
}
if _, ok := ei.(redisError); ok {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return false
}
return true
Expand Down
4 changes: 4 additions & 0 deletions export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ func (c *baseClient) Pool() pool {
return c.connPool
}

func (c *PubSub) Pool() pool {
return c.base.connPool
}

var NewConnDialer = newConnDialer

func (cn *conn) SetNetConn(netcn net.Conn) {
Expand Down
5 changes: 0 additions & 5 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"os/exec"
"path/filepath"
"sync/atomic"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -243,10 +242,6 @@ func startSentinel(port, masterName, masterPort string) (*redisProcess, error) {

//------------------------------------------------------------------------------

var (
errTimeout = syscall.ETIMEDOUT
)

type badConnError string

func (e badConnError) Error() string { return string(e) }
Expand Down
14 changes: 1 addition & 13 deletions multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,6 @@ func (c *Client) Multi() *Multi {
return multi
}

func (c *Multi) putConn(cn *conn, err error) {
if isBadConn(cn, err) {
// Close current connection.
c.base.connPool.(*stickyConnPool).Reset(err)
} else {
err := c.base.connPool.Put(cn)
if err != nil {
Logger.Printf("pool.Put failed: %s", err)
}
}
}

func (c *Multi) process(cmd Cmder) {
if c.cmds == nil {
c.base.process(cmd)
Expand Down Expand Up @@ -145,7 +133,7 @@ func (c *Multi) Exec(f func() error) ([]Cmder, error) {
}

err = c.execCmds(cn, cmds)
c.putConn(cn, err)
c.base.putConn(cn, err)
return retCmds, err
}

Expand Down
27 changes: 27 additions & 0 deletions multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,31 @@ var _ = Describe("Multi", func() {
})
Expect(err).NotTo(HaveOccurred())
})

It("should recover from bad connection when there are no commands", func() {
// Put bad connection in the pool.
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())

cn.SetNetConn(&badConn{})
err = client.Pool().Put(cn)
Expect(err).NotTo(HaveOccurred())

{
tx, err := client.Watch("key")
Expect(err).To(MatchError("bad connection"))
Expect(tx).To(BeNil())
}

{
tx, err := client.Watch("key")
Expect(err).NotTo(HaveOccurred())

err = tx.Ping().Err()
Expect(err).NotTo(HaveOccurred())

err = tx.Close()
Expect(err).NotTo(HaveOccurred())
}
})
})
20 changes: 6 additions & 14 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,14 @@ func (p *connPool) Get() (cn *conn, isNew bool, err error) {

// Try to create a new one.
if p.conns.Reserve() {
isNew = true

cn, err = p.new()
if err != nil {
p.conns.Remove(nil)
return
}
p.conns.Add(cn)
isNew = true
return
}

Expand Down Expand Up @@ -481,13 +482,13 @@ func (p *stickyConnPool) Put(cn *conn) error {
return nil
}

func (p *stickyConnPool) remove(reason error) (err error) {
err = p.pool.Remove(p.cn, reason)
func (p *stickyConnPool) remove(reason error) error {
err := p.pool.Remove(p.cn, reason)
p.cn = nil
return err
}

func (p *stickyConnPool) Remove(cn *conn, _ error) error {
func (p *stickyConnPool) Remove(cn *conn, reason error) error {
defer p.mx.Unlock()
p.mx.Lock()
if p.closed {
Expand All @@ -499,7 +500,7 @@ func (p *stickyConnPool) Remove(cn *conn, _ error) error {
if cn != nil && p.cn != cn {
panic("p.cn != cn")
}
return nil
return p.remove(reason)
}

func (p *stickyConnPool) Len() int {
Expand All @@ -522,15 +523,6 @@ func (p *stickyConnPool) FreeLen() int {

func (p *stickyConnPool) Stats() *PoolStats { return nil }

func (p *stickyConnPool) Reset(reason error) (err error) {
p.mx.Lock()
if p.cn != nil {
err = p.remove(reason)
}
p.mx.Unlock()
return err
}

func (p *stickyConnPool) Close() error {
defer p.mx.Unlock()
p.mx.Lock()
Expand Down
76 changes: 45 additions & 31 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ func (c *Client) Publish(channel, message string) *IntCmd {
// http://redis.io/topics/pubsub. It's NOT safe for concurrent use by
// multiple goroutines.
type PubSub struct {
*baseClient
base *baseClient

channels []string
patterns []string

nsub int // number of active subscriptions
}

// Deprecated. Use Subscribe/PSubscribe instead.
func (c *Client) PubSub() *PubSub {
return &PubSub{
baseClient: &baseClient{
base: &baseClient{
opt: c.opt,
connPool: newStickyConnPool(c.connPool, false),
},
Expand All @@ -46,7 +48,7 @@ func (c *Client) PSubscribe(channels ...string) (*PubSub, error) {
}

func (c *PubSub) subscribe(cmd string, channels ...string) error {
cn, _, err := c.conn()
cn, _, err := c.base.conn()
if err != nil {
return err
}
Expand All @@ -65,6 +67,7 @@ func (c *PubSub) Subscribe(channels ...string) error {
err := c.subscribe("SUBSCRIBE", channels...)
if err == nil {
c.channels = append(c.channels, channels...)
c.nsub += len(channels)
}
return err
}
Expand All @@ -74,6 +77,7 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
err := c.subscribe("PSUBSCRIBE", patterns...)
if err == nil {
c.patterns = append(c.patterns, patterns...)
c.nsub += len(patterns)
}
return err
}
Expand Down Expand Up @@ -113,8 +117,12 @@ func (c *PubSub) PUnsubscribe(patterns ...string) error {
return err
}

func (c *PubSub) Close() error {
return c.base.Close()
}

func (c *PubSub) Ping(payload string) error {
cn, _, err := c.conn()
cn, _, err := c.base.conn()
if err != nil {
return err
}
Expand Down Expand Up @@ -178,7 +186,7 @@ func (p *Pong) String() string {
return "Pong"
}

func newMessage(reply []interface{}) (interface{}, error) {
func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) {
switch kind := reply[0].(string); kind {
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
return &Subscription{
Expand Down Expand Up @@ -210,7 +218,11 @@ func newMessage(reply []interface{}) (interface{}, error) {
// is not received in time. This is low-level API and most clients
// should use ReceiveMessage.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
cn, _, err := c.conn()
if c.nsub == 0 {
c.resubscribe()
}

cn, _, err := c.base.conn()
if err != nil {
return nil, err
}
Expand All @@ -222,7 +234,8 @@ func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
if err != nil {
return nil, err
}
return newMessage(cmd.Val())

return c.newMessage(cmd.Val())
}

// Receive returns a message as a Subscription, Message, PMessage,
Expand All @@ -232,22 +245,6 @@ func (c *PubSub) Receive() (interface{}, error) {
return c.ReceiveTimeout(0)
}

func (c *PubSub) reconnect(reason error) {
// Close current connection.
c.connPool.(*stickyConnPool).Reset(reason)

if len(c.channels) > 0 {
if err := c.Subscribe(c.channels...); err != nil {
Logger.Printf("Subscribe failed: %s", err)
}
}
if len(c.patterns) > 0 {
if err := c.PSubscribe(c.patterns...); err != nil {
Logger.Printf("PSubscribe failed: %s", err)
}
}
}

// ReceiveMessage returns a message or error. It automatically
// reconnects to Redis in case of network errors.
func (c *PubSub) ReceiveMessage() (*Message, error) {
Expand All @@ -259,27 +256,25 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
return nil, err
}

goodConn := errNum == 0
errNum++

if goodConn {
if errNum < 3 {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
err := c.Ping("")
if err == nil {
continue
}
Logger.Printf("PubSub.Ping failed: %s", err)
}
}

if errNum > 2 {
} else {
// 3 consequent errors - connection is bad
// and/or Redis Server is down.
// Sleep to not exceed max number of open connections.
time.Sleep(time.Second)
}
c.reconnect(err)
continue
}

// Reset error number.
// Reset error number, because we received a message.
errNum = 0

switch msg := msgi.(type) {
Expand All @@ -300,3 +295,22 @@ func (c *PubSub) ReceiveMessage() (*Message, error) {
}
}
}

func (c *PubSub) putConn(cn *conn, err error) {
if !c.base.putConn(cn, err) {
c.nsub = 0
}
}

func (c *PubSub) resubscribe() {
if len(c.channels) > 0 {
if err := c.Subscribe(c.channels...); err != nil {
Logger.Printf("Subscribe failed: %s", err)
}
}
if len(c.patterns) > 0 {
if err := c.PSubscribe(c.patterns...); err != nil {
Logger.Printf("PSubscribe failed: %s", err)
}
}
}
Loading

0 comments on commit 110e93a

Please sign in to comment.