@@ -13,6 +13,7 @@ import (
13
13
"database/sql/driver"
14
14
"errors"
15
15
"net"
16
+ "strconv"
16
17
"strings"
17
18
"time"
18
19
)
@@ -26,26 +27,28 @@ type mysqlConn struct {
26
27
maxPacketAllowed int
27
28
maxWriteSize int
28
29
flags clientFlag
30
+ status statusFlag
29
31
sequence uint8
30
32
parseTime bool
31
33
strict bool
32
34
}
33
35
34
36
type config struct {
35
- user string
36
- passwd string
37
- net string
38
- addr string
39
- dbname string
40
- params map [string ]string
41
- loc * time.Location
42
- tls * tls.Config
43
- timeout time.Duration
44
- collation uint8
45
- allowAllFiles bool
46
- allowOldPasswords bool
47
- clientFoundRows bool
48
- columnsWithAlias bool
37
+ user string
38
+ passwd string
39
+ net string
40
+ addr string
41
+ dbname string
42
+ params map [string ]string
43
+ loc * time.Location
44
+ tls * tls.Config
45
+ timeout time.Duration
46
+ collation uint8
47
+ allowAllFiles bool
48
+ allowOldPasswords bool
49
+ clientFoundRows bool
50
+ columnsWithAlias bool
51
+ substitutePlaceholder bool
49
52
}
50
53
51
54
// Handles parameters set in DSN after the connection is established
@@ -162,28 +165,146 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
162
165
return stmt , err
163
166
}
164
167
168
+ func (mc * mysqlConn ) escapeBytes (v []byte ) string {
169
+ buf := make ([]byte , len (v )* 2 + 2 )
170
+ buf [0 ] = '\''
171
+ pos := 1
172
+ if mc .status & statusNoBackslashEscapes == 0 {
173
+ for _ , c := range v {
174
+ switch c {
175
+ case '\x00' :
176
+ buf [pos ] = '\\'
177
+ buf [pos + 1 ] = '0'
178
+ pos += 2
179
+ case '\n' :
180
+ buf [pos ] = '\\'
181
+ buf [pos + 1 ] = 'n'
182
+ pos += 2
183
+ case '\r' :
184
+ buf [pos ] = '\\'
185
+ buf [pos + 1 ] = 'r'
186
+ pos += 2
187
+ case '\x1a' :
188
+ buf [pos ] = '\\'
189
+ buf [pos + 1 ] = 'Z'
190
+ pos += 2
191
+ case '\'' :
192
+ buf [pos ] = '\\'
193
+ buf [pos + 1 ] = '\''
194
+ pos += 2
195
+ case '"' :
196
+ buf [pos ] = '\\'
197
+ buf [pos + 1 ] = '"'
198
+ pos += 2
199
+ case '\\' :
200
+ buf [pos ] = '\\'
201
+ buf [pos + 1 ] = '\\'
202
+ pos += 2
203
+ default :
204
+ buf [pos ] = c
205
+ pos += 1
206
+ }
207
+ }
208
+ } else {
209
+ for _ , c := range v {
210
+ if c == '\'' {
211
+ buf [pos ] = '\''
212
+ buf [pos + 1 ] = '\''
213
+ pos += 2
214
+ } else {
215
+ buf [pos ] = c
216
+ pos ++
217
+ }
218
+ }
219
+ }
220
+ buf [pos ] = '\''
221
+ return string (buf [:pos + 1 ])
222
+ }
223
+
224
+ func (mc * mysqlConn ) buildQuery (query string , args []driver.Value ) (string , error ) {
225
+ chunks := strings .Split (query , "?" )
226
+ if len (chunks ) != len (args )+ 1 {
227
+ return "" , driver .ErrSkip
228
+ }
229
+
230
+ parts := make ([]string , len (chunks )+ len (args ))
231
+ parts [0 ] = chunks [0 ]
232
+
233
+ for i , arg := range args {
234
+ pos := i * 2 + 1
235
+ parts [pos + 1 ] = chunks [i + 1 ]
236
+ if arg == nil {
237
+ parts [pos ] = "NULL"
238
+ continue
239
+ }
240
+ switch v := arg .(type ) {
241
+ case int64 :
242
+ parts [pos ] = strconv .FormatInt (v , 10 )
243
+ case float64 :
244
+ parts [pos ] = strconv .FormatFloat (v , 'f' , - 1 , 64 )
245
+ case bool :
246
+ if v {
247
+ parts [pos ] = "1"
248
+ } else {
249
+ parts [pos ] = "0"
250
+ }
251
+ case time.Time :
252
+ if v .IsZero () {
253
+ parts [pos ] = "'0000-00-00'"
254
+ } else {
255
+ fmt := "'2006-01-02 15:04:05.999999'"
256
+ parts [pos ] = v .In (mc .cfg .loc ).Format (fmt )
257
+ }
258
+ case []byte :
259
+ if v == nil {
260
+ parts [pos ] = "NULL"
261
+ } else {
262
+ parts [pos ] = mc .escapeBytes (v )
263
+ }
264
+ case string :
265
+ parts [pos ] = mc .escapeBytes ([]byte (v ))
266
+ default :
267
+ return "" , driver .ErrSkip
268
+ }
269
+ }
270
+ pktSize := len (query ) + 4 // 4 bytes for header.
271
+ for _ , p := range parts {
272
+ pktSize += len (p )
273
+ }
274
+ if pktSize > mc .maxPacketAllowed {
275
+ return "" , driver .ErrSkip
276
+ }
277
+ return strings .Join (parts , "" ), nil
278
+ }
279
+
165
280
func (mc * mysqlConn ) Exec (query string , args []driver.Value ) (driver.Result , error ) {
166
281
if mc .netConn == nil {
167
282
errLog .Print (ErrInvalidConn )
168
283
return nil , driver .ErrBadConn
169
284
}
170
- if len (args ) == 0 { // no args, fastpath
171
- mc .affectedRows = 0
172
- mc .insertId = 0
173
-
174
- err := mc .exec (query )
175
- if err == nil {
176
- return & mysqlResult {
177
- affectedRows : int64 (mc .affectedRows ),
178
- insertId : int64 (mc .insertId ),
179
- }, err
285
+ if len (args ) != 0 {
286
+ if ! mc .cfg .substitutePlaceholder {
287
+ return nil , driver .ErrSkip
180
288
}
181
- return nil , err
289
+ // try client-side prepare to reduce roundtrip
290
+ prepared , err := mc .buildQuery (query , args )
291
+ if err != nil {
292
+ return nil , err
293
+ }
294
+ query = prepared
295
+ args = nil
182
296
}
297
+ mc .affectedRows = 0
298
+ mc .insertId = 0
183
299
184
- // with args, must use prepared stmt
185
- return nil , driver .ErrSkip
186
-
300
+ err := mc .exec (query )
301
+ if err == nil {
302
+ return & mysqlResult {
303
+ affectedRows : int64 (mc .affectedRows ),
304
+ insertId : int64 (mc .insertId ),
305
+ }, err
306
+ }
307
+ return nil , err
187
308
}
188
309
189
310
// Internal function to execute commands
@@ -212,31 +333,38 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
212
333
errLog .Print (ErrInvalidConn )
213
334
return nil , driver .ErrBadConn
214
335
}
215
- if len (args ) == 0 { // no args, fastpath
216
- // Send command
217
- err := mc .writeCommandPacketStr (comQuery , query )
336
+ if len (args ) != 0 {
337
+ if ! mc .cfg .substitutePlaceholder {
338
+ return nil , driver .ErrSkip
339
+ }
340
+ // try client-side prepare to reduce roundtrip
341
+ prepared , err := mc .buildQuery (query , args )
342
+ if err != nil {
343
+ return nil , err
344
+ }
345
+ query = prepared
346
+ args = nil
347
+ }
348
+ // Send command
349
+ err := mc .writeCommandPacketStr (comQuery , query )
350
+ if err == nil {
351
+ // Read Result
352
+ var resLen int
353
+ resLen , err = mc .readResultSetHeaderPacket ()
218
354
if err == nil {
219
- // Read Result
220
- var resLen int
221
- resLen , err = mc .readResultSetHeaderPacket ()
222
- if err == nil {
223
- rows := new (textRows )
224
- rows .mc = mc
225
-
226
- if resLen == 0 {
227
- // no columns, no more data
228
- return emptyRows {}, nil
229
- }
230
- // Columns
231
- rows .columns , err = mc .readColumns (resLen )
232
- return rows , err
355
+ rows := new (textRows )
356
+ rows .mc = mc
357
+
358
+ if resLen == 0 {
359
+ // no columns, no more data
360
+ return emptyRows {}, nil
233
361
}
362
+ // Columns
363
+ rows .columns , err = mc .readColumns (resLen )
364
+ return rows , err
234
365
}
235
- return nil , err
236
366
}
237
-
238
- // with args, must use prepared stmt
239
- return nil , driver .ErrSkip
367
+ return nil , err
240
368
}
241
369
242
370
// Gets the value of the given MySQL System Variable
0 commit comments