Skip to content

Commit

Permalink
types: handle truncate error in BinaryLiteral.ToInt (pingcap#6163)
Browse files Browse the repository at this point in the history
Truncate error make be treated as warning, we should handle it in `BinaryLiteral.ToInt`.

And compare BinaryLiteral directly instead of converting to uint64 because that the compare function doesn't have statement context.
  • Loading branch information
coocood authored Mar 28, 2018
1 parent c86cd59 commit 9ca8689
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 30 deletions.
2 changes: 1 addition & 1 deletion ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ func getDefaultValue(ctx sessionctx.Context, c *ast.ColumnOption, tp byte, fsp i
return v.GetBinaryLiteral().ToString(), nil
}
// For other kind of fields (e.g. INT), we supply its integer value so that it acts as integers.
return v.GetBinaryLiteral().ToInt()
return v.GetBinaryLiteral().ToInt(ctx.GetSessionVars().StmtCtx)
}

if tp == mysql.TypeBit {
Expand Down
12 changes: 6 additions & 6 deletions tablecodec/tablecodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func DecodeRowKey(key kv.Key) (int64, error) {
// EncodeValue encodes a go value to bytes.
func EncodeValue(sc *stmtctx.StatementContext, raw types.Datum) ([]byte, error) {
var v types.Datum
err := flatten(raw, sc.TimeZone, &v)
err := flatten(sc, raw, &v)
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -216,7 +216,7 @@ func EncodeRow(sc *stmtctx.StatementContext, row []types.Datum, colIDs []int64,
for i, c := range row {
id := colIDs[i]
values[2*i].SetInt64(id)
err := flatten(c, sc.TimeZone, &values[2*i+1])
err := flatten(sc, c, &values[2*i+1])
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -228,13 +228,13 @@ func EncodeRow(sc *stmtctx.StatementContext, row []types.Datum, colIDs []int64,
return codec.EncodeValue(sc, valBuf, values...)
}

func flatten(data types.Datum, loc *time.Location, ret *types.Datum) error {
func flatten(sc *stmtctx.StatementContext, data types.Datum, ret *types.Datum) error {
switch data.Kind() {
case types.KindMysqlTime:
// for mysql datetime, timestamp and date type
t := data.GetMysqlTime()
if t.Type == mysql.TypeTimestamp && loc != time.UTC {
err := t.ConvertTimeZone(loc, time.UTC)
if t.Type == mysql.TypeTimestamp && sc.TimeZone != time.UTC {
err := t.ConvertTimeZone(sc.TimeZone, time.UTC)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -254,7 +254,7 @@ func flatten(data types.Datum, loc *time.Location, ret *types.Datum) error {
return nil
case types.KindBinaryLiteral, types.KindMysqlBit:
// We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit.
val, err := data.GetBinaryLiteral().ToInt()
val, err := data.GetBinaryLiteral().ToInt(sc)
if err != nil {
return errors.Trace(err)
}
Expand Down
22 changes: 20 additions & 2 deletions types/binary_literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"strings"

"github.com/juju/errors"
"github.com/pingcap/tidb/sessionctx/stmtctx"
)

// BinaryLiteral is the internal type for storing bit / hex literal type.
Expand Down Expand Up @@ -100,14 +101,18 @@ func (b BinaryLiteral) ToBitLiteralString(trimLeadingZero bool) string {
}

// ToInt returns the int value for the literal.
func (b BinaryLiteral) ToInt() (uint64, error) {
func (b BinaryLiteral) ToInt(sc *stmtctx.StatementContext) (uint64, error) {
buf := trimLeadingZeroBytes(b)
length := len(buf)
if length == 0 {
return 0, nil
}
if length > 8 {
return math.MaxUint64, ErrTruncated
var err error = ErrTruncatedWrongVal.GenByArgs("BINARY", b)
if sc != nil {
err = sc.HandleTruncate(err)
}
return math.MaxUint64, err
}
// Note: the byte-order is BigEndian.
val := uint64(buf[0])
Expand All @@ -117,6 +122,19 @@ func (b BinaryLiteral) ToInt() (uint64, error) {
return val, nil
}

// Compare compares BinaryLiteral to another one
func (b BinaryLiteral) Compare(b2 BinaryLiteral) int {
bufB := trimLeadingZeroBytes(b)
bufB2 := trimLeadingZeroBytes(b2)
if len(bufB) > len(bufB2) {
return 1
}
if len(bufB) < len(bufB2) {
return -1
}
return bytes.Compare(bufB, bufB2)
}

// ParseBitStr parses bit string.
// The string format can be b'val', B'val' or 0bval, val must be 0 or 1.
// See https://dev.mysql.com/doc/refman/5.7/en/bit-value-literals.html
Expand Down
20 changes: 19 additions & 1 deletion types/binary_literal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package types

import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/util/testleak"
)

Expand Down Expand Up @@ -199,10 +200,11 @@ func (s *testBinaryLiteralSuite) TestToInt(c *C) {
{"0x1010ffff8080ff12", 0x1010ffff8080ff12, false},
{"0x1010ffff8080ff12ff", 0xffffffffffffffff, true},
}
sc := new(stmtctx.StatementContext)
for _, t := range tbl {
hex, err := ParseHexStr(t.Input)
c.Assert(err, IsNil)
intValue, err := hex.ToInt()
intValue, err := hex.ToInt(sc)
if t.HasError {
c.Assert(err, NotNil)
} else {
Expand Down Expand Up @@ -242,3 +244,19 @@ func (s *testBinaryLiteralSuite) TestNewBinaryLiteralFromUint(c *C) {
c.Assert([]byte(hex), DeepEquals, t.Expected, Commentf("%#v", t))
}
}

func (s *testBinaryLiteralSuite) TestCompare(c *C) {
tbl := []struct {
a BinaryLiteral
b BinaryLiteral
cmp int
}{
{BinaryLiteral{0, 0, 1}, BinaryLiteral{2}, -1},
{BinaryLiteral{0, 1}, BinaryLiteral{0, 0, 2}, -1},
{BinaryLiteral{0, 1}, BinaryLiteral{1}, 0},
{BinaryLiteral{0, 2, 1}, BinaryLiteral{1, 2}, 1},
}
for _, t := range tbl {
c.Assert(t.a.Compare(t.b), Equals, t.cmp)
}
}
22 changes: 11 additions & 11 deletions types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ func (d *Datum) compareFloat64(sc *stmtctx.StatementContext, f float64) (int, er
fVal := d.GetMysqlEnum().ToNumber()
return CompareFloat64(fVal, f), nil
case KindBinaryLiteral, KindMysqlBit:
val, err := d.GetBinaryLiteral().ToInt()
val, err := d.GetBinaryLiteral().ToInt(sc)
fVal := float64(val)
return CompareFloat64(fVal, f), errors.Trace(err)
case KindMysqlSet:
Expand Down Expand Up @@ -610,7 +610,7 @@ func (d *Datum) compareBinaryLiteral(sc *stmtctx.StatementContext, b BinaryLiter
case KindBinaryLiteral, KindMysqlBit:
return CompareString(d.GetBinaryLiteral().ToString(), b.ToString()), nil
default:
val, err := b.ToInt()
val, err := b.ToInt(sc)
if err != nil {
return 0, errors.Trace(err)
}
Expand Down Expand Up @@ -722,7 +722,7 @@ func (d *Datum) convertToFloat(sc *stmtctx.StatementContext, target *FieldType)
case KindMysqlEnum:
f = d.GetMysqlEnum().ToNumber()
case KindBinaryLiteral, KindMysqlBit:
val, err1 := d.GetBinaryLiteral().ToInt()
val, err1 := d.GetBinaryLiteral().ToInt(sc)
f, err = float64(val), err1
case KindMysqlJSON:
f, err = ConvertJSONToFloat(sc, d.GetMysqlJSON())
Expand Down Expand Up @@ -889,7 +889,7 @@ func (d *Datum) convertToUint(sc *stmtctx.StatementContext, target *FieldType) (
case KindMysqlSet:
val, err = ConvertFloatToUint(d.GetMysqlSet().ToNumber(), upperBound, tp)
case KindBinaryLiteral, KindMysqlBit:
val, err = d.GetBinaryLiteral().ToInt()
val, err = d.GetBinaryLiteral().ToInt(sc)
case KindMysqlJSON:
var i64 int64
i64, err = ConvertJSONToInt(sc, d.GetMysqlJSON(), true)
Expand Down Expand Up @@ -1047,7 +1047,7 @@ func (d *Datum) convertToMysqlDecimal(sc *stmtctx.StatementContext, target *Fiel
case KindMysqlSet:
err = dec.FromFloat64(d.GetMysqlSet().ToNumber())
case KindBinaryLiteral, KindMysqlBit:
val, err1 := d.GetBinaryLiteral().ToInt()
val, err1 := d.GetBinaryLiteral().ToInt(sc)
err = err1
dec.FromUint(val)
case KindMysqlJSON:
Expand Down Expand Up @@ -1143,7 +1143,7 @@ func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldTyp
var err error
switch d.k {
case KindString, KindBytes:
uintValue, err = BinaryLiteral(d.b).ToInt()
uintValue, err = BinaryLiteral(d.b).ToInt(sc)
default:
uintDatum, err1 := d.convertToUint(sc, target)
uintValue, err = uintDatum.GetUint64(), err1
Expand Down Expand Up @@ -1269,7 +1269,7 @@ func (d *Datum) ToBool(sc *stmtctx.StatementContext) (int64, error) {
case KindMysqlSet:
isZero = d.GetMysqlSet().ToNumber() == 0
case KindBinaryLiteral, KindMysqlBit:
val, err1 := d.GetBinaryLiteral().ToInt()
val, err1 := d.GetBinaryLiteral().ToInt(sc)
isZero, err = val == 0, err1
default:
return 0, errors.Errorf("cannot convert %v(type %T) to bool", d.GetValue(), d.GetValue())
Expand Down Expand Up @@ -1308,7 +1308,7 @@ func ConvertDatumToDecimal(sc *stmtctx.StatementContext, d Datum) (*MyDecimal, e
case KindMysqlSet:
dec.FromUint(d.GetMysqlSet().Value)
case KindBinaryLiteral, KindMysqlBit:
val, err1 := d.GetBinaryLiteral().ToInt()
val, err1 := d.GetBinaryLiteral().ToInt(sc)
dec.FromUint(val)
err = err1
case KindMysqlJSON:
Expand Down Expand Up @@ -1406,7 +1406,7 @@ func (d *Datum) toSignedInteger(sc *stmtctx.StatementContext, tp byte) (int64, e
case KindMysqlJSON:
return ConvertJSONToInt(sc, d.GetMysqlJSON(), false)
case KindBinaryLiteral, KindMysqlBit:
val, err := d.GetBinaryLiteral().ToInt()
val, err := d.GetBinaryLiteral().ToInt(sc)
return int64(val), errors.Trace(err)
default:
return 0, errors.Errorf("cannot convert %v(type %T) to int64", d.GetValue(), d.GetValue())
Expand Down Expand Up @@ -1442,7 +1442,7 @@ func (d *Datum) ToFloat64(sc *stmtctx.StatementContext) (float64, error) {
case KindMysqlSet:
return d.GetMysqlSet().ToNumber(), nil
case KindBinaryLiteral, KindMysqlBit:
val, err := d.GetBinaryLiteral().ToInt()
val, err := d.GetBinaryLiteral().ToInt(sc)
return float64(val), errors.Trace(err)
case KindMysqlJSON:
f, err := ConvertJSONToFloat(sc, d.GetMysqlJSON())
Expand Down Expand Up @@ -1590,7 +1590,7 @@ func CoerceDatum(sc *stmtctx.StatementContext, a, b Datum) (x, y Datum, err erro
y.SetFloat64(float64(y.GetUint64()))
case KindBinaryLiteral, KindMysqlBit:
var fval uint64
fval, err = y.GetBinaryLiteral().ToInt()
fval, err = y.GetBinaryLiteral().ToInt(sc)
if err != nil {
return x, y, errors.Trace(err)
}
Expand Down
2 changes: 1 addition & 1 deletion types/datum_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func CoerceArithmetic(sc *stmtctx.StatementContext, a Datum) (d Datum, err error
d.SetMysqlDecimal(de)
return d, nil
case KindBinaryLiteral, KindMysqlBit:
val, err1 := a.GetBinaryLiteral().ToInt()
val, err1 := a.GetBinaryLiteral().ToInt(sc)
d.SetUint64(val)
return d, err1
case KindMysqlEnum:
Expand Down
4 changes: 4 additions & 0 deletions types/datum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ func (ts *testTypeConvertSuite) TestToInt64(c *C) {
v, err := Convert(3.1415926, ft)
c.Assert(err, IsNil)
testDatumToInt64(c, v, int64(3))

binLit, err := ParseHexStr("0x9999999999999999999999999999999999999999999")
c.Assert(err, IsNil)
testDatumToInt64(c, binLit, -1)
}

func (ts *testTypeConvertSuite) TestToFloat32(c *C) {
Expand Down
7 changes: 1 addition & 6 deletions util/chunk/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"sort"

"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
)
Expand Down Expand Up @@ -150,11 +149,7 @@ func cmpBit(l Row, lCol int, r Row, rCol int) int {
}
lBit := types.BinaryLiteral(l.GetBytes(lCol))
rBit := types.BinaryLiteral(r.GetBytes(rCol))
lUint, err := lBit.ToInt()
terror.Log(err)
rUint, err := rBit.ToInt()
terror.Log(err)
return types.CompareUint64(lUint, rUint)
return lBit.Compare(rBit)
}

func cmpJSON(l Row, lCol int, r Row, rCol int) int {
Expand Down
4 changes: 2 additions & 2 deletions util/codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func encode(sc *stmtctx.StatementContext, b []byte, vals []types.Datum, comparab
b = encodeUnsignedInt(b, uint64(vals[i].GetMysqlSet().ToNumber()), comparable)
case types.KindMysqlBit, types.KindBinaryLiteral:
// We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit.
val, err := vals[i].GetBinaryLiteral().ToInt()
val, err := vals[i].GetBinaryLiteral().ToInt(sc)
terror.Log(errors.Trace(err))
b = encodeUnsignedInt(b, val, comparable)
case types.KindMysqlJSON:
Expand Down Expand Up @@ -244,7 +244,7 @@ func encodeChunkRow(sc *stmtctx.StatementContext, b []byte, row chunk.Row, allTy
b = encodeUnsignedInt(b, uint64(row.GetSet(i).ToNumber()), comparable)
case mysql.TypeBit:
// We don't need to handle errors here since the literal is ensured to be able to store in uint64 in convertToMysqlBit.
val, err := types.BinaryLiteral(row.GetBytes(i)).ToInt()
val, err := types.BinaryLiteral(row.GetBytes(i)).ToInt(sc)
terror.Log(errors.Trace(err))
b = encodeUnsignedInt(b, val, comparable)
case mysql.TypeJSON:
Expand Down

0 comments on commit 9ca8689

Please sign in to comment.