diff --git a/executor/builder.go b/executor/builder.go index 6b97abba5853f..b76db26b0af82 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -78,6 +78,8 @@ func (b *executorBuilder) build(p plan.Plan) Executor { return b.buildSelectLock(v) case *plan.ShowDDL: return b.buildShowDDL(v) + case *plan.Simple: + return b.buildSimple(v) case *plan.Sort: return b.buildSort(v) case *plan.TableScan: @@ -327,3 +329,7 @@ func (b *executorBuilder) buildDelete(v *plan.Delete) Executor { IsMultiTable: v.IsMultiTable, } } + +func (b *executorBuilder) buildSimple(v *plan.Simple) Executor { + return &SimpleExec{Statement: v.Statement, ctx: b.ctx} +} diff --git a/executor/converter/convert_stmt.go b/executor/converter/convert_stmt.go index 40703f054c027..7e1eacac5fb8a 100644 --- a/executor/converter/convert_stmt.go +++ b/executor/converter/convert_stmt.go @@ -868,51 +868,6 @@ func convertRollback(converter *expressionConverter, v *ast.RollbackStmt) (*stmt }, nil } -func convertUse(converter *expressionConverter, v *ast.UseStmt) (*stmts.UseStmt, error) { - return &stmts.UseStmt{ - DBName: v.DBName, - Text: v.Text(), - }, nil -} - -func convertVariableAssignment(converter *expressionConverter, v *ast.VariableAssignment) (*stmts.VariableAssignment, error) { - oldValue, err := convertExpr(converter, v.Value) - if err != nil { - return nil, errors.Trace(err) - } - - return &stmts.VariableAssignment{ - IsGlobal: v.IsGlobal, - IsSystem: v.IsSystem, - Name: v.Name, - Value: oldValue, - Text: v.Text(), - }, nil -} - -func convertSet(converter *expressionConverter, v *ast.SetStmt) (*stmts.SetStmt, error) { - oldSet := &stmts.SetStmt{ - Text: v.Text(), - Variables: make([]*stmts.VariableAssignment, len(v.Variables)), - } - for i, val := range v.Variables { - oldAssign, err := convertVariableAssignment(converter, val) - if err != nil { - return nil, errors.Trace(err) - } - oldSet.Variables[i] = oldAssign - } - return oldSet, nil -} - -func convertSetCharset(converter *expressionConverter, v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { - return &stmts.SetCharsetStmt{ - Charset: v.Charset, - Collate: v.Collate, - Text: v.Text(), - }, nil -} - func convertSetPwd(converter *expressionConverter, v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { return &stmts.SetPwdStmt{ User: v.User, diff --git a/executor/converter/converter.go b/executor/converter/converter.go index dc106ce18fac2..e1f986008a794 100644 --- a/executor/converter/converter.go +++ b/executor/converter/converter.go @@ -62,12 +62,8 @@ func (con *Converter) Convert(node ast.Node) (stmt.Statement, error) { return convertRollback(c, v) case *ast.SelectStmt: return convertSelect(c, v) - case *ast.SetCharsetStmt: - return convertSetCharset(c, v) case *ast.SetPwdStmt: return convertSetPwd(c, v) - case *ast.SetStmt: - return convertSet(c, v) case *ast.ShowStmt: return convertShow(c, v) case *ast.TruncateTableStmt: @@ -76,8 +72,6 @@ func (con *Converter) Convert(node ast.Node) (stmt.Statement, error) { return convertUnion(c, v) case *ast.UpdateStmt: return convertUpdate(c, v) - case *ast.UseStmt: - return convertUse(c, v) } return nil, nil } diff --git a/executor/executor_simple.go b/executor/executor_simple.go new file mode 100644 index 0000000000000..b932ecf6d7662 --- /dev/null +++ b/executor/executor_simple.go @@ -0,0 +1,165 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "fmt" + "strings" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/optimizer/evaluator" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/db" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/charset" + "github.com/pingcap/tidb/util/types" +) + +// SimpleExec represents simple statement executor. +// For statements do simple execution. +// includes `UseStmt`, 'SetStmt`, `SetCharsetStmt`. +// TODO: list all simple statements. +type SimpleExec struct { + Statement ast.StmtNode + ctx context.Context + done bool +} + +// Fields implements Executor Fields interface. +func (e *SimpleExec) Fields() []*ast.ResultField { + return nil +} + +// Next implements Execution Next interface. +func (e *SimpleExec) Next() (*Row, error) { + if e.done { + return nil, nil + } + var err error + switch x := e.Statement.(type) { + case *ast.UseStmt: + err = e.executeUse(x) + case *ast.SetStmt: + err = e.executeSet(x) + case *ast.SetCharsetStmt: + err = e.executeSetCharset(x) + } + if err != nil { + return nil, errors.Trace(err) + } + e.done = true + return nil, nil +} + +// Close implements Executor Close interface. +func (e *SimpleExec) Close() error { + return nil +} + +func (e *SimpleExec) executeUse(s *ast.UseStmt) error { + dbname := model.NewCIStr(s.DBName) + dbinfo, exists := sessionctx.GetDomain(e.ctx).InfoSchema().SchemaByName(dbname) + if !exists { + return infoschema.DatabaseNotExists.Gen("database %s not exists", dbname) + } + db.BindCurrentSchema(e.ctx, dbname.O) + // character_set_database is the character set used by the default database. + // The server sets this variable whenever the default database changes. + // See: http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_character_set_database + sessionVars := variable.GetSessionVars(e.ctx) + sessionVars.Systems[variable.CharsetDatabase] = dbinfo.Charset + sessionVars.Systems[variable.CollationDatabase] = dbinfo.Collate + return nil +} + +func (e *SimpleExec) executeSet(s *ast.SetStmt) error { + sessionVars := variable.GetSessionVars(e.ctx) + globalVars := variable.GetGlobalVarAccessor(e.ctx) + for _, v := range s.Variables { + // Variable is case insensitive, we use lower case. + name := strings.ToLower(v.Name) + if !v.IsSystem { + // User variable. + value, err := evaluator.Eval(e.ctx, v.Value) + if err != nil { + return errors.Trace(err) + } + + if value == nil { + delete(sessionVars.Users, name) + } else { + sessionVars.Users[name] = fmt.Sprintf("%v", value) + } + return nil + } + sysVar := variable.GetSysVar(name) + if sysVar == nil { + return variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name) + } + if sysVar.Scope == variable.ScopeNone { + return errors.Errorf("Variable '%s' is a read only variable", name) + } + if v.IsGlobal { + if sysVar.Scope&variable.ScopeGlobal > 0 { + value, err := evaluator.Eval(e.ctx, v.Value) + if err != nil { + return errors.Trace(err) + } + if value == nil { + value = "" + } + svalue, err := types.ToString(value) + if err != nil { + return errors.Trace(err) + } + err = globalVars.SetGlobalSysVar(e.ctx, name, svalue) + return errors.Trace(err) + } + return errors.Errorf("Variable '%s' is a SESSION variable and can't be used with SET GLOBAL", name) + } + if sysVar.Scope&variable.ScopeSession > 0 { + if value, err := evaluator.Eval(e.ctx, v.Value); err != nil { + return errors.Trace(err) + } else if value == nil { + sessionVars.Systems[name] = "" + } else { + sessionVars.Systems[name] = fmt.Sprintf("%v", value) + } + return nil + } + return errors.Errorf("Variable '%s' is a GLOBAL variable and should be set with SET GLOBAL", name) + } + return nil +} + +func (e *SimpleExec) executeSetCharset(s *ast.SetCharsetStmt) error { + collation := s.Collate + if len(collation) == 0 { + var err error + collation, err = charset.GetDefaultCollation(s.Charset) + if err != nil { + return errors.Trace(err) + } + } + sessionVars := variable.GetSessionVars(e.ctx) + for _, v := range variable.SetNamesVariables { + sessionVars.Systems[v] = s.Charset + } + sessionVars.Systems[variable.CollationConnection] = collation + return nil +} diff --git a/executor/executor_simple_test.go b/executor/executor_simple_test.go new file mode 100644 index 0000000000000..a65eccc4b11fe --- /dev/null +++ b/executor/executor_simple_test.go @@ -0,0 +1,100 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/testkit" +) + +func (s *testSuite) TestCharsetDatabase(c *C) { + tk := testkit.NewTestKit(c, s.store) + testSQL := `create database if not exists cd_test_utf8 CHARACTER SET utf8 COLLATE utf8_bin;` + tk.MustExec(testSQL) + + testSQL = `create database if not exists cd_test_latin1 CHARACTER SET latin1 COLLATE latin1_swedish_ci;` + tk.MustExec(testSQL) + + testSQL = `use cd_test_utf8;` + tk.MustExec(testSQL) + tk.MustQuery(`select @@character_set_database;`).Check(testkit.Rows("utf8")) + tk.MustQuery(`select @@collation_database;`).Check(testkit.Rows("utf8_bin")) + + testSQL = `use cd_test_latin1;` + tk.MustExec(testSQL) + tk.MustQuery(`select @@character_set_database;`).Check(testkit.Rows("latin1")) + tk.MustQuery(`select @@collation_database;`).Check(testkit.Rows("latin1_swedish_ci")) +} + +func (s *testSuite) TestSet(c *C) { + tk := testkit.NewTestKit(c, s.store) + testSQL := "SET @a = 1;" + tk.MustExec(testSQL) + + testSQL = `SET @a = "1";` + tk.MustExec(testSQL) + + testSQL = "SET @a = null;" + tk.MustExec(testSQL) + + testSQL = "SET @@global.autocommit = 1;" + tk.MustExec(testSQL) + + testSQL = "SET @@global.autocommit = null;" + tk.MustExec(testSQL) + + testSQL = "SET @@autocommit = 1;" + tk.MustExec(testSQL) + + testSQL = "SET @@autocommit = null;" + tk.MustExec(testSQL) + + errTestSql := "SET @@date_format = 1;" + _, err := tk.Exec(errTestSql) + c.Assert(err, NotNil) + + errTestSql = "SET @@rewriter_enabled = 1;" + _, err = tk.Exec(errTestSql) + c.Assert(err, NotNil) + + errTestSql = "SET xxx = abcd;" + _, err = tk.Exec(errTestSql) + c.Assert(err, NotNil) + + errTestSql = "SET @@global.a = 1;" + _, err = tk.Exec(errTestSql) + c.Assert(err, NotNil) + + errTestSql = "SET @@global.timestamp = 1;" + _, err = tk.Exec(errTestSql) + c.Assert(err, NotNil) +} + +func (s *testSuite) TestSetCharset(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`SET NAMES latin1`) + + ctx := tk.Se.(context.Context) + sessionVars := variable.GetSessionVars(ctx) + for _, v := range variable.SetNamesVariables { + c.Assert(sessionVars.Systems[v] != "utf8", IsTrue) + } + tk.MustExec(`SET NAMES utf8`) + for _, v := range variable.SetNamesVariables { + c.Assert(sessionVars.Systems[v], Equals, "utf8") + } + c.Assert(sessionVars.Systems[variable.CollationConnection], Equals, "utf8_general_ci") +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 3fce439a98c81..4ff4f4b41f11a 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -88,6 +88,7 @@ func IsSupported(node ast.Node) bool { switch node.(type) { case *ast.SelectStmt, *ast.PrepareStmt, *ast.ExecuteStmt, *ast.DeallocateStmt, *ast.AdminStmt, *ast.UpdateStmt, *ast.DeleteStmt, *ast.UnionStmt: + case *ast.UseStmt, *ast.SetStmt, *ast.SetCharsetStmt: default: return false } diff --git a/optimizer/plan/planbuilder.go b/optimizer/plan/planbuilder.go index 827bf39883102..162dcdb70c293 100644 --- a/optimizer/plan/planbuilder.go +++ b/optimizer/plan/planbuilder.go @@ -71,6 +71,12 @@ func (b *planBuilder) build(node ast.Node) Plan { return b.buildUnion(x) case *ast.UpdateStmt: return b.buildUpdate(x) + case *ast.UseStmt: + return b.buildSimple(x) + case *ast.SetCharsetStmt: + return b.buildSimple(x) + case *ast.SetStmt: + return b.buildSimple(x) } b.err = ErrUnsupportedType.Gen("Unsupported type %T", node) return nil @@ -790,3 +796,7 @@ func columnOffsetInFields(cn *ast.ColumnName, fields []*ast.ResultField) (int, e } return offset, nil } + +func (b *planBuilder) buildSimple(node ast.StmtNode) Plan { + return &Simple{Statement: node} +} diff --git a/optimizer/plan/plans.go b/optimizer/plan/plans.go index 72efd1e08d6d7..fd5c3acc815e8 100644 --- a/optimizer/plan/plans.go +++ b/optimizer/plan/plans.go @@ -568,3 +568,20 @@ func (p *Filter) SetLimit(limit float64) { // We assume 50% of the src row is filtered out. p.src.SetLimit(limit * 2) } + +// Simple represents a simple statement plan which doesn't need any optimization. +type Simple struct { + basePlan + + Statement ast.StmtNode +} + +// Accept implements Plan Accept interface. +func (p *Simple) Accept(v Visitor) (Plan, bool) { + np, skip := v.Enter(p) + if skip { + v.Leave(np) + } + p = np.(*Simple) + return v.Leave(p) +} diff --git a/optimizer/resolver.go b/optimizer/resolver.go index 4022c6787a2f8..19a32da61b0aa 100644 --- a/optimizer/resolver.go +++ b/optimizer/resolver.go @@ -171,6 +171,14 @@ func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren b nr.currentContext().inOrderBy = true case *ast.SelectStmt: nr.pushContext() + case *ast.SetStmt: + for _, assign := range v.Variables { + if cn, ok := assign.Value.(*ast.ColumnNameExpr); ok && cn.Name.Table.L == "" { + // Convert column name expression to string value expression. + assign.Value = ast.NewValueExpr(cn.Name.Name.O) + } + } + nr.pushContext() case *ast.TableRefsClause: nr.currentContext().inTableRefs = true case *ast.UnionStmt: @@ -233,6 +241,8 @@ func (nr *nameResolver) Leave(inNode ast.Node) (node ast.Node, ok bool) { nr.useOuterContext = true } nr.popContext() + case *ast.SetStmt: + nr.popContext() case *ast.SubqueryExpr: if nr.useOuterContext { // TODO: check this diff --git a/session_test.go b/session_test.go index 198cfe2ced807..61fdfcd6a51c9 100644 --- a/session_test.go +++ b/session_test.go @@ -283,7 +283,7 @@ func (s *testSessionSuite) TestInTrans(c *C) { checkInTrans(c, se, "commit", 0) checkInTrans(c, se, "insert t values ()", 0) - checkInTrans(c, se, "set autocommit=O;", 0) + checkInTrans(c, se, "set autocommit=0;", 0) checkInTrans(c, se, "begin", 1) checkInTrans(c, se, "insert t values ()", 1) checkInTrans(c, se, "commit", 0) diff --git a/stmt/stmts/set.go b/stmt/stmts/set.go deleted file mode 100644 index f8c63bce094d4..0000000000000 --- a/stmt/stmts/set.go +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package stmts - -import ( - "fmt" - "strings" - - "github.com/juju/errors" - "github.com/ngaut/log" - "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/rset" - "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/stmt" - "github.com/pingcap/tidb/util/charset" - "github.com/pingcap/tidb/util/format" - "github.com/pingcap/tidb/util/types" -) - -var ( - _ stmt.Statement = (*SetStmt)(nil) - _ stmt.Statement = (*SetCharsetStmt)(nil) -) - -// VariableAssignment is a varible assignment struct. -type VariableAssignment struct { - Name string - Value expression.Expression - IsGlobal bool - IsSystem bool - - Text string -} - -// getValue gets VariableAssignment value from context. -// See: https://github.com/mysql/mysql-server/blob/5.7/sql/set_var.cc#L679 -func (v *VariableAssignment) getValue(ctx context.Context) (interface{}, error) { - switch vv := v.Value.(type) { - case *expression.Ident: - return vv.O, nil - default: - return vv.Eval(ctx, nil) - } -} - -// String implements the fmt.Stringer interface. -func (v *VariableAssignment) String() string { - if !v.IsSystem { - return fmt.Sprintf("@%s=%s", v.Name, v.Value.String()) - } - - if v.IsGlobal { - return fmt.Sprintf("@@global.%s=%s", v.Name, v.Value.String()) - } - return fmt.Sprintf("@@session.%s=%s", v.Name, v.Value.String()) -} - -// SetStmt is a statement to assigns values to different types of variables. -// See: https://dev.mysql.com/doc/refman/5.7/en/set-statement.html -type SetStmt struct { - Variables []*VariableAssignment - - Text string -} - -// Explain implements the stmt.Statement Explain interface. -func (s *SetStmt) Explain(ctx context.Context, w format.Formatter) { - w.Format("%s\n", s.Text) -} - -// IsDDL implements the stmt.Statement IsDDL interface. -func (s *SetStmt) IsDDL() bool { - return false -} - -// OriginText implements the stmt.Statement OriginText interface. -func (s *SetStmt) OriginText() string { - return s.Text -} - -// SetText implements the stmt.Statement SetText interface. -func (s *SetStmt) SetText(text string) { - s.Text = text -} - -// Exec implements the stmt.Statement Exec interface. -func (s *SetStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { - log.Debug("Set sys/user variables") - - sessionVars := variable.GetSessionVars(ctx) - globalVars := variable.GetGlobalVarAccessor(ctx) - for _, v := range s.Variables { - // Variable is case insensitive, we use lower case. - name := strings.ToLower(v.Name) - if !v.IsSystem { - // User variable. - value, err := v.getValue(ctx) - if err != nil { - return nil, errors.Trace(err) - } - - if value == nil { - delete(sessionVars.Users, name) - } else { - sessionVars.Users[name] = fmt.Sprintf("%v", value) - } - return nil, nil - } - sysVar := variable.GetSysVar(name) - if sysVar == nil { - return nil, variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name) - } - if sysVar.Scope == variable.ScopeNone { - return nil, errors.Errorf("Variable '%s' is a read only variable", name) - } - if v.IsGlobal { - if sysVar.Scope&variable.ScopeGlobal > 0 { - value, err := v.getValue(ctx) - if err != nil { - return nil, errors.Trace(err) - } - if value == nil { - value = "" - } - svalue, err := types.ToString(value) - if err != nil { - return nil, errors.Trace(err) - } - err = globalVars.SetGlobalSysVar(ctx, name, svalue) - return nil, errors.Trace(err) - } - return nil, errors.Errorf("Variable '%s' is a SESSION variable and can't be used with SET GLOBAL", name) - } - if sysVar.Scope&variable.ScopeSession > 0 { - if value, err := v.getValue(ctx); err != nil { - return nil, errors.Trace(err) - } else if value == nil { - sessionVars.Systems[name] = "" - } else { - sessionVars.Systems[name] = fmt.Sprintf("%v", value) - } - return nil, nil - } - return nil, errors.Errorf("Variable '%s' is a GLOBAL variable and should be set with SET GLOBAL", name) - } - - return nil, nil -} - -// SetCharsetStmt is a statement to assign values to character and collation variables. -// See: https://dev.mysql.com/doc/refman/5.7/en/set-statement.html -type SetCharsetStmt struct { - Charset string - Collate string - - Text string -} - -// Explain implements the stmt.Statement Explain interface. -func (s *SetCharsetStmt) Explain(ctx context.Context, w format.Formatter) { - w.Format("%s\n", s.Text) -} - -// IsDDL implements the stmt.Statement IsDDL interface. -func (s *SetCharsetStmt) IsDDL() bool { - return false -} - -// OriginText implements the stmt.Statement OriginText interface. -func (s *SetCharsetStmt) OriginText() string { - return s.Text -} - -// SetText implements the stmt.Statement SetText interface. -func (s *SetCharsetStmt) SetText(text string) { - s.Text = text -} - -// Exec implements the stmt.Statement Exec interface. -// SET NAMES sets the three session system variables character_set_client, character_set_connection, -// and character_set_results to the given character set. Setting character_set_connection to charset_name -// also sets collation_connection to the default collation for charset_name. -// The optional COLLATE clause may be used to specify a collation explicitly. -func (s *SetCharsetStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { - log.Debug("Set charset to ", s.Charset) - collation := s.Collate - if len(collation) == 0 { - collation, err = charset.GetDefaultCollation(s.Charset) - if err != nil { - return nil, errors.Trace(err) - } - } - sessionVars := variable.GetSessionVars(ctx) - for _, v := range variable.SetNamesVariables { - sessionVars.Systems[v] = s.Charset - } - sessionVars.Systems[variable.CollationConnection] = collation - return nil, nil -} diff --git a/stmt/stmts/set_test.go b/stmt/stmts/set_test.go deleted file mode 100644 index a6e4af7577dbb..0000000000000 --- a/stmt/stmts/set_test.go +++ /dev/null @@ -1,151 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package stmts_test - -import ( - . "github.com/pingcap/check" - "github.com/pingcap/tidb" - "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/stmt/stmts" - "github.com/pingcap/tidb/util/mock" -) - -func (s *testStmtSuite) TestSet(c *C) { - testSQL := "SET @a = 1;" - mustExec(c, s.testDB, testSQL) - - stmtList, err := tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - c.Assert(stmtList, HasLen, 1) - - testStmt, ok := stmtList[0].(*stmts.SetStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsFalse) - c.Assert(len(testStmt.Variables[0].String()), Greater, 0) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - mf := newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - testSQL = `SET @a = "1";` - mustExec(c, s.testDB, testSQL) - - testSQL = "SET @a = null;" - mustExec(c, s.testDB, testSQL) - - testSQL = "SET @@global.autocommit = 1;" - mustExec(c, s.testDB, testSQL) - - stmtList, err = tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - c.Assert(stmtList, HasLen, 1) - - testStmt, ok = stmtList[0].(*stmts.SetStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsFalse) - c.Assert(len(testStmt.Variables[0].String()), Greater, 0) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - mf = newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - testSQL = "SET @@global.autocommit = null;" - mustExec(c, s.testDB, testSQL) - - testSQL = "SET @@autocommit = 1;" - mustExec(c, s.testDB, testSQL) - - stmtList, err = tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - c.Assert(stmtList, HasLen, 1) - - testStmt, ok = stmtList[0].(*stmts.SetStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsFalse) - c.Assert(len(testStmt.Variables[0].String()), Greater, 0) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - mf = newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - testSQL = "SET @@autocommit = null;" - mustExec(c, s.testDB, testSQL) - - errTestSql := "SET @@date_format = 1;" - tx := mustBegin(c, s.testDB) - _, err = tx.Exec(errTestSql) - c.Assert(err, NotNil) - tx.Rollback() - - errTestSql = "SET @@rewriter_enabled = 1;" - tx = mustBegin(c, s.testDB) - _, err = tx.Exec(errTestSql) - c.Assert(err, NotNil) - tx.Rollback() - - errTestSql = "SET xxx = abcd;" - tx = mustBegin(c, s.testDB) - _, err = tx.Exec(errTestSql) - c.Assert(err, NotNil) - tx.Rollback() - - errTestSql = "SET @@global.a = 1;" - tx = mustBegin(c, s.testDB) - _, err = tx.Exec(errTestSql) - c.Assert(err, NotNil) - tx.Rollback() - - errTestSql = "SET @@global.timestamp = 1;" - tx = mustBegin(c, s.testDB) - _, err = tx.Exec(errTestSql) - c.Assert(err, NotNil) - tx.Rollback() -} - -func (s *testStmtSuite) TestSetCharsetStmt(c *C) { - testSQL := `SET NAMES utf8;` - - stmtList, err := tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - c.Assert(stmtList, HasLen, 1) - - testStmt, ok := stmtList[0].(*stmts.SetCharsetStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsFalse) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - ctx := mock.NewContext() - variable.BindSessionVars(ctx) - sessionVars := variable.GetSessionVars(ctx) - for _, v := range variable.SetNamesVariables { - c.Assert(sessionVars.Systems[v] != "utf8", IsTrue) - } - _, err = testStmt.Exec(ctx) - c.Assert(err, IsNil) - for _, v := range variable.SetNamesVariables { - c.Assert(sessionVars.Systems[v], Equals, "utf8") - } - c.Assert(sessionVars.Systems[variable.CollationConnection], Equals, "utf8_general_ci") - - mf := newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) -} diff --git a/stmt/stmts/use.go b/stmt/stmts/use.go deleted file mode 100644 index ac0273c7ed1c3..0000000000000 --- a/stmt/stmts/use.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package stmts - -import ( - "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/infoschema" - "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/rset" - "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/db" - "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/stmt" - "github.com/pingcap/tidb/util/format" -) - -var _ stmt.Statement = (*UseStmt)(nil) - -// UseStmt is a statement to use the DBName database as the current database. -// See: https://dev.mysql.com/doc/refman/5.7/en/use.html -type UseStmt struct { - DBName string - - Text string -} - -// Explain implements the stmt.Statement Explain interface. -func (s *UseStmt) Explain(ctx context.Context, w format.Formatter) { - w.Format("%s\n", s.Text) -} - -// IsDDL implements the stmt.Statement IsDDL interface. -func (s *UseStmt) IsDDL() bool { - return false -} - -// OriginText implements the stmt.Statement OriginText interface. -func (s *UseStmt) OriginText() string { - return s.Text -} - -// SetText implements the stmt.Statement SetText interface. -func (s *UseStmt) SetText(text string) { - s.Text = text -} - -// Exec implements the stmt.Statement Exec interface. -func (s *UseStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { - dbname := model.NewCIStr(s.DBName) - dbinfo, exists := sessionctx.GetDomain(ctx).InfoSchema().SchemaByName(dbname) - if !exists { - return nil, infoschema.DatabaseNotExists.Gen("database %s not exists", dbname) - } - db.BindCurrentSchema(ctx, dbname.O) - s.updateSysVars(ctx, dbinfo) - return nil, nil -} - -// Update system variables -func (s *UseStmt) updateSysVars(ctx context.Context, dbinfo *model.DBInfo) { - // character_set_database is the character set used by the default database. - // The server sets this variable whenever the default database changes. - // See: http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_character_set_database - sessionVars := variable.GetSessionVars(ctx) - sessionVars.Systems[variable.CharsetDatabase] = dbinfo.Charset - sessionVars.Systems[variable.CollationDatabase] = dbinfo.Collate -} diff --git a/stmt/stmts/use_test.go b/stmt/stmts/use_test.go deleted file mode 100644 index 010c79a293d2e..0000000000000 --- a/stmt/stmts/use_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package stmts_test - -import ( - . "github.com/pingcap/check" - "github.com/pingcap/tidb" - "github.com/pingcap/tidb/stmt/stmts" -) - -func (s *testStmtSuite) TestUse(c *C) { - testSQL := `create database if not exists use_test;` - mustExec(c, s.testDB, testSQL) - - testSQL = `use test;` - stmtList, err := tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - c.Assert(stmtList, HasLen, 1) - - testStmt, ok := stmtList[0].(*stmts.UseStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsFalse) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - mf := newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - errTestSQL := `use xxx;` - tx := mustBegin(c, s.testDB) - _, err = tx.Exec(errTestSQL) - c.Assert(err, NotNil) - tx.Rollback() -} - -func (s *testStmtSuite) TestCharsetDatabase(c *C) { - testSQL := `create database if not exists cd_test_utf8 CHARACTER SET utf8 COLLATE utf8_bin;` - mustExec(c, s.testDB, testSQL) - - testSQL = `create database if not exists cd_test_latin1 CHARACTER SET latin1 COLLATE latin1_swedish_ci;` - mustExec(c, s.testDB, testSQL) - - testSQL = `use cd_test_utf8;` - mustExec(c, s.testDB, testSQL) - - tx := mustBegin(c, s.testDB) - rows, err := tx.Query(`select @@character_set_database;`) - c.Assert(err, IsNil) - matchRows(c, rows, [][]interface{}{{"utf8"}}) - rows, err = tx.Query(`select @@collation_database;`) - c.Assert(err, IsNil) - matchRows(c, rows, [][]interface{}{{"utf8_bin"}}) - mustCommit(c, tx) - - testSQL = `use cd_test_latin1;` - mustExec(c, s.testDB, testSQL) - - tx = mustBegin(c, s.testDB) - rows, err = tx.Query(`select @@character_set_database;`) - c.Assert(err, IsNil) - matchRows(c, rows, [][]interface{}{{"latin1"}}) - rows, err = tx.Query(`select @@collation_database;`) - c.Assert(err, IsNil) - matchRows(c, rows, [][]interface{}{{"latin1_swedish_ci"}}) - mustCommit(c, tx) -}