Skip to content

Commit

Permalink
add NamedStmt.Unsafe, inherit safety correctly and pass safety correc…
Browse files Browse the repository at this point in the history
…tly to rows obj when querying with NamedStmt, fixes jmoiron#181
  • Loading branch information
jmoiron committed Nov 14, 2015
1 parent 929a401 commit 92e3330
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 14 deletions.
11 changes: 9 additions & 2 deletions named.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) {
if err != nil {
return nil, err
}
return &Rows{Rows: r, Mapper: n.Stmt.Mapper}, err
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err
}

// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is
Expand All @@ -90,7 +90,7 @@ func (n *NamedStmt) QueryRowx(arg interface{}) *Row {

// Select using this NamedStmt
func (n *NamedStmt) Select(dest interface{}, arg interface{}) error {
rows, err := n.Query(arg)
rows, err := n.Queryx(arg)
if err != nil {
return err
}
Expand All @@ -105,6 +105,13 @@ func (n *NamedStmt) Get(dest interface{}, arg interface{}) error {
return r.scanAny(dest, false)
}

// Unsafe creates an unsafe version of the NamedStmt
func (n *NamedStmt) Unsafe() *NamedStmt {
r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString}
r.Stmt.unsafe = true
return r
}

// A union interface of preparer and binder, required to be able to prepare
// named statements (as the bindtype must be determined).
type namedPreparer interface {
Expand Down
30 changes: 18 additions & 12 deletions sqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,29 +105,35 @@ type Preparer interface {

// determine if any of our extensions are unsafe
func isUnsafe(i interface{}) bool {
switch i.(type) {
switch v := i.(type) {
case Row:
return i.(Row).unsafe
return v.unsafe
case *Row:
return i.(*Row).unsafe
return v.unsafe
case Rows:
return i.(Rows).unsafe
return v.unsafe
case *Rows:
return i.(*Rows).unsafe
return v.unsafe
case NamedStmt:
return v.Stmt.unsafe
case *NamedStmt:
return v.Stmt.unsafe
case Stmt:
return i.(Stmt).unsafe
return v.unsafe
case *Stmt:
return v.unsafe
case qStmt:
return i.(qStmt).Stmt.unsafe
return v.unsafe
case *qStmt:
return i.(*qStmt).Stmt.unsafe
return v.unsafe
case DB:
return i.(DB).unsafe
return v.unsafe
case *DB:
return i.(*DB).unsafe
return v.unsafe
case Tx:
return i.(Tx).unsafe
return v.unsafe
case *Tx:
return i.(*Tx).unsafe
return v.unsafe
case sql.Rows, *sql.Rows:
return false
default:
Expand Down
47 changes: 47 additions & 0 deletions sqlx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,53 @@ func TestMissingNames(t *testing.T) {
}
rowsx.Close()

// test Named stmt
if !isUnsafe(db) {
t.Error("Expected db to be unsafe, but it isn't")
}
nstmt, err := db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`)
if err != nil {
t.Fatal(err)
}
// its internal stmt should be marked unsafe
if !nstmt.Stmt.unsafe {
t.Error("expected NamedStmt to be unsafe but its underlying stmt did not inherit safety")
}
pps = []PersonPlus{}
err = nstmt.Select(&pps, map[string]interface{}{"name": "Jason"})
if err != nil {
t.Fatal(err)
}
if len(pps) != 1 {
t.Errorf("Expected 1 person back, got %d", len(pps))
}

// test it with a safe db
db.unsafe = false
if isUnsafe(db) {
t.Error("expected db to be safe but it isn't")
}
nstmt, err = db.PrepareNamed(`SELECT * FROM person WHERE first_name != :name`)
if err != nil {
t.Fatal(err)
}
// it should be safe
if isUnsafe(nstmt) {
t.Error("NamedStmt did not inherit safety")
}
nstmt.Unsafe()
if !isUnsafe(nstmt) {
t.Error("expected newly unsafed NamedStmt to be unsafe")
}
pps = []PersonPlus{}
err = nstmt.Select(&pps, map[string]interface{}{"name": "Jason"})
if err != nil {
t.Fatal(err)
}
if len(pps) != 1 {
t.Errorf("Expected 1 person back, got %d", len(pps))
}

})
}

Expand Down

0 comments on commit 92e3330

Please sign in to comment.