@@ -41,7 +41,7 @@ func init() {
41
41
tlsConfigRegister = make (map [string ]* tls.Config )
42
42
}
43
43
44
- // Registers a custom tls.Config to be used with sql.Open.
44
+ // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
45
45
// Use the key as a value in the DSN where tls=value.
46
46
//
47
47
// rootCertPool := x509.NewCertPool()
@@ -64,11 +64,16 @@ func init() {
64
64
// })
65
65
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
66
66
//
67
- func RegisterTLSConfig (key string , config * tls.Config ) {
67
+ func RegisterTLSConfig (key string , config * tls.Config ) error {
68
+ if _ , isBool := readBool (key ); isBool || strings .ToLower (key ) == "skip-verify" {
69
+ return fmt .Errorf ("Key '%s' is reserved" , key )
70
+ }
71
+
68
72
tlsConfigRegister [key ] = config
73
+ return nil
69
74
}
70
75
71
- // Removes tls.Config associated with key.
76
+ // DeregisterTLSConfig removes the tls.Config associated with key.
72
77
func DeregisterTLSConfig (key string ) {
73
78
delete (tlsConfigRegister , key )
74
79
}
@@ -104,11 +109,21 @@ func parseDSN(dsn string) (cfg *config, err error) {
104
109
105
110
// Disable INFILE whitelist / enable all files
106
111
case "allowAllFiles" :
107
- cfg .allowAllFiles = readBool (value )
112
+ var isBool bool
113
+ cfg .allowAllFiles , isBool = readBool (value )
114
+ if ! isBool {
115
+ err = fmt .Errorf ("Invalid Bool value: %s" , value )
116
+ return
117
+ }
108
118
109
119
// Switch "rowsAffected" mode
110
120
case "clientFoundRows" :
111
- cfg .clientFoundRows = readBool (value )
121
+ var isBool bool
122
+ cfg .clientFoundRows , isBool = readBool (value )
123
+ if ! isBool {
124
+ err = fmt .Errorf ("Invalid Bool value: %s" , value )
125
+ return
126
+ }
112
127
113
128
// Time Location
114
129
case "loc" :
@@ -126,13 +141,20 @@ func parseDSN(dsn string) (cfg *config, err error) {
126
141
127
142
// TLS-Encryption
128
143
case "tls" :
129
- if readBool (value ) {
130
- cfg .tls = & tls.Config {}
131
- } else if strings .ToLower (value ) == "skip-verify" {
132
- cfg .tls = & tls.Config {InsecureSkipVerify : true }
133
- // TODO: Check for Boolean false
134
- } else if tlsConfig , ok := tlsConfigRegister [value ]; ok {
135
- cfg .tls = tlsConfig
144
+ boolValue , isBool := readBool (value )
145
+ if isBool {
146
+ if boolValue {
147
+ cfg .tls = & tls.Config {}
148
+ }
149
+ } else {
150
+ if strings .ToLower (value ) == "skip-verify" {
151
+ cfg .tls = & tls.Config {InsecureSkipVerify : true }
152
+ } else if tlsConfig , ok := tlsConfigRegister [value ]; ok {
153
+ cfg .tls = tlsConfig
154
+ } else {
155
+ err = fmt .Errorf ("Invalid value / unknown config name: %s" , value )
156
+ return
157
+ }
136
158
}
137
159
138
160
default :
@@ -191,14 +213,18 @@ func scramblePassword(scramble, password []byte) []byte {
191
213
return scramble
192
214
}
193
215
194
- func readBool (value string ) bool {
195
- switch strings .ToLower (value ) {
196
- case "true" :
197
- return true
198
- case "1" :
199
- return true
216
+ // Returns the bool value of the input.
217
+ // The 2nd return value indicates if the input was a valid bool value
218
+ func readBool (input string ) (value bool , valid bool ) {
219
+ switch input {
220
+ case "1" , "true" , "TRUE" , "True" :
221
+ return true , true
222
+ case "0" , "false" , "FALSE" , "False" :
223
+ return false , true
200
224
}
201
- return false
225
+
226
+ // Not a valid bool value
227
+ return
202
228
}
203
229
204
230
/******************************************************************************
0 commit comments