Skip to content

Commit

Permalink
*: extract a builtin function factory for date arithmetic (pingcap#2403)
Browse files Browse the repository at this point in the history
* *: refactor DATE_ARITH according to pingcap#2368
  • Loading branch information
zyguan authored and shenli committed Jan 7, 2017
1 parent 1dd0945 commit 2521ab4
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 165 deletions.
13 changes: 8 additions & 5 deletions ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ const (
Curtime = "curtime"
Date = "date"
DateDiff = "datediff"
DateArith = "date_arith"
DateAdd = "date_add"
AddDate = "adddate"
DateSub = "date_sub"
SubDate = "subdate"
DateFormat = "date_format"
Day = "day"
DayName = "dayname"
Expand Down Expand Up @@ -259,14 +262,14 @@ const (
type DateArithType byte

const (
// DateAdd is to run adddate or date_add function option.
// DateArithAdd is to run adddate or date_add function option.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
DateAdd DateArithType = iota + 1
// DateSub is to run subdate or date_sub function option.
DateArithAdd DateArithType = iota + 1
// DateArithSub is to run subdate or date_sub function option.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
DateSub
DateArithSub
)

const (
Expand Down
4 changes: 2 additions & 2 deletions ast/functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (ts *testFunctionsSuite) TestAggregateFuncExtractor(c *C) {

extractor = &AggregateFuncExtractor{}
expr = &FuncCallExpr{
FnName: model.NewCIStr("DATE_ARITH"),
FnName: model.NewCIStr("FAKE_FUNC"),
}
expr.Accept(extractor)
c.Assert(extractor.AggFuncs, HasLen, 0)
Expand All @@ -43,7 +43,7 @@ func (ts *testFunctionsSuite) TestAggregateFuncExtractor(c *C) {
r := &AggregateFuncExpr{}
expr = &BinaryOperationExpr{
L: &FuncCallExpr{
FnName: model.NewCIStr("DATE_ARITH"),
FnName: model.NewCIStr("FAKE_FUNC"),
},
R: r,
}
Expand Down
5 changes: 4 additions & 1 deletion expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,11 @@ var Funcs = map[string]Func{
ast.CurrentDate: {builtinCurrentDate, 0, 0},
ast.CurrentTime: {builtinCurrentTime, 0, 1},
ast.Date: {builtinDate, 1, 1},
ast.DateArith: {builtinDateArith, 4, 4},
ast.DateDiff: {builtinDateDiff, 2, 2},
ast.DateAdd: {dateArithFuncFactory(ast.DateArithAdd), 3, 3},
ast.AddDate: {dateArithFuncFactory(ast.DateArithAdd), 3, 3},
ast.DateSub: {dateArithFuncFactory(ast.DateArithSub), 3, 3},
ast.SubDate: {dateArithFuncFactory(ast.DateArithSub), 3, 3},
ast.DateFormat: {builtinDateFormat, 2, 2},
ast.CurrentTimestamp: {builtinNow, 0, 1},
ast.Curtime: {builtinCurrentTime, 0, 1},
Expand Down
169 changes: 84 additions & 85 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,101 +660,100 @@ func checkFsp(sc *variable.StatementContext, arg types.Datum) (int, error) {
return int(fsp), nil
}

func builtinDateArith(args []types.Datum, ctx context.Context) (d types.Datum, err error) {
// Op is used for distinguishing date_add and date_sub.
// args[0] -> Op
// args[1] -> Date
// args[2] -> Interval Value
// args[3] -> Interval Unit
// health check for date and interval
if args[1].IsNull() || args[2].IsNull() {
return d, nil
}
nodeDate := args[1]
nodeIntervalValue := args[2]
nodeIntervalUnit := args[3].GetString()
if nodeIntervalValue.IsNull() {
return d, nil
}
// parse date
fieldType := mysql.TypeDate
var resultField *types.FieldType
switch nodeDate.Kind() {
case types.KindMysqlTime:
x := nodeDate.GetMysqlTime()
if (x.Type == mysql.TypeDatetime) || (x.Type == mysql.TypeTimestamp) {
fieldType = mysql.TypeDatetime
func dateArithFuncFactory(op ast.DateArithType) BuiltinFunc {
return func(args []types.Datum, ctx context.Context) (d types.Datum, err error) {
// args[0] -> Date
// args[1] -> Interval Value
// args[2] -> Interval Unit
// health check for date and interval
if args[0].IsNull() || args[1].IsNull() {
return d, nil
}
case types.KindString:
x := nodeDate.GetString()
if !types.IsDateFormat(x) {
fieldType = mysql.TypeDatetime
nodeDate := args[0]
nodeIntervalValue := args[1]
nodeIntervalUnit := args[2].GetString()
if nodeIntervalValue.IsNull() {
return d, nil
}
case types.KindInt64:
x := nodeDate.GetInt64()
if t, err1 := types.ParseTimeFromInt64(x); err1 == nil {
if (t.Type == mysql.TypeDatetime) || (t.Type == mysql.TypeTimestamp) {
// parse date
fieldType := mysql.TypeDate
var resultField *types.FieldType
switch nodeDate.Kind() {
case types.KindMysqlTime:
x := nodeDate.GetMysqlTime()
if (x.Type == mysql.TypeDatetime) || (x.Type == mysql.TypeTimestamp) {
fieldType = mysql.TypeDatetime
}
case types.KindString:
x := nodeDate.GetString()
if !types.IsDateFormat(x) {
fieldType = mysql.TypeDatetime
}
case types.KindInt64:
x := nodeDate.GetInt64()
if t, err1 := types.ParseTimeFromInt64(x); err1 == nil {
if (t.Type == mysql.TypeDatetime) || (t.Type == mysql.TypeTimestamp) {
fieldType = mysql.TypeDatetime
}
}
}
}
sc := ctx.GetSessionVars().StmtCtx
if types.IsClockUnit(nodeIntervalUnit) {
fieldType = mysql.TypeDatetime
}
resultField = types.NewFieldType(fieldType)
resultField.Decimal = types.MaxFsp
value, err := nodeDate.ConvertTo(ctx.GetSessionVars().StmtCtx, resultField)
if err != nil {
return d, errInvalidOperation.Gen("DateArith invalid args, need date but get %T", nodeDate)
}
if value.IsNull() {
return d, errInvalidOperation.Gen("DateArith invalid args, need date but get %v", value.GetValue())
}
if value.Kind() != types.KindMysqlTime {
return d, errInvalidOperation.Gen("DateArith need time type, but got %T", value.GetValue())
}
result := value.GetMysqlTime()
// parse interval
var interval string
if strings.ToLower(nodeIntervalUnit) == "day" {
day, err1 := parseDayInterval(sc, nodeIntervalValue)
if err1 != nil {
return d, errInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeIntervalValue.GetString())
sc := ctx.GetSessionVars().StmtCtx
if types.IsClockUnit(nodeIntervalUnit) {
fieldType = mysql.TypeDatetime
}
interval = fmt.Sprintf("%d", day)
} else {
if nodeIntervalValue.Kind() == types.KindString {
interval = fmt.Sprintf("%v", nodeIntervalValue.GetString())
} else {
ii, err1 := nodeIntervalValue.ToInt64(sc)
resultField = types.NewFieldType(fieldType)
resultField.Decimal = types.MaxFsp
value, err := nodeDate.ConvertTo(ctx.GetSessionVars().StmtCtx, resultField)
if err != nil {
return d, errInvalidOperation.Gen("DateArith invalid args, need date but get %T", nodeDate)
}
if value.IsNull() {
return d, errInvalidOperation.Gen("DateArith invalid args, need date but get %v", value.GetValue())
}
if value.Kind() != types.KindMysqlTime {
return d, errInvalidOperation.Gen("DateArith need time type, but got %T", value.GetValue())
}
result := value.GetMysqlTime()
// parse interval
var interval string
if strings.ToLower(nodeIntervalUnit) == "day" {
day, err1 := parseDayInterval(sc, nodeIntervalValue)
if err1 != nil {
return d, errors.Trace(err1)
return d, errInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeIntervalValue.GetString())
}
interval = fmt.Sprintf("%d", day)
} else {
if nodeIntervalValue.Kind() == types.KindString {
interval = fmt.Sprintf("%v", nodeIntervalValue.GetString())
} else {
ii, err1 := nodeIntervalValue.ToInt64(sc)
if err1 != nil {
return d, errors.Trace(err1)
}
interval = fmt.Sprintf("%v", ii)
}
interval = fmt.Sprintf("%v", ii)
}
year, month, day, duration, err := types.ExtractTimeValue(nodeIntervalUnit, interval)
if err != nil {
return d, errors.Trace(err)
}
if op == ast.DateArithSub {
year, month, day, duration = -year, -month, -day, -duration
}
// TODO: Consider time_zone variable.
t, err := result.Time.GoTime(time.Local)
if err != nil {
return d, errors.Trace(err)
}
t = t.Add(duration)
t = t.AddDate(int(year), int(month), int(day))
if t.Nanosecond() == 0 {
result.Fsp = 0
}
result.Time = types.FromGoTime(t)
d.SetMysqlTime(result)
return d, nil
}
year, month, day, duration, err := types.ExtractTimeValue(nodeIntervalUnit, interval)
if err != nil {
return d, errors.Trace(err)
}
op := args[0].GetInterface().(ast.DateArithType)
if op == ast.DateSub {
year, month, day, duration = -year, -month, -day, -duration
}
// TODO: Consider time_zone variable.
t, err := result.Time.GoTime(time.Local)
if err != nil {
return d, errors.Trace(err)
}
t = t.Add(duration)
t = t.AddDate(int(year), int(month), int(day))
if t.Nanosecond() == 0 {
result.Fsp = 0
}
result.Time = types.FromGoTime(t)
d.SetMysqlTime(result)
return d, nil
}

var reg = regexp.MustCompile(`[\d]+`)
Expand Down
20 changes: 11 additions & 9 deletions expression/builtin_time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,28 +596,30 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) {
}
}

func (s *testEvaluatorSuite) TestDateArith(c *C) {
func (s *testEvaluatorSuite) TestDateArithFuncs(c *C) {
defer testleak.AfterTest(c)()

date := []string{"2016-12-31", "2017-01-01"}
dateAdd := dateArithFuncFactory(ast.DateArithAdd)
dateSub := dateArithFuncFactory(ast.DateArithSub)

args := types.MakeDatums(ast.DateAdd, date[0], 1, "DAY")
v, err := builtinDateArith(args, s.ctx)
args := types.MakeDatums(date[0], 1, "DAY")
v, err := dateAdd(args, s.ctx)
c.Assert(err, IsNil)
c.Assert(v.GetMysqlTime().String(), Equals, date[1])

args = types.MakeDatums(ast.DateSub, date[1], 1, "DAY")
v, err = builtinDateArith(args, s.ctx)
args = types.MakeDatums(date[1], 1, "DAY")
v, err = dateSub(args, s.ctx)
c.Assert(err, IsNil)
c.Assert(v.GetMysqlTime().String(), Equals, date[0])

args = types.MakeDatums(ast.DateAdd, date[0], nil, "DAY")
v, err = builtinDateArith(args, s.ctx)
args = types.MakeDatums(date[0], nil, "DAY")
v, err = dateAdd(args, s.ctx)
c.Assert(err, IsNil)
c.Assert(v.IsNull(), IsTrue)

args = types.MakeDatums(ast.DateSub, date[1], nil, "DAY")
v, err = builtinDateArith(args, s.ctx)
args = types.MakeDatums(date[1], nil, "DAY")
v, err = dateSub(args, s.ctx)
c.Assert(err, IsNil)
c.Assert(v.IsNull(), IsTrue)
}
39 changes: 13 additions & 26 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,6 @@ import (
DatabaseOptionListOpt "CREATE Database specification list opt"
CreateTableStmt "CREATE TABLE statement"
CreateUserStmt "CREATE User statement"
DateArithOpt "Date arith dateadd or datesub option"
DateArithMultiFormsOpt "Date arith adddate or subdate option"
DBName "Database Name"
DeallocateStmt "Deallocate prepared statement"
Default "DEFAULT clause"
Expand Down Expand Up @@ -716,7 +714,9 @@ import (
NotKeywordToken "Tokens not mysql keyword but treated specially"
UnReservedKeyword "MySQL unreserved keywords"
ReservedKeyword "MySQL reserved keywords"
FunctionNameConflict "Built-in function call names which are conflict with keywords"
FunctionNameConflict "Built-in function call names which are conflict with keywords"
FunctionNameDateArith "Date arith function call names (date_add or date_sub)"
FunctionNameDateArithMultiForms "Date arith function call names (adddate or subdate)"

%precedence lowestOpt
%token tableRefPriority
Expand Down Expand Up @@ -2661,36 +2661,33 @@ FunctionCallNonKeyword:
{
$$ = &ast.FuncCallExpr{FnName: model.NewCIStr($1), Args: []ast.ExprNode{$3.(ast.ExprNode)}}
}
| DateArithMultiFormsOpt '(' Expression ',' Expression ')'
| FunctionNameDateArithMultiForms '(' Expression ',' Expression ')'
{
$$ = &ast.FuncCallExpr{
FnName: model.NewCIStr("DATE_ARITH"),
FnName: model.NewCIStr($1),
Args: []ast.ExprNode{
ast.NewValueExpr($1),
$3.(ast.ExprNode),
$5.(ast.ExprNode),
ast.NewValueExpr("DAY"),
},
}
}
| DateArithMultiFormsOpt '(' Expression ',' "INTERVAL" Expression TimeUnit ')'
| FunctionNameDateArithMultiForms '(' Expression ',' "INTERVAL" Expression TimeUnit ')'
{
$$ = &ast.FuncCallExpr{
FnName: model.NewCIStr("DATE_ARITH"),
FnName: model.NewCIStr($1),
Args: []ast.ExprNode{
ast.NewValueExpr($1),
$3.(ast.ExprNode),
$6.(ast.ExprNode),
ast.NewValueExpr($7),
},
}
}
| DateArithOpt '(' Expression ',' "INTERVAL" Expression TimeUnit ')'
| FunctionNameDateArith '(' Expression ',' "INTERVAL" Expression TimeUnit ')'
{
$$ = &ast.FuncCallExpr{
FnName: model.NewCIStr("DATE_ARITH"),
FnName: model.NewCIStr($1),
Args: []ast.ExprNode{
ast.NewValueExpr($1),
$3.(ast.ExprNode),
$6.(ast.ExprNode),
ast.NewValueExpr($7),
Expand Down Expand Up @@ -3101,25 +3098,15 @@ FunctionCallNonKeyword:
}


DateArithOpt:
FunctionNameDateArith:
"DATE_ADD"
{
$$ = ast.DateAdd
}
| "DATE_SUB"
{
$$ = ast.DateSub
}

DateArithMultiFormsOpt:

FunctionNameDateArithMultiForms:
"ADDDATE"
{
$$ = ast.DateAdd
}
| "SUBDATE"
{
$$ = ast.DateSub
}


TrimDirection:
"BOTH"
Expand Down
Loading

0 comments on commit 2521ab4

Please sign in to comment.