From bc4a32431ff628cee756e0de6b581475e653e651 Mon Sep 17 00:00:00 2001 From: Huang Zexin Date: Mon, 12 Aug 2019 11:13:10 +0800 Subject: [PATCH] =?UTF-8?q?expression:=20remove=20unnecessary=20convertInt?= =?UTF-8?q?ToUint.=20fix=20error=20in=E2=80=A6=20(#11640)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- expression/builtin_cast.go | 29 +++++++++-------------------- expression/integration_test.go | 25 +++++++++++++++++++++++++ types/convert.go | 1 + types/etc.go | 5 +---- 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 366c48def1b80..0b8d0abd4977e 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -437,6 +437,9 @@ func (b *builtinCastIntAsIntSig) Clone() builtinFunc { func (b *builtinCastIntAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { res, isNull, err = b.args[0].EvalInt(b.ctx, row) + if isNull || err != nil { + return + } if b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && res < 0 { res = 0 } @@ -463,10 +466,8 @@ func (b *builtinCastIntAsRealSig) evalReal(row chunk.Row) (res float64, isNull b } else if b.inUnion && val < 0 { res = 0 } else { - var uVal uint64 - sc := b.ctx.GetSessionVars().StmtCtx - uVal, err = types.ConvertIntToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) - res = float64(uVal) + // recall that, int to float is different from uint to float + res = float64(uint64(val)) } return res, false, err } @@ -491,13 +492,7 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe } else if b.inUnion && val < 0 { res = &types.MyDecimal{} } else { - var uVal uint64 - sc := b.ctx.GetSessionVars().StmtCtx - uVal, err = types.ConvertIntToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) - if err != nil { - return res, false, err - } - res = types.NewDecFromUint(uVal) + res = types.NewDecFromUint(uint64(val)) } res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx) return res, isNull, err @@ -521,13 +516,7 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul if !mysql.HasUnsignedFlag(b.args[0].GetType().Flag) { res = strconv.FormatInt(val, 10) } else { - var uVal uint64 - sc := b.ctx.GetSessionVars().StmtCtx - uVal, err = types.ConvertIntToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) - if err != nil { - return res, false, err - } - res = strconv.FormatUint(uVal, 10) + res = strconv.FormatUint(uint64(val), 10) } res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx, false) if err != nil { @@ -748,13 +737,13 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool return res, isNull, err } if !mysql.HasUnsignedFlag(b.tp.Flag) { - res, err = types.ConvertFloatToInt(val, types.IntergerSignedLowerBound(mysql.TypeLonglong), types.IntergerSignedUpperBound(mysql.TypeLonglong), mysql.TypeDouble) + res, err = types.ConvertFloatToInt(val, types.IntergerSignedLowerBound(mysql.TypeLonglong), types.IntergerSignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) } else if b.inUnion && val < 0 { res = 0 } else { var uintVal uint64 sc := b.ctx.GetSessionVars().StmtCtx - uintVal, err = types.ConvertFloatToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeDouble) + uintVal, err = types.ConvertFloatToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) res = int64(uintVal) } return res, isNull, err diff --git a/expression/integration_test.go b/expression/integration_test.go index df87ece0ff292..59bdae56242d7 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -2200,6 +2200,31 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) { msg := strings.Split(err.Error(), " ") last := msg[len(msg)-1] c.Assert(last, Equals, "bigint") + tk.MustExec(`drop table tb5;`) + + // test builtinCastIntAsDecimalSig + tk.MustExec(`create table tb5(a bigint(64) unsigned, b decimal(64, 10));`) + tk.MustExec(`insert into tb5 (a, b) values (9223372036854775808, 9223372036854775808);`) + tk.MustExec(`insert into tb5 (select * from tb5 where a = b);`) + result = tk.MustQuery(`select * from tb5;`) + result.Check(testkit.Rows("9223372036854775808 9223372036854775808.0000000000", "9223372036854775808 9223372036854775808.0000000000")) + tk.MustExec(`drop table tb5;`) + + // test builtinCastIntAsRealSig + tk.MustExec(`create table tb5(a bigint(64) unsigned, b double(64, 10));`) + tk.MustExec(`insert into tb5 (a, b) values (13835058000000000000, 13835058000000000000);`) + tk.MustExec(`insert into tb5 (select * from tb5 where a = b);`) + result = tk.MustQuery(`select * from tb5;`) + result.Check(testkit.Rows("13835058000000000000 13835058000000000000", "13835058000000000000 13835058000000000000")) + tk.MustExec(`drop table tb5;`) + + // test builtinCastIntAsStringSig + tk.MustExec(`create table tb5(a bigint(64) unsigned,b varchar(50));`) + tk.MustExec(`insert into tb5(a, b) values (9223372036854775808, '9223372036854775808');`) + tk.MustExec(`insert into tb5(select * from tb5 where a = b);`) + result = tk.MustQuery(`select * from tb5;`) + result.Check(testkit.Rows("9223372036854775808 9223372036854775808", "9223372036854775808 9223372036854775808")) + tk.MustExec(`drop table tb5;`) // Test corner cases of cast string as datetime result = tk.MustQuery(`select cast("170102034" as datetime);`) diff --git a/types/convert.go b/types/convert.go index 4a8368332b5f7..1f733d50354d1 100644 --- a/types/convert.go +++ b/types/convert.go @@ -98,6 +98,7 @@ func IntergerSignedLowerBound(intType byte) int64 { } // ConvertFloatToInt converts a float64 value to a int value. +// `tp` is used in err msg, if there is overflow, this func will report err according to `tp` func ConvertFloatToInt(fval float64, lowerBound, upperBound int64, tp byte) (int64, error) { val := RoundFloat(fval) if val < float64(lowerBound) { diff --git a/types/etc.go b/types/etc.go index b8b5af64f38d5..e29c91171e5a7 100644 --- a/types/etc.go +++ b/types/etc.go @@ -83,10 +83,7 @@ func IsTemporalWithDate(tp byte) bool { // IsBinaryStr returns a boolean indicating // whether the field type is a binary string type. func IsBinaryStr(ft *FieldType) bool { - if ft.Collate == charset.CollationBin && IsString(ft.Tp) { - return true - } - return false + return ft.Collate == charset.CollationBin && IsString(ft.Tp) } // IsNonBinaryStr returns a boolean indicating