Skip to content

Commit

Permalink
expression, executor: introduce propagateType for castDecimalAsReal (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored Aug 24, 2021
1 parent 9691e50 commit 0f51627
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 0 deletions.
21 changes: 21 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9019,6 +9019,27 @@ func (s *testSuite) TestIssue25506(c *C) {
tk.MustQuery("(select col_15 from tbl_23) union all (select col_15 from tbl_3 for update) order by col_15").Check(testkit.Rows("\x00\x00\x0F", "\x00\x00\xFF", "\x00\xFF\xFF"))
}

func (s *testSuite) TestIssue26348(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

tk.MustExec("drop table if exists t")
tk.MustExec(`CREATE TABLE t (
a varchar(8) DEFAULT NULL,
b varchar(8) DEFAULT NULL,
c decimal(20,2) DEFAULT NULL,
d decimal(15,8) DEFAULT NULL
);`)
tk.MustExec(`insert into t values(20210606, 20210606, 50000.00, 5.04600000);`)
tk.MustQuery(`select a * c *(d/36000) from t;`).Check(testkit.Rows("141642663.71666598"))
tk.MustQuery(`select cast(a as double) * cast(c as double) *cast(d/36000 as double) from t;`).Check(testkit.Rows("141642663.71666598"))
tk.MustQuery("select 20210606*50000.00*(5.04600000/36000)").Check(testkit.Rows("141642663.71666599297980"))

// differs from MySQL cause constant-fold .
tk.MustQuery("select \"20210606\"*50000.00*(5.04600000/36000)").Check(testkit.Rows("141642663.71666598"))
tk.MustQuery("select cast(\"20210606\" as double)*50000.00*(5.04600000/36000)").Check(testkit.Rows("141642663.71666598"))
}

func (s *testSuite) TestIssue26532(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
15 changes: 15 additions & 0 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ func (c *castAsRealFunctionClass) getFunction(ctx sessionctx.Context, args []Exp
sig.setPbCode(tipb.ScalarFuncSig_CastRealAsReal)
case types.ETDecimal:
sig = &builtinCastDecimalAsRealSig{bf}
PropagateType(types.ETReal, sig.getArgs()...)
sig.setPbCode(tipb.ScalarFuncSig_CastDecimalAsReal)
case types.ETDatetime, types.ETTimestamp:
sig = &builtinCastTimeAsRealSig{bf}
Expand Down Expand Up @@ -1012,6 +1013,20 @@ type builtinCastDecimalAsRealSig struct {
baseBuiltinCastFunc
}

func setDataTypeDouble(srcDecimal int) (flen, decimal int) {
decimal = mysql.NotFixedDec
flen = floatLength(srcDecimal, decimal)
return
}

func floatLength(srcDecimal int, decimalPar int) int {
const dblDIG = 15
if srcDecimal != mysql.NotFixedDec {
return dblDIG + 2 + decimalPar
}
return dblDIG + 8
}

func (b *builtinCastDecimalAsRealSig) Clone() builtinFunc {
newSig := &builtinCastDecimalAsRealSig{}
newSig.cloneFrom(&b.baseBuiltinCastFunc)
Expand Down
8 changes: 8 additions & 0 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ func (c *Constant) EvalDecimal(ctx sessionctx.Context, row chunk.Row) (*types.My
return nil, true, nil
}
res, err := dt.ToDecimal(ctx.GetSessionVars().StmtCtx)
if err != nil {
return nil, false, err
}
// The decimal may be modified during plan building.
_, frac := res.PrecisionAndFrac()
if frac < c.GetType().Decimal {
err = res.Round(res, c.GetType().Decimal, types.ModeHalfEven)
}
return res, false, err
}

Expand Down
48 changes: 48 additions & 0 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -1290,3 +1290,51 @@ func wrapWithIsTrue(ctx sessionctx.Context, keepNull bool, arg Expression, wrapF
}
return FoldConstant(sf), nil
}

// PropagateType propagates the type information to the `expr`.
// Note: For now, we only propagate type for the function CastDecimalAsDouble.
//
// e.g.
// > create table t(a decimal(9, 8));
// > insert into t values(5.04600000)
// > select a/36000 from t;
// Type: NEWDECIMAL
// Length: 15
// Decimals: 12
// +------------------+
// | 5.04600000/36000 |
// +------------------+
// | 0.000140166667 |
// +------------------+
//
// > select cast(a/36000 as double) as result from t;
// Type: DOUBLE
// Length: 23
// Decimals: 31
// +----------------------+
// | result |
// +----------------------+
// | 0.000140166666666666 |
// +----------------------+
// The expected `decimal` and `length` of the outer cast_as_double need to be
// propagated to the inner div.
func PropagateType(evalType types.EvalType, args ...Expression) {
switch evalType {
case types.ETReal:
expr := args[0]
oldFlen, oldDecimal := expr.GetType().Flen, expr.GetType().Decimal
newFlen, newDecimal := setDataTypeDouble(expr.GetType().Decimal)
// For float(M,D), double(M,D) or decimal(M,D), M must be >= D.
if newFlen < newDecimal {
newFlen = oldFlen - oldDecimal + newDecimal
}
if oldFlen != newFlen || oldDecimal != newDecimal {
if col, ok := args[0].(*Column); ok {
newCol := col.Clone()
newCol.(*Column).RetType = col.RetType.Clone()
args[0] = newCol
}
args[0].GetType().Flen, args[0].GetType().Decimal = newFlen, newDecimal
}
}
}

0 comments on commit 0f51627

Please sign in to comment.