From 3da4357c0cc5d528b3ef49307f7a34db40eaba0e Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sat, 18 May 2019 14:00:07 +0300 Subject: [PATCH] Pass network and addr to dialer --- cluster.go | 3 +++ options.go | 10 ++++++---- redis_test.go | 7 ++++--- sentinel.go | 7 ++++--- universal.go | 8 +++++++- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/cluster.go b/cluster.go index cf96035c5..e1c1b125a 100644 --- a/cluster.go +++ b/cluster.go @@ -53,6 +53,8 @@ type ClusterOptions struct { // Following options are copied from Options struct. + Dialer func(network, addr string) (net.Conn, error) + OnConnect func(*Conn) error Password string @@ -122,6 +124,7 @@ func (opt *ClusterOptions) clientOptions() *Options { const disableIdleCheck = -1 return &Options{ + Dialer: opt.Dialer, OnConnect: opt.OnConnect, MaxRetries: opt.MaxRetries, diff --git a/options.go b/options.go index b6fabf3f2..d00ead00f 100644 --- a/options.go +++ b/options.go @@ -34,7 +34,7 @@ type Options struct { // Dialer creates new network connection and has priority over // Network and Addr options. - Dialer func() (net.Conn, error) + Dialer func(network, addr string) (net.Conn, error) // Hook that is called when new connection is established. OnConnect func(*Conn) error @@ -105,13 +105,13 @@ func (opt *Options) init() { opt.Addr = "localhost:6379" } if opt.Dialer == nil { - opt.Dialer = func() (net.Conn, error) { + opt.Dialer = func(network, addr string) (net.Conn, error) { netDialer := &net.Dialer{ Timeout: opt.DialTimeout, KeepAlive: 5 * time.Minute, } if opt.TLSConfig == nil { - return netDialer.Dial(opt.Network, opt.Addr) + return netDialer.Dial(network, addr) } else { return tls.DialWithDialer(netDialer, opt.Network, opt.Addr, opt.TLSConfig) } @@ -215,7 +215,9 @@ func ParseURL(redisURL string) (*Options, error) { func newConnPool(opt *Options) *pool.ConnPool { return pool.NewConnPool(&pool.Options{ - Dialer: opt.Dialer, + Dialer: func() (net.Conn, error) { + return opt.Dialer(opt.Network, opt.Addr) + }, PoolSize: opt.PoolSize, MinIdleConns: opt.MinIdleConns, MaxConnAge: opt.MaxConnAge, diff --git a/redis_test.go b/redis_test.go index ceeb87e2d..e829686c4 100644 --- a/redis_test.go +++ b/redis_test.go @@ -39,9 +39,10 @@ var _ = Describe("Client", func() { It("should support custom dialers", func() { custom := redis.NewClient(&redis.Options{ - Addr: ":1234", - Dialer: func() (net.Conn, error) { - return net.Dial("tcp", redisAddr) + Network: "tcp", + Addr: redisAddr, + Dialer: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) }, }) diff --git a/sentinel.go b/sentinel.go index bf43a165b..889835007 100644 --- a/sentinel.go +++ b/sentinel.go @@ -25,6 +25,7 @@ type FailoverOptions struct { // Following options are copied from Options struct. + Dialer func(network, addr string) (net.Conn, error) OnConnect func(*Conn) error Password string @@ -50,8 +51,8 @@ type FailoverOptions struct { func (opt *FailoverOptions) options() *Options { return &Options{ - Addr: "FailoverClient", - + Addr: "FailoverClient", + Dialer: opt.Dialer, OnConnect: opt.OnConnect, DB: opt.DB, @@ -304,7 +305,7 @@ func (c *sentinelFailover) Pool() *pool.ConnPool { return c.pool } -func (c *sentinelFailover) dial() (net.Conn, error) { +func (c *sentinelFailover) dial(network, addr string) (net.Conn, error) { addr, err := c.MasterAddr() if err != nil { return nil, err diff --git a/universal.go b/universal.go index fd240d007..656bc5c6a 100644 --- a/universal.go +++ b/universal.go @@ -3,6 +3,7 @@ package redis import ( "context" "crypto/tls" + "net" "time" ) @@ -19,6 +20,7 @@ type UniversalOptions struct { // Common options. + Dialer func(network, addr string) (net.Conn, error) OnConnect func(*Conn) error Password string MaxRetries int @@ -54,6 +56,7 @@ func (o *UniversalOptions) cluster() *ClusterOptions { return &ClusterOptions{ Addrs: o.Addrs, + Dialer: o.Dialer, OnConnect: o.OnConnect, Password: o.Password, @@ -89,7 +92,9 @@ func (o *UniversalOptions) failover() *FailoverOptions { return &FailoverOptions{ SentinelAddrs: o.Addrs, MasterName: o.MasterName, - OnConnect: o.OnConnect, + + Dialer: o.Dialer, + OnConnect: o.OnConnect, DB: o.DB, Password: o.Password, @@ -121,6 +126,7 @@ func (o *UniversalOptions) simple() *Options { return &Options{ Addr: addr, + Dialer: o.Dialer, OnConnect: o.OnConnect, DB: o.DB,