diff --git a/expression/substring.go b/expression/substring.go index 73a0a422a89e2..3e517924b93f7 100644 --- a/expression/substring.go +++ b/expression/substring.go @@ -15,9 +15,16 @@ package expression import ( "fmt" + "strings" "github.com/juju/errors" "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/util/types" +) + +var ( + _ Expression = (*FunctionSubstring)(nil) + _ Expression = (*FunctionSubstringIndex)(nil) ) // FunctionSubstring returns the substring as specified. @@ -112,3 +119,86 @@ func (f *FunctionSubstring) Eval(ctx context.Context, args map[interface{}]inter func (f *FunctionSubstring) Accept(v Visitor) (Expression, error) { return v.VisitFunctionSubstring(f) } + +// FunctionSubstringIndex returns the substring as specified. +// See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_substring-index +type FunctionSubstringIndex struct { + StrExpr Expression + Delim Expression + Count Expression +} + +// Clone implements the Expression Clone interface. +func (f *FunctionSubstringIndex) Clone() Expression { + nf := &FunctionSubstringIndex{ + StrExpr: f.StrExpr.Clone(), + Delim: f.Delim.Clone(), + Count: f.Count.Clone(), + } + return nf +} + +// IsStatic implements the Expression IsStatic interface. +func (f *FunctionSubstringIndex) IsStatic() bool { + return f.StrExpr.IsStatic() && f.Delim.IsStatic() && f.Count.IsStatic() +} + +// String implements the Expression String interface. +func (f *FunctionSubstringIndex) String() string { + return fmt.Sprintf("SUBSTRING_INDEX(%s, %s, %s)", f.StrExpr, f.Delim, f.Count) +} + +// Eval implements the Expression Eval interface. +func (f *FunctionSubstringIndex) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { + fs, err := f.StrExpr.Eval(ctx, args) + if err != nil { + return nil, errors.Trace(err) + } + str, ok := fs.(string) + if !ok { + return nil, errors.Errorf("Substring_Index invalid args, need string but get %T", fs) + } + + t, err := f.Delim.Eval(ctx, args) + if err != nil { + return nil, errors.Trace(err) + } + delim, ok := t.(string) + if !ok { + return nil, errors.Errorf("Substring_Index invalid delim, need string but get %T", t) + } + + t, err = f.Count.Eval(ctx, args) + if err != nil { + return nil, errors.Trace(err) + } + c, err := types.ToInt64(t) + if err != nil { + return nil, errors.Trace(err) + } + count := int(c) + strs := strings.Split(str, delim) + var ( + start = 0 + end = len(strs) + ) + if count > 0 { + // If count is positive, everything to the left of the final delimiter (counting from the left) is returned. + if count < end { + end = count + } + } else { + // If count is negative, everything to the right of the final delimiter (counting from the right) is returned. + count = -count + if count < end { + start = end - count + } + } + substrs := strs[start:end] + return strings.Join(substrs, delim), nil +} + +// Accept implements Expression Accept interface. +func (f *FunctionSubstringIndex) Accept(v Visitor) (Expression, error) { + return v.VisitFunctionSubstringIndex(f) +} diff --git a/expression/substring_test.go b/expression/substring_test.go index 8a6b2cbf50d04..4c27b677c37e2 100644 --- a/expression/substring_test.go +++ b/expression/substring_test.go @@ -87,3 +87,44 @@ func (s *testSubstringSuite) TestSubstring(c *C) { c.Assert(err, NotNil) } } + +func (s *testSubstringSuite) TestSubstringIndex(c *C) { + tbl := []struct { + str string + delim string + count int64 + result string + }{ + {"www.mysql.com", ".", 2, "www.mysql"}, + {"www.mysql.com", ".", -2, "mysql.com"}, + {"www.mysql.com", ".", 20, "www.mysql.com"}, + {"www.mysql.com", ".", -20, "www.mysql.com"}, + {"www.mysql.com", "_", 2, "www.mysql.com"}, + {"www.mysql.com", "_", 0, ""}, + } + for _, v := range tbl { + f := FunctionSubstringIndex{ + StrExpr: &Value{Val: v.str}, + Delim: &Value{Val: v.delim}, + Count: &Value{Val: v.count}, + } + c.Assert(f.IsStatic(), Equals, true) + + fs := f.String() + c.Assert(len(fs), Greater, 0) + + f1 := f.Clone() + + r, err := f.Eval(nil, nil) + c.Assert(err, IsNil) + s, ok := r.(string) + c.Assert(ok, Equals, true) + c.Assert(s, Equals, v.result) + + r1, err := f1.Eval(nil, nil) + c.Assert(err, IsNil) + s1, ok := r1.(string) + c.Assert(ok, Equals, true) + c.Assert(s, Equals, s1) + } +} diff --git a/expression/visitor.go b/expression/visitor.go index 7f31259dc0e1c..180161fd504a7 100644 --- a/expression/visitor.go +++ b/expression/visitor.go @@ -44,6 +44,9 @@ type Visitor interface { // VisitFunctionSubstring visits FunctionSubstring expression. VisitFunctionSubstring(ss *FunctionSubstring) (Expression, error) + // VisitFunctionSubstringIndex visits FunctionSubstringIndex expression. + VisitFunctionSubstringIndex(ss *FunctionSubstringIndex) (Expression, error) + // VisitExistsSubQuery visits ExistsSubQuery expression. VisitExistsSubQuery(es *ExistsSubQuery) (Expression, error) @@ -243,6 +246,27 @@ func (bv *BaseVisitor) VisitFunctionSubstring(ss *FunctionSubstring) (Expression return ss, nil } +// VisitFunctionSubstringIndex implements Visitor interface. +func (bv *BaseVisitor) VisitFunctionSubstringIndex(ss *FunctionSubstringIndex) (Expression, error) { + var err error + ss.StrExpr, err = ss.StrExpr.Accept(bv.V) + if err != nil { + return ss, errors.Trace(err) + } + ss.Delim, err = ss.Delim.Accept(bv.V) + if err != nil { + return ss, errors.Trace(err) + } + if ss.Count == nil { + return ss, nil + } + ss.Count, err = ss.Count.Accept(bv.V) + if err != nil { + return ss, errors.Trace(err) + } + return ss, nil +} + // VisitIdent implements Visitor interface. func (bv *BaseVisitor) VisitIdent(i *Ident) (Expression, error) { return i, nil diff --git a/parser/parser.y b/parser/parser.y index 4e092e1da8738..3ff0283486d39 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -210,6 +210,7 @@ import ( start "START" stringType "string" substring "SUBSTRING" + substringIndex "SUBSTRING_INDEX" sum "SUM" sysVar "SYS_VAR" sysDate "SYSDATE" @@ -1614,7 +1615,7 @@ UnReservedKeyword: NotKeywordToken: "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" | "HOUR" | "IFNULL" | "LENGTH" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "SECOND" | "SQL_CALC_FOUND_ROWS" -| "SUBSTRING" %prec lowerThanLeftParen | "SUM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" +| "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" /************************************************************************************ * @@ -2262,6 +2263,14 @@ FunctionCallNonKeyword: Len: $7.(expression.Expression), } } +| "SUBSTRING_INDEX" '(' Expression ',' Expression ',' Expression ')' + { + $$ = &expression.FunctionSubstringIndex{ + StrExpr: $3.(expression.Expression), + Delim: $5.(expression.Expression), + Count: $7.(expression.Expression), + } + } | "SYSDATE" '(' ExpressionOpt ')' { args := []expression.Expression{} diff --git a/parser/parser_test.go b/parser/parser_test.go index 2822abce3a45b..2b54ff6472361 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -274,6 +274,9 @@ func (s *testParserSuite) TestParser0(c *C) { {"SELECT CURRENT_USER();", true}, {"SELECT CURRENT_USER;", true}, + {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', 2);", true}, + {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', -2);", true}, + // For delete statement {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, {"DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, diff --git a/parser/scanner.l b/parser/scanner.l index 957fb02521d30..b16127d86aa89 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -355,6 +355,7 @@ show {s}{h}{o}{w} some {s}{o}{m}{e} start {s}{t}{a}{r}{t} substring {s}{u}{b}{s}{t}{r}{i}{n}{g} +substring_index {s}{u}{b}{s}{t}{r}{i}{n}{g}_{i}{n}{d}{e}{x} sum {s}{u}{m} sysdate {s}{y}{s}{d}{a}{t}{e} table {t}{a}{b}{l}{e} @@ -694,6 +695,8 @@ sys_var "@@"(({global}".")|({session}".")|{local}".")?{ident} {show} return show {substring} lval.item = string(l.val) return substring +{substring_index} lval.item = string(l.val) + return substringIndex {sum} lval.item = string(l.val) return sum {sysdate} lval.item = string(l.val)