Skip to content

Commit

Permalink
Optimize sqlx.In and add benchmark
Browse files Browse the repository at this point in the history
nussjustin committed Aug 4, 2015
1 parent 56b62f2 commit 7400168
Showing 2 changed files with 79 additions and 45 deletions.
114 changes: 70 additions & 44 deletions bind.go
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ import (
"errors"
"reflect"
"strconv"
"strings"
)

// Bindvar types supported by Rebind, BindMap and BindStruct.
@@ -92,70 +93,95 @@ func rebindBuff(bindType int, query string) string {
// and a new arg list that can be executed by a database. The `query` should
// use the `?` bindVar. The return value uses the `?` bindVar.
func In(query string, args ...interface{}) (string, []interface{}, error) {
type ra struct {
v reflect.Value
t reflect.Type
isSlice bool
// argMeta stores reflect.Value and length for slices and
// the value itself for non-slice arguments
type argMeta struct {
v reflect.Value
i interface{}
length int
}
ras := make([]ra, 0, len(args))
for _, arg := range args {

var flatArgsCount, sliceCount int

meta := make([]argMeta, len(args))

for i, arg := range args {
v := reflect.ValueOf(arg)
t, _ := baseType(v.Type(), reflect.Slice)
ras = append(ras, ra{v, t, t != nil})
}

anySlices := false
for _, s := range ras {
if s.isSlice {
anySlices = true
if s.v.Len() == 0 {
if t != nil {
meta[i].length = v.Len()
meta[i].v = v

sliceCount++
flatArgsCount += meta[i].length

if meta[i].length == 0 {
return "", nil, errors.New("empty slice passed to 'in' query")
}
} else {
meta[i].i = arg
flatArgsCount++
}
}

// don't do any parsing if there aren't any slices; note that this means
// some errors that we might have caught below will not be returned.
if !anySlices {
if sliceCount == 0 {
return query, args, nil
}

var a []interface{}
newArgs := make([]interface{}, 0, flatArgsCount)

var arg, offset int
var buf bytes.Buffer
var pos int

for _, r := range query {
if r == '?' {
if pos >= len(ras) {
// if this argument wasn't passed, lets return an error; this is
// not actually how database/sql Exec/Query works, but since we are
// creating an argument list programmatically, we want to be able
// to catch these programmer errors earlier.
return "", nil, errors.New("number of bindVars exceeds arguments")
} else if ras[pos].isSlice {
// if this argument is a slice, expand the slice into arguments and
// assume that the bindVars should be comma separated.
length := ras[pos].v.Len()
for i := 0; i < length-1; i++ {
buf.Write([]byte("?, "))
a = append(a, ras[pos].v.Index(i).Interface())
}
a = append(a, ras[pos].v.Index(length-1).Interface())
buf.WriteRune('?')
} else {
// a normal argument, procede as normal.
a = append(a, args[pos])
buf.WriteRune(r)
}
pos++
} else {
buf.WriteRune(r)
for i := strings.IndexByte(query[offset:], '?'); i != -1 && arg < len(meta); i = strings.IndexByte(query[offset:], '?') {
argMeta := meta[arg]
arg++

// not a slice, continue.
// our questionmark will either be written before the next expansion
// of a slice or after the loop when writing the rest of the query
if argMeta.length == 0 {
offset = offset + i + 1
newArgs = append(newArgs, argMeta.i)
continue
}

// write everything up to and including our ? character
buf.WriteString(query[:offset+i+1])

newArgs = append(newArgs, argMeta.v.Index(0).Interface())

for si := 1; si < argMeta.length; si++ {
buf.WriteString(", ?")
newArgs = append(newArgs, argMeta.v.Index(si).Interface())
}

// slice the query and reset the offset. this avoids some bookkeeping for
// the write after the loop
query = query[offset+i+1:]
offset = 0
}

if pos != len(ras) {
buf.WriteString(query)

if arg < len(meta) {
return "", nil, errors.New("number of bindVars less than number arguments")
}

return buf.String(), a, nil
// get the result as bytes first, to avoid converting to a string if we return
// an error
res := buf.Bytes()

if bytes.Count(res, []byte{'?'}) > flatArgsCount {
// if an argument wasn't passed, lets return an error; this is
// not actually how database/sql Exec/Query works, but since we are
// creating an argument list programmatically, we want to be able
// to catch these programmer errors earlier.
return "", nil, errors.New("number of bindVars exceeds arguments")
}

return string(res), newArgs, nil
}
10 changes: 9 additions & 1 deletion sqlx_test.go
Original file line number Diff line number Diff line change
@@ -1228,7 +1228,7 @@ func TestIn(t *testing.T) {
t.Error(err)
}
if len(a) != test.c {
t.Errorf("Expected %d args, but got %d (%+v)", len(a), test.c, a)
t.Errorf("Expected %d args, but got %d (%+v)", test.c, len(a), a)
}
if strings.Count(q, "?") != test.c {
t.Errorf("Expected %d bindVars, got %d", test.c, strings.Count(q, "?"))
@@ -1460,6 +1460,14 @@ func BenchmarkBindMap(b *testing.B) {
}
}

func BenchmarkIn(b *testing.B) {
q := `SELECT * FROM foo WHERE x = ? AND v in (?) AND y = ?`

for i := 0; i < b.N; i++ {
_, _, _ = In(q, []interface{}{"foo", []int{0, 5, 7, 2, 9}, "bar"}...)
}
}

func BenchmarkRebind(b *testing.B) {
b.StopTimer()
q1 := `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`

0 comments on commit 7400168

Please sign in to comment.