diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 945ec4a032f01..d707b8eb0c4bf 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -400,7 +400,7 @@ func getDefaultValue(ctx sessionctx.Context, c *ast.ColumnOption, tp byte, fsp i return v.GetBinaryLiteral().ToString(), nil } // For other kind of fields (e.g. INT), we supply its integer value so that it acts as integers. - return v.GetBinaryLiteral().ToInt() + return v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx) } if tp == mysql.TypeBit { diff --git a/tablecodec/tablecodec.go b/tablecodec/tablecodec.go index 94fb3a1fab8c6..cbd7c48be6734 100644 --- a/tablecodec/tablecodec.go +++ b/tablecodec/tablecodec.go @@ -193,7 +193,7 @@ func DecodeRowKey(key kv.Key) (int64, error) { // EncodeValue encodes a go value to bytes. func EncodeValue(sc *stmtctx.StatementContext, raw types.Datum) ([]byte, error) { var v types.Datum - err := flatten(raw, sc.TimeZone, &v) + err := flatten(sc, raw, &v) if err != nil { return nil, errors.Trace(err) } @@ -216,7 +216,7 @@ func EncodeRow(sc *stmtctx.StatementContext, row []types.Datum, colIDs []int64, for i, c := range row { id := colIDs[i] values[2*i].SetInt64(id) - err := flatten(c, sc.TimeZone, &values[2*i+1]) + err := flatten(sc, c, &values[2*i+1]) if err != nil { return nil, errors.Trace(err) } @@ -228,13 +228,13 @@ func EncodeRow(sc *stmtctx.StatementContext, row []types.Datum, colIDs []int64, return codec.EncodeValue(sc, valBuf, values...) } -func flatten(data types.Datum, loc *time.Location, ret *types.Datum) error { +func flatten(sc *stmtctx.StatementContext, data types.Datum, ret *types.Datum) error { switch data.Kind() { case types.KindMysqlTime: // for mysql datetime, timestamp and date type t := data.GetMysqlTime() - if t.Type == mysql.TypeTimestamp && loc != time.UTC { - err := t.ConvertTimeZone(loc, time.UTC) + if t.Type == mysql.TypeTimestamp && sc.TimeZone != time.UTC { + err := t.ConvertTimeZone(sc.TimeZone, time.UTC) if err != nil { return errors.Trace(err) } @@ -254,7 +254,7 @@ func flatten(data types.Datum, loc *time.Location, ret *types.Datum) error { return nil case types.KindBinaryLiteral, types.KindMysqlBit: // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. - val, err := data.GetBinaryLiteral().ToInt() + val, err := data.GetBinaryLiteral().ToInt(sc) if err != nil { return errors.Trace(err) } diff --git a/types/binary_literal.go b/types/binary_literal.go index d7de61c4eb5f9..7e72bca891cf0 100644 --- a/types/binary_literal.go +++ b/types/binary_literal.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/juju/errors" + "github.com/pingcap/tidb/sessionctx/stmtctx" ) // BinaryLiteral is the internal type for storing bit / hex literal type. @@ -100,14 +101,18 @@ func (b BinaryLiteral) ToBitLiteralString(trimLeadingZero bool) string { } // ToInt returns the int value for the literal. -func (b BinaryLiteral) ToInt() (uint64, error) { +func (b BinaryLiteral) ToInt(sc *stmtctx.StatementContext) (uint64, error) { buf := trimLeadingZeroBytes(b) length := len(buf) if length == 0 { return 0, nil } if length > 8 { - return math.MaxUint64, ErrTruncated + var err error = ErrTruncatedWrongVal.GenByArgs("BINARY", b) + if sc != nil { + err = sc.HandleTruncate(err) + } + return math.MaxUint64, err } // Note: the byte-order is BigEndian. val := uint64(buf[0]) @@ -117,6 +122,19 @@ func (b BinaryLiteral) ToInt() (uint64, error) { return val, nil } +// Compare compares BinaryLiteral to another one +func (b BinaryLiteral) Compare(b2 BinaryLiteral) int { + bufB := trimLeadingZeroBytes(b) + bufB2 := trimLeadingZeroBytes(b2) + if len(bufB) > len(bufB2) { + return 1 + } + if len(bufB) < len(bufB2) { + return -1 + } + return bytes.Compare(bufB, bufB2) +} + // ParseBitStr parses bit string. // The string format can be b'val', B'val' or 0bval, val must be 0 or 1. // See https://dev.mysql.com/doc/refman/5.7/en/bit-value-literals.html diff --git a/types/binary_literal_test.go b/types/binary_literal_test.go index 2028ed1b39883..a8788ecd422bc 100644 --- a/types/binary_literal_test.go +++ b/types/binary_literal_test.go @@ -15,6 +15,7 @@ package types import ( . "github.com/pingcap/check" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/util/testleak" ) @@ -199,10 +200,11 @@ func (s *testBinaryLiteralSuite) TestToInt(c *C) { {"0x1010ffff8080ff12", 0x1010ffff8080ff12, false}, {"0x1010ffff8080ff12ff", 0xffffffffffffffff, true}, } + sc := new(stmtctx.StatementContext) for _, t := range tbl { hex, err := ParseHexStr(t.Input) c.Assert(err, IsNil) - intValue, err := hex.ToInt() + intValue, err := hex.ToInt(sc) if t.HasError { c.Assert(err, NotNil) } else { @@ -242,3 +244,19 @@ func (s *testBinaryLiteralSuite) TestNewBinaryLiteralFromUint(c *C) { c.Assert([]byte(hex), DeepEquals, t.Expected, Commentf("%#v", t)) } } + +func (s *testBinaryLiteralSuite) TestCompare(c *C) { + tbl := []struct { + a BinaryLiteral + b BinaryLiteral + cmp int + }{ + {BinaryLiteral{0, 0, 1}, BinaryLiteral{2}, -1}, + {BinaryLiteral{0, 1}, BinaryLiteral{0, 0, 2}, -1}, + {BinaryLiteral{0, 1}, BinaryLiteral{1}, 0}, + {BinaryLiteral{0, 2, 1}, BinaryLiteral{1, 2}, 1}, + } + for _, t := range tbl { + c.Assert(t.a.Compare(t.b), Equals, t.cmp) + } +} diff --git a/types/datum.go b/types/datum.go index 4008e46125091..1605a03ed3546 100644 --- a/types/datum.go +++ b/types/datum.go @@ -514,7 +514,7 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er fVal := d.GetMysqlEnum().ToNumber() return CompareFloat64(fVal, f), nil case KindBinaryLiteral, KindMysqlBit: - val, err := d.GetBinaryLiteral().ToInt() + val, err := d.GetBinaryLiteral().ToInt(sc) fVal := float64(val) return CompareFloat64(fVal, f), errors.Trace(err) case KindMysqlSet: @@ -610,7 +610,7 @@ func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiter case KindBinaryLiteral, KindMysqlBit: return CompareString(d.GetBinaryLiteral().ToString(), b.ToString()), nil default: - val, err := b.ToInt() + val, err := b.ToInt(sc) if err != nil { return 0, errors.Trace(err) } @@ -722,7 +722,7 @@ func (d *Datum) convertToFloat(sc *stmtctx.StatementContext, target *FieldType) case KindMysqlEnum: f = d.GetMysqlEnum().ToNumber() case KindBinaryLiteral, KindMysqlBit: - val, err1 := d.GetBinaryLiteral().ToInt() + val, err1 := d.GetBinaryLiteral().ToInt(sc) f, err = float64(val), err1 case KindMysqlJSON: f, err = ConvertJSONToFloat(sc, d.GetMysqlJSON()) @@ -889,7 +889,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) ( case KindMysqlSet: val, err = ConvertFloatToUint(d.GetMysqlSet().ToNumber(), upperBound, tp) case KindBinaryLiteral, KindMysqlBit: - val, err = d.GetBinaryLiteral().ToInt() + val, err = d.GetBinaryLiteral().ToInt(sc) case KindMysqlJSON: var i64 int64 i64, err = ConvertJSONToInt(sc, d.GetMysqlJSON(), true) @@ -1047,7 +1047,7 @@ func (d *Datum) convertToMysqlDecimal(sc *stmtctx.StatementContext, target *Fiel case KindMysqlSet: err = dec.FromFloat64(d.GetMysqlSet().ToNumber()) case KindBinaryLiteral, KindMysqlBit: - val, err1 := d.GetBinaryLiteral().ToInt() + val, err1 := d.GetBinaryLiteral().ToInt(sc) err = err1 dec.FromUint(val) case KindMysqlJSON: @@ -1143,7 +1143,7 @@ func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldTyp var err error switch d.k { case KindString, KindBytes: - uintValue, err = BinaryLiteral(d.b).ToInt() + uintValue, err = BinaryLiteral(d.b).ToInt(sc) default: uintDatum, err1 := d.convertToUint(sc, target) uintValue, err = uintDatum.GetUint64(), err1 @@ -1269,7 +1269,7 @@ func (d *Datum) ToBool(sc *stmtctx.StatementContext) (int64, error) { case KindMysqlSet: isZero = d.GetMysqlSet().ToNumber() == 0 case KindBinaryLiteral, KindMysqlBit: - val, err1 := d.GetBinaryLiteral().ToInt() + val, err1 := d.GetBinaryLiteral().ToInt(sc) isZero, err = val == 0, err1 default: return 0, errors.Errorf("cannot convert %v(type %T) to bool", d.GetValue(), d.GetValue()) @@ -1308,7 +1308,7 @@ func ConvertDatumToDecimal(sc *stmtctx.StatementContext, d Datum) (*MyDecimal, e case KindMysqlSet: dec.FromUint(d.GetMysqlSet().Value) case KindBinaryLiteral, KindMysqlBit: - val, err1 := d.GetBinaryLiteral().ToInt() + val, err1 := d.GetBinaryLiteral().ToInt(sc) dec.FromUint(val) err = err1 case KindMysqlJSON: @@ -1406,7 +1406,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e case KindMysqlJSON: return ConvertJSONToInt(sc, d.GetMysqlJSON(), false) case KindBinaryLiteral, KindMysqlBit: - val, err := d.GetBinaryLiteral().ToInt() + val, err := d.GetBinaryLiteral().ToInt(sc) return int64(val), errors.Trace(err) default: return 0, errors.Errorf("cannot convert %v(type %T) to int64", d.GetValue(), d.GetValue()) @@ -1442,7 +1442,7 @@ func (d *Datum) ToFloat64(sc *stmtctx.StatementContext) (float64, error) { case KindMysqlSet: return d.GetMysqlSet().ToNumber(), nil case KindBinaryLiteral, KindMysqlBit: - val, err := d.GetBinaryLiteral().ToInt() + val, err := d.GetBinaryLiteral().ToInt(sc) return float64(val), errors.Trace(err) case KindMysqlJSON: f, err := ConvertJSONToFloat(sc, d.GetMysqlJSON()) @@ -1590,7 +1590,7 @@ func CoerceDatum(sc *stmtctx.StatementContext, a, b Datum) (x, y Datum, err erro y.SetFloat64(float64(y.GetUint64())) case KindBinaryLiteral, KindMysqlBit: var fval uint64 - fval, err = y.GetBinaryLiteral().ToInt() + fval, err = y.GetBinaryLiteral().ToInt(sc) if err != nil { return x, y, errors.Trace(err) } diff --git a/types/datum_eval.go b/types/datum_eval.go index f828de7233690..02067901c05a9 100644 --- a/types/datum_eval.go +++ b/types/datum_eval.go @@ -62,7 +62,7 @@ func CoerceArithmetic(sc *stmtctx.StatementContext, a Datum) (d Datum, err error d.SetMysqlDecimal(de) return d, nil case KindBinaryLiteral, KindMysqlBit: - val, err1 := a.GetBinaryLiteral().ToInt() + val, err1 := a.GetBinaryLiteral().ToInt(sc) d.SetUint64(val) return d, err1 case KindMysqlEnum: diff --git a/types/datum_test.go b/types/datum_test.go index d34bc33acd7b0..b2fd3f202dfd6 100644 --- a/types/datum_test.go +++ b/types/datum_test.go @@ -158,6 +158,10 @@ func (ts *testTypeConvertSuite) TestToInt64(c *C) { v, err := Convert(3.1415926, ft) c.Assert(err, IsNil) testDatumToInt64(c, v, int64(3)) + + binLit, err := ParseHexStr("0x9999999999999999999999999999999999999999999") + c.Assert(err, IsNil) + testDatumToInt64(c, binLit, -1) } func (ts *testTypeConvertSuite) TestToFloat32(c *C) { diff --git a/util/chunk/compare.go b/util/chunk/compare.go index 53b456c06aad7..cb21d420906bb 100644 --- a/util/chunk/compare.go +++ b/util/chunk/compare.go @@ -17,7 +17,6 @@ import ( "sort" "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" ) @@ -150,11 +149,7 @@ func cmpBit(l Row, lCol int, r Row, rCol int) int { } lBit := types.BinaryLiteral(l.GetBytes(lCol)) rBit := types.BinaryLiteral(r.GetBytes(rCol)) - lUint, err := lBit.ToInt() - terror.Log(err) - rUint, err := rBit.ToInt() - terror.Log(err) - return types.CompareUint64(lUint, rUint) + return lBit.Compare(rBit) } func cmpJSON(l Row, lCol int, r Row, rCol int) int { diff --git a/util/codec/codec.go b/util/codec/codec.go index 9dc92656cae89..da32927d0598f 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -105,7 +105,7 @@ func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparab b = encodeUnsignedInt(b, uint64(vals[i].GetMysqlSet().ToNumber()), comparable) case types.KindMysqlBit, types.KindBinaryLiteral: // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. - val, err := vals[i].GetBinaryLiteral().ToInt() + val, err := vals[i].GetBinaryLiteral().ToInt(sc) terror.Log(errors.Trace(err)) b = encodeUnsignedInt(b, val, comparable) case types.KindMysqlJSON: @@ -244,7 +244,7 @@ func encodeChunkRow(sc *stmtctx.StatementContext, b []byte, row chunk.Row, allTy b = encodeUnsignedInt(b, uint64(row.GetSet(i).ToNumber()), comparable) case mysql.TypeBit: // We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit. - val, err := types.BinaryLiteral(row.GetBytes(i)).ToInt() + val, err := types.BinaryLiteral(row.GetBytes(i)).ToInt(sc) terror.Log(errors.Trace(err)) b = encodeUnsignedInt(b, val, comparable) case mysql.TypeJSON: