Skip to content

Commit

Permalink
support sqlserver parameters syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
marpio committed Apr 15, 2018
1 parent cf35089 commit fb9af31
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 2 deletions.
5 changes: 5 additions & 0 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const (
QUESTION
DOLLAR
NAMED
AT
)

// BindType returns the bindtype for a given database given a drivername.
Expand All @@ -29,6 +30,8 @@ func BindType(driverName string) int {
return QUESTION
case "oci8", "ora", "goracle":
return NAMED
case "sqlserver":
return AT
}
return UNKNOWN
}
Expand Down Expand Up @@ -56,6 +59,8 @@ func Rebind(bindType int, query string) string {
rqb = append(rqb, '$')
case NAMED:
rqb = append(rqb, ':', 'a', 'r', 'g')
case AT:
rqb = append(rqb, '@', 'p')
}

j++
Expand Down
6 changes: 6 additions & 0 deletions named.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e
rebound = append(rebound, byte(b))
}
currentVar++
case AT:
rebound = append(rebound, '@', 'p')
for _, b := range strconv.Itoa(currentVar) {
rebound = append(rebound, byte(b))
}
currentVar++
}
// add this byte to string unless it was not part of the name
if i != last {
Expand Down
13 changes: 11 additions & 2 deletions named_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ import (

func TestCompileQuery(t *testing.T) {
table := []struct {
Q, R, D, N string
V []string
Q, R, D, T, N string
V []string
}{
// basic test for named parameters, invalid char ',' terminating
{
Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`,
N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
V: []string{"name", "age", "first", "last"},
},
Expand All @@ -23,20 +24,23 @@ func TestCompileQuery(t *testing.T) {
Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`,
R: `SELECT * FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`,
T: `SELECT * FROM a WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`,
V: []string{"name1", "name2"},
},
{
Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`,
D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`,
T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
V: []string{"name1", "name2"},
},
{
Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`,
N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
V: []string{"first_name", "last_name"},
},
Expand Down Expand Up @@ -74,6 +78,11 @@ func TestCompileQuery(t *testing.T) {
t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd)
}

qt, _, _ := compileNamedQuery([]byte(test.Q), AT)
if qt != test.T {
t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt)
}

qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED)
if qq != test.N {
t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq))
Expand Down
11 changes: 11 additions & 0 deletions sqlx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,17 @@ func TestRebind(t *testing.T) {
t.Errorf("q2 failed")
}

s1 = Rebind(AT, q1)
s2 = Rebind(AT, q2)

if s1 != `INSERT INTO foo (a, b, c, d, e, f, g, h, i) VALUES (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)` {
t.Errorf("q1 failed")
}

if s2 != `INSERT INTO foo (a, b, c) VALUES (@p1, @p2, "foo"), ("Hi", @p3, @p4)` {
t.Errorf("q2 failed")
}

s1 = Rebind(NAMED, q1)
s2 = Rebind(NAMED, q2)

Expand Down

0 comments on commit fb9af31

Please sign in to comment.