Skip to content

Commit

Permalink
*: support sql mode (pingcap#1263)
Browse files Browse the repository at this point in the history
  • Loading branch information
coocood committed May 30, 2016
1 parent 1b21901 commit f09574d
Show file tree
Hide file tree
Showing 18 changed files with 365 additions and 161 deletions.
30 changes: 25 additions & 5 deletions evaluator/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,20 +597,40 @@ func (e *Evaluator) variable(v *ast.VariableExpr) bool {
return true
}

_, ok := variable.SysVars[name]
sysVar, ok := variable.SysVars[name]
if !ok {
// select null sys vars is not permitted
e.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name)
return false
}
if sysVar.Scope == variable.ScopeNone {
v.SetString(sysVar.Value)
return true
}

if !v.IsGlobal {
if value, ok := sessionVars.Systems[name]; ok {
v.SetString(value)
return true
d := sessionVars.GetSystemVar(name)
if d.Kind() == types.KindNull {
if sysVar.Scope&variable.ScopeGlobal == 0 {
d.SetString(sysVar.Value)
} else {
// Get global system variable and fill it in session.
globalVal, err := globalVars.GetGlobalSysVar(e.ctx, name)
if err != nil {
e.err = errors.Trace(err)
return false
}
d.SetString(globalVal)
err = sessionVars.SetSystemVar(name, d)
if err != nil {
e.err = errors.Trace(err)
return false
}
}
}
v.SetDatum(d)
return true
}

value, err := globalVars.GetGlobalSysVar(e.ctx, name)
if err != nil {
e.err = errors.Trace(err)
Expand Down
9 changes: 4 additions & 5 deletions evaluator/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -908,32 +908,31 @@ func (s *testEvaluatorSuite) TestGetTimeValue(c *C) {
ctx := mock.NewContext()
variable.BindSessionVars(ctx)
sessionVars := variable.GetSessionVars(ctx)

sessionVars.Systems["timestamp"] = ""
sessionVars.SetSystemVar("timestamp", types.NewStringDatum(""))
v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
c.Assert(err, IsNil)

c.Assert(v.Kind(), Equals, types.KindMysqlTime)
timeValue = v.GetMysqlTime()
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")

sessionVars.Systems["timestamp"] = "0"
sessionVars.SetSystemVar("timestamp", types.NewStringDatum("0"))
v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
c.Assert(err, IsNil)

c.Assert(v.Kind(), Equals, types.KindMysqlTime)
timeValue = v.GetMysqlTime()
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")

delete(sessionVars.Systems, "timestamp")
sessionVars.SetSystemVar("timestamp", types.Datum{})
v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
c.Assert(err, IsNil)

c.Assert(v.Kind(), Equals, types.KindMysqlTime)
timeValue = v.GetMysqlTime()
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")

sessionVars.Systems["timestamp"] = "1234"
sessionVars.SetSystemVar("timestamp", types.NewStringDatum("1234"))

tbl := []struct {
Expr interface{}
Expand Down
23 changes: 9 additions & 14 deletions evaluator/helper.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package evaluator

import (
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -119,20 +118,16 @@ func getSystemTimestamp(ctx context.Context) (time.Time, error) {

// check whether use timestamp varibale
sessionVars := variable.GetSessionVars(ctx)
if v, ok := sessionVars.Systems["timestamp"]; ok {
if v != "" {
timestamp, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return time.Time{}, errors.Trace(err)
}

if timestamp <= 0 {
return value, nil
}

return time.Unix(timestamp, 0), nil
ts := sessionVars.GetSystemVar("timestamp")
if ts.Kind() != types.KindNull && ts.GetString() != "" {
timestamp, err := ts.ToInt64()
if err != nil {
return time.Time{}, errors.Trace(err)
}
if timestamp <= 0 {
return value, nil
}
return time.Unix(timestamp, 0), nil
}

return value, nil
}
37 changes: 23 additions & 14 deletions executor/executor_simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,14 @@ func (e *SimpleExec) executeUse(s *ast.UseStmt) error {
// The server sets this variable whenever the default database changes.
// See: http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_character_set_database
sessionVars := variable.GetSessionVars(e.ctx)
sessionVars.Systems[variable.CharsetDatabase] = dbinfo.Charset
sessionVars.Systems[variable.CollationDatabase] = dbinfo.Collate
err := sessionVars.SetSystemVar(variable.CharsetDatabase, types.NewStringDatum(dbinfo.Charset))
if err != nil {
return errors.Trace(err)
}
err = sessionVars.SetSystemVar(variable.CollationDatabase, types.NewStringDatum(dbinfo.Collate))
if err != nil {
return errors.Trace(err)
}
return nil
}

Expand Down Expand Up @@ -161,16 +167,13 @@ func (e *SimpleExec) executeSet(s *ast.SetStmt) error {
if sysVar.Scope&variable.ScopeSession == 0 {
return errors.Errorf("Variable '%s' is a GLOBAL variable and should be set with SET GLOBAL", name)
}
if value, err := evaluator.Eval(e.ctx, v.Value); err != nil {
value, err := evaluator.Eval(e.ctx, v.Value)
if err != nil {
return errors.Trace(err)
}
err = sessionVars.SetSystemVar(name, value)
if err != nil {
return errors.Trace(err)
} else if value.Kind() == types.KindNull {
sessionVars.Systems[name] = ""
} else {
svalue, err := value.ToString()
if err != nil {
return errors.Trace(err)
}
sessionVars.Systems[name] = fmt.Sprintf("%v", svalue)
}
}
}
Expand All @@ -179,18 +182,24 @@ func (e *SimpleExec) executeSet(s *ast.SetStmt) error {

func (e *SimpleExec) executeSetCharset(s *ast.SetCharsetStmt) error {
collation := s.Collate
var err error
if len(collation) == 0 {
var err error
collation, err = charset.GetDefaultCollation(s.Charset)
if err != nil {
return errors.Trace(err)
}
}
sessionVars := variable.GetSessionVars(e.ctx)
for _, v := range variable.SetNamesVariables {
sessionVars.Systems[v] = s.Charset
err = sessionVars.SetSystemVar(v, types.NewStringDatum(s.Charset))
if err != nil {
return errors.Trace(err)
}
}
err = sessionVars.SetSystemVar(variable.CollationConnection, types.NewStringDatum(collation))
if err != nil {
return errors.Trace(err)
}
sessionVars.Systems[variable.CollationConnection] = collation
return nil
}

Expand Down
20 changes: 13 additions & 7 deletions executor/executor_simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,20 @@ func (s *testSuite) TestSetVar(c *C) {
testSQL = "SET @@global.autocommit = 1;"
tk.MustExec(testSQL)

testSQL = "SET @@global.autocommit = null;"
tk.MustExec(testSQL)
// TODO: this test case should returns error.
// testSQL = "SET @@global.autocommit = null;"
// _, err := tk.Exec(testSQL)
// c.Assert(err, NotNil)

testSQL = "SET @@autocommit = 1;"
tk.MustExec(testSQL)

testSQL = "SET @@autocommit = null;"
tk.MustExec(testSQL)
_, err := tk.Exec(testSQL)
c.Assert(err, NotNil)

errTestSql := "SET @@date_format = 1;"
_, err := tk.Exec(errTestSql)
_, err = tk.Exec(errTestSql)
c.Assert(err, NotNil)

errTestSql = "SET @@rewriter_enabled = 1;"
Expand Down Expand Up @@ -109,13 +112,16 @@ func (s *testSuite) TestSetCharset(c *C) {
ctx := tk.Se.(context.Context)
sessionVars := variable.GetSessionVars(ctx)
for _, v := range variable.SetNamesVariables {
c.Assert(sessionVars.Systems[v] != "utf8", IsTrue)
sVar := sessionVars.GetSystemVar(v)
c.Assert(sVar.GetString() != "utf8", IsTrue)
}
tk.MustExec(`SET NAMES utf8`)
for _, v := range variable.SetNamesVariables {
c.Assert(sessionVars.Systems[v], Equals, "utf8")
sVar := sessionVars.GetSystemVar(v)
c.Assert(sVar.GetString(), Equals, "utf8")
}
c.Assert(sessionVars.Systems[variable.CollationConnection], Equals, "utf8_general_ci")
sVar := sessionVars.GetSystemVar(variable.CollationConnection)
c.Assert(sVar.GetString(), Equals, "utf8_general_ci")
}

func (s *testSuite) TestDo(c *C) {
Expand Down
19 changes: 19 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,22 @@ func (s *testSuite) TestJoinPanic(c *C) {
tk.MustExec("create table events (clock int, source int)")
tk.MustQuery("SELECT * FROM events e JOIN (SELECT MAX(clock) AS clock FROM events e2 GROUP BY e2.source) e3 ON e3.clock=e.clock")
}

func (s *testSuite) TestSQLMode(c *C) {
defer testleak.AfterTest(c)()
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a tinyint not null)")
tk.MustExec("set sql_mode = 'STRICT_TRANS_TABLES'")
_, err := tk.Exec("insert t values ()")
c.Check(err, NotNil)

_, err = tk.Exec("insert t values ('1000')")
c.Check(err, NotNil)

tk.MustExec("set sql_mode = ''")
tk.MustExec("insert t values ()")
tk.MustExec("insert t values (1000)")
tk.MustQuery("select * from t").Check(testkit.Rows("0", "127"))
}
13 changes: 8 additions & 5 deletions executor/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,19 @@ func (e *ShowExec) fetchShowVariables() error {
var value string
if !e.GlobalScope {
// Try to get Session Scope variable value first.
sv, ok := sessionVars.Systems[v.Name]
if ok {
value = sv
} else {
// If session scope variable is not set, get the global scope value.
sv := sessionVars.GetSystemVar(v.Name)
if sv.Kind() == types.KindNull {
value, err = globalVars.GetGlobalSysVar(e.ctx, v.Name)
if err != nil {
return errors.Trace(err)
}
sv.SetString(value)
err = sessionVars.SetSystemVar(v.Name, sv)
if err != nil {
return errors.Trace(err)
}
}
value = sv.GetString()
} else {
value, err = globalVars.GetGlobalSysVar(e.ctx, v.Name)
if err != nil {
Expand Down
18 changes: 11 additions & 7 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,21 +344,25 @@ func (s *session) SetGlobalSysVar(ctx context.Context, name string, value string

// IsAutocommit checks if it is in the auto-commit mode.
func (s *session) isAutocommit(ctx context.Context) bool {
autocommit, ok := variable.GetSessionVars(ctx).Systems["autocommit"]
if !ok {
sessionVar := variable.GetSessionVars(ctx)
autocommit := sessionVar.GetSystemVar("autocommit")
if autocommit.Kind() == types.KindNull {
if s.initing {
return false
}
var err error
autocommit, err = s.GetGlobalSysVar(ctx, "autocommit")
autocommitStr, err := s.GetGlobalSysVar(ctx, "autocommit")
if err != nil {
log.Errorf("Get global sys var error: %v", err)
return false
}
variable.GetSessionVars(ctx).Systems["autocommit"] = autocommit
ok = true
autocommit.SetString(autocommitStr)
err = sessionVar.SetSystemVar("autocommit", autocommit)
if err != nil {
log.Errorf("Set session sys var error: %v", err)
}
}
if ok && (autocommit == "ON" || autocommit == "on" || autocommit == "1") {
autocommitStr := autocommit.GetString()
if autocommitStr == "ON" || autocommitStr == "on" || autocommitStr == "1" {
variable.GetSessionVars(ctx).SetStatusFlag(mysql.ServerStatusAutocommit, true)
return true
}
Expand Down
Loading

0 comments on commit f09574d

Please sign in to comment.