Skip to content

Commit

Permalink
[pocketbase#4704] fixed '~' autowildcard wrapping when the string has…
Browse files Browse the repository at this point in the history
… escaped % character
  • Loading branch information
ganigeorgiev committed Apr 5, 2024
1 parent ac76166 commit 63bcffb
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 5 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## (WIP) v0.22.8

- Fixed '~' auto wildcard wrapping when the param has escaped `%` character ([#4704](https://github.com/pocketbase/pocketbase/discussions/4704)).


## v0.22.7

- Replaced the default `s3blob` driver with a trimmed vendored version to reduce the binary size with ~10MB.
Expand Down
66 changes: 61 additions & 5 deletions tools/search/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,16 +432,15 @@ func mergeParams(params ...dbx.Params) dbx.Params {
}

// wrapLikeParams wraps each provided param value string with `%`
// if the string doesn't contains the `%` char (including its escape sequence).
// if the param doesn't contain an explicit wildcard (`%`) character already.
func wrapLikeParams(params dbx.Params) dbx.Params {
result := dbx.Params{}

for k, v := range params {
vStr := cast.ToString(v)
if !strings.Contains(vStr, "%") {
for i := 0; i < len(dbx.DefaultLikeEscape); i += 2 {
vStr = strings.ReplaceAll(vStr, dbx.DefaultLikeEscape[i], dbx.DefaultLikeEscape[i+1])
}
if !containsUnescapedChar(vStr, '%') {
// note: this is done to minimize the breaking changes and to preserve the original autoescape behavior
vStr = escapeUnescapedChars(vStr, '\\', '%', '_')
vStr = "%" + vStr + "%"
}
result[k] = vStr
Expand All @@ -450,6 +449,63 @@ func wrapLikeParams(params dbx.Params) dbx.Params {
return result
}

func escapeUnescapedChars(str string, escapeChars ...rune) string {
rs := []rune(str)
total := len(rs)
result := make([]rune, 0, total)

var match bool

for i := total - 1; i >= 0; i-- {
if match {
// check if already escaped
if rs[i] != '\\' {
result = append(result, '\\')
}
match = false
} else {
for _, ec := range escapeChars {
if rs[i] == ec {
match = true
break
}
}
}

result = append(result, rs[i])

// in case the matching char is at the beginning
if i == 0 && match {
result = append(result, '\\')
}
}

// reverse
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
result[i], result[j] = result[j], result[i]
}

return string(result)
}

func containsUnescapedChar(str string, ch rune) bool {
var prev rune

for _, c := range str {
if c == ch && prev != '\\' {
return true
}

if c == '\\' && prev == '\\' {
prev = rune(0) // reset escape sequence
} else {
prev = c
}
}

return false
}

// -------------------------------------------------------------------

var _ dbx.Expression = (*opExpr)(nil)
Expand Down
65 changes: 65 additions & 0 deletions tools/search/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,68 @@ func TestFilterDataBuildExprWithParams(t *testing.T) {
t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0])
}
}

func TestLikeParamsWrapping(t *testing.T) {
// create a dummy db
sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared")
if err != nil {
t.Fatal(err)
}
db := dbx.NewFromDB(sqlDB, "sqlite")

calledQueries := []string{}
db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) {
calledQueries = append(calledQueries, sql)
}
db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) {
calledQueries = append(calledQueries, sql)
}

resolver := search.NewSimpleFieldResolver(`^test\w+$`)

filter := search.FilterData(`
test1 ~ {:p1} ||
test2 ~ {:p2} ||
test3 ~ {:p3} ||
test4 ~ {:p4} ||
test5 ~ {:p5} ||
test6 ~ {:p6} ||
test7 ~ {:p7} ||
test8 ~ {:p8} ||
test9 ~ {:p9} ||
test10 ~ {:p10} ||
test11 ~ {:p11} ||
test12 ~ {:p12}
`)

replacements := []dbx.Params{
{"p1": `abc`},
{"p2": `ab%c`},
{"p3": `ab\%c`},
{"p4": `%ab\%c`},
{"p5": `ab\\%c`},
{"p6": `ab\\\%c`},
{"p7": `ab_c`},
{"p8": `ab\_c`},
{"p9": `%ab_c`},
{"p10": `ab\c`},
{"p11": `_ab\c_`},
{"p12": `ab\c%`},
}

expr, err := filter.BuildExpr(resolver, replacements...)
if err != nil {
t.Fatal(err)
}

db.Select().Where(expr).Build().Execute()

if len(calledQueries) != 1 {
t.Fatalf("Expected 1 query, got %d", len(calledQueries))
}

expectedQuery := `SELECT * WHERE ([[test1]] LIKE '%abc%' ESCAPE '\' OR [[test2]] LIKE 'ab%c' ESCAPE '\' OR [[test3]] LIKE 'ab\\%c' ESCAPE '\' OR [[test4]] LIKE '%ab\\%c' ESCAPE '\' OR [[test5]] LIKE 'ab\\\\%c' ESCAPE '\' OR [[test6]] LIKE 'ab\\\\\\%c' ESCAPE '\' OR [[test7]] LIKE '%ab\_c%' ESCAPE '\' OR [[test8]] LIKE '%ab\\\_c%' ESCAPE '\' OR [[test9]] LIKE '%ab_c' ESCAPE '\' OR [[test10]] LIKE '%ab\\c%' ESCAPE '\' OR [[test11]] LIKE '%\_ab\\c\_%' ESCAPE '\' OR [[test12]] LIKE 'ab\\c%' ESCAPE '\')`
if expectedQuery != calledQueries[0] {
t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0])
}
}

0 comments on commit 63bcffb

Please sign in to comment.