Skip to content

Commit

Permalink
Support DialContext on Pool
Browse files Browse the repository at this point in the history
Support Pool.DialContext option

Fixes gomodule#406
  • Loading branch information
izumin5210 authored and garyburd committed Mar 22, 2019
1 parent 9f26187 commit 39e2c31
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
25 changes: 20 additions & 5 deletions redis/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ var (
// return &redis.Pool{
// MaxIdle: 3,
// IdleTimeout: 240 * time.Second,
// // Dial or DialContext must be set. When both are set, DialContext takes precedence over Dial.
// Dial: func () (redis.Conn, error) { return redis.Dial("tcp", addr) },
// }
// }
Expand Down Expand Up @@ -126,6 +127,13 @@ type Pool struct {
// (subscribed to pubsub channel, transaction started, ...).
Dial func() (Conn, error)

// DialContext is an application supplied function for creating and configuring a
// connection with the given context.
//
// The connection returned from Dial must not be in a special state
// (subscribed to pubsub channel, transaction started, ...).
DialContext func(ctx context.Context) (Conn, error)

// TestOnBorrow is an optional application supplied function for checking
// the health of an idle connection before the connection is used again by
// the application. Argument t is the time that the connection was returned
Expand Down Expand Up @@ -293,10 +301,7 @@ func (p *Pool) lazyInit() {

// get prunes stale connections and returns a connection from the idle list or
// creates a new connection.
func (p *Pool) get(ctx interface {
Done() <-chan struct{}
Err() error
}) (*poolConn, error) {
func (p *Pool) get(ctx context.Context) (*poolConn, error) {

// Handle limit for p.Wait == true.
var waited time.Duration
Expand Down Expand Up @@ -372,7 +377,7 @@ func (p *Pool) get(ctx interface {

p.active++
p.mu.Unlock()
c, err := p.Dial()
c, err := p.dial(ctx)
if err != nil {
c = nil
p.mu.Lock()
Expand All @@ -385,6 +390,16 @@ func (p *Pool) get(ctx interface {
return &poolConn{c: c, created: nowFunc()}, err
}

func (p *Pool) dial(ctx context.Context) (Conn, error) {
if p.DialContext != nil {
return p.DialContext(ctx)
}
if p.Dial != nil {
return p.Dial()
}
return nil, errors.New("redigo: must pass Dial or DialContext to pool")
}

func (p *Pool) put(pc *poolConn, forceClose bool) error {
p.mu.Lock()
if !p.closed && !forceClose {
Expand Down
20 changes: 20 additions & 0 deletions redis/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ func (d *poolDialer) dial() (redis.Conn, error) {
return &poolTestConn{d: d, Conn: c}, nil
}

func (d *poolDialer) dialContext(ctx context.Context) (redis.Conn, error) {
return d.dial()
}

func (d *poolDialer) check(message string, p *redis.Pool, dialed, open, inuse int) {
d.checkAll(message, p, dialed, open, inuse, 0, 0)
}
Expand Down Expand Up @@ -820,6 +824,22 @@ func TestWaitPoolGetContext(t *testing.T) {
defer c.Close()
}

func TestWaitPoolGetContextWithDialContext(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
MaxIdle: 1,
MaxActive: 1,
DialContext: d.dialContext,
Wait: true,
}
defer p.Close()
c, err := p.GetContext(context.Background())
if err != nil {
t.Fatalf("GetContext returned %v", err)
}
defer c.Close()
}

func TestWaitPoolGetAfterClose(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
Expand Down

0 comments on commit 39e2c31

Please sign in to comment.