Skip to content

Commit

Permalink
Add In Clause handling in json indexed col (Attr) (cadence-workflow#6147
Browse files Browse the repository at this point in the history
)

* add In Clause handling in json indexed col (Attr)

* add more test cases to cover string cases

* change dot to be colon
  • Loading branch information
bowenxia authored Jun 26, 2024
1 parent 0b46176 commit c7f6233
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 4 deletions.
39 changes: 37 additions & 2 deletions common/pinot/pinotQueryValidator.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (qv *VisibilityQueryValidator) ValidateQuery(whereClause string) (string, e

stmt, err := sqlparser.Parse(placeholderQuery)
if err != nil {
return "", &types.BadRequestError{Message: "Invalid query."}
return "", &types.BadRequestError{Message: "Invalid query: " + err.Error()}
}

sel, ok := stmt.(*sqlparser.Select)
Expand Down Expand Up @@ -313,6 +313,36 @@ func (qv *VisibilityQueryValidator) processSystemKey(expr sqlparser.Expr) (strin
return buf.String(), nil
}

func (qv *VisibilityQueryValidator) processInClause(expr sqlparser.Expr) (string, error) {
comparisonExpr, ok := expr.(*sqlparser.ComparisonExpr)
if !ok {
return "", errors.New("invalid IN expression")
}

colName, ok := comparisonExpr.Left.(*sqlparser.ColName)
if !ok {
return "", errors.New("invalid IN expression, left")
}

colNameStr := colName.Name.String()
valTuple, ok := comparisonExpr.Right.(sqlparser.ValTuple)
if !ok {
return "", errors.New("invalid IN expression, right")
}

values := make([]string, len(valTuple))
for i, val := range valTuple {
sqlVal, ok := val.(*sqlparser.SQLVal)
if !ok {
return "", errors.New("invalid IN expression, value")
}
values[i] = "''" + string(sqlVal.Val) + "''"
}

return fmt.Sprintf("JSON_MATCH(Attr, '\"$.%s\" IN (%s)') or JSON_MATCH(Attr, '\"$.%s[*]\" IN (%s)')",
colNameStr, strings.Join(values, ","), colNameStr, strings.Join(values, ",")), nil
}

func (qv *VisibilityQueryValidator) processCustomKey(expr sqlparser.Expr) (string, error) {
comparisonExpr := expr.(*sqlparser.ComparisonExpr)

Expand All @@ -329,6 +359,12 @@ func (qv *VisibilityQueryValidator) processCustomKey(expr sqlparser.Expr) (strin
return "", fmt.Errorf("invalid search attribute")
}

// process IN clause in json indexed col: Attr
operator := strings.ToLower(comparisonExpr.Operator)
if operator == sqlparser.InStr {
return qv.processInClause(expr)
}

// get the column value
colVal, ok := comparisonExpr.Right.(*sqlparser.SQLVal)
if !ok {
Expand All @@ -337,7 +373,6 @@ func (qv *VisibilityQueryValidator) processCustomKey(expr sqlparser.Expr) (strin

// get the value type
indexValType := common.ConvertIndexedValueTypeToInternalType(valType, log.NewNoop())
operator := comparisonExpr.Operator
colValStr := string(colVal.Val)

switch indexValType {
Expand Down
62 changes: 60 additions & 2 deletions common/pinot/pinotQueryValidator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestValidateQuery(t *testing.T) {
},
"Case7: invalid sql query": {
query: "Invalid SQL",
err: "Invalid query.",
err: "Invalid query: syntax error at position 38 near 'sql'",
},
"Case8-1: query with missing val": {
query: "CloseTime = missing",
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestValidateQuery(t *testing.T) {
},
"Case12-1: security SQL injection - with another statement": {
query: "WorkflowID = 'wid'; SELECT * FROM important_table;",
err: "Invalid query.",
err: "Invalid query: syntax error at position 53 near 'select'",
},
"Case12-2: security SQL injection - with union": {
query: "WorkflowID = 'wid' union select * from dummy",
Expand Down Expand Up @@ -239,6 +239,27 @@ func TestValidateQuery(t *testing.T) {
query: "CloseStatus = 1",
validated: "CloseStatus = 1",
},
"case20-1: in clause in Attr": {
query: "CustomKeywordField in (123)",
validated: "JSON_MATCH(Attr, '\"$.CustomKeywordField\" IN (''123'')') or JSON_MATCH(Attr, '\"$.CustomKeywordField[*]\" IN (''123'')')",
},
"case20-2: in clause in Attr with multiple values": {
query: "CustomKeywordField in (123, 456)",
validated: "JSON_MATCH(Attr, '\"$.CustomKeywordField\" IN (''123'',''456'')') or JSON_MATCH(Attr, '\"$.CustomKeywordField[*]\" IN (''123'',''456'')')",
},
"case20-3-1: in clause in Attr with a string value, double quote": {
query: "CustomKeywordField in (\"abc\")",
validated: "JSON_MATCH(Attr, '\"$.CustomKeywordField\" IN (''abc'')') or JSON_MATCH(Attr, '\"$.CustomKeywordField[*]\" IN (''abc'')')",
},
"case20-3-2: in clause in Attr with a string value, single quote": {
query: "CustomKeywordField in ('abc')",
validated: "JSON_MATCH(Attr, '\"$.CustomKeywordField\" IN (''abc'')') or JSON_MATCH(Attr, '\"$.CustomKeywordField[*]\" IN (''abc'')')",
},
"case20-4: in clause in Attr with invalid IN expression, value": {
query: "CustomKeywordField in (abc)",
validated: "",
err: "invalid IN expression, value",
},
}

for name, test := range tests {
Expand All @@ -255,6 +276,43 @@ func TestValidateQuery(t *testing.T) {
}
}

func TestProcessInClause_FailedInputExprCases(t *testing.T) {
// Define test cases
tests := map[string]struct {
inputExpr sqlparser.Expr
expectedError string
}{
"case1: in clause in Attr with invalid expr": {
inputExpr: &sqlparser.SQLVal{Type: sqlparser.StrVal, Val: []byte("invalid")},
expectedError: "invalid IN expression",
},
"case2: in clause in Attr with invalid expr, left": {
inputExpr: &sqlparser.ComparisonExpr{Operator: sqlparser.InStr},
expectedError: "invalid IN expression, left",
},
"case3: in clause in Attr with invalid expr, right": {
inputExpr: &sqlparser.ComparisonExpr{Operator: sqlparser.InStr, Left: &sqlparser.ColName{Name: sqlparser.NewColIdent("CustomKeywordField")}},
expectedError: "invalid IN expression, right",
},
}

// Create a new VisibilityQueryValidator
validSearchAttr := dynamicconfig.GetMapPropertyFn(definition.GetDefaultIndexedKeys())
qv := NewPinotQueryValidator(validSearchAttr())

for name, test := range tests {
t.Run(name, func(t *testing.T) {
// Call processInClause with the input expression
_, err := qv.processInClause(test.inputExpr)

// Check that an error was returned and that the error message matches the expected error message
if assert.Error(t, err) {
assert.Contains(t, err.Error(), test.expectedError)
}
})
}
}

func TestParseTime(t *testing.T) {
var tests = []struct {
name string
Expand Down

0 comments on commit c7f6233

Please sign in to comment.