From f09574d8264c43fa5e0b8179e585bd8981f92f71 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Mon, 30 May 2016 13:16:41 +0800 Subject: [PATCH] *: support sql mode (#1263) --- evaluator/evaluator.go | 30 ++++++++-- evaluator/evaluator_test.go | 9 ++- evaluator/helper.go | 23 +++----- executor/executor_simple.go | 37 +++++++----- executor/executor_simple_test.go | 20 ++++--- executor/executor_test.go | 19 ++++++ executor/show.go | 13 +++-- session.go | 18 +++--- sessionctx/variable/session.go | 59 ++++++++++++++++--- sessionctx/variable/session_test.go | 19 ++++++ sessionctx/variable/sysvar.go | 8 +-- table/column.go | 90 ++++++++++++++++++++++++++++- table/column_test.go | 71 +++++++++++++++++++++++ table/table.go | 43 ++++---------- table/table_test.go | 56 ------------------ table/tables/tables.go | 2 +- terror/terror.go | 3 + tidb.go | 6 +- 18 files changed, 365 insertions(+), 161 deletions(-) delete mode 100644 table/table_test.go diff --git a/evaluator/evaluator.go b/evaluator/evaluator.go index bb8679c54f5b4..e1a3064bcc05c 100644 --- a/evaluator/evaluator.go +++ b/evaluator/evaluator.go @@ -597,20 +597,40 @@ func (e *Evaluator) variable(v *ast.VariableExpr) bool { return true } - _, ok := variable.SysVars[name] + sysVar, ok := variable.SysVars[name] if !ok { // select null sys vars is not permitted e.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name) return false } + if sysVar.Scope == variable.ScopeNone { + v.SetString(sysVar.Value) + return true + } if !v.IsGlobal { - if value, ok := sessionVars.Systems[name]; ok { - v.SetString(value) - return true + d := sessionVars.GetSystemVar(name) + if d.Kind() == types.KindNull { + if sysVar.Scope&variable.ScopeGlobal == 0 { + d.SetString(sysVar.Value) + } else { + // Get global system variable and fill it in session. + globalVal, err := globalVars.GetGlobalSysVar(e.ctx, name) + if err != nil { + e.err = errors.Trace(err) + return false + } + d.SetString(globalVal) + err = sessionVars.SetSystemVar(name, d) + if err != nil { + e.err = errors.Trace(err) + return false + } + } } + v.SetDatum(d) + return true } - value, err := globalVars.GetGlobalSysVar(e.ctx, name) if err != nil { e.err = errors.Trace(err) diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 2f6f6cadc5cfc..f5e8470ff5678 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -908,8 +908,7 @@ func (s *testEvaluatorSuite) TestGetTimeValue(c *C) { ctx := mock.NewContext() variable.BindSessionVars(ctx) sessionVars := variable.GetSessionVars(ctx) - - sessionVars.Systems["timestamp"] = "" + sessionVars.SetSystemVar("timestamp", types.NewStringDatum("")) v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) c.Assert(err, IsNil) @@ -917,7 +916,7 @@ func (s *testEvaluatorSuite) TestGetTimeValue(c *C) { timeValue = v.GetMysqlTime() c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") - sessionVars.Systems["timestamp"] = "0" + sessionVars.SetSystemVar("timestamp", types.NewStringDatum("0")) v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) c.Assert(err, IsNil) @@ -925,7 +924,7 @@ func (s *testEvaluatorSuite) TestGetTimeValue(c *C) { timeValue = v.GetMysqlTime() c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") - delete(sessionVars.Systems, "timestamp") + sessionVars.SetSystemVar("timestamp", types.Datum{}) v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) c.Assert(err, IsNil) @@ -933,7 +932,7 @@ func (s *testEvaluatorSuite) TestGetTimeValue(c *C) { timeValue = v.GetMysqlTime() c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") - sessionVars.Systems["timestamp"] = "1234" + sessionVars.SetSystemVar("timestamp", types.NewStringDatum("1234")) tbl := []struct { Expr interface{} diff --git a/evaluator/helper.go b/evaluator/helper.go index c547576985f95..4295c252f5e12 100644 --- a/evaluator/helper.go +++ b/evaluator/helper.go @@ -1,7 +1,6 @@ package evaluator import ( - "strconv" "strings" "time" @@ -119,20 +118,16 @@ func getSystemTimestamp(ctx context.Context) (time.Time, error) { // check whether use timestamp varibale sessionVars := variable.GetSessionVars(ctx) - if v, ok := sessionVars.Systems["timestamp"]; ok { - if v != "" { - timestamp, err := strconv.ParseInt(v, 10, 64) - if err != nil { - return time.Time{}, errors.Trace(err) - } - - if timestamp <= 0 { - return value, nil - } - - return time.Unix(timestamp, 0), nil + ts := sessionVars.GetSystemVar("timestamp") + if ts.Kind() != types.KindNull && ts.GetString() != "" { + timestamp, err := ts.ToInt64() + if err != nil { + return time.Time{}, errors.Trace(err) } + if timestamp <= 0 { + return value, nil + } + return time.Unix(timestamp, 0), nil } - return value, nil } diff --git a/executor/executor_simple.go b/executor/executor_simple.go index 24f13e79202d4..fd5c74d3de24b 100644 --- a/executor/executor_simple.go +++ b/executor/executor_simple.go @@ -98,8 +98,14 @@ func (e *SimpleExec) executeUse(s *ast.UseStmt) error { // The server sets this variable whenever the default database changes. // See: http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_character_set_database sessionVars := variable.GetSessionVars(e.ctx) - sessionVars.Systems[variable.CharsetDatabase] = dbinfo.Charset - sessionVars.Systems[variable.CollationDatabase] = dbinfo.Collate + err := sessionVars.SetSystemVar(variable.CharsetDatabase, types.NewStringDatum(dbinfo.Charset)) + if err != nil { + return errors.Trace(err) + } + err = sessionVars.SetSystemVar(variable.CollationDatabase, types.NewStringDatum(dbinfo.Collate)) + if err != nil { + return errors.Trace(err) + } return nil } @@ -161,16 +167,13 @@ func (e *SimpleExec) executeSet(s *ast.SetStmt) error { if sysVar.Scope&variable.ScopeSession == 0 { return errors.Errorf("Variable '%s' is a GLOBAL variable and should be set with SET GLOBAL", name) } - if value, err := evaluator.Eval(e.ctx, v.Value); err != nil { + value, err := evaluator.Eval(e.ctx, v.Value) + if err != nil { + return errors.Trace(err) + } + err = sessionVars.SetSystemVar(name, value) + if err != nil { return errors.Trace(err) - } else if value.Kind() == types.KindNull { - sessionVars.Systems[name] = "" - } else { - svalue, err := value.ToString() - if err != nil { - return errors.Trace(err) - } - sessionVars.Systems[name] = fmt.Sprintf("%v", svalue) } } } @@ -179,8 +182,8 @@ func (e *SimpleExec) executeSet(s *ast.SetStmt) error { func (e *SimpleExec) executeSetCharset(s *ast.SetCharsetStmt) error { collation := s.Collate + var err error if len(collation) == 0 { - var err error collation, err = charset.GetDefaultCollation(s.Charset) if err != nil { return errors.Trace(err) @@ -188,9 +191,15 @@ func (e *SimpleExec) executeSetCharset(s *ast.SetCharsetStmt) error { } sessionVars := variable.GetSessionVars(e.ctx) for _, v := range variable.SetNamesVariables { - sessionVars.Systems[v] = s.Charset + err = sessionVars.SetSystemVar(v, types.NewStringDatum(s.Charset)) + if err != nil { + return errors.Trace(err) + } + } + err = sessionVars.SetSystemVar(variable.CollationConnection, types.NewStringDatum(collation)) + if err != nil { + return errors.Trace(err) } - sessionVars.Systems[variable.CollationConnection] = collation return nil } diff --git a/executor/executor_simple_test.go b/executor/executor_simple_test.go index 98a860dc4528f..07796d70e3d93 100644 --- a/executor/executor_simple_test.go +++ b/executor/executor_simple_test.go @@ -60,17 +60,20 @@ func (s *testSuite) TestSetVar(c *C) { testSQL = "SET @@global.autocommit = 1;" tk.MustExec(testSQL) - testSQL = "SET @@global.autocommit = null;" - tk.MustExec(testSQL) + // TODO: this test case should returns error. + // testSQL = "SET @@global.autocommit = null;" + // _, err := tk.Exec(testSQL) + // c.Assert(err, NotNil) testSQL = "SET @@autocommit = 1;" tk.MustExec(testSQL) testSQL = "SET @@autocommit = null;" - tk.MustExec(testSQL) + _, err := tk.Exec(testSQL) + c.Assert(err, NotNil) errTestSql := "SET @@date_format = 1;" - _, err := tk.Exec(errTestSql) + _, err = tk.Exec(errTestSql) c.Assert(err, NotNil) errTestSql = "SET @@rewriter_enabled = 1;" @@ -109,13 +112,16 @@ func (s *testSuite) TestSetCharset(c *C) { ctx := tk.Se.(context.Context) sessionVars := variable.GetSessionVars(ctx) for _, v := range variable.SetNamesVariables { - c.Assert(sessionVars.Systems[v] != "utf8", IsTrue) + sVar := sessionVars.GetSystemVar(v) + c.Assert(sVar.GetString() != "utf8", IsTrue) } tk.MustExec(`SET NAMES utf8`) for _, v := range variable.SetNamesVariables { - c.Assert(sessionVars.Systems[v], Equals, "utf8") + sVar := sessionVars.GetSystemVar(v) + c.Assert(sVar.GetString(), Equals, "utf8") } - c.Assert(sessionVars.Systems[variable.CollationConnection], Equals, "utf8_general_ci") + sVar := sessionVars.GetSystemVar(variable.CollationConnection) + c.Assert(sVar.GetString(), Equals, "utf8_general_ci") } func (s *testSuite) TestDo(c *C) { diff --git a/executor/executor_test.go b/executor/executor_test.go index a856b03cc07d1..db34de64dc51e 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1350,3 +1350,22 @@ func (s *testSuite) TestJoinPanic(c *C) { tk.MustExec("create table events (clock int, source int)") tk.MustQuery("SELECT * FROM events e JOIN (SELECT MAX(clock) AS clock FROM events e2 GROUP BY e2.source) e3 ON e3.clock=e.clock") } + +func (s *testSuite) TestSQLMode(c *C) { + defer testleak.AfterTest(c)() + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a tinyint not null)") + tk.MustExec("set sql_mode = 'STRICT_TRANS_TABLES'") + _, err := tk.Exec("insert t values ()") + c.Check(err, NotNil) + + _, err = tk.Exec("insert t values ('1000')") + c.Check(err, NotNil) + + tk.MustExec("set sql_mode = ''") + tk.MustExec("insert t values ()") + tk.MustExec("insert t values (1000)") + tk.MustQuery("select * from t").Check(testkit.Rows("0", "127")) +} diff --git a/executor/show.go b/executor/show.go index 1057e08770723..960eab6b36a1f 100644 --- a/executor/show.go +++ b/executor/show.go @@ -286,16 +286,19 @@ func (e *ShowExec) fetchShowVariables() error { var value string if !e.GlobalScope { // Try to get Session Scope variable value first. - sv, ok := sessionVars.Systems[v.Name] - if ok { - value = sv - } else { - // If session scope variable is not set, get the global scope value. + sv := sessionVars.GetSystemVar(v.Name) + if sv.Kind() == types.KindNull { value, err = globalVars.GetGlobalSysVar(e.ctx, v.Name) if err != nil { return errors.Trace(err) } + sv.SetString(value) + err = sessionVars.SetSystemVar(v.Name, sv) + if err != nil { + return errors.Trace(err) + } } + value = sv.GetString() } else { value, err = globalVars.GetGlobalSysVar(e.ctx, v.Name) if err != nil { diff --git a/session.go b/session.go index 0685590352e69..24e088ad8bec3 100644 --- a/session.go +++ b/session.go @@ -344,21 +344,25 @@ func (s *session) SetGlobalSysVar(ctx context.Context, name string, value string // IsAutocommit checks if it is in the auto-commit mode. func (s *session) isAutocommit(ctx context.Context) bool { - autocommit, ok := variable.GetSessionVars(ctx).Systems["autocommit"] - if !ok { + sessionVar := variable.GetSessionVars(ctx) + autocommit := sessionVar.GetSystemVar("autocommit") + if autocommit.Kind() == types.KindNull { if s.initing { return false } - var err error - autocommit, err = s.GetGlobalSysVar(ctx, "autocommit") + autocommitStr, err := s.GetGlobalSysVar(ctx, "autocommit") if err != nil { log.Errorf("Get global sys var error: %v", err) return false } - variable.GetSessionVars(ctx).Systems["autocommit"] = autocommit - ok = true + autocommit.SetString(autocommitStr) + err = sessionVar.SetSystemVar("autocommit", autocommit) + if err != nil { + log.Errorf("Set session sys var error: %v", err) + } } - if ok && (autocommit == "ON" || autocommit == "on" || autocommit == "1") { + autocommitStr := autocommit.GetString() + if autocommitStr == "ON" || autocommitStr == "on" || autocommitStr == "1" { variable.GetSessionVars(ctx).SetStatusFlag(mysql.ServerStatusAutocommit, true) return true } diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index ed30373d811f9..6563fef1567dd 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -14,13 +14,22 @@ package variable import ( + "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/terror" + "github.com/pingcap/tidb/util/types" + "strings" ) -const codeCantGetValidID terror.ErrCode = 1 +const ( + codeCantGetValidID terror.ErrCode = 1 + codeCantSetToNull terror.ErrCode = 2 +) -var errCantGetValidID = terror.ClassVariable.New(codeCantGetValidID, "cannot get valid auto-increment id in retry") +var ( + errCantGetValidID = terror.ClassVariable.New(codeCantGetValidID, "cannot get valid auto-increment id in retry") + errCantSetToNull = terror.ClassVariable.New(codeCantSetToNull, "cannot set variable to null") +) // RetryInfo saves retry information. type RetryInfo struct { @@ -63,7 +72,7 @@ type SessionVars struct { // user-defined variables Users map[string]string // system variables - Systems map[string]string + systems map[string]string // prepared statement PreparedStmts map[uint32]interface{} PreparedStmtNameToID map[string]uint32 @@ -89,6 +98,9 @@ type SessionVars struct { // Current user User string + + // Strict SQL mode + StrictSQLMode bool } // sessionVarsKeyType is a dummy type to avoid naming collision in context. @@ -105,12 +117,12 @@ const sessionVarsKey sessionVarsKeyType = 0 func BindSessionVars(ctx context.Context) { v := &SessionVars{ Users: make(map[string]string), - Systems: make(map[string]string), + systems: make(map[string]string), PreparedStmts: make(map[uint32]interface{}), PreparedStmtNameToID: make(map[string]uint32), RetryInfo: &RetryInfo{}, + StrictSQLMode: true, } - ctx.SetValue(sessionVarsKey, v) } @@ -139,8 +151,8 @@ const ( // See: https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html func GetCharsetInfo(ctx context.Context) (charset, collation string) { sessionVars := GetSessionVars(ctx) - charset = sessionVars.Systems[characterSetConnection] - collation = sessionVars.Systems[collationConnection] + charset = sessionVars.systems[characterSetConnection] + collation = sessionVars.systems[collationConnection] return } @@ -186,3 +198,36 @@ func (s *SessionVars) GetNextPreparedStmtID() uint32 { func (s *SessionVars) SetCurrentUser(user string) { s.User = user } + +// SetSystemVar sets a system variable. +func (s *SessionVars) SetSystemVar(key string, value types.Datum) error { + key = strings.ToLower(key) + if value.Kind() == types.KindNull { + return errCantSetToNull + } + sVal, err := value.ToString() + if err != nil { + return errors.Trace(err) + } + if key == "sql_mode" { + sVal = strings.ToUpper(sVal) + if strings.Contains(sVal, "STRICT_TRANS_TABLES") || strings.Contains(sVal, "STRICT_ALL_TABLES") { + s.StrictSQLMode = true + } else { + s.StrictSQLMode = false + } + } + s.systems[key] = sVal + return nil +} + +// GetSystemVar gets a system variable. +func (s *SessionVars) GetSystemVar(key string) types.Datum { + var d types.Datum + key = strings.ToLower(key) + sVal, ok := s.systems[key] + if ok { + d.SetString(sVal) + } + return d +} diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index 8a17de692c36f..edd6bb54ec9c5 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -17,6 +17,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/mock" + "github.com/pingcap/tidb/util/types" ) var _ = Suite(&testSessionSuite{}) @@ -47,4 +48,22 @@ func (*testSessionSuite) TestSession(c *C) { // For last insert id v.SetLastInsertID(uint64(1)) c.Assert(v.LastInsertID, Equals, uint64(1)) + + v.SetSystemVar("autocommit", types.NewStringDatum("1")) + val := v.GetSystemVar("autocommit") + c.Assert(val.GetString(), Equals, "1") + c.Assert(v.SetSystemVar("autocommit", types.Datum{}), NotNil) + + v.SetSystemVar("sql_mode", types.NewStringDatum("strict_trans_tables")) + val = v.GetSystemVar("sql_mode") + c.Assert(val.GetString(), Equals, "STRICT_TRANS_TABLES") + c.Assert(v.StrictSQLMode, IsTrue) + v.SetSystemVar("sql_mode", types.NewStringDatum("")) + c.Assert(v.StrictSQLMode, IsFalse) + + v.SetSystemVar("character_set_connection", types.NewStringDatum("utf8")) + v.SetSystemVar("collation_connection", types.NewStringDatum("utf8_general_ci")) + charset, collation := variable.GetCharsetInfo(ctx) + c.Assert(charset, Equals, "utf8") + c.Assert(collation, Equals, "utf8_general_ci") } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 1f281ce19cb27..93ad9a8075cff 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -26,11 +26,11 @@ type ScopeFlag uint8 const ( // ScopeNone means the system variable can not be changed dynamically. - ScopeNone ScopeFlag = iota << 0 + ScopeNone ScopeFlag = 0 // ScopeGlobal means the system variable can be changed globally. - ScopeGlobal + ScopeGlobal ScopeFlag = 1 << 0 // ScopeSession means the system variable can only be changed in current session. - ScopeSession + ScopeSession ScopeFlag = 1 << 1 ) // SysVar is for system variable. @@ -121,7 +121,7 @@ var defaultSysVars = []*SysVar{ {ScopeNone, "skip_name_resolve", "OFF"}, {ScopeNone, "performance_schema_max_file_handles", "32768"}, {ScopeSession, "transaction_allow_batching", ""}, - {ScopeGlobal | ScopeSession, "sql_mode", "NO_ENGINE_SUBSTITUTION"}, + {ScopeGlobal | ScopeSession, "sql_mode", "STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION"}, {ScopeNone, "performance_schema_max_statement_classes", "168"}, {ScopeGlobal, "server_id", "0"}, {ScopeGlobal, "innodb_flushing_avg_loops", "30"}, diff --git a/table/column.go b/table/column.go index a4116794e1a84..ca554c5c4bd08 100644 --- a/table/column.go +++ b/table/column.go @@ -21,9 +21,12 @@ import ( "strings" "github.com/juju/errors" + "github.com/ngaut/log" "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/evaluator" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/types" ) @@ -94,7 +97,7 @@ func FindOnUpdateCols(cols []*Column) []*Column { func CastValues(ctx context.Context, rec []types.Datum, cols []*Column) (err error) { for _, c := range cols { var converted types.Datum - converted, err = rec[c.Offset].ConvertTo(&c.FieldType) + converted, err = CastValue(ctx, rec[c.Offset], c) if err != nil { return errors.Trace(err) } @@ -103,6 +106,19 @@ func CastValues(ctx context.Context, rec []types.Datum, cols []*Column) (err err return nil } +// CastValue casts a value based on column type. +func CastValue(ctx context.Context, val types.Datum, col *Column) (casted types.Datum, err error) { + casted, err = val.ConvertTo(&col.FieldType) + if err != nil { + if variable.GetSessionVars(ctx).StrictSQLMode { + return casted, errors.Trace(err) + } + // TODO: add warnings. + log.Warnf("cast value error %v", err) + } + return casted, nil +} + // ColDesc describes column information like MySQL desc and show columns do. type ColDesc struct { Field string @@ -229,3 +245,75 @@ func (idx *IndexedColumn) FetchValues(r []types.Datum) ([]types.Datum, error) { } return vals, nil } + +// GetColDefaultValue gets default value of the column. +func GetColDefaultValue(ctx context.Context, col *model.ColumnInfo) (types.Datum, bool, error) { + // Check no default value flag. + if mysql.HasNoDefaultValueFlag(col.Flag) && col.Tp != mysql.TypeEnum { + err := ErrNoDefaultValue.Gen("Field '%s' doesn't have a default value", col.Name) + if ctx != nil { + sessVars := variable.GetSessionVars(ctx) + if !sessVars.StrictSQLMode { + // TODO: add warning. + return getZeroValue(col), true, nil + } + } + return types.Datum{}, false, errors.Trace(err) + } + + // Check and get timestamp/datetime default value. + if col.Tp == mysql.TypeTimestamp || col.Tp == mysql.TypeDatetime { + if col.DefaultValue == nil { + return types.Datum{}, true, nil + } + + value, err := evaluator.GetTimeValue(ctx, col.DefaultValue, col.Tp, col.Decimal) + if err != nil { + return types.Datum{}, true, errors.Errorf("Field '%s' get default value fail - %s", col.Name, errors.Trace(err)) + } + return value, true, nil + } else if col.Tp == mysql.TypeEnum { + // For enum type, if no default value and not null is set, + // the default value is the first element of the enum list + if col.DefaultValue == nil && mysql.HasNotNullFlag(col.Flag) { + return types.NewDatum(col.FieldType.Elems[0]), true, nil + } + } + + return types.NewDatum(col.DefaultValue), true, nil +} + +func getZeroValue(col *model.ColumnInfo) types.Datum { + var d types.Datum + switch col.Tp { + case mysql.TypeTiny, mysql.TypeInt24, mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: + if mysql.HasUnsignedFlag(col.Flag) { + d.SetUint64(0) + } else { + d.SetInt64(0) + } + case mysql.TypeFloat: + d.SetFloat32(0) + case mysql.TypeDouble: + d.SetFloat64(0) + case mysql.TypeNewDecimal: + d.SetMysqlDecimal(mysql.NewDecimalFromInt(0, 0)) + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar: + d.SetString("") + case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + d.SetBytes([]byte{}) + case mysql.TypeDuration: + d.SetMysqlDuration(mysql.ZeroDuration) + case mysql.TypeDate, mysql.TypeNewDate: + d.SetMysqlTime(mysql.ZeroDate) + case mysql.TypeTimestamp: + d.SetMysqlTime(mysql.ZeroTimestamp) + case mysql.TypeDatetime: + d.SetMysqlTime(mysql.ZeroDatetime) + case mysql.TypeBit: + d.SetMysqlBit(mysql.Bit{Value: 0, Width: mysql.MinBitWidth}) + case mysql.TypeSet: + d.SetMysqlSet(mysql.Set{}) + } + return d +} diff --git a/table/column_test.go b/table/column_test.go index 5e17ab3cb8174..0a7fe1f08e035 100644 --- a/table/column_test.go +++ b/table/column_test.go @@ -112,6 +112,77 @@ func (s *testColumnSuite) TestDesc(c *C) { ColDescFieldNames(true) } +func (s *testColumnSuite) TestGetZeroValue(c *C) { + cases := []struct { + ft *types.FieldType + value types.Datum + }{ + { + types.NewFieldType(mysql.TypeLong), + types.NewIntDatum(0), + }, + { + &types.FieldType{ + Tp: mysql.TypeLonglong, + Flag: mysql.UnsignedFlag, + }, + types.NewUintDatum(0), + }, + { + types.NewFieldType(mysql.TypeFloat), + types.NewFloat32Datum(0), + }, + { + types.NewFieldType(mysql.TypeDouble), + types.NewFloat64Datum(0), + }, + { + types.NewFieldType(mysql.TypeNewDecimal), + types.NewDecimalDatum(mysql.NewDecimalFromInt(0, 0)), + }, + { + types.NewFieldType(mysql.TypeVarchar), + types.NewStringDatum(""), + }, + { + types.NewFieldType(mysql.TypeBlob), + types.NewBytesDatum([]byte{}), + }, + { + types.NewFieldType(mysql.TypeDuration), + types.NewDurationDatum(mysql.ZeroDuration), + }, + { + types.NewFieldType(mysql.TypeDatetime), + types.NewDatum(mysql.ZeroDatetime), + }, + { + types.NewFieldType(mysql.TypeTimestamp), + types.NewDatum(mysql.ZeroTimestamp), + }, + { + types.NewFieldType(mysql.TypeDate), + types.NewDatum(mysql.ZeroDate), + }, + { + types.NewFieldType(mysql.TypeBit), + types.NewDatum(mysql.Bit{Value: 0, Width: mysql.MinBitWidth}), + }, + { + types.NewFieldType(mysql.TypeSet), + types.NewDatum(mysql.Set{}), + }, + } + for _, ca := range cases { + colInfo := &model.ColumnInfo{FieldType: *ca.ft} + zv := getZeroValue(colInfo) + c.Assert(zv.Kind(), Equals, ca.value.Kind()) + cmp, err := zv.CompareDatum(ca.value) + c.Assert(err, IsNil) + c.Assert(cmp, Equals, 0) + } +} + func newCol(name string) *Column { return &Column{ model.ColumnInfo{ diff --git a/table/table.go b/table/table.go index 2737f7116992a..34db71116e9b1 100644 --- a/table/table.go +++ b/table/table.go @@ -18,16 +18,24 @@ package table import ( - "github.com/juju/errors" "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/evaluator" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/util/types" ) +var ( + // ErrNoDefaultValue is used when insert a row, the column value is not given, and the column has not null flag + // and it doesn't have a default value. + ErrNoDefaultValue = terror.ClassTable.New(codeNoDefaultValue, "Field doesn't have a default value") +) + +const ( + codeNoDefaultValue = 1 +) + // RecordIterFunc is used for low-level record iteration. type RecordIterFunc func(h int64, rec []types.Datum, cols []*Column) (more bool, err error) @@ -94,34 +102,5 @@ type Table interface { // Currently, it is assigned to tables.TableFromMeta in tidb package's init function. var TableFromMeta func(alloc autoid.Allocator, tblInfo *model.TableInfo) (Table, error) -// GetColDefaultValue gets default value of the column. -func GetColDefaultValue(ctx context.Context, col *model.ColumnInfo) (types.Datum, bool, error) { - // Check no default value flag. - if mysql.HasNoDefaultValueFlag(col.Flag) && col.Tp != mysql.TypeEnum { - return types.Datum{}, false, errors.Errorf("Field '%s' doesn't have a default value", col.Name) - } - - // Check and get timestamp/datetime default value. - if col.Tp == mysql.TypeTimestamp || col.Tp == mysql.TypeDatetime { - if col.DefaultValue == nil { - return types.Datum{}, true, nil - } - - value, err := evaluator.GetTimeValue(ctx, col.DefaultValue, col.Tp, col.Decimal) - if err != nil { - return types.Datum{}, true, errors.Errorf("Field '%s' get default value fail - %s", col.Name, errors.Trace(err)) - } - return value, true, nil - } else if col.Tp == mysql.TypeEnum { - // For enum type, if no default value and not null is set, - // the default value is the first element of the enum list - if col.DefaultValue == nil && mysql.HasNotNullFlag(col.Flag) { - return types.NewDatum(col.FieldType.Elems[0]), true, nil - } - } - - return types.NewDatum(col.DefaultValue), true, nil -} - // MockTableFromMeta only serves for test. var MockTableFromMeta func(tableInfo *model.TableInfo) Table diff --git a/table/table_test.go b/table/table_test.go deleted file mode 100644 index 669d225d448c7..0000000000000 --- a/table/table_test.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package table_test - -import ( - "testing" - - . "github.com/pingcap/check" - "github.com/pingcap/tidb" - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/sessionctx/db" - "github.com/pingcap/tidb/store/localstore" - "github.com/pingcap/tidb/store/localstore/goleveldb" -) - -func TestT(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testSuite{}) - -type testSuite struct { -} - -func (*testSuite) TestT(c *C) { - var ident = ast.Ident{ - Name: model.NewCIStr("t"), - } - c.Assert(ident.String(), Not(Equals), "") - driver := localstore.Driver{Driver: goleveldb.MemoryDriver{}} - store, err := driver.Open("memory") - c.Assert(err, IsNil) - se, err := tidb.CreateSession(store) - c.Assert(err, IsNil) - ctx := se.(context.Context) - db.BindCurrentSchema(ctx, "test") - fullIdent := ident.Full(ctx) - c.Assert(fullIdent.Schema.L, Equals, "test") - c.Assert(fullIdent.Name.L, Equals, "t") - c.Assert(fullIdent.String(), Not(Equals), "") - fullIdent2 := fullIdent.Full(ctx) - c.Assert(fullIdent2.Schema.L, Equals, fullIdent.Schema.L) -} diff --git a/table/tables/tables.go b/table/tables/tables.go index 2dbf938b335d0..10578e6adb375 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -340,7 +340,7 @@ func (t *Table) AddRecord(ctx context.Context, r []types.Datum) (recordID int64, if err != nil { return 0, errors.Trace(err) } - value, err = value.ConvertTo(&col.FieldType) + value, err = table.CastValue(ctx, value, col) if err != nil { return 0, errors.Trace(err) } diff --git a/terror/terror.go b/terror/terror.go index 4f81ef8021a1d..681528d9c6216 100644 --- a/terror/terror.go +++ b/terror/terror.go @@ -72,6 +72,7 @@ const ( ClassStructure ClassVariable ClassXEval + ClassTable // Add more as needed. ) @@ -110,6 +111,8 @@ func (ec ErrClass) String() string { return "structure" case ClassVariable: return "variable" + case ClassTable: + return "table" } return strconv.Itoa(int(ec)) } diff --git a/tidb.go b/tidb.go index aaca24154977c..19be7fdd4a59b 100644 --- a/tidb.go +++ b/tidb.go @@ -115,9 +115,9 @@ func SetSchemaLease(lease time.Duration) { // See: https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html func getCtxCharsetInfo(ctx context.Context) (string, string) { sessionVars := variable.GetSessionVars(ctx) - charset := sessionVars.Systems["character_set_connection"] - collation := sessionVars.Systems["collation_connection"] - return charset, collation + charset := sessionVars.GetSystemVar("character_set_connection") + collation := sessionVars.GetSystemVar("collation_connection") + return charset.GetString(), collation.GetString() } // Parse parses a query string to raw ast.StmtNode.