Skip to content

Commit

Permalink
Enable keep-alive on server connections (gomodule#294)
Browse files Browse the repository at this point in the history
Add the option to configure connection keep-alive using the new DialKeepAlive option.

This is set to 5 minutes by default to ensure that half-closed connections are detected.

Without this Pub/Sub connections can get disconnected by the redis server and we will never notice which can potentially deadlock applications.

Also:
* Add --redis-address <address> command line option to tests defaulting to 127.0.0.1. Without this all tests fail on machines with multiple addresses.
  • Loading branch information
stevenh authored and garyburd committed Nov 22, 2017
1 parent 47dc60e commit 4a7d9db
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
28 changes: 22 additions & 6 deletions redis/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type DialOption struct {
type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
dialer *net.Dialer
dial func(network, addr string) (net.Conn, error)
db int
password string
Expand All @@ -94,17 +95,27 @@ func DialWriteTimeout(d time.Duration) DialOption {
}}
}

// DialConnectTimeout specifies the timeout for connecting to the Redis server.
// DialConnectTimeout specifies the timeout for connecting to the Redis server when
// no DialNetDial option is specified.
func DialConnectTimeout(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
dialer := net.Dialer{Timeout: d}
do.dial = dialer.Dial
do.dialer.Timeout = d
}}
}

// DialKeepAlive specifies the keep-alive period for TCP connections to the Redis server
// when no DialNetDial option is specified.
// If zero, keep-alives are not enabled. If no DialKeepAlive option is specified then
// the default of 5 minutes is used to ensure that half-closed TCP sessions are detected.
func DialKeepAlive(d time.Duration) DialOption {
return DialOption{func(do *dialOptions) {
do.dialer.KeepAlive = d
}}
}

// DialNetDial specifies a custom dial function for creating TCP
// connections. If this option is left out, then net.Dial is
// used. DialNetDial overrides DialConnectTimeout.
// connections, otherwise a net.Dialer customized via the other options is used.
// DialNetDial overrides DialConnectTimeout and DialKeepAlive.
func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dial = dial
Expand Down Expand Up @@ -154,11 +165,16 @@ func DialUseTLS(useTLS bool) DialOption {
// address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) {
do := dialOptions{
dial: net.Dial,
dialer: &net.Dialer{
KeepAlive: time.Minute * 5,
},
}
for _, option := range options {
option.f(&do)
}
if do.dial == nil {
do.dial = do.dialer.Dial
}

netConn, err := do.dial(network, address)
if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion redis/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var (
ErrNegativeInt = errNegativeInt

serverPath = flag.String("redis-server", "redis-server", "Path to redis server binary")
serverAddress = flag.String("redis-address", "127.0.0.1", "The address of the server")
serverBasePort = flag.Int("redis-port", 16379, "Beginning of port range for test servers")
serverLogName = flag.String("redis-log", "", "Write Redis server logs to `filename`")
serverLog = ioutil.Discard
Expand Down Expand Up @@ -136,6 +137,7 @@ func startDefaultServer() error {
defaultServer, defaultServerErr = NewServer(
"default",
"--port", strconv.Itoa(*serverBasePort),
"--bind", *serverAddress,
"--save", "",
"--appendonly", "no")
return defaultServerErr
Expand All @@ -147,7 +149,7 @@ func DialDefaultServer() (Conn, error) {
if err := startDefaultServer(); err != nil {
return nil, err
}
c, err := Dial("tcp", fmt.Sprintf(":%d", *serverBasePort), DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
c, err := Dial("tcp", fmt.Sprintf("%v:%d", *serverAddress, *serverBasePort), DialReadTimeout(1*time.Second), DialWriteTimeout(1*time.Second))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 4a7d9db

Please sign in to comment.