Skip to content

Commit

Permalink
PubSub conns don't share connection pool limit
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Apr 17, 2017
1 parent aeb22d6 commit 6499563
Show file tree
Hide file tree
Showing 14 changed files with 179 additions and 152 deletions.
5 changes: 3 additions & 2 deletions export_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package redis

import (
"net"
"time"

"github.com/go-redis/redis/internal/pool"
Expand All @@ -10,8 +11,8 @@ func (c *baseClient) Pool() pool.Pooler {
return c.connPool
}

func (c *PubSub) Pool() pool.Pooler {
return c.base.connPool
func (c *PubSub) SetNetConn(netConn net.Conn) {
c.cn = pool.NewConn(netConn)
}

func (c *PubSub) ReceiveMessageTimeout(timeout time.Duration) (*Message, error) {
Expand Down
4 changes: 1 addition & 3 deletions internal/pool/bench_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package pool_test

import (
"errors"
"testing"
"time"

Expand Down Expand Up @@ -40,7 +39,6 @@ func BenchmarkPoolGetPut1000Conns(b *testing.B) {

func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
connPool := pool.NewConnPool(dummyDialer, poolSize, time.Second, time.Hour, time.Hour)
removeReason := errors.New("benchmark")

b.ResetTimer()

Expand All @@ -50,7 +48,7 @@ func benchmarkPoolGetRemove(b *testing.B, poolSize int) {
if err != nil {
b.Fatal(err)
}
if err := connPool.Remove(cn, removeReason); err != nil {
if err := connPool.Remove(cn); err != nil {
b.Fatal(err)
}
}
Expand Down
66 changes: 36 additions & 30 deletions internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pool

import (
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
Expand All @@ -11,11 +10,8 @@ import (
"github.com/go-redis/redis/internal"
)

var (
ErrClosed = errors.New("redis: client is closed")
ErrPoolTimeout = errors.New("redis: connection pool timeout")
errConnStale = errors.New("connection is stale")
)
var ErrClosed = errors.New("redis: client is closed")
var ErrPoolTimeout = errors.New("redis: connection pool timeout")

var timers = sync.Pool{
New: func() interface{} {
Expand All @@ -36,12 +32,17 @@ type Stats struct {
}

type Pooler interface {
NewConn() (*Conn, error)
CloseConn(*Conn) error

Get() (*Conn, bool, error)
Put(*Conn) error
Remove(*Conn, error) error
Remove(*Conn) error

Len() int
FreeLen() int
Stats() *Stats

Close() error
}

Expand Down Expand Up @@ -87,11 +88,21 @@ func NewConnPool(dial dialer, poolSize int, poolTimeout, idleTimeout, idleCheckF
}

func (p *ConnPool) NewConn() (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}

netConn, err := p.dial()
if err != nil {
return nil, err
}
return NewConn(netConn), nil

cn := NewConn(netConn)
p.connsMu.Lock()
p.conns = append(p.conns, cn)
p.connsMu.Unlock()

return cn, nil
}

func (p *ConnPool) PopFree() *Conn {
Expand Down Expand Up @@ -164,7 +175,7 @@ func (p *ConnPool) Get() (*Conn, bool, error) {
}

if cn.IsStale(p.idleTimeout) {
p.remove(cn, errConnStale)
p.CloseConn(cn)
continue
}

Expand All @@ -178,18 +189,13 @@ func (p *ConnPool) Get() (*Conn, bool, error) {
return nil, false, err
}

p.connsMu.Lock()
p.conns = append(p.conns, newcn)
p.connsMu.Unlock()

return newcn, true, nil
}

func (p *ConnPool) Put(cn *Conn) error {
if data := cn.Rd.PeekBuffered(); data != nil {
err := fmt.Errorf("connection has unread data: %q", data)
internal.Logf(err.Error())
return p.Remove(cn, err)
internal.Logf("connection has unread data: %q", data)
return p.Remove(cn)
}
p.freeConnsMu.Lock()
p.freeConns = append(p.freeConns, cn)
Expand All @@ -198,15 +204,13 @@ func (p *ConnPool) Put(cn *Conn) error {
return nil
}

func (p *ConnPool) Remove(cn *Conn, reason error) error {
p.remove(cn, reason)
func (p *ConnPool) Remove(cn *Conn) error {
_ = p.CloseConn(cn)
<-p.queue
return nil
}

func (p *ConnPool) remove(cn *Conn, reason error) {
_ = p.closeConn(cn, reason)

func (p *ConnPool) CloseConn(cn *Conn) error {
p.connsMu.Lock()
for i, c := range p.conns {
if c == cn {
Expand All @@ -215,6 +219,15 @@ func (p *ConnPool) remove(cn *Conn, reason error) {
}
}
p.connsMu.Unlock()

return p.closeConn(cn)
}

func (p *ConnPool) closeConn(cn *Conn) error {
if p.OnClose != nil {
_ = p.OnClose(cn)
}
return cn.Close()
}

// Len returns total number of connections.
Expand Down Expand Up @@ -258,7 +271,7 @@ func (p *ConnPool) Close() error {
if cn == nil {
continue
}
if err := p.closeConn(cn, ErrClosed); err != nil && firstErr == nil {
if err := p.closeConn(cn); err != nil && firstErr == nil {
firstErr = err
}
}
Expand All @@ -272,13 +285,6 @@ func (p *ConnPool) Close() error {
return firstErr
}

func (p *ConnPool) closeConn(cn *Conn, reason error) error {
if p.OnClose != nil {
_ = p.OnClose(cn)
}
return cn.Close()
}

func (p *ConnPool) reapStaleConn() bool {
if len(p.freeConns) == 0 {
return false
Expand All @@ -289,7 +295,7 @@ func (p *ConnPool) reapStaleConn() bool {
return false
}

p.remove(cn, errConnStale)
p.CloseConn(cn)
p.freeConns = append(p.freeConns[:0], p.freeConns[1:]...)

return true
Expand Down
10 changes: 9 additions & 1 deletion internal/pool/pool_single.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ func NewSingleConnPool(cn *Conn) *SingleConnPool {
}
}

func (p *SingleConnPool) NewConn() (*Conn, error) {
panic("not implemented")
}

func (p *SingleConnPool) CloseConn(*Conn) error {
panic("not implemented")
}

func (p *SingleConnPool) Get() (*Conn, bool, error) {
return p.cn, false, nil
}
Expand All @@ -23,7 +31,7 @@ func (p *SingleConnPool) Put(cn *Conn) error {
return nil
}

func (p *SingleConnPool) Remove(cn *Conn, _ error) error {
func (p *SingleConnPool) Remove(cn *Conn) error {
if p.cn != cn {
panic("p.cn != cn")
}
Expand Down
24 changes: 14 additions & 10 deletions internal/pool/pool_sticky.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package pool

import (
"errors"
"sync"
)
import "sync"

type StickyConnPool struct {
pool *ConnPool
Expand All @@ -23,6 +20,14 @@ func NewStickyConnPool(pool *ConnPool, reusable bool) *StickyConnPool {
}
}

func (p *StickyConnPool) NewConn() (*Conn, error) {
panic("not implemented")
}

func (p *StickyConnPool) CloseConn(*Conn) error {
panic("not implemented")
}

func (p *StickyConnPool) Get() (*Conn, bool, error) {
p.mu.Lock()
defer p.mu.Unlock()
Expand Down Expand Up @@ -58,20 +63,20 @@ func (p *StickyConnPool) Put(cn *Conn) error {
return nil
}

func (p *StickyConnPool) removeUpstream(reason error) error {
err := p.pool.Remove(p.cn, reason)
func (p *StickyConnPool) removeUpstream() error {
err := p.pool.Remove(p.cn)
p.cn = nil
return err
}

func (p *StickyConnPool) Remove(cn *Conn, reason error) error {
func (p *StickyConnPool) Remove(cn *Conn) error {
p.mu.Lock()
defer p.mu.Unlock()

if p.closed {
return nil
}
return p.removeUpstream(reason)
return p.removeUpstream()
}

func (p *StickyConnPool) Len() int {
Expand Down Expand Up @@ -111,8 +116,7 @@ func (p *StickyConnPool) Close() error {
if p.reusable {
err = p.putUpstream()
} else {
reason := errors.New("redis: unreusable sticky connection")
err = p.removeUpstream(reason)
err = p.removeUpstream()
}
}
return err
Expand Down
7 changes: 3 additions & 4 deletions internal/pool/pool_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package pool_test

import (
"errors"
"testing"
"time"

Expand Down Expand Up @@ -59,7 +58,7 @@ var _ = Describe("ConnPool", func() {
// ok
}

err = connPool.Remove(cn, errors.New("test"))
err = connPool.Remove(cn)
Expect(err).NotTo(HaveOccurred())

// Check that Ping is unblocked.
Expand Down Expand Up @@ -169,7 +168,7 @@ var _ = Describe("conns reaper", func() {
Expect(connPool.Len()).To(Equal(4))
Expect(connPool.FreeLen()).To(Equal(0))

err = connPool.Remove(cn, errors.New("test"))
err = connPool.Remove(cn)
Expect(err).NotTo(HaveOccurred())

Expect(connPool.Len()).To(Equal(3))
Expand Down Expand Up @@ -219,7 +218,7 @@ var _ = Describe("race", func() {
cn, _, err := connPool.Get()
Expect(err).NotTo(HaveOccurred())
if err == nil {
Expect(connPool.Remove(cn, errors.New("test"))).NotTo(HaveOccurred())
Expect(connPool.Remove(cn)).NotTo(HaveOccurred())
}
}
})
Expand Down
2 changes: 1 addition & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (opt *Options) init() {
}
}
if opt.PoolSize == 0 {
opt.PoolSize = 100
opt.PoolSize = 10
}
if opt.DialTimeout == 0 {
opt.DialTimeout = 5 * time.Second
Expand Down
12 changes: 0 additions & 12 deletions pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,6 @@ var _ = Describe("pool", func() {
Expect(pool.Len()).To(Equal(pool.FreeLen()))
})

It("respects max size on pubsub", func() {
connPool := client.Pool()

perform(1000, func(id int) {
pubsub := client.Subscribe("test")
Expect(pubsub.Close()).NotTo(HaveOccurred())
})

Expect(connPool.Len()).To(Equal(connPool.FreeLen()))
Expect(connPool.Len()).To(BeNumerically("<=", 10))
})

It("removes broken connections", func() {
cn, _, err := client.Pool().Get()
Expect(err).NotTo(HaveOccurred())
Expand Down
Loading

0 comments on commit 6499563

Please sign in to comment.