Skip to content

Commit

Permalink
feat: Add context dial support (gomodule#476)
Browse files Browse the repository at this point in the history
Add support context on dial using new DialContext and DialContextFunc DialOption.
  • Loading branch information
178inaba authored May 24, 2020
1 parent b685725 commit 2eadaa0
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 5 deletions.
28 changes: 23 additions & 5 deletions redis/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package redis
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -77,7 +78,7 @@ type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
dialer *net.Dialer
dial func(network, addr string) (net.Conn, error)
dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
db int
password string
clientName string
Expand Down Expand Up @@ -123,7 +124,18 @@ func DialKeepAlive(d time.Duration) DialOption {
// DialNetDial overrides DialConnectTimeout and DialKeepAlive.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dial = dial
do.dialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dial(network, addr)
}
}}
}

// DialContextFunc specifies a custom dial function with context for creating TCP
// connections, otherwise a net.Dialer customized via the other options is used.
// DialContextFunc overrides DialConnectTimeout and DialKeepAlive.
func DialContextFunc(f func(ctx context.Context, network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dialContext = f
}}
}

Expand Down Expand Up @@ -177,6 +189,12 @@ func DialUseTLS(useTLS bool) DialOption {
// Dial connects to the Redis server at the given network and
// address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) {
return DialContext(context.Background(), network, address, options...)
}

// DialContext connects to the Redis server at the given network and
// address using the specified options and context.
func DialContext(ctx context.Context, network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dialer: &net.Dialer{
KeepAlive: time.Minute * 5,
Expand All @@ -185,11 +203,11 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
for _, option := range options {
option.f(&do)
}
if do.dial == nil {
do.dial = do.dialer.Dial
if do.dialContext == nil {
do.dialContext = do.dialer.DialContext
}

netConn, err := do.dial(network, address)
netConn, err := do.dialContext(ctx, network, address)
if err != nil {
return nil, err
}
Expand Down
42 changes: 42 additions & 0 deletions redis/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package redis_test

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand Down Expand Up @@ -531,6 +532,37 @@ func TestReadTimeout(t *testing.T) {
}
}

func TestDialContextFunc(t *testing.T) {
var isPassed bool
f := func(ctx context.Context, network, addr string) (net.Conn, error) {
isPassed = true
return &testConn{}, nil
}

_, err := redis.DialContext(context.Background(), "", "", redis.DialContextFunc(f))
if err != nil {
t.Fatalf("DialContext returned %v", err)
}

if !isPassed {
t.Fatal("DialContextFunc not passed")
}
}

func TestDialContext_CanceledContext(t *testing.T) {
addr, err := redis.DefaultServerAddr()
if err != nil {
t.Fatalf("redis.DefaultServerAddr returned %v", err)
}

ctx, cancel := context.WithCancel(context.Background())
cancel()

if _, err = redis.DialContext(ctx, "tcp", addr); err == nil {
t.Fatalf("DialContext returned nil, expect error")
}
}

var dialErrors = []struct {
rawurl string
expectedError string
Expand Down Expand Up @@ -725,6 +757,16 @@ func ExampleDial() {
defer c.Close()
}

// Connect to local instance of Redis running on the default port using the provided context.
func ExampleDialContext() {
ctx := context.Background()
c, err := redis.DialContext(ctx, "tcp", ":6379")
if err != nil {
// handle error
}
defer c.Close()
}

// Connect to remote instance of Redis using a URL.
func ExampleDialURL() {
c, err := redis.DialURL(os.Getenv("REDIS_URL"))
Expand Down
43 changes: 43 additions & 0 deletions redis/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"errors"
"io"
"net"
"reflect"
"sync"
"testing"
Expand Down Expand Up @@ -840,6 +841,48 @@ func TestWaitPoolGetContextWithDialContext(t *testing.T) {
defer c.Close()
}

func TestPoolGetContext_DialContext(t *testing.T) {
var isPassed bool
f := func(ctx context.Context, network, addr string) (net.Conn, error) {
isPassed = true
return &testConn{}, nil
}

p := &redis.Pool{
DialContext: func(ctx context.Context) (redis.Conn, error) {
return redis.DialContext(ctx, "", "", redis.DialContextFunc(f))
},
}
defer p.Close()

if _, err := p.GetContext(context.Background()); err != nil {
t.Fatalf("GetContext returned %v", err)
}

if !isPassed {
t.Fatal("DialContextFunc not passed")
}
}

func TestPoolGetContext_DialContext_CanceledContext(t *testing.T) {
addr, err := redis.DefaultServerAddr()
if err != nil {
t.Fatalf("redis.DefaultServerAddr returned %v", err)
}

p := &redis.Pool{
DialContext: func(ctx context.Context) (redis.Conn, error) { return redis.DialContext(ctx, "tcp", addr) },
}
defer p.Close()

ctx, cancel := context.WithCancel(context.Background())
cancel()

if _, err := p.GetContext(ctx); err == nil {
t.Fatalf("GetContext returned nil, expect error")
}
}

func TestWaitPoolGetAfterClose(t *testing.T) {
d := poolDialer{t: t}
p := &redis.Pool{
Expand Down

0 comments on commit 2eadaa0

Please sign in to comment.