From cca10bab58a767570c3cf672f84b9a9660488317 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 15 Mar 2017 11:26:19 +0800 Subject: [PATCH] table: check and truncate UTF8 string in CastValue. (#2819) --- executor/statement_context_test.go | 16 +++++++++++++++- sessionctx/variable/session.go | 15 +++++++++++++++ table/column.go | 24 ++++++++++++++++++------ table/table.go | 20 ++++++++++++-------- 4 files changed, 60 insertions(+), 15 deletions(-) diff --git a/executor/statement_context_test.go b/executor/statement_context_test.go index 9fcc20c244d12..fd359b2d58fa7 100644 --- a/executor/statement_context_test.go +++ b/executor/statement_context_test.go @@ -14,7 +14,10 @@ package executor_test import ( + "fmt" + . "github.com/pingcap/check" + "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" @@ -38,7 +41,7 @@ func (s *testSuite) TestStatementContext(c *C) { tk.MustExec(strictModeSQL) tk.MustQuery("select * from sc where a > cast(1.1 as decimal)").Check(testkit.Rows("2")) - _, err := tk.Exec(`select * from sc where a > cast(1.1 as decimal); + _, err := tk.Exec(`select * from sc where a > cast(1.1 as decimal); update sc set a = 4 where a > cast(1.1 as decimal)`) c.Check(terror.ErrorEqual(err, types.ErrTruncated), IsTrue) @@ -65,4 +68,15 @@ func (s *testSuite) TestStatementContext(c *C) { tk.MustExec("update sc set a = 4 where a > '1x'") tk.MustExec("delete from sc where a < '1x'") tk.MustQuery("select * from sc where a > '1x'").Check(testkit.Rows("4")) + + // Test invalid UTF8 + tk.MustExec("create table sc2 (a varchar(255))") + // Insert an invalid UTF8 + tk.MustExec("insert sc2 values (unhex('4040ffff'))") + c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Greater, uint16(0)) + tk.MustQuery("select * from sc2").Check(testkit.Rows(fmt.Sprintf("%v", []byte("@@")))) + tk.MustExec(strictModeSQL) + _, err = tk.Exec("insert sc2 values (unhex('4040ffff'))") + c.Assert(err, NotNil) + c.Assert(terror.ErrorEqual(err, table.ErrTruncateWrongValue), IsTrue) } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index c247617447627..40ca58d88095e 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -344,3 +344,18 @@ func (sc *StatementContext) AppendWarning(warn error) { } sc.mu.Unlock() } + +// HandleTruncate ignores or returns the error based on the StatementContext state. +func (sc *StatementContext) HandleTruncate(err error) error { + if err == nil { + return nil + } + if sc.IgnoreTruncate { + return nil + } + if sc.TruncateAsWarning { + sc.AppendWarning(err) + return nil + } + return err +} diff --git a/table/column.go b/table/column.go index e40492ae2d7c2..a51a59c254b90 100644 --- a/table/column.go +++ b/table/column.go @@ -19,6 +19,7 @@ package table import ( "strings" + "unicode/utf8" "github.com/juju/errors" "github.com/ngaut/log" @@ -113,15 +114,26 @@ func CastValues(ctx context.Context, rec []types.Datum, cols []*Column, ignoreEr // CastValue casts a value based on column type. func CastValue(ctx context.Context, val types.Datum, col *model.ColumnInfo) (casted types.Datum, err error) { - casted, err = val.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType) + sc := ctx.GetSessionVars().StmtCtx + casted, err = val.ConvertTo(sc, &col.FieldType) + // TODO: make sure all truncate errors are handled by ConvertTo. + err = sc.HandleTruncate(err) if err != nil { - if ctx.GetSessionVars().StrictSQLMode { - return casted, errors.Trace(err) + return casted, errors.Trace(err) + } + if !mysql.IsUTF8Charset(col.Charset) { + return casted, nil + } + str := casted.GetString() + for i, r := range str { + if r == utf8.RuneError { + // Truncate to valid utf8 string. + casted = types.NewStringDatum(str[:i]) + err = sc.HandleTruncate(ErrTruncateWrongValue) + break } - // TODO: add warnings. - log.Warnf("cast value error %v", err) } - return casted, nil + return casted, errors.Trace(err) } // ColDesc describes column information like MySQL desc and show columns do. diff --git a/table/table.go b/table/table.go index 2f812080934c7..f01078d047bdd 100644 --- a/table/table.go +++ b/table/table.go @@ -53,6 +53,8 @@ var ( ErrIndexStateCantNone = terror.ClassTable.New(codeIndexStateCantNone, "index can not be in none state") // ErrInvalidRecordKey returns for invalid record key. ErrInvalidRecordKey = terror.ClassTable.New(codeInvalidRecordKey, "invalid record key") + // ErrTruncateWrongValue returns for truncate wrong value for field. + ErrTruncateWrongValue = terror.ClassTable.New(codeTruncateWrongValue, "Incorrect value") ) // RecordIterFunc is used for low-level record iteration. @@ -137,10 +139,11 @@ const ( codeIndexStateCantNone = 8 codeInvalidRecordKey = 9 - codeColumnCantNull = 1048 - codeUnknownColumn = 1054 - codeDuplicateColumn = 1110 - codeNoDefaultValue = 1364 + codeColumnCantNull = 1048 + codeUnknownColumn = 1054 + codeDuplicateColumn = 1110 + codeNoDefaultValue = 1364 + codeTruncateWrongValue = 1366 ) // Slice is used for table sorting. @@ -156,10 +159,11 @@ func (s Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func init() { tableMySQLErrCodes := map[terror.ErrCode]uint16{ - codeColumnCantNull: mysql.ErrBadNull, - codeUnknownColumn: mysql.ErrBadField, - codeDuplicateColumn: mysql.ErrFieldSpecifiedTwice, - codeNoDefaultValue: mysql.ErrNoDefaultForField, + codeColumnCantNull: mysql.ErrBadNull, + codeUnknownColumn: mysql.ErrBadField, + codeDuplicateColumn: mysql.ErrFieldSpecifiedTwice, + codeNoDefaultValue: mysql.ErrNoDefaultForField, + codeTruncateWrongValue: mysql.ErrTruncatedWrongValueForField, } terror.ErrClassToMySQLCodes[terror.ClassTable] = tableMySQLErrCodes }