Skip to content

Commit

Permalink
Simplify resubscribing in PubSub.
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Sep 29, 2016
1 parent 833b0c6 commit e57ac63
Show file tree
Hide file tree
Showing 14 changed files with 90 additions and 93 deletions.
2 changes: 1 addition & 1 deletion cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ func (c *ClusterClient) pipelineExec(cmds []Cmder) error {
}
}

cn, err := node.Client.conn()
cn, _, err := node.Client.conn()
if err != nil {
setCmdsErr(cmds, err)
setRetErr(err)
Expand Down
4 changes: 2 additions & 2 deletions internal/pool/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func benchmarkPoolGetPut(b *testing.B, poolSize int) {

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -48,7 +48,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) {

b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
if err != nil {
b.Fatal(err)
}
Expand Down
14 changes: 7 additions & 7 deletions internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type PoolStats struct {
}

type Pooler interface {
Get() (*Conn, error)
Get() (*Conn, bool, error)
Put(*Conn) error
Remove(*Conn, error) error
Len() int
Expand Down Expand Up @@ -152,9 +152,9 @@ func (p *ConnPool) popFree() *Conn {
}

// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get() (*Conn, error) {
func (p *ConnPool) Get() (*Conn, bool, error) {
if p.Closed() {
return nil, ErrClosed
return nil, false, ErrClosed
}

atomic.AddUint32(&p.stats.Requests, 1)
Expand All @@ -170,7 +170,7 @@ func (p *ConnPool) Get() (*Conn, error) {
case <-timer.C:
timers.Put(timer)
atomic.AddUint32(&p.stats.Timeouts, 1)
return nil, ErrPoolTimeout
return nil, false, ErrPoolTimeout
}

p.freeConnsMu.Lock()
Expand All @@ -180,15 +180,15 @@ func (p *ConnPool) Get() (*Conn, error) {
if cn != nil {
atomic.AddUint32(&p.stats.Hits, 1)
if !cn.IsStale(p.idleTimeout) {
return cn, nil
return cn, false, nil
}
_ = p.closeConn(cn, errConnStale)
}

newcn, err := p.NewConn()
if err != nil {
<-p.queue
return nil, err
return nil, false, err
}

p.connsMu.Lock()
Expand All @@ -198,7 +198,7 @@ func (p *ConnPool) Get() (*Conn, error) {
p.conns = append(p.conns, newcn)
p.connsMu.Unlock()

return newcn, nil
return newcn, true, nil
}

func (p *ConnPool) Put(cn *Conn) error {
Expand Down
4 changes: 2 additions & 2 deletions internal/pool/pool_single.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ func (p *SingleConnPool) First() *Conn {
return p.cn
}

func (p *SingleConnPool) Get() (*Conn, error) {
return p.cn, nil
func (p *SingleConnPool) Get() (*Conn, bool, error) {
return p.cn, false, nil
}

func (p *SingleConnPool) Put(cn *Conn) error {
Expand Down
12 changes: 6 additions & 6 deletions internal/pool/pool_sticky.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ func (p *StickyConnPool) First() *Conn {
return cn
}

func (p *StickyConnPool) Get() (*Conn, error) {
func (p *StickyConnPool) Get() (*Conn, bool, error) {
defer p.mx.Unlock()
p.mx.Lock()

if p.closed {
return nil, ErrClosed
return nil, false, ErrClosed
}
if p.cn != nil {
return p.cn, nil
return p.cn, false, nil
}

cn, err := p.pool.Get()
cn, _, err := p.pool.Get()
if err != nil {
return nil, err
return nil, false, err
}
p.cn = cn
return cn, nil
return cn, true, nil
}

func (p *StickyConnPool) put() (err error) {
Expand Down
22 changes: 11 additions & 11 deletions internal/pool/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ var _ = Describe("ConnPool", func() {
It("rate limits dial", func() {
var rateErr error
for i := 0; i < 1000; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
if err != nil {
rateErr = err
break
Expand All @@ -40,13 +40,13 @@ var _ = Describe("ConnPool", func() {

It("should unblock client when conn is removed", func() {
// Reserve one connection.
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())

// Reserve all other connections.
var cns []*pool.Conn
for i := 0; i < 9; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
cns = append(cns, cn)
}
Expand All @@ -57,7 +57,7 @@ var _ = Describe("ConnPool", func() {
defer GinkgoRecover()

started <- true
_, err := connPool.Get()
_, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
done <- true

Expand Down Expand Up @@ -113,7 +113,7 @@ var _ = Describe("conns reaper", func() {
// add stale connections
idleConns = nil
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
cn.UsedAt = time.Now().Add(-2 * idleTimeout)
conns = append(conns, cn)
Expand All @@ -122,7 +122,7 @@ var _ = Describe("conns reaper", func() {

// add fresh connections
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
conns = append(conns, cn)
}
Expand Down Expand Up @@ -167,7 +167,7 @@ var _ = Describe("conns reaper", func() {
for j := 0; j < 3; j++ {
var freeCns []*pool.Conn
for i := 0; i < 3; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
freeCns = append(freeCns, cn)
Expand All @@ -176,7 +176,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(3))
Expect(connPool.FreeLen()).To(Equal(0))

cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
Expect(cn).NotTo(BeNil())
conns = append(conns, cn)
Expand Down Expand Up @@ -224,15 +224,15 @@ var _ = Describe("race", func() {

perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil {
Expect(connPool.Put(cn)).NotTo(HaveOccurred())
}
}
}, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil {
Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred())
Expand All @@ -248,7 +248,7 @@ var _ = Describe("race", func() {

perform(C, func(id int) {
for i := 0; i < N; i++ {
cn, err := connPool.Get()
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil {
Expect(connPool.Put(cn)).NotTo(HaveOccurred())
Expand Down
2 changes: 1 addition & 1 deletion pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ var _ = Describe("pool", func() {
})

It("should remove broken connections", func() {
cn, err := client.Pool().Get()
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn.NetConn = &badConn{}
Expect(client.Pool().Put(cn)).NotTo(HaveOccurred())
Expand Down
93 changes: 45 additions & 48 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,25 @@ type PubSub struct {

channels []string
patterns []string
}

func (c *PubSub) conn() (*pool.Conn, bool, error) {
cn, isNew, err := c.base.conn()
if err != nil {
return nil, false, err
}
if isNew {
c.resubscribe()
}
return cn, isNew, nil
}

nsub int // number of active subscriptions
func (c *PubSub) putConn(cn *pool.Conn, err error) {
c.base.putConn(cn, err, true)
}

func (c *PubSub) subscribe(redisCmd string, channels ...string) error {
cn, err := c.base.conn()
cn, _, err := c.conn()
if err != nil {
return err
}
Expand All @@ -44,7 +57,6 @@ func (c *PubSub) Subscribe(channels ...string) error {
err := c.subscribe("SUBSCRIBE", channels...)
if err == nil {
c.channels = appendIfNotExists(c.channels, channels...)
c.nsub += len(channels)
}
return err
}
Expand All @@ -54,43 +66,10 @@ func (c *PubSub) PSubscribe(patterns ...string) error {
err := c.subscribe("PSUBSCRIBE", patterns...)
if err == nil {
c.patterns = appendIfNotExists(c.patterns, patterns...)
c.nsub += len(patterns)
}
return err
}

func remove(ss []string, es ...string) []string {
if len(es) == 0 {
return ss[:0]
}
for _, e := range es {
for i, s := range ss {
if s == e {
ss = append(ss[:i], ss[i+1:]...)
break
}
}
}
return ss
}

func appendIfNotExists(ss []string, es ...string) []string {
for _, e := range es {
found := false
for _, s := range ss {
if s == e {
found = true
break
}
}

if !found {
ss = append(ss, e)
}
}
return ss
}

// Unsubscribes the client from the given channels, or from all of
// them if none is given.
func (c *PubSub) Unsubscribe(channels ...string) error {
Expand All @@ -116,7 +95,7 @@ func (c *PubSub) Close() error {
}

func (c *PubSub) Ping(payload string) error {
cn, err := c.base.conn()
cn, _, err := c.conn()
if err != nil {
return err
}
Expand Down Expand Up @@ -198,11 +177,7 @@ func (c *PubSub) newMessage(reply []interface{}) (interface{}, error) {
// is not received in time. This is low-level API and most clients
// should use ReceiveMessage.
func (c *PubSub) ReceiveTimeout(timeout time.Duration) (interface{}, error) {
if c.nsub == 0 {
c.resubscribe()
}

cn, err := c.base.conn()
cn, _, err := c.conn()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -274,12 +249,6 @@ func (c *PubSub) receiveMessage(timeout time.Duration) (*Message, error) {
}
}

func (c *PubSub) putConn(cn *pool.Conn, err error) {
if !c.base.putConn(cn, err, true) {
c.nsub = 0
}
}

func (c *PubSub) resubscribe() {
if c.base.closed() {
return
Expand All @@ -295,3 +264,31 @@ func (c *PubSub) resubscribe() {
}
}
}

func remove(ss []string, es ...string) []string {
if len(es) == 0 {
return ss[:0]
}
for _, e := range es {
for i, s := range ss {
if s == e {
ss = append(ss[:i], ss[i+1:]...)
break
}
}
}
return ss
}

func appendIfNotExists(ss []string, es ...string) []string {
loop:
for _, e := range es {
for _, s := range ss {
if s == e {
continue loop
}
}
ss = append(ss, e)
}
return ss
}
2 changes: 1 addition & 1 deletion pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ var _ = Describe("PubSub", func() {
})

expectReceiveMessageOnError := func(pubsub *redis.PubSub) {
cn1, err := pubsub.Pool().Get()
cn1, _, err := pubsub.Pool().Get()
Expect(err).NotTo(HaveOccurred())
cn1.NetConn = &badConn{
readErr: io.EOF,
Expand Down
Loading

0 comments on commit e57ac63

Please sign in to comment.