Skip to content

Commit f3b82fd

Browse files
committed
Merge remote-tracking branch 'upstream/pr/297'
Conflicts: connection.go utils_test.go
2 parents 511937c + b7c2c47 commit f3b82fd

File tree

6 files changed

+230
-62
lines changed

6 files changed

+230
-62
lines changed

connection.go

+177-49
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"database/sql/driver"
1414
"errors"
1515
"net"
16+
"strconv"
1617
"strings"
1718
"time"
1819
)
@@ -26,26 +27,28 @@ type mysqlConn struct {
2627
maxPacketAllowed int
2728
maxWriteSize int
2829
flags clientFlag
30+
status statusFlag
2931
sequence uint8
3032
parseTime bool
3133
strict bool
3234
}
3335

3436
type config struct {
35-
user string
36-
passwd string
37-
net string
38-
addr string
39-
dbname string
40-
params map[string]string
41-
loc *time.Location
42-
tls *tls.Config
43-
timeout time.Duration
44-
collation uint8
45-
allowAllFiles bool
46-
allowOldPasswords bool
47-
clientFoundRows bool
48-
columnsWithAlias bool
37+
user string
38+
passwd string
39+
net string
40+
addr string
41+
dbname string
42+
params map[string]string
43+
loc *time.Location
44+
tls *tls.Config
45+
timeout time.Duration
46+
collation uint8
47+
allowAllFiles bool
48+
allowOldPasswords bool
49+
clientFoundRows bool
50+
columnsWithAlias bool
51+
substitutePlaceholder bool
4952
}
5053

5154
// Handles parameters set in DSN after the connection is established
@@ -162,28 +165,146 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
162165
return stmt, err
163166
}
164167

168+
func (mc *mysqlConn) escapeBytes(v []byte) string {
169+
buf := make([]byte, len(v)*2+2)
170+
buf[0] = '\''
171+
pos := 1
172+
if mc.status&statusNoBackslashEscapes == 0 {
173+
for _, c := range v {
174+
switch c {
175+
case '\x00':
176+
buf[pos] = '\\'
177+
buf[pos+1] = '0'
178+
pos += 2
179+
case '\n':
180+
buf[pos] = '\\'
181+
buf[pos+1] = 'n'
182+
pos += 2
183+
case '\r':
184+
buf[pos] = '\\'
185+
buf[pos+1] = 'r'
186+
pos += 2
187+
case '\x1a':
188+
buf[pos] = '\\'
189+
buf[pos+1] = 'Z'
190+
pos += 2
191+
case '\'':
192+
buf[pos] = '\\'
193+
buf[pos+1] = '\''
194+
pos += 2
195+
case '"':
196+
buf[pos] = '\\'
197+
buf[pos+1] = '"'
198+
pos += 2
199+
case '\\':
200+
buf[pos] = '\\'
201+
buf[pos+1] = '\\'
202+
pos += 2
203+
default:
204+
buf[pos] = c
205+
pos += 1
206+
}
207+
}
208+
} else {
209+
for _, c := range v {
210+
if c == '\'' {
211+
buf[pos] = '\''
212+
buf[pos+1] = '\''
213+
pos += 2
214+
} else {
215+
buf[pos] = c
216+
pos++
217+
}
218+
}
219+
}
220+
buf[pos] = '\''
221+
return string(buf[:pos+1])
222+
}
223+
224+
func (mc *mysqlConn) buildQuery(query string, args []driver.Value) (string, error) {
225+
chunks := strings.Split(query, "?")
226+
if len(chunks) != len(args)+1 {
227+
return "", driver.ErrSkip
228+
}
229+
230+
parts := make([]string, len(chunks)+len(args))
231+
parts[0] = chunks[0]
232+
233+
for i, arg := range args {
234+
pos := i*2 + 1
235+
parts[pos+1] = chunks[i+1]
236+
if arg == nil {
237+
parts[pos] = "NULL"
238+
continue
239+
}
240+
switch v := arg.(type) {
241+
case int64:
242+
parts[pos] = strconv.FormatInt(v, 10)
243+
case float64:
244+
parts[pos] = strconv.FormatFloat(v, 'f', -1, 64)
245+
case bool:
246+
if v {
247+
parts[pos] = "1"
248+
} else {
249+
parts[pos] = "0"
250+
}
251+
case time.Time:
252+
if v.IsZero() {
253+
parts[pos] = "'0000-00-00'"
254+
} else {
255+
fmt := "'2006-01-02 15:04:05.999999'"
256+
parts[pos] = v.In(mc.cfg.loc).Format(fmt)
257+
}
258+
case []byte:
259+
if v == nil {
260+
parts[pos] = "NULL"
261+
} else {
262+
parts[pos] = mc.escapeBytes(v)
263+
}
264+
case string:
265+
parts[pos] = mc.escapeBytes([]byte(v))
266+
default:
267+
return "", driver.ErrSkip
268+
}
269+
}
270+
pktSize := len(query) + 4 // 4 bytes for header.
271+
for _, p := range parts {
272+
pktSize += len(p)
273+
}
274+
if pktSize > mc.maxPacketAllowed {
275+
return "", driver.ErrSkip
276+
}
277+
return strings.Join(parts, ""), nil
278+
}
279+
165280
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
166281
if mc.netConn == nil {
167282
errLog.Print(ErrInvalidConn)
168283
return nil, driver.ErrBadConn
169284
}
170-
if len(args) == 0 { // no args, fastpath
171-
mc.affectedRows = 0
172-
mc.insertId = 0
173-
174-
err := mc.exec(query)
175-
if err == nil {
176-
return &mysqlResult{
177-
affectedRows: int64(mc.affectedRows),
178-
insertId: int64(mc.insertId),
179-
}, err
285+
if len(args) != 0 {
286+
if !mc.cfg.substitutePlaceholder {
287+
return nil, driver.ErrSkip
180288
}
181-
return nil, err
289+
// try client-side prepare to reduce roundtrip
290+
prepared, err := mc.buildQuery(query, args)
291+
if err != nil {
292+
return nil, err
293+
}
294+
query = prepared
295+
args = nil
182296
}
297+
mc.affectedRows = 0
298+
mc.insertId = 0
183299

184-
// with args, must use prepared stmt
185-
return nil, driver.ErrSkip
186-
300+
err := mc.exec(query)
301+
if err == nil {
302+
return &mysqlResult{
303+
affectedRows: int64(mc.affectedRows),
304+
insertId: int64(mc.insertId),
305+
}, err
306+
}
307+
return nil, err
187308
}
188309

189310
// Internal function to execute commands
@@ -212,31 +333,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
212333
errLog.Print(ErrInvalidConn)
213334
return nil, driver.ErrBadConn
214335
}
215-
if len(args) == 0 { // no args, fastpath
216-
// Send command
217-
err := mc.writeCommandPacketStr(comQuery, query)
336+
if len(args) != 0 {
337+
if !mc.cfg.substitutePlaceholder {
338+
return nil, driver.ErrSkip
339+
}
340+
// try client-side prepare to reduce roundtrip
341+
prepared, err := mc.buildQuery(query, args)
342+
if err != nil {
343+
return nil, err
344+
}
345+
query = prepared
346+
args = nil
347+
}
348+
// Send command
349+
err := mc.writeCommandPacketStr(comQuery, query)
350+
if err == nil {
351+
// Read Result
352+
var resLen int
353+
resLen, err = mc.readResultSetHeaderPacket()
218354
if err == nil {
219-
// Read Result
220-
var resLen int
221-
resLen, err = mc.readResultSetHeaderPacket()
222-
if err == nil {
223-
rows := new(textRows)
224-
rows.mc = mc
225-
226-
if resLen == 0 {
227-
// no columns, no more data
228-
return emptyRows{}, nil
229-
}
230-
// Columns
231-
rows.columns, err = mc.readColumns(resLen)
232-
return rows, err
355+
rows := new(textRows)
356+
rows.mc = mc
357+
358+
if resLen == 0 {
359+
// no columns, no more data
360+
return emptyRows{}, nil
233361
}
362+
// Columns
363+
rows.columns, err = mc.readColumns(resLen)
364+
return rows, err
234365
}
235-
return nil, err
236366
}
237-
238-
// with args, must use prepared stmt
239-
return nil, driver.ErrSkip
367+
return nil, err
240368
}
241369

242370
// Gets the value of the given MySQL System Variable

const.go

+22
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,25 @@ const (
130130
flagUnknown3
131131
flagUnknown4
132132
)
133+
134+
// http://dev.mysql.com/doc/internals/en/status-flags.html
135+
136+
type statusFlag uint16
137+
138+
const (
139+
statusInTrans statusFlag = 1 << iota
140+
statusInAutocommit
141+
statusUnknown1
142+
statusMoreResultsExists
143+
statusNoGoodIndexUsed
144+
statusNoIndexUsed
145+
statusCursorExists
146+
statusLastRowSent
147+
statusDbDropped
148+
statusNoBackslashEscapes
149+
statusMetadataChanged
150+
statusQueryWasSlow
151+
statusPsOutParams
152+
statusInTransReadonly
153+
statusSessionStateChanged
154+
)

driver_test.go

+9
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,19 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
8787

8888
db.Exec("DROP TABLE IF EXISTS test")
8989

90+
dbp, err := sql.Open("mysql", dsn+"&substitutePlaceholder=true")
91+
if err != nil {
92+
t.Fatalf("Error connecting: %s", err.Error())
93+
}
94+
defer dbp.Close()
95+
9096
dbt := &DBTest{t, db}
97+
dbtp := &DBTest{t, dbp}
9198
for _, test := range tests {
9299
test(dbt)
93100
dbt.db.Exec("DROP TABLE IF EXISTS test")
101+
test(dbtp)
102+
dbtp.db.Exec("DROP TABLE IF EXISTS test")
94103
}
95104
}
96105

packets.go

+1
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
484484
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
485485

486486
// server_status [2 bytes]
487+
mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
487488

488489
// warning count [2 bytes]
489490
if !mc.strict {

utils.go

+8
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,14 @@ func parseDSNParams(cfg *config, params string) (err error) {
180180
// cfg params
181181
switch value := param[1]; param[0] {
182182

183+
// Enable client side placeholder substitution
184+
case "substitutePlaceholder":
185+
var isBool bool
186+
cfg.substitutePlaceholder, isBool = readBool(value)
187+
if !isBool {
188+
return fmt.Errorf("Invalid Bool value: %s", value)
189+
}
190+
183191
// Disable INFILE whitelist / enable all files
184192
case "allowAllFiles":
185193
var isBool bool

utils_test.go

+13-13
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@ var testDSNs = []struct {
2222
out string
2323
loc *time.Location
2424
}{
25-
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
26-
{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true}", time.UTC},
27-
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
28-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
29-
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
30-
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false}", time.UTC},
31-
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.Local},
32-
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
33-
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
34-
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
35-
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
36-
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
37-
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false}", time.UTC},
25+
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
26+
{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true substitutePlaceholder:false}", time.UTC},
27+
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
28+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
29+
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
30+
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
31+
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.Local},
32+
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
33+
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
34+
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
35+
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
36+
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
37+
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false substitutePlaceholder:false}", time.UTC},
3838
}
3939

4040
func TestDSNParser(t *testing.T) {

0 commit comments

Comments
 (0)