Skip to content

Commit 877a977

Browse files
nemithmethane
authored andcommittedMay 10, 2019
move tls and pubkey object creation to Config.normalize() (go-sql-driver#958)
This is still less than ideal since we cannot directly pass in tls.Config into Config and have it be used, but it is sill backwards compatable. In the future this should be revisited to be able to use a custome tls.Config passed directly in without string parsing/registering.
1 parent 8056f2c commit 877a977

File tree

4 files changed

+76
-27
lines changed

4 files changed

+76
-27
lines changed
 

‎AUTHORS

+2-1
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,12 @@ Zhenye Xie <xiezhenye at gmail.com>
9090

9191
Barracuda Networks, Inc.
9292
Counting Ltd.
93+
Facebook Inc.
9394
GitHub Inc.
9495
Google Inc.
9596
InfoSum Ltd.
9697
Keybase Inc.
98+
Multiplay Ltd.
9799
Percona LLC
98100
Pivotal Inc.
99101
Stripe Inc.
100-
Multiplay Ltd.

‎dsn.go

+27-23
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,35 @@ func (cfg *Config) normalize() error {
113113
default:
114114
return errors.New("default addr for network '" + cfg.Net + "' unknown")
115115
}
116-
117116
} else if cfg.Net == "tcp" {
118117
cfg.Addr = ensureHavePort(cfg.Addr)
119118
}
120119

121-
if cfg.tls != nil {
122-
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
123-
host, _, err := net.SplitHostPort(cfg.Addr)
124-
if err == nil {
125-
cfg.tls.ServerName = host
126-
}
120+
switch cfg.TLSConfig {
121+
case "false", "":
122+
// don't set anything
123+
case "true":
124+
cfg.tls = &tls.Config{}
125+
case "skip-verify", "preferred":
126+
cfg.tls = &tls.Config{InsecureSkipVerify: true}
127+
default:
128+
cfg.tls = getTLSConfigClone(cfg.TLSConfig)
129+
if cfg.tls == nil {
130+
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
131+
}
132+
}
133+
134+
if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
135+
host, _, err := net.SplitHostPort(cfg.Addr)
136+
if err == nil {
137+
cfg.tls.ServerName = host
138+
}
139+
}
140+
141+
if cfg.ServerPubKey != "" {
142+
cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
143+
if cfg.pubKey == nil {
144+
return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
127145
}
128146
}
129147

@@ -552,13 +570,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
552570
if err != nil {
553571
return fmt.Errorf("invalid value for server pub key name: %v", err)
554572
}
555-
556-
if pubKey := getServerPubKey(name); pubKey != nil {
557-
cfg.ServerPubKey = name
558-
cfg.pubKey = pubKey
559-
} else {
560-
return errors.New("invalid value / unknown server pub key name: " + name)
561-
}
573+
cfg.ServerPubKey = name
562574

563575
// Strict mode
564576
case "strict":
@@ -577,25 +589,17 @@ func parseDSNParams(cfg *Config, params string) (err error) {
577589
if isBool {
578590
if boolValue {
579591
cfg.TLSConfig = "true"
580-
cfg.tls = &tls.Config{}
581592
} else {
582593
cfg.TLSConfig = "false"
583594
}
584595
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
585596
cfg.TLSConfig = vl
586-
cfg.tls = &tls.Config{InsecureSkipVerify: true}
587597
} else {
588598
name, err := url.QueryUnescape(value)
589599
if err != nil {
590600
return fmt.Errorf("invalid value for TLS config name: %v", err)
591601
}
592-
593-
if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
594-
cfg.TLSConfig = name
595-
cfg.tls = tlsConfig
596-
} else {
597-
return errors.New("invalid value / unknown config name: " + name)
598-
}
602+
cfg.TLSConfig = name
599603
}
600604

601605
// I/O write Timeout

‎dsn_test.go

+46-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ var testDSNs = []struct {
3939
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
4040
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"},
4141
}, {
42-
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216",
43-
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216},
42+
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true",
43+
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
4444
}, {
4545
"user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0",
4646
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false},
@@ -358,6 +358,50 @@ func TestCloneConfig(t *testing.T) {
358358
}
359359
}
360360

361+
func TestNormalizeTLSConfig(t *testing.T) {
362+
tt := []struct {
363+
tlsConfig string
364+
want *tls.Config
365+
}{
366+
{"", nil},
367+
{"false", nil},
368+
{"true", &tls.Config{ServerName: "myserver"}},
369+
{"skip-verify", &tls.Config{InsecureSkipVerify: true}},
370+
{"preferred", &tls.Config{InsecureSkipVerify: true}},
371+
{"test_tls_config", &tls.Config{ServerName: "myServerName"}},
372+
}
373+
374+
RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"})
375+
defer func() { DeregisterTLSConfig("test_tls_config") }()
376+
377+
for _, tc := range tt {
378+
t.Run(tc.tlsConfig, func(t *testing.T) {
379+
cfg := &Config{
380+
Addr: "myserver:3306",
381+
TLSConfig: tc.tlsConfig,
382+
}
383+
384+
cfg.normalize()
385+
386+
if cfg.tls == nil {
387+
if tc.want != nil {
388+
t.Fatal("wanted a tls config but got nil instead")
389+
}
390+
return
391+
}
392+
393+
if cfg.tls.ServerName != tc.want.ServerName {
394+
t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')",
395+
tc.want.ServerName, cfg.tls.ServerName)
396+
}
397+
if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify {
398+
t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)",
399+
tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify)
400+
}
401+
})
402+
}
403+
}
404+
361405
func BenchmarkParseDSN(b *testing.B) {
362406
b.ReportAllocs()
363407

‎utils.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ var (
5656
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
5757
//
5858
func RegisterTLSConfig(key string, config *tls.Config) error {
59-
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
59+
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
6060
return fmt.Errorf("key '%s' is reserved", key)
6161
}
6262

0 commit comments

Comments
 (0)