Skip to content

Commit

Permalink
table: check and truncate UTF8 string in CastValue. (pingcap#2819)
Browse files Browse the repository at this point in the history
  • Loading branch information
coocood authored Mar 15, 2017
1 parent 0ef5f84 commit cca10ba
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 15 deletions.
16 changes: 15 additions & 1 deletion executor/statement_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
package executor_test

import (
"fmt"

. "github.com/pingcap/check"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
Expand All @@ -38,7 +41,7 @@ func (s *testSuite) TestStatementContext(c *C) {

tk.MustExec(strictModeSQL)
tk.MustQuery("select * from sc where a > cast(1.1 as decimal)").Check(testkit.Rows("2"))
_, err := tk.Exec(`select * from sc where a > cast(1.1 as decimal);
_, err := tk.Exec(`select * from sc where a > cast(1.1 as decimal);
update sc set a = 4 where a > cast(1.1 as decimal)`)
c.Check(terror.ErrorEqual(err, types.ErrTruncated), IsTrue)

Expand All @@ -65,4 +68,15 @@ func (s *testSuite) TestStatementContext(c *C) {
tk.MustExec("update sc set a = 4 where a > '1x'")
tk.MustExec("delete from sc where a < '1x'")
tk.MustQuery("select * from sc where a > '1x'").Check(testkit.Rows("4"))

// Test invalid UTF8
tk.MustExec("create table sc2 (a varchar(255))")
// Insert an invalid UTF8
tk.MustExec("insert sc2 values (unhex('4040ffff'))")
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Greater, uint16(0))
tk.MustQuery("select * from sc2").Check(testkit.Rows(fmt.Sprintf("%v", []byte("@@"))))
tk.MustExec(strictModeSQL)
_, err = tk.Exec("insert sc2 values (unhex('4040ffff'))")
c.Assert(err, NotNil)
c.Assert(terror.ErrorEqual(err, table.ErrTruncateWrongValue), IsTrue)
}
15 changes: 15 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,18 @@ func (sc *StatementContext) AppendWarning(warn error) {
}
sc.mu.Unlock()
}

// HandleTruncate ignores or returns the error based on the StatementContext state.
func (sc *StatementContext) HandleTruncate(err error) error {
if err == nil {
return nil
}
if sc.IgnoreTruncate {
return nil
}
if sc.TruncateAsWarning {
sc.AppendWarning(err)
return nil
}
return err
}
24 changes: 18 additions & 6 deletions table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package table

import (
"strings"
"unicode/utf8"

"github.com/juju/errors"
"github.com/ngaut/log"
Expand Down Expand Up @@ -113,15 +114,26 @@ func CastValues(ctx context.Context, rec []types.Datum, cols []*Column, ignoreEr

// CastValue casts a value based on column type.
func CastValue(ctx context.Context, val types.Datum, col *model.ColumnInfo) (casted types.Datum, err error) {
casted, err = val.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType)
sc := ctx.GetSessionVars().StmtCtx
casted, err = val.ConvertTo(sc, &col.FieldType)
// TODO: make sure all truncate errors are handled by ConvertTo.
err = sc.HandleTruncate(err)
if err != nil {
if ctx.GetSessionVars().StrictSQLMode {
return casted, errors.Trace(err)
return casted, errors.Trace(err)
}
if !mysql.IsUTF8Charset(col.Charset) {
return casted, nil
}
str := casted.GetString()
for i, r := range str {
if r == utf8.RuneError {
// Truncate to valid utf8 string.
casted = types.NewStringDatum(str[:i])
err = sc.HandleTruncate(ErrTruncateWrongValue)
break
}
// TODO: add warnings.
log.Warnf("cast value error %v", err)
}
return casted, nil
return casted, errors.Trace(err)
}

// ColDesc describes column information like MySQL desc and show columns do.
Expand Down
20 changes: 12 additions & 8 deletions table/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ var (
ErrIndexStateCantNone = terror.ClassTable.New(codeIndexStateCantNone, "index can not be in none state")
// ErrInvalidRecordKey returns for invalid record key.
ErrInvalidRecordKey = terror.ClassTable.New(codeInvalidRecordKey, "invalid record key")
// ErrTruncateWrongValue returns for truncate wrong value for field.
ErrTruncateWrongValue = terror.ClassTable.New(codeTruncateWrongValue, "Incorrect value")
)

// RecordIterFunc is used for low-level record iteration.
Expand Down Expand Up @@ -137,10 +139,11 @@ const (
codeIndexStateCantNone = 8
codeInvalidRecordKey = 9

codeColumnCantNull = 1048
codeUnknownColumn = 1054
codeDuplicateColumn = 1110
codeNoDefaultValue = 1364
codeColumnCantNull = 1048
codeUnknownColumn = 1054
codeDuplicateColumn = 1110
codeNoDefaultValue = 1364
codeTruncateWrongValue = 1366
)

// Slice is used for table sorting.
Expand All @@ -156,10 +159,11 @@ func (s Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

func init() {
tableMySQLErrCodes := map[terror.ErrCode]uint16{
codeColumnCantNull: mysql.ErrBadNull,
codeUnknownColumn: mysql.ErrBadField,
codeDuplicateColumn: mysql.ErrFieldSpecifiedTwice,
codeNoDefaultValue: mysql.ErrNoDefaultForField,
codeColumnCantNull: mysql.ErrBadNull,
codeUnknownColumn: mysql.ErrBadField,
codeDuplicateColumn: mysql.ErrFieldSpecifiedTwice,
codeNoDefaultValue: mysql.ErrNoDefaultForField,
codeTruncateWrongValue: mysql.ErrTruncatedWrongValueForField,
}
terror.ErrClassToMySQLCodes[terror.ClassTable] = tableMySQLErrCodes
}

0 comments on commit cca10ba

Please sign in to comment.