Skip to content

Commit

Permalink
Add DialUseTLS and TLS tests
Browse files Browse the repository at this point in the history
  • Loading branch information
garyburd committed Sep 22, 2017
1 parent 70e1b19 commit 8bed3b5
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 12 deletions.
24 changes: 13 additions & 11 deletions redis/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ type dialOptions struct {
dial func(network, addr string) (net.Conn, error)
db int
password string
dialTLS bool
useTLS bool
skipVerify bool
tlsConfig *tls.Config
}
Expand Down Expand Up @@ -135,14 +135,22 @@ func DialTLSConfig(c *tls.Config) DialOption {
}}
}

// DialTLSSkipVerify to disable server name verification when connecting
// over TLS. Has no effect when not dialing a TLS connection.
// DialTLSSkipVerify disables server name verification when connecting over
// TLS. Has no effect when not dialing a TLS connection.
func DialTLSSkipVerify(skip bool) DialOption {
return DialOption{func(do *dialOptions) {
do.skipVerify = skip
}}
}

// DialUseTLS specifies whether TLS should be used when connecting to the
// server. This option is ignore by DialURL.
func DialUseTLS(useTLS bool) DialOption {
return DialOption{func(do *dialOptions) {
do.useTLS = useTLS
}}
}

// Dial connects to the Redis server at the given network and
// address using the specified options.
func Dial(network, address string, options ...DialOption) (Conn, error) {
Expand All @@ -158,7 +166,7 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
return nil, err
}

if do.dialTLS {
if do.useTLS {
tlsConfig := cloneTLSClientConfig(do.tlsConfig, do.skipVerify)
if tlsConfig.ServerName == "" {
host, _, err := net.SplitHostPort(address)
Expand Down Expand Up @@ -202,10 +210,6 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
return c, nil
}

func dialTLS(do *dialOptions) {
do.dialTLS = true
}

var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)

// DialURL connects to a Redis server at the given URL using the Redis
Expand Down Expand Up @@ -257,9 +261,7 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) {
return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
}

if u.Scheme == "rediss" {
options = append([]DialOption{{dialTLS}}, options...)
}
options = append(options, DialUseTLS(u.Scheme == "rediss"))

return Dial("tcp", address, options...)
}
Expand Down
145 changes: 144 additions & 1 deletion redis/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ package redis_test

import (
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"math"
"net"
Expand All @@ -41,11 +44,21 @@ func (*testConn) SetReadDeadline(t time.Time) error { return nil }
func (*testConn) SetWriteDeadline(t time.Time) error { return nil }

func dialTestConn(r io.Reader, w io.Writer) redis.DialOption {
return redis.DialNetDial(func(net, addr string) (net.Conn, error) {
return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
return &testConn{Reader: r, Writer: w}, nil
})
}

func dialTestConnTLS(r io.Reader, w io.Writer) redis.DialOption {
return redis.DialNetDial(func(network, addr string) (net.Conn, error) {
client, server := net.Pipe()
tlsServer := tls.Server(server, &serverTLSConfig)
go io.Copy(tlsServer, r)
go io.Copy(w, tlsServer)
return client, nil
})
}

type durationArg struct {
time.Duration
}
Expand Down Expand Up @@ -551,6 +564,74 @@ func TestDialURLDatabase(t *testing.T) {
}
}

func checkPingPong(t *testing.T, buf *bytes.Buffer, c redis.Conn) {
resp, err := c.Do("PING")
if err != nil {
t.Fatal("ping error:", err)
}
expected := "*1\r\n$4\r\nPING\r\n"
actual := buf.String()
if actual != expected {
t.Errorf("commands = %q, want %q", actual, expected)
}
if resp != "PONG" {
t.Errorf("resp = %v, want %v", resp, "PONG")
}
}

func pingRespReader() io.Reader { return strings.NewReader("+PONG\r\n") }

func TestDialURLTLS(t *testing.T) {
var buf bytes.Buffer
c, err := redis.DialURL("rediss://example.com/",
redis.DialTLSConfig(&clientTLSConfig),
dialTestConnTLS(pingRespReader(), &buf))
if err != nil {
t.Fatal("dial error:", err)
}
defer c.Close()
checkPingPong(t, &buf, c)
}

func TestDialURLIgnoreUseTLS(t *testing.T) {
var buf bytes.Buffer
c, err := redis.DialURL("redis://example.com/",
redis.DialTLSConfig(&clientTLSConfig),
dialTestConn(pingRespReader(), &buf),
redis.DialUseTLS(true))
if err != nil {
t.Fatal("dial error:", err)
}
defer c.Close()
checkPingPong(t, &buf, c)
}

func TestDialUseTLS(t *testing.T) {
var buf bytes.Buffer
c, err := redis.Dial("tcp", "example.com:6379",
redis.DialTLSConfig(&clientTLSConfig),
dialTestConnTLS(pingRespReader(), &buf),
redis.DialUseTLS(true))
if err != nil {
t.Fatal("dial error:", err)
}
defer c.Close()
checkPingPong(t, &buf, c)
}

func TestDialTLSSKipVerify(t *testing.T) {
var buf bytes.Buffer
c, err := redis.Dial("tcp", "example.com:6379",
dialTestConnTLS(pingRespReader(), &buf),
redis.DialTLSSkipVerify(true),
redis.DialUseTLS(true))
if err != nil {
t.Fatal("dial error:", err)
}
defer c.Close()
checkPingPong(t, &buf, c)
}

// Connect to local instance of Redis running on the default port.
func ExampleDial() {
c, err := redis.Dial("tcp", ":6379")
Expand Down Expand Up @@ -680,3 +761,65 @@ func BenchmarkDoPing(b *testing.B) {
}
}
}

var clientTLSConfig, serverTLSConfig tls.Config

func init() {

// The certificate and key for testing TLS dial options was created
// using the command
//
// go run GOROOT/src/crypto/tls/generate_cert.go \
// --rsa-bits 1024 \
// --host 127.0.0.1,::1,example.com --ca \
// --start-date "Jan 1 00:00:00 1970" \
// --duration=1000000h
//
// where GOROOT is the value of GOROOT reported by go env.
localhostCert := []byte(`
-----BEGIN CERTIFICATE-----
MIICFDCCAX2gAwIBAgIRAJfBL4CUxkXcdlFurb3K+iowDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
gYkCgYEArizw8WxMUQ3bGHLeuJ4fDrEpy+L2pqrbYRlKk1DasJ/VkB8bImzIpe6+
LGjiYIxvnDCOJ3f3QplcQuiuMyl6f2irJlJsbFT8Lo/3obnuTKAIaqUdJUqBg6y+
JaL8Auk97FvunfKFv8U1AIhgiLzAfQ/3Eaq1yi87Ra6pMjGbTtcCAwEAAaNoMGYw
DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF
MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA
AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAdZ8daIVkyhVwflt5I19m0oq1TycbGO1+
ach7T6cZiBQeNR/SJtxr/wKPEpmvUgbv2BfFrKJ8QoIHYsbNSURTWSEa02pfw4k9
6RQhij3ZkG79Ituj5OYRORV6Z0HUW32r670BtcuHuAhq7YA6Nxy4FtSt7bAlVdRt
rrKgNsltzMk=
-----END CERTIFICATE-----`)

localhostKey := []byte(`
-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQCuLPDxbExRDdsYct64nh8OsSnL4vamqtthGUqTUNqwn9WQHxsi
bMil7r4saOJgjG+cMI4nd/dCmVxC6K4zKXp/aKsmUmxsVPwuj/ehue5MoAhqpR0l
SoGDrL4lovwC6T3sW+6d8oW/xTUAiGCIvMB9D/cRqrXKLztFrqkyMZtO1wIDAQAB
AoGACrc5G6FOEK6JjDeE/Fa+EmlT6PdNtXNNi+vCas3Opo8u1G8VfEi1D4BgstrB
Eq+RLkrOdB8tVyuYQYWPMhabMqF+hhKJN72j0OwfuPlVvTInwb/cKjo/zbH1IA+Y
HenHNK4ywv7/p/9/MvQPJ3I32cQBCgGUW5chVSH5M1sj5gECQQDabQAI1X0uDqCm
KbX9gXVkAgxkFddrt6LBHt57xujFcqEKFE7nwKhDh7DweVs/VEJ+kpid4z+UnLOw
KjtP9JolAkEAzCNBphQ//IsbH5rNs10wIUw3Ks/Oepicvr6kUFbIv+neRzi1iJHa
m6H7EayK3PWgax6BAsR/t0Jc9XV7r2muSwJAVzN09BHnK+ADGtNEKLTqXMbEk6B0
pDhn7ZmZUOkUPN+Kky+QYM11X6Bob1jDqQDGmymDbGUxGO+GfSofC8inUQJAGfci
Eo3g1a6b9JksMPRZeuLG4ZstGErxJRH6tH1Va5PDwitka8qhk8o2tTjNMO3NSdLH
diKoXBcE2/Pll5pJoQJBAIMiiMIzXJhnN4mX8may44J/HvMlMf2xuVH2gNMwmZuc
Bjqn3yoLHaoZVvbWOi0C2TCN4FjXjaLNZGifQPbIcaA=
-----END RSA PRIVATE KEY-----`)

cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
panic(fmt.Sprintf("error creating key pair: %v", err))
}
serverTLSConfig.Certificates = []tls.Certificate{cert}

certificate, err := x509.ParseCertificate(serverTLSConfig.Certificates[0].Certificate[0])
if err != nil {
panic(fmt.Sprintf("error parsing x509 certificate: %v", err))
}

clientTLSConfig.RootCAs = x509.NewCertPool()
clientTLSConfig.RootCAs.AddCert(certificate)
}

0 comments on commit 8bed3b5

Please sign in to comment.