Skip to content

Commit 9181e3a

Browse files
authored
Fix tls=true didn't work with host without port (go-sql-driver#718)
Fixes go-sql-driver#717
1 parent cd4cb90 commit 9181e3a

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

dsn.go

+9-11
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,15 @@ func (cfg *Config) normalize() error {
9494
cfg.Addr = ensureHavePort(cfg.Addr)
9595
}
9696

97+
if cfg.tls != nil {
98+
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
99+
host, _, err := net.SplitHostPort(cfg.Addr)
100+
if err == nil {
101+
cfg.tls.ServerName = host
102+
}
103+
}
104+
}
105+
97106
return nil
98107
}
99108

@@ -521,10 +530,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
521530
if boolValue {
522531
cfg.TLSConfig = "true"
523532
cfg.tls = &tls.Config{}
524-
host, _, err := net.SplitHostPort(cfg.Addr)
525-
if err == nil {
526-
cfg.tls.ServerName = host
527-
}
528533
} else {
529534
cfg.TLSConfig = "false"
530535
}
@@ -538,13 +543,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
538543
}
539544

540545
if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
541-
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
542-
host, _, err := net.SplitHostPort(cfg.Addr)
543-
if err == nil {
544-
tlsConfig.ServerName = host
545-
}
546-
}
547-
548546
cfg.TLSConfig = name
549547
cfg.tls = tlsConfig
550548
} else {

dsn_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,34 @@ func TestDSNWithCustomTLS(t *testing.T) {
177177
DeregisterTLSConfig("utils_test")
178178
}
179179

180+
func TestDSNTLSConfig(t *testing.T) {
181+
expectedServerName := "example.com"
182+
dsn := "tcp(example.com:1234)/?tls=true"
183+
184+
cfg, err := ParseDSN(dsn)
185+
if err != nil {
186+
t.Error(err.Error())
187+
}
188+
if cfg.tls == nil {
189+
t.Error("cfg.tls should not be nil")
190+
}
191+
if cfg.tls.ServerName != expectedServerName {
192+
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
193+
}
194+
195+
dsn = "tcp(example.com)/?tls=true"
196+
cfg, err = ParseDSN(dsn)
197+
if err != nil {
198+
t.Error(err.Error())
199+
}
200+
if cfg.tls == nil {
201+
t.Error("cfg.tls should not be nil")
202+
}
203+
if cfg.tls.ServerName != expectedServerName {
204+
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
205+
}
206+
}
207+
180208
func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
181209
const configKey = "&%!:"
182210
dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)

0 commit comments

Comments
 (0)