Skip to content

Commit e517683

Browse files
committedFeb 12, 2015
Allow interpolateParams only with ascii, latin1 and utf8 collations
1 parent 20b75cd commit e517683

File tree

3 files changed

+78
-12
lines changed

3 files changed

+78
-12
lines changed
 

‎driver_test.go

+11-9
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,21 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
8787

8888
db.Exec("DROP TABLE IF EXISTS test")
8989

90-
dbp, err := sql.Open("mysql", dsn+"&interpolateParams=true")
91-
if err != nil {
92-
t.Fatalf("Error connecting: %s", err.Error())
90+
dsn2 := dsn + "&interpolateParams=true"
91+
var db2 *sql.DB
92+
if _, err := parseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
93+
db2, err = sql.Open("mysql", dsn2)
9394
}
94-
defer dbp.Close()
9595

9696
dbt := &DBTest{t, db}
97-
dbtp := &DBTest{t, dbp}
97+
dbt2 := &DBTest{t, db2}
9898
for _, test := range tests {
9999
test(dbt)
100100
dbt.db.Exec("DROP TABLE IF EXISTS test")
101-
test(dbtp)
102-
dbtp.db.Exec("DROP TABLE IF EXISTS test")
101+
if db2 != nil {
102+
test(dbt2)
103+
dbt2.db.Exec("DROP TABLE IF EXISTS test")
104+
}
103105
}
104106
}
105107

@@ -864,7 +866,7 @@ func TestLoadData(t *testing.T) {
864866
dbt.Fatalf("%d != %d", i, id)
865867
}
866868
if values[i-1] != value {
867-
dbt.Fatalf("%s != %s", values[i-1], value)
869+
dbt.Fatalf("%q != %q", values[i-1], value)
868870
}
869871
}
870872
err = rows.Err()
@@ -889,7 +891,7 @@ func TestLoadData(t *testing.T) {
889891

890892
// Local File
891893
RegisterLocalFile(file.Name())
892-
dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE '%q' INTO TABLE test", file.Name()))
894+
dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
893895
verifyLoadDataResult()
894896
// negative test
895897
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")

‎utils.go

+30-3
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ import (
2525
var (
2626
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
2727

28-
errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
29-
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
30-
errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name")
28+
errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
29+
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
30+
errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name")
31+
errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset")
3132
)
3233

3334
func init() {
@@ -147,6 +148,32 @@ func parseDSN(dsn string) (cfg *config, err error) {
147148
return nil, errInvalidDSNNoSlash
148149
}
149150

151+
if cfg.interpolateParams && cfg.collation != defaultCollation {
152+
// A whitelist of collations which safe to interpolate parameters.
153+
// ASCII and latin-1 are safe since they are single byte encoding.
154+
// utf-8 is safe since it doesn't conatins ASCII characters in trailing bytes.
155+
safeCollations := []string{"ascii_", "latin1_", "utf8_", "utf8mb4_"}
156+
157+
var collationName string
158+
for name, collation := range collations {
159+
if collation == cfg.collation {
160+
collationName = name
161+
break
162+
}
163+
}
164+
165+
safe := false
166+
for _, p := range safeCollations {
167+
if strings.HasPrefix(collationName, p) {
168+
safe = true
169+
break
170+
}
171+
}
172+
if !safe {
173+
return nil, errInvalidDSNUnsafeCollation
174+
}
175+
}
176+
150177
// Set default network if empty
151178
if cfg.net == "" {
152179
cfg.net = "tcp"

‎utils_test.go

+37
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,43 @@ func TestDSNWithCustomTLS(t *testing.T) {
116116
DeregisterTLSConfig("utils_test")
117117
}
118118

119+
func TestDSNUnsafeCollation(t *testing.T) {
120+
_, err := parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
121+
if err != errInvalidDSNUnsafeCollation {
122+
t.Error("Expected %v, Got %v", errInvalidDSNUnsafeCollation, err)
123+
}
124+
125+
_, err = parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false")
126+
if err != nil {
127+
t.Error("Expected %v, Got %v", nil, err)
128+
}
129+
130+
_, err = parseDSN("/dbname?collation=gbk_chinese_ci")
131+
if err != nil {
132+
t.Error("Expected %v, Got %v", nil, err)
133+
}
134+
135+
_, err = parseDSN("/dbname?collation=ascii_bin&interpolateParams=true")
136+
if err != nil {
137+
t.Error("Expected %v, Got %v", nil, err)
138+
}
139+
140+
_, err = parseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true")
141+
if err != nil {
142+
t.Error("Expected %v, Got %v", nil, err)
143+
}
144+
145+
_, err = parseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true")
146+
if err != nil {
147+
t.Error("Expected %v, Got %v", nil, err)
148+
}
149+
150+
_, err = parseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true")
151+
if err != nil {
152+
t.Error("Expected %v, Got %v", nil, err)
153+
}
154+
}
155+
119156
func BenchmarkParseDSN(b *testing.B) {
120157
b.ReportAllocs()
121158

0 commit comments

Comments
 (0)
Please sign in to comment.