Skip to content

Commit

Permalink
Check context.Done while waiting for a connection
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jun 8, 2019
1 parent 530e66a commit 35932b7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
25 changes: 17 additions & 8 deletions internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ func (p *ConnPool) NewConn() (*Conn, error) {
return p._NewConn(nil, false)
}

func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
cn, err := p.newConn(c, pooled)
func (p *ConnPool) _NewConn(ctx context.Context, pooled bool) (*Conn, error) {
cn, err := p.newConn(ctx, pooled)
if err != nil {
return nil, err
}
Expand All @@ -149,7 +149,7 @@ func (p *ConnPool) _NewConn(c context.Context, pooled bool) (*Conn, error) {
return cn, nil
}

func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}
Expand All @@ -158,7 +158,7 @@ func (p *ConnPool) newConn(c context.Context, pooled bool) (*Conn, error) {
return nil, p.getLastDialError()
}

netConn, err := p.opt.Dialer(c)
netConn, err := p.opt.Dialer(ctx)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
Expand Down Expand Up @@ -205,12 +205,12 @@ func (p *ConnPool) getLastDialError() error {
}

// Get returns existed connection from the pool or creates a new one.
func (p *ConnPool) Get(c context.Context) (*Conn, error) {
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
if p.closed() {
return nil, ErrClosed
}

err := p.waitTurn()
err := p.waitTurn(ctx)
if err != nil {
return nil, err
}
Expand All @@ -235,7 +235,7 @@ func (p *ConnPool) Get(c context.Context) (*Conn, error) {

atomic.AddUint32(&p.stats.Misses, 1)

newcn, err := p._NewConn(c, true)
newcn, err := p._NewConn(ctx, true)
if err != nil {
p.freeTurn()
return nil, err
Expand All @@ -248,15 +248,24 @@ func (p *ConnPool) getTurn() {
p.queue <- struct{}{}
}

func (p *ConnPool) waitTurn() error {
func (p *ConnPool) waitTurn(ctx context.Context) error {
var done <-chan struct{}
if ctx != nil {
done = ctx.Done()
}

select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}:
return nil
default:
timer := timers.Get().(*time.Timer)
timer.Reset(p.opt.PoolTimeout)

select {
case <-done:
return ctx.Err()
case p.queue <- struct{}{}:
if !timer.Stop() {
<-timer.C
Expand Down
2 changes: 1 addition & 1 deletion internal/pool/pool_single.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (p *SingleConnPool) CloseConn(*Conn) error {
panic("not implemented")
}

func (p *SingleConnPool) Get(c context.Context) (*Conn, error) {
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
return p.cn, nil
}

Expand Down
4 changes: 2 additions & 2 deletions internal/pool/pool_sticky.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (p *StickyConnPool) CloseConn(*Conn) error {
panic("not implemented")
}

func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) {
p.mu.Lock()
defer p.mu.Unlock()

Expand All @@ -42,7 +42,7 @@ func (p *StickyConnPool) Get(c context.Context) (*Conn, error) {
return p.cn, nil
}

cn, err := p.pool.Get(c)
cn, err := p.pool.Get(ctx)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 35932b7

Please sign in to comment.