diff --git a/CHANGELOG.md b/CHANGELOG.md index 091bbe84..e725c987 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## (WIP) v0.22.8 + +- Fixed '~' auto wildcard wrapping when the param has escaped `%` character ([#4704](https://github.com/pocketbase/pocketbase/discussions/4704)). + + ## v0.22.7 - Replaced the default `s3blob` driver with a trimmed vendored version to reduce the binary size with ~10MB. diff --git a/tools/search/filter.go b/tools/search/filter.go index 442986b4..04f720b3 100644 --- a/tools/search/filter.go +++ b/tools/search/filter.go @@ -432,16 +432,15 @@ func mergeParams(params ...dbx.Params) dbx.Params { } // wrapLikeParams wraps each provided param value string with `%` -// if the string doesn't contains the `%` char (including its escape sequence). +// if the param doesn't contain an explicit wildcard (`%`) character already. func wrapLikeParams(params dbx.Params) dbx.Params { result := dbx.Params{} for k, v := range params { vStr := cast.ToString(v) - if !strings.Contains(vStr, "%") { - for i := 0; i < len(dbx.DefaultLikeEscape); i += 2 { - vStr = strings.ReplaceAll(vStr, dbx.DefaultLikeEscape[i], dbx.DefaultLikeEscape[i+1]) - } + if !containsUnescapedChar(vStr, '%') { + // note: this is done to minimize the breaking changes and to preserve the original autoescape behavior + vStr = escapeUnescapedChars(vStr, '\\', '%', '_') vStr = "%" + vStr + "%" } result[k] = vStr @@ -450,6 +449,63 @@ func wrapLikeParams(params dbx.Params) dbx.Params { return result } +func escapeUnescapedChars(str string, escapeChars ...rune) string { + rs := []rune(str) + total := len(rs) + result := make([]rune, 0, total) + + var match bool + + for i := total - 1; i >= 0; i-- { + if match { + // check if already escaped + if rs[i] != '\\' { + result = append(result, '\\') + } + match = false + } else { + for _, ec := range escapeChars { + if rs[i] == ec { + match = true + break + } + } + } + + result = append(result, rs[i]) + + // in case the matching char is at the beginning + if i == 0 && match { + result = append(result, '\\') + } + } + + // reverse + for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 { + result[i], result[j] = result[j], result[i] + } + + return string(result) +} + +func containsUnescapedChar(str string, ch rune) bool { + var prev rune + + for _, c := range str { + if c == ch && prev != '\\' { + return true + } + + if c == '\\' && prev == '\\' { + prev = rune(0) // reset escape sequence + } else { + prev = c + } + } + + return false +} + // ------------------------------------------------------------------- var _ dbx.Expression = (*opExpr)(nil) diff --git a/tools/search/filter_test.go b/tools/search/filter_test.go index 2f572efa..59e37e09 100644 --- a/tools/search/filter_test.go +++ b/tools/search/filter_test.go @@ -238,3 +238,68 @@ func TestFilterDataBuildExprWithParams(t *testing.T) { t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0]) } } + +func TestLikeParamsWrapping(t *testing.T) { + // create a dummy db + sqlDB, err := sql.Open("sqlite", "file::memory:?cache=shared") + if err != nil { + t.Fatal(err) + } + db := dbx.NewFromDB(sqlDB, "sqlite") + + calledQueries := []string{} + db.QueryLogFunc = func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) { + calledQueries = append(calledQueries, sql) + } + db.ExecLogFunc = func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) { + calledQueries = append(calledQueries, sql) + } + + resolver := search.NewSimpleFieldResolver(`^test\w+$`) + + filter := search.FilterData(` + test1 ~ {:p1} || + test2 ~ {:p2} || + test3 ~ {:p3} || + test4 ~ {:p4} || + test5 ~ {:p5} || + test6 ~ {:p6} || + test7 ~ {:p7} || + test8 ~ {:p8} || + test9 ~ {:p9} || + test10 ~ {:p10} || + test11 ~ {:p11} || + test12 ~ {:p12} + `) + + replacements := []dbx.Params{ + {"p1": `abc`}, + {"p2": `ab%c`}, + {"p3": `ab\%c`}, + {"p4": `%ab\%c`}, + {"p5": `ab\\%c`}, + {"p6": `ab\\\%c`}, + {"p7": `ab_c`}, + {"p8": `ab\_c`}, + {"p9": `%ab_c`}, + {"p10": `ab\c`}, + {"p11": `_ab\c_`}, + {"p12": `ab\c%`}, + } + + expr, err := filter.BuildExpr(resolver, replacements...) + if err != nil { + t.Fatal(err) + } + + db.Select().Where(expr).Build().Execute() + + if len(calledQueries) != 1 { + t.Fatalf("Expected 1 query, got %d", len(calledQueries)) + } + + expectedQuery := `SELECT * WHERE ([[test1]] LIKE '%abc%' ESCAPE '\' OR [[test2]] LIKE 'ab%c' ESCAPE '\' OR [[test3]] LIKE 'ab\\%c' ESCAPE '\' OR [[test4]] LIKE '%ab\\%c' ESCAPE '\' OR [[test5]] LIKE 'ab\\\\%c' ESCAPE '\' OR [[test6]] LIKE 'ab\\\\\\%c' ESCAPE '\' OR [[test7]] LIKE '%ab\_c%' ESCAPE '\' OR [[test8]] LIKE '%ab\\\_c%' ESCAPE '\' OR [[test9]] LIKE '%ab_c' ESCAPE '\' OR [[test10]] LIKE '%ab\\c%' ESCAPE '\' OR [[test11]] LIKE '%\_ab\\c\_%' ESCAPE '\' OR [[test12]] LIKE 'ab\\c%' ESCAPE '\')` + if expectedQuery != calledQueries[0] { + t.Fatalf("Expected query \n%s, \ngot \n%s", expectedQuery, calledQueries[0]) + } +}