diff --git a/executor/write_test.go b/executor/write_test.go index b558d2e03d8ca..fc728ad43bfec 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -137,6 +137,14 @@ func (s *testSuite) TestInsert(c *C) { tk.MustExec("update t t1 set id = (select count(*) + 1 from t t2 where t1.id = t2.id)") r = tk.MustQuery("select * from t;") r.Check(testkit.Rows("2")) + + // issue 3235 + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(c decimal(5, 5))") + _, err = tk.Exec("insert into t value(0)") + c.Assert(err, IsNil) + _, err = tk.Exec("insert into t value(1)") + c.Assert(types.ErrOverflow.Equal(err), IsTrue) } func (s *testSuite) TestInsertAutoInc(c *C) { diff --git a/parser/parser.y b/parser/parser.y index 0d5a24cc241a6..1030b2d1ceeff 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -3944,6 +3944,8 @@ CastType: if fopt.Flen == types.UnspecifiedLength { x.Flen = mysql.GetDefaultFieldLength(mysql.TypeNewDecimal) x.Decimal = mysql.GetDefaultDecimal(mysql.TypeNewDecimal) + } else if fopt.Decimal == types.UnspecifiedLength { + x.Decimal = mysql.GetDefaultDecimal(mysql.TypeNewDecimal) } $$ = x } diff --git a/util/types/datum.go b/util/types/datum.go index d55c1eccf16a4..3bfdae6d73e7e 100644 --- a/util/types/datum.go +++ b/util/types/datum.go @@ -1048,20 +1048,29 @@ func (d *Datum) convertToMysqlDecimal(sc *variable.StatementContext, target *Fie default: return invalidConv(d, target.Tp) } - if target.Flen != UnspecifiedLength && target.Decimal != UnspecifiedLength { + dec, err = produceDecWithSpecifiedTp(dec, target, sc) + ret.SetValue(dec) + return ret, errors.Trace(err) +} + +// produceDecToSpecifiedTp produces a new decimal according to `Tp`. +func produceDecWithSpecifiedTp(dec *MyDecimal, tp *FieldType, sc *variable.StatementContext) (_ *MyDecimal, err error) { + flen, decimal := tp.Flen, tp.Decimal + if flen != UnspecifiedLength && decimal != UnspecifiedLength { prec, frac := dec.PrecisionAndFrac() - if prec-frac > target.Flen-target.Decimal { - dec = NewMaxOrMinDec(dec.IsNegative(), target.Flen, target.Decimal) - err = ErrOverflow.GenByArgs("DECIMAL", fmt.Sprintf("(%d, %d)", target.Flen, target.Decimal)) - } else if frac != target.Decimal { - dec.Round(dec, target.Decimal, ModeHalfEven) - if frac > target.Decimal { - err = errors.Trace(handleTruncateError(sc)) + if !dec.IsZero() && prec-frac > flen-decimal { + dec = NewMaxOrMinDec(dec.IsNegative(), flen, decimal) + // TODO: we may need a OverlowAsWarning. + // select (cast 111 as decimal(1)) causes a warning in MySQL. + err = ErrOverflow.GenByArgs("DECIMAL", fmt.Sprintf("(%d, %d)", flen, decimal)) + } else if frac != decimal { + dec.Round(dec, decimal, ModeHalfEven) + if !dec.IsZero() && frac > decimal { + err = sc.HandleTruncate(ErrTruncated) } } } - ret.SetValue(dec) - return ret, err + return dec, errors.Trace(err) } func (d *Datum) convertToMysqlYear(sc *variable.StatementContext, target *FieldType) (Datum, error) { diff --git a/util/types/mydecimal.go b/util/types/mydecimal.go index 11bfdadd4b10a..80b0df1979e5f 100644 --- a/util/types/mydecimal.go +++ b/util/types/mydecimal.go @@ -1182,6 +1182,18 @@ func (d *MyDecimal) PrecisionAndFrac() (precision, frac int) { return } +// IsZero checks whether it's a zero decimal. +func (d *MyDecimal) IsZero() bool { + isZero := true + for _, val := range d.wordBuf { + if val != 0 { + isZero = false + break + } + } + return isZero +} + // FromBin Restores decimal from its binary fixed-length representation. func (d *MyDecimal) FromBin(bin []byte, precision, frac int) (binSize int, err error) { if len(bin) == 0 {