Skip to content

Commit c24056d

Browse files
committed
strict tls.Config key check
1 parent abd1799 commit c24056d

File tree

3 files changed

+56
-22
lines changed

3 files changed

+56
-22
lines changed

connection.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,19 @@ func (mc *mysqlConn) handleParams() (err error) {
6868

6969
// time.Time parsing
7070
case "parseTime":
71-
mc.parseTime = readBool(val)
71+
var isBool bool
72+
mc.parseTime, isBool = readBool(val)
73+
if !isBool {
74+
return errors.New("Invalid Bool value: " + val)
75+
}
7276

7377
// Strict mode
7478
case "strict":
75-
mc.strict = readBool(val)
79+
var isBool bool
80+
mc.strict, isBool = readBool(val)
81+
if !isBool {
82+
return errors.New("Invalid Bool value: " + val)
83+
}
7684

7785
// Compression
7886
case "compress":

driver_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,7 @@ func TestStmtMultiRows(t *testing.T) {
10531053
}
10541054

10551055
func TestConcurrent(t *testing.T) {
1056-
if readBool(os.Getenv("MYSQL_TEST_CONCURRENT")) != true {
1056+
if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled {
10571057
t.Skip("CONCURRENT env var not set")
10581058
}
10591059

utils.go

+45-19
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func init() {
4141
tlsConfigRegister = make(map[string]*tls.Config)
4242
}
4343

44-
// Registers a custom tls.Config to be used with sql.Open.
44+
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
4545
// Use the key as a value in the DSN where tls=value.
4646
//
4747
// rootCertPool := x509.NewCertPool()
@@ -64,11 +64,16 @@ func init() {
6464
// })
6565
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
6666
//
67-
func RegisterTLSConfig(key string, config *tls.Config) {
67+
func RegisterTLSConfig(key string, config *tls.Config) error {
68+
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
69+
return fmt.Errorf("Key '%s' is reserved", key)
70+
}
71+
6872
tlsConfigRegister[key] = config
73+
return nil
6974
}
7075

71-
// Removes tls.Config associated with key.
76+
// DeregisterTLSConfig removes the tls.Config associated with key.
7277
func DeregisterTLSConfig(key string) {
7378
delete(tlsConfigRegister, key)
7479
}
@@ -104,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) {
104109

105110
// Disable INFILE whitelist / enable all files
106111
case "allowAllFiles":
107-
cfg.allowAllFiles = readBool(value)
112+
var isBool bool
113+
cfg.allowAllFiles, isBool = readBool(value)
114+
if !isBool {
115+
err = fmt.Errorf("Invalid Bool value: %s", value)
116+
return
117+
}
108118

109119
// Switch "rowsAffected" mode
110120
case "clientFoundRows":
111-
cfg.clientFoundRows = readBool(value)
121+
var isBool bool
122+
cfg.clientFoundRows, isBool = readBool(value)
123+
if !isBool {
124+
err = fmt.Errorf("Invalid Bool value: %s", value)
125+
return
126+
}
112127

113128
// Time Location
114129
case "loc":
@@ -126,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) {
126141

127142
// TLS-Encryption
128143
case "tls":
129-
if readBool(value) {
130-
cfg.tls = &tls.Config{}
131-
} else if strings.ToLower(value) == "skip-verify" {
132-
cfg.tls = &tls.Config{InsecureSkipVerify: true}
133-
// TODO: Check for Boolean false
134-
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
135-
cfg.tls = tlsConfig
144+
boolValue, isBool := readBool(value)
145+
if isBool {
146+
if boolValue {
147+
cfg.tls = &tls.Config{}
148+
}
149+
} else {
150+
if strings.ToLower(value) == "skip-verify" {
151+
cfg.tls = &tls.Config{InsecureSkipVerify: true}
152+
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
153+
cfg.tls = tlsConfig
154+
} else {
155+
err = fmt.Errorf("Invalid value / unknown config name: %s", value)
156+
return
157+
}
136158
}
137159

138160
default:
@@ -191,14 +213,18 @@ func scramblePassword(scramble, password []byte) []byte {
191213
return scramble
192214
}
193215

194-
func readBool(value string) bool {
195-
switch strings.ToLower(value) {
196-
case "true":
197-
return true
198-
case "1":
199-
return true
216+
// Returns the bool value of the input.
217+
// The 2nd return value indicates if the input was a valid bool value
218+
func readBool(input string) (value bool, valid bool) {
219+
switch input {
220+
case "1", "true", "TRUE", "True":
221+
return true, true
222+
case "0", "false", "FALSE", "False":
223+
return false, true
200224
}
201-
return false
225+
226+
// Not a valid bool value
227+
return
202228
}
203229

204230
/******************************************************************************

0 commit comments

Comments
 (0)