diff --git a/ast/functions.go b/ast/functions.go index 6f18bfaaaf7dd..ecf30b695157d 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -96,7 +96,10 @@ const ( Curtime = "curtime" Date = "date" DateDiff = "datediff" - DateArith = "date_arith" + DateAdd = "date_add" + AddDate = "adddate" + DateSub = "date_sub" + SubDate = "subdate" DateFormat = "date_format" Day = "day" DayName = "dayname" @@ -259,14 +262,14 @@ const ( type DateArithType byte const ( - // DateAdd is to run adddate or date_add function option. + // DateArithAdd is to run adddate or date_add function option. // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add - DateAdd DateArithType = iota + 1 - // DateSub is to run subdate or date_sub function option. + DateArithAdd DateArithType = iota + 1 + // DateArithSub is to run subdate or date_sub function option. // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub - DateSub + DateArithSub ) const ( diff --git a/ast/functions_test.go b/ast/functions_test.go index 3da2a22e849d9..a41e49acfe2dc 100644 --- a/ast/functions_test.go +++ b/ast/functions_test.go @@ -34,7 +34,7 @@ func (ts *testFunctionsSuite) TestAggregateFuncExtractor(c *C) { extractor = &AggregateFuncExtractor{} expr = &FuncCallExpr{ - FnName: model.NewCIStr("DATE_ARITH"), + FnName: model.NewCIStr("FAKE_FUNC"), } expr.Accept(extractor) c.Assert(extractor.AggFuncs, HasLen, 0) @@ -43,7 +43,7 @@ func (ts *testFunctionsSuite) TestAggregateFuncExtractor(c *C) { r := &AggregateFuncExpr{} expr = &BinaryOperationExpr{ L: &FuncCallExpr{ - FnName: model.NewCIStr("DATE_ARITH"), + FnName: model.NewCIStr("FAKE_FUNC"), }, R: r, } diff --git a/expression/builtin.go b/expression/builtin.go index 54199f6eaeee8..e69b939d5f013 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -178,8 +178,11 @@ var Funcs = map[string]Func{ ast.CurrentDate: {builtinCurrentDate, 0, 0}, ast.CurrentTime: {builtinCurrentTime, 0, 1}, ast.Date: {builtinDate, 1, 1}, - ast.DateArith: {builtinDateArith, 4, 4}, ast.DateDiff: {builtinDateDiff, 2, 2}, + ast.DateAdd: {dateArithFuncFactory(ast.DateArithAdd), 3, 3}, + ast.AddDate: {dateArithFuncFactory(ast.DateArithAdd), 3, 3}, + ast.DateSub: {dateArithFuncFactory(ast.DateArithSub), 3, 3}, + ast.SubDate: {dateArithFuncFactory(ast.DateArithSub), 3, 3}, ast.DateFormat: {builtinDateFormat, 2, 2}, ast.CurrentTimestamp: {builtinNow, 0, 1}, ast.Curtime: {builtinCurrentTime, 0, 1}, diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 37296035a6f85..d34f1de342647 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -660,101 +660,100 @@ func checkFsp(sc *variable.StatementContext, arg types.Datum) (int, error) { return int(fsp), nil } -func builtinDateArith(args []types.Datum, ctx context.Context) (d types.Datum, err error) { - // Op is used for distinguishing date_add and date_sub. - // args[0] -> Op - // args[1] -> Date - // args[2] -> Interval Value - // args[3] -> Interval Unit - // health check for date and interval - if args[1].IsNull() || args[2].IsNull() { - return d, nil - } - nodeDate := args[1] - nodeIntervalValue := args[2] - nodeIntervalUnit := args[3].GetString() - if nodeIntervalValue.IsNull() { - return d, nil - } - // parse date - fieldType := mysql.TypeDate - var resultField *types.FieldType - switch nodeDate.Kind() { - case types.KindMysqlTime: - x := nodeDate.GetMysqlTime() - if (x.Type == mysql.TypeDatetime) || (x.Type == mysql.TypeTimestamp) { - fieldType = mysql.TypeDatetime +func dateArithFuncFactory(op ast.DateArithType) BuiltinFunc { + return func(args []types.Datum, ctx context.Context) (d types.Datum, err error) { + // args[0] -> Date + // args[1] -> Interval Value + // args[2] -> Interval Unit + // health check for date and interval + if args[0].IsNull() || args[1].IsNull() { + return d, nil } - case types.KindString: - x := nodeDate.GetString() - if !types.IsDateFormat(x) { - fieldType = mysql.TypeDatetime + nodeDate := args[0] + nodeIntervalValue := args[1] + nodeIntervalUnit := args[2].GetString() + if nodeIntervalValue.IsNull() { + return d, nil } - case types.KindInt64: - x := nodeDate.GetInt64() - if t, err1 := types.ParseTimeFromInt64(x); err1 == nil { - if (t.Type == mysql.TypeDatetime) || (t.Type == mysql.TypeTimestamp) { + // parse date + fieldType := mysql.TypeDate + var resultField *types.FieldType + switch nodeDate.Kind() { + case types.KindMysqlTime: + x := nodeDate.GetMysqlTime() + if (x.Type == mysql.TypeDatetime) || (x.Type == mysql.TypeTimestamp) { fieldType = mysql.TypeDatetime } + case types.KindString: + x := nodeDate.GetString() + if !types.IsDateFormat(x) { + fieldType = mysql.TypeDatetime + } + case types.KindInt64: + x := nodeDate.GetInt64() + if t, err1 := types.ParseTimeFromInt64(x); err1 == nil { + if (t.Type == mysql.TypeDatetime) || (t.Type == mysql.TypeTimestamp) { + fieldType = mysql.TypeDatetime + } + } } - } - sc := ctx.GetSessionVars().StmtCtx - if types.IsClockUnit(nodeIntervalUnit) { - fieldType = mysql.TypeDatetime - } - resultField = types.NewFieldType(fieldType) - resultField.Decimal = types.MaxFsp - value, err := nodeDate.ConvertTo(ctx.GetSessionVars().StmtCtx, resultField) - if err != nil { - return d, errInvalidOperation.Gen("DateArith invalid args, need date but get %T", nodeDate) - } - if value.IsNull() { - return d, errInvalidOperation.Gen("DateArith invalid args, need date but get %v", value.GetValue()) - } - if value.Kind() != types.KindMysqlTime { - return d, errInvalidOperation.Gen("DateArith need time type, but got %T", value.GetValue()) - } - result := value.GetMysqlTime() - // parse interval - var interval string - if strings.ToLower(nodeIntervalUnit) == "day" { - day, err1 := parseDayInterval(sc, nodeIntervalValue) - if err1 != nil { - return d, errInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeIntervalValue.GetString()) + sc := ctx.GetSessionVars().StmtCtx + if types.IsClockUnit(nodeIntervalUnit) { + fieldType = mysql.TypeDatetime } - interval = fmt.Sprintf("%d", day) - } else { - if nodeIntervalValue.Kind() == types.KindString { - interval = fmt.Sprintf("%v", nodeIntervalValue.GetString()) - } else { - ii, err1 := nodeIntervalValue.ToInt64(sc) + resultField = types.NewFieldType(fieldType) + resultField.Decimal = types.MaxFsp + value, err := nodeDate.ConvertTo(ctx.GetSessionVars().StmtCtx, resultField) + if err != nil { + return d, errInvalidOperation.Gen("DateArith invalid args, need date but get %T", nodeDate) + } + if value.IsNull() { + return d, errInvalidOperation.Gen("DateArith invalid args, need date but get %v", value.GetValue()) + } + if value.Kind() != types.KindMysqlTime { + return d, errInvalidOperation.Gen("DateArith need time type, but got %T", value.GetValue()) + } + result := value.GetMysqlTime() + // parse interval + var interval string + if strings.ToLower(nodeIntervalUnit) == "day" { + day, err1 := parseDayInterval(sc, nodeIntervalValue) if err1 != nil { - return d, errors.Trace(err1) + return d, errInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeIntervalValue.GetString()) + } + interval = fmt.Sprintf("%d", day) + } else { + if nodeIntervalValue.Kind() == types.KindString { + interval = fmt.Sprintf("%v", nodeIntervalValue.GetString()) + } else { + ii, err1 := nodeIntervalValue.ToInt64(sc) + if err1 != nil { + return d, errors.Trace(err1) + } + interval = fmt.Sprintf("%v", ii) } - interval = fmt.Sprintf("%v", ii) } + year, month, day, duration, err := types.ExtractTimeValue(nodeIntervalUnit, interval) + if err != nil { + return d, errors.Trace(err) + } + if op == ast.DateArithSub { + year, month, day, duration = -year, -month, -day, -duration + } + // TODO: Consider time_zone variable. + t, err := result.Time.GoTime(time.Local) + if err != nil { + return d, errors.Trace(err) + } + t = t.Add(duration) + t = t.AddDate(int(year), int(month), int(day)) + if t.Nanosecond() == 0 { + result.Fsp = 0 + } + result.Time = types.FromGoTime(t) + d.SetMysqlTime(result) + return d, nil } - year, month, day, duration, err := types.ExtractTimeValue(nodeIntervalUnit, interval) - if err != nil { - return d, errors.Trace(err) - } - op := args[0].GetInterface().(ast.DateArithType) - if op == ast.DateSub { - year, month, day, duration = -year, -month, -day, -duration - } - // TODO: Consider time_zone variable. - t, err := result.Time.GoTime(time.Local) - if err != nil { - return d, errors.Trace(err) - } - t = t.Add(duration) - t = t.AddDate(int(year), int(month), int(day)) - if t.Nanosecond() == 0 { - result.Fsp = 0 - } - result.Time = types.FromGoTime(t) - d.SetMysqlTime(result) - return d, nil } var reg = regexp.MustCompile(`[\d]+`) diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 8cad589fdce67..4db2c560a8114 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -596,28 +596,30 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) { } } -func (s *testEvaluatorSuite) TestDateArith(c *C) { +func (s *testEvaluatorSuite) TestDateArithFuncs(c *C) { defer testleak.AfterTest(c)() date := []string{"2016-12-31", "2017-01-01"} + dateAdd := dateArithFuncFactory(ast.DateArithAdd) + dateSub := dateArithFuncFactory(ast.DateArithSub) - args := types.MakeDatums(ast.DateAdd, date[0], 1, "DAY") - v, err := builtinDateArith(args, s.ctx) + args := types.MakeDatums(date[0], 1, "DAY") + v, err := dateAdd(args, s.ctx) c.Assert(err, IsNil) c.Assert(v.GetMysqlTime().String(), Equals, date[1]) - args = types.MakeDatums(ast.DateSub, date[1], 1, "DAY") - v, err = builtinDateArith(args, s.ctx) + args = types.MakeDatums(date[1], 1, "DAY") + v, err = dateSub(args, s.ctx) c.Assert(err, IsNil) c.Assert(v.GetMysqlTime().String(), Equals, date[0]) - args = types.MakeDatums(ast.DateAdd, date[0], nil, "DAY") - v, err = builtinDateArith(args, s.ctx) + args = types.MakeDatums(date[0], nil, "DAY") + v, err = dateAdd(args, s.ctx) c.Assert(err, IsNil) c.Assert(v.IsNull(), IsTrue) - args = types.MakeDatums(ast.DateSub, date[1], nil, "DAY") - v, err = builtinDateArith(args, s.ctx) + args = types.MakeDatums(date[1], nil, "DAY") + v, err = dateSub(args, s.ctx) c.Assert(err, IsNil) c.Assert(v.IsNull(), IsTrue) } diff --git a/parser/parser.y b/parser/parser.y index f1a5656576e1b..898cefb2c61f9 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -475,8 +475,6 @@ import ( DatabaseOptionListOpt "CREATE Database specification list opt" CreateTableStmt "CREATE TABLE statement" CreateUserStmt "CREATE User statement" - DateArithOpt "Date arith dateadd or datesub option" - DateArithMultiFormsOpt "Date arith adddate or subdate option" DBName "Database Name" DeallocateStmt "Deallocate prepared statement" Default "DEFAULT clause" @@ -716,7 +714,9 @@ import ( NotKeywordToken "Tokens not mysql keyword but treated specially" UnReservedKeyword "MySQL unreserved keywords" ReservedKeyword "MySQL reserved keywords" - FunctionNameConflict "Built-in function call names which are conflict with keywords" + FunctionNameConflict "Built-in function call names which are conflict with keywords" + FunctionNameDateArith "Date arith function call names (date_add or date_sub)" + FunctionNameDateArithMultiForms "Date arith function call names (adddate or subdate)" %precedence lowestOpt %token tableRefPriority @@ -2661,36 +2661,33 @@ FunctionCallNonKeyword: { $$ = &ast.FuncCallExpr{FnName: model.NewCIStr($1), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } -| DateArithMultiFormsOpt '(' Expression ',' Expression ')' +| FunctionNameDateArithMultiForms '(' Expression ',' Expression ')' { $$ = &ast.FuncCallExpr{ - FnName: model.NewCIStr("DATE_ARITH"), + FnName: model.NewCIStr($1), Args: []ast.ExprNode{ - ast.NewValueExpr($1), $3.(ast.ExprNode), $5.(ast.ExprNode), ast.NewValueExpr("DAY"), }, } } -| DateArithMultiFormsOpt '(' Expression ',' "INTERVAL" Expression TimeUnit ')' +| FunctionNameDateArithMultiForms '(' Expression ',' "INTERVAL" Expression TimeUnit ')' { $$ = &ast.FuncCallExpr{ - FnName: model.NewCIStr("DATE_ARITH"), + FnName: model.NewCIStr($1), Args: []ast.ExprNode{ - ast.NewValueExpr($1), $3.(ast.ExprNode), $6.(ast.ExprNode), ast.NewValueExpr($7), }, } } -| DateArithOpt '(' Expression ',' "INTERVAL" Expression TimeUnit ')' +| FunctionNameDateArith '(' Expression ',' "INTERVAL" Expression TimeUnit ')' { $$ = &ast.FuncCallExpr{ - FnName: model.NewCIStr("DATE_ARITH"), + FnName: model.NewCIStr($1), Args: []ast.ExprNode{ - ast.NewValueExpr($1), $3.(ast.ExprNode), $6.(ast.ExprNode), ast.NewValueExpr($7), @@ -3101,25 +3098,15 @@ FunctionCallNonKeyword: } -DateArithOpt: +FunctionNameDateArith: "DATE_ADD" - { - $$ = ast.DateAdd - } | "DATE_SUB" - { - $$ = ast.DateSub - } -DateArithMultiFormsOpt: + +FunctionNameDateArithMultiForms: "ADDDATE" - { - $$ = ast.DateAdd - } | "SUBDATE" - { - $$ = ast.DateSub - } + TrimDirection: "BOTH" diff --git a/plan/expression_test.go b/plan/expression_test.go index bd743a1c9b0bc..d65e5c135abc1 100644 --- a/plan/expression_test.go +++ b/plan/expression_test.go @@ -366,50 +366,42 @@ func (s *testExpressionSuite) TestDateArith(c *C) { // run the test cases for _, t := range tests { - op := ast.NewValueExpr(ast.DateAdd) var interval ast.ExprNode if n, ok := t.Interval.(ast.ExprNode); ok { interval = n } else { interval = ast.NewValueExpr(t.Interval) } - expr := &ast.FuncCallExpr{ - FnName: model.NewCIStr("DATE_ARITH"), - Args: []ast.ExprNode{ - op, - ast.NewValueExpr(t.Date), - interval, - ast.NewValueExpr(t.Unit), - }, - } - ast.SetFlag(expr) - v, err := evalAstExpr(expr, s.ctx) - if t.error == true { - c.Assert(err, NotNil) - } else { - c.Assert(err, IsNil) - if v.IsNull() { - c.Assert(nil, Equals, t.AddResult) - } else { - c.Assert(v.Kind(), Equals, types.KindMysqlTime) - value := v.GetMysqlTime() - c.Assert(value.String(), Equals, t.AddResult) + for _, x := range []struct { + fnName string + result interface{} + }{ + {ast.DateAdd, t.AddResult}, + {ast.DateSub, t.SubResult}, + {ast.AddDate, t.AddResult}, + {ast.SubDate, t.SubResult}, + } { + expr := &ast.FuncCallExpr{ + FnName: model.NewCIStr(x.fnName), + Args: []ast.ExprNode{ + ast.NewValueExpr(t.Date), + interval, + ast.NewValueExpr(t.Unit), + }, } - } - - op = ast.NewValueExpr(ast.DateSub) - expr.Args[0] = op - v, err = evalAstExpr(expr, s.ctx) - if t.error == true { - c.Assert(err, NotNil) - } else { - c.Assert(err, IsNil) - if v.IsNull() { - c.Assert(nil, Equals, t.AddResult) + ast.SetFlag(expr) + v, err := evalAstExpr(expr, s.ctx) + if t.error == true { + c.Assert(err, NotNil) } else { - c.Assert(v.Kind(), Equals, types.KindMysqlTime) - value := v.GetMysqlTime() - c.Assert(value.String(), Equals, t.SubResult) + c.Assert(err, IsNil) + if v.IsNull() { + c.Assert(nil, Equals, x.result) + } else { + c.Assert(v.Kind(), Equals, types.KindMysqlTime) + value := v.GetMysqlTime() + c.Assert(value.String(), Equals, x.result) + } } } } diff --git a/plan/typeinferer.go b/plan/typeinferer.go index 960eba24a20ed..0cae18de35ae3 100644 --- a/plan/typeinferer.go +++ b/plan/typeinferer.go @@ -292,7 +292,7 @@ func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) { case "curtime", "current_time", "timediff": tp = types.NewFieldType(mysql.TypeDuration) tp.Decimal = v.getFsp(x) - case "current_timestamp", "date_arith": + case "current_timestamp", "date_add", "date_sub", "adddate", "subdate": tp = types.NewFieldType(mysql.TypeDatetime) case "microsecond", "second", "minute", "hour", "day", "week", "month", "year", "dayofweek", "dayofmonth", "dayofyear", "weekday", "weekofyear", "yearweek", "datediff",