Skip to content

Commit

Permalink
fixes jmoiron#73, obey configured mapper on NamedStmt and other named…
Browse files Browse the repository at this point in the history
… queries

tiny slightly backwards incompatible change which probably won't affect most code using sqlx;  the embedded sql.Rows in sqlx.Rows has been changed from an sql.Rows to an *sql.Rows.  This was done to future proof it from possible copy-related issues (eg. copying mutexes or other private data would fail and break the rows), but no such issues have been reported or seen.
jmoiron committed Jul 23, 2014
1 parent 66d77f1 commit 2f383ca
Showing 4 changed files with 131 additions and 25 deletions.
34 changes: 19 additions & 15 deletions named.go
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ func (n *NamedStmt) Close() error {

// Exec executes a named statement using the struct passed.
func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) {
args, err := bindAnyArgs(n.Params, arg)
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return *new(sql.Result), err
}
@@ -46,7 +46,7 @@ func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) {

// Query executes a named statement using the struct argument, returning rows.
func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) {
args, err := bindAnyArgs(n.Params, arg)
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return nil, err
}
@@ -57,7 +57,7 @@ func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) {
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
// returns a *sqlx.Row instead.
func (n *NamedStmt) QueryRow(arg interface{}) *Row {
args, err := bindAnyArgs(n.Params, arg)
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
if err != nil {
return &Row{err: err}
}
@@ -79,7 +79,7 @@ func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{Rows: *r, Mapper: n.Stmt.Mapper}, err
return &Rows{Rows: r, Mapper: n.Stmt.Mapper}, err
}

// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is
@@ -129,17 +129,17 @@ func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) {
}, nil
}

func bindAnyArgs(names []string, arg interface{}) ([]interface{}, error) {
func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
if maparg, ok := arg.(map[string]interface{}); ok {
return bindMapArgs(names, maparg)
}
return bindArgs(names, arg)
return bindArgs(names, arg, m)
}

// private interface to generate a list of interfaces from a given struct
// type, given a list of names to pull out of the struct. Used by public
// BindStruct interface.
func bindArgs(names []string, arg interface{}) ([]interface{}, error) {
func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
arglist := make([]interface{}, 0, len(names))

// grab the indirected value of arg
@@ -148,11 +148,11 @@ func bindArgs(names []string, arg interface{}) ([]interface{}, error) {
v = v.Elem()
}

m := mapper()
fields := m.TraversalsByName(v.Type(), names)
for i, t := range fields {
if len(t) == 0 {
return arglist, fmt.Errorf("could not find name %s in %v", names[i], arg)
fmt.Println(fields, names)
return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg)
}
val := reflectx.FieldByIndexesReadOnly(v, t)
arglist = append(arglist, val.Interface())
@@ -168,7 +168,7 @@ func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, err
for _, name := range names {
val, ok := arg[name]
if !ok {
return arglist, fmt.Errorf("could not find name %s in %v", name, arg)
return arglist, fmt.Errorf("could not find name %s in %#v", name, arg)
}
arglist = append(arglist, val)
}
@@ -178,13 +178,13 @@ func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, err
// bindStruct binds a named parameter query with fields from a struct argument.
// The rules for binding field names to parameter names follow the same
// conventions as for StructScan, including obeying the `db` struct tags.
func bindStruct(bindType int, query string, arg interface{}) (string, []interface{}, error) {
func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
bound, names, err := compileNamedQuery([]byte(query), bindType)
if err != nil {
return "", []interface{}{}, err
}

arglist, err := bindArgs(names, arg)
arglist, err := bindArgs(names, arg, m)
if err != nil {
return "", []interface{}{}, err
}
@@ -285,17 +285,21 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e

// Bind binds a struct or a map to a query with named parameters.
func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) {
return bindNamedMapper(bindType, query, arg, mapper())
}

func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
if maparg, ok := arg.(map[string]interface{}); ok {
return bindMap(bindType, query, maparg)
}
return bindStruct(bindType, query, arg)
return bindStruct(bindType, query, arg, m)
}

// NamedQuery binds a named query and then runs Query on the result using the
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
// map[string]interface{} types.
func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) {
q, args, err := BindNamed(BindType(e.DriverName()), query, arg)
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
@@ -306,7 +310,7 @@ func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) {
// then runs Exec on the result. Returns an error from the binding
// or the query excution itself.
func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) {
q, args, err := BindNamed(BindType(e.DriverName()), query, arg)
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
if err != nil {
return nil, err
}
2 changes: 2 additions & 0 deletions reflectx/reflect.go
Original file line number Diff line number Diff line change
@@ -114,6 +114,8 @@ func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
mustBe(t, reflect.Struct)
nm := m.TypeMap(t)

// fmt.Printf("%#v\n", nm)

r := make([][]int, 0, len(names))
for _, name := range names {
traversal, ok := nm[name]
10 changes: 5 additions & 5 deletions sqlx.go
Original file line number Diff line number Diff line change
@@ -112,7 +112,7 @@ func isUnsafe(i interface{}) bool {
}
}

func mapperFor(i Preparer) *reflectx.Mapper {
func mapperFor(i interface{}) *reflectx.Mapper {
switch i.(type) {
case DB:
return i.(DB).Mapper
@@ -292,7 +292,7 @@ func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{Rows: *r, unsafe: db.unsafe, Mapper: db.Mapper}, err
return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err
}

// QueryRowx queries the database and returns an *sqlx.Row.
@@ -366,7 +366,7 @@ func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{Rows: *r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err
return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err
}

// QueryRowx within a transaction.
@@ -477,7 +477,7 @@ func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{Rows: *r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err
return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err
}

func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row {
@@ -492,7 +492,7 @@ func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) {
// Rows is a wrapper around sql.Rows which caches costly reflect operations
// during a looped StructScan
type Rows struct {
sql.Rows
*sql.Rows
unsafe bool
Mapper *reflectx.Mapper
// these fields cache memory use for a rows during iteration w/ structScan
110 changes: 105 additions & 5 deletions sqlx_test.go
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ import (
"time"

_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx/reflectx"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
@@ -461,8 +462,16 @@ func TestNamedQuery(t *testing.T) {
first_name text NULL,
last_name text NULL,
email text NULL
);
CREATE TABLE jsperson (
"FIRST" text NULL,
last_name text NULL,
"EMAIL" text NULL
);`,
drop: `drop table person;`,
drop: `
drop table person;
drop table jsperson;
`,
}

RunWithSchema(schema, t, func(db *DB, t *testing.T) {
@@ -501,6 +510,97 @@ func TestNamedQuery(t *testing.T) {
t.Error("Expected first name of `doe`, got " + p2.LastName.String)
}
}

// these are tests for #73; they verify that named queries work if you've
// changed the db mapper. This code checks both NamedQuery "ad-hoc" style
// queries and NamedStmt queries, which use different code paths internally.
old := *db.Mapper

type JsonPerson struct {
FirstName sql.NullString `json:"FIRST"`
LastName sql.NullString `json:"last_name"`
Email sql.NullString
}

jp := JsonPerson{
FirstName: sql.NullString{"ben", true},
LastName: sql.NullString{"smith", true},
Email: sql.NullString{"[email protected]", true},
}

db.Mapper = reflectx.NewMapperFunc("json", strings.ToUpper)

// prepare queries for case sensitivity to test our ToUpper function.
// postgres and sqlite accept "", but mysql uses ``; since Go's multi-line
// strings are `` we use "" by default and swap out for MySQL
pdb := func(s string, db *DB) string {
if db.DriverName() == "mysql" {
return strings.Replace(s, `"`, "`", -1)
}
return s
}

q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)`
_, err = db.NamedExec(pdb(q1, db), jp)
if err != nil {
t.Fatal(err, db.DriverName())
}

// Checks that a person pulled out of the db matches the one we put in
check := func(t *testing.T, rows *Rows) {
jp = JsonPerson{}
for rows.Next() {
err = rows.StructScan(&jp)
if err != nil {
t.Error(err)
}
if jp.FirstName.String != "ben" {
t.Errorf("Expected first name of `ben`, got `%s` (%s) ", jp.FirstName.String, db.DriverName())
}
if jp.LastName.String != "smith" {
t.Errorf("Expected LastName of `smith`, got `%s` (%s)", jp.LastName.String, db.DriverName())
}
if jp.Email.String != "[email protected]" {
t.Errorf("Expected first name of `doe`, got `%s` (%s)", jp.Email.String, db.DriverName())
}
}
}

ns, err := db.PrepareNamed(pdb(`
SELECT * FROM jsperson
WHERE
"FIRST"=:FIRST AND
last_name=:last_name AND
"EMAIL"=:EMAIL
`, db))

if err != nil {
t.Fatal(err)
}
rows, err = ns.Queryx(jp)
if err != nil {
t.Fatal(err)
}

check(t, rows)

// Check exactly the same thing, but with db.NamedQuery, which does not go
// through the PrepareNamed/NamedStmt path.
rows, err = db.NamedQuery(pdb(`
SELECT * FROM jsperson
WHERE
"FIRST"=:FIRST AND
last_name=:last_name AND
"EMAIL"=:EMAIL
`, db), jp)
if err != nil {
t.Fatal(err)
}

check(t, rows)

db.Mapper = &old

})
}

@@ -977,7 +1077,7 @@ func TestBindStruct(t *testing.T) {

am := tt{"Jason Moiron", 30, "Jason", "Moiron"}

bq, args, _ := bindStruct(QUESTION, q1, am)
bq, args, _ := bindStruct(QUESTION, q1, am, mapper())
expect := `INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?)`
if bq != expect {
t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect)
@@ -1000,7 +1100,7 @@ func TestBindStruct(t *testing.T) {
}

am2 := tt2{"Hello", "World"}
bq, args, _ = bindStruct(QUESTION, "INSERT INTO foo (a, b) VALUES (:field_2, :field_1)", am2)
bq, args, _ = bindStruct(QUESTION, "INSERT INTO foo (a, b) VALUES (:field_2, :field_1)", am2, mapper())
expect = `INSERT INTO foo (a, b) VALUES (?, ?)`
if bq != expect {
t.Errorf("Interpolation of query failed: got `%v`, expected `%v`\n", bq, expect)
@@ -1017,7 +1117,7 @@ func TestBindStruct(t *testing.T) {
am3.Field1 = "Hello"
am3.Field2 = "World"

bq, args, err = bindStruct(QUESTION, "INSERT INTO foo (a, b, c) VALUES (:name, :field_1, :field_2)", am3)
bq, args, err = bindStruct(QUESTION, "INSERT INTO foo (a, b, c) VALUES (:name, :field_1, :field_2)", am3, mapper())

if err != nil {
t.Fatal(err)
@@ -1051,7 +1151,7 @@ func BenchmarkBindStruct(b *testing.B) {
am := t{"Jason Moiron", 30, "Jason", "Moiron"}
b.StartTimer()
for i := 0; i < b.N; i++ {
bindStruct(DOLLAR, q1, am)
bindStruct(DOLLAR, q1, am, mapper())
//bindMap(QUESTION, q1, am)
}
}

0 comments on commit 2f383ca

Please sign in to comment.