diff --git a/plan/typeinferer.go b/plan/typeinferer.go index b713d3270d14d..a673908d2bb6c 100644 --- a/plan/typeinferer.go +++ b/plan/typeinferer.go @@ -172,7 +172,7 @@ func (v *typeInferrer) binaryOperation(x *ast.BinaryOperationExpr) { x.Type.Init(mysql.TypeLonglong) case opcode.Plus, opcode.Minus, opcode.Mul, opcode.Mod: if x.L.GetType() != nil && x.R.GetType() != nil { - xTp := mergeArithType(x.L.GetType().Tp, x.R.GetType().Tp) + xTp := mergeArithType(x.L.GetType(), x.R.GetType()) x.Type.Init(xTp) leftUnsigned := x.L.GetType().Flag & mysql.UnsignedFlag rightUnsigned := x.R.GetType().Flag & mysql.UnsignedFlag @@ -181,7 +181,7 @@ func (v *typeInferrer) binaryOperation(x *ast.BinaryOperationExpr) { } case opcode.Div: if x.L.GetType() != nil && x.R.GetType() != nil { - xTp := mergeArithType(x.L.GetType().Tp, x.R.GetType().Tp) + xTp := mergeArithType(x.L.GetType(), x.R.GetType()) if xTp == mysql.TypeLonglong { xTp = mysql.TypeNewDecimal } @@ -192,7 +192,21 @@ func (v *typeInferrer) binaryOperation(x *ast.BinaryOperationExpr) { x.Type.Collate = charset.CollationBin } -func mergeArithType(a, b byte) byte { +// toArithType converts DateTime, Duration and Timestamp types to NewDecimal type if Decimal > 0. +func toArithType(ft *types.FieldType) (tp byte) { + tp = ft.Tp + if types.IsTypeTime(tp) { + if ft.Decimal > 0 { + tp = mysql.TypeNewDecimal + } else { + tp = mysql.TypeLonglong + } + } + return +} + +func mergeArithType(fta, ftb *types.FieldType) byte { + a, b := toArithType(fta), toArithType(ftb) switch a { case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat: return mysql.TypeDouble @@ -220,7 +234,7 @@ func (v *typeInferrer) unaryOperation(x *ast.UnaryOperationExpr) { x.Type.Init(mysql.TypeLonglong) if x.V.GetType() != nil { switch x.V.GetType().Tp { - case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat: + case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat, mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp: x.Type.Tp = mysql.TypeDouble case mysql.TypeNewDecimal: x.Type.Tp = mysql.TypeNewDecimal @@ -279,7 +293,7 @@ func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) { if len(x.Args) > 0 { tp = x.Args[0].GetType() for i := 1; i < len(x.Args); i++ { - mergeArithType(tp.Tp, x.Args[i].GetType().Tp) + mergeArithType(tp, x.Args[i].GetType()) } } case "interval": diff --git a/plan/typeinferer_test.go b/plan/typeinferer_test.go index b365d4262d896..7e46a333b822f 100644 --- a/plan/typeinferer_test.go +++ b/plan/typeinferer_test.go @@ -41,7 +41,7 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) { defer store.Close() testKit := testkit.NewTestKit(c, store) testKit.MustExec("use test") - testKit.MustExec("create table t (c1 int, c2 double, c3 text)") + testKit.MustExec("create table t (c1 int, c2 double, c3 text, c4 timestamp)") cases := []struct { expr string tp byte @@ -51,6 +51,8 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) { {"+1", mysql.TypeLonglong, charset.CharsetBin}, {"-1", mysql.TypeLonglong, charset.CharsetBin}, {"-'1'", mysql.TypeDouble, charset.CharsetBin}, + {"-curtime()", mysql.TypeDouble, charset.CharsetBin}, + {"-now()", mysql.TypeDouble, charset.CharsetBin}, {"~1", mysql.TypeLonglong, charset.CharsetBin}, {"1e0", mysql.TypeDouble, charset.CharsetBin}, {"1.0", mysql.TypeNewDecimal, charset.CharsetBin}, @@ -73,6 +75,22 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) { {"1 + '1'", mysql.TypeDouble, charset.CharsetBin}, {"1 + 1.1", mysql.TypeNewDecimal, charset.CharsetBin}, + {"now() + 0", mysql.TypeLonglong, charset.CharsetBin}, + {"curtime() + 0", mysql.TypeLonglong, charset.CharsetBin}, + {"now(0) + 0", mysql.TypeLonglong, charset.CharsetBin}, + {"now(2) + 0", mysql.TypeNewDecimal, charset.CharsetBin}, + {"now() + 1.1", mysql.TypeNewDecimal, charset.CharsetBin}, + {"now() + '1'", mysql.TypeDouble, charset.CharsetBin}, + {"now(2) + '1'", mysql.TypeDouble, charset.CharsetBin}, + {"now() + curtime()", mysql.TypeLonglong, charset.CharsetBin}, + {"now() + now()", mysql.TypeLonglong, charset.CharsetBin}, + {"now() + now(2)", mysql.TypeNewDecimal, charset.CharsetBin}, + {"c2 + now()", mysql.TypeDouble, charset.CharsetBin}, + {"c4 + 1", mysql.TypeLonglong, charset.CharsetBin}, + {"c4 + 1.1", mysql.TypeNewDecimal, charset.CharsetBin}, + {"c4 + '1.1'", mysql.TypeDouble, charset.CharsetBin}, + {"1.1 + now()", mysql.TypeNewDecimal, charset.CharsetBin}, + {"1 + now()", mysql.TypeLonglong, charset.CharsetBin}, {"1 div 2", mysql.TypeLonglong, charset.CharsetBin}, {"1 / 2", mysql.TypeNewDecimal, charset.CharsetBin}, @@ -106,6 +124,7 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) { {"curtime()", mysql.TypeDuration, charset.CharsetBin}, {"current_time()", mysql.TypeDuration, charset.CharsetBin}, {"curtime()", mysql.TypeDuration, charset.CharsetBin}, + {"curtime(2)", mysql.TypeDuration, charset.CharsetBin}, {"current_timestamp()", mysql.TypeDatetime, charset.CharsetBin}, {"utc_timestamp()", mysql.TypeDatetime, charset.CharsetBin}, {"microsecond('2009-12-31 23:59:59.000010')", mysql.TypeLonglong, charset.CharsetBin}, diff --git a/util/types/etc.go b/util/types/etc.go index c9439c8d90586..cc17fa73855da 100644 --- a/util/types/etc.go +++ b/util/types/etc.go @@ -49,6 +49,17 @@ func IsTypeChar(tp byte) bool { } } +// IsTypeTime returns a boolean indicating +// whether the tp is the time type like a datetime type, a duration type, or a timestamp type. +func IsTypeTime(tp byte) bool { + switch tp { + case mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp: + return true + default: + return false + } +} + // IsTypePrefixable returns a boolean indicating // whether an index on a column with the tp can be defined with a prefix. func IsTypePrefixable(tp byte) bool {