From a7b0a1f915dd6ec7fe6334fdd2720ca8da961407 Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Sat, 9 Dec 2017 01:27:58 -0600 Subject: [PATCH] server,tidb: move cancel function from session to clientConn (#5346) --- executor/prepared_test.go | 20 ++++++++-------- new_session_test.go | 7 +++--- server/conn.go | 21 +++++++++++++---- server/conn_stmt.go | 3 +-- server/driver.go | 5 +--- server/driver_tidb.go | 9 ++------ server/server.go | 8 ++++++- session.go | 48 ++++++++++++--------------------------- tidb_test.go | 5 ++-- util/testkit/testkit.go | 5 ++-- 10 files changed, 63 insertions(+), 68 deletions(-) diff --git a/executor/prepared_test.go b/executor/prepared_test.go index ddd85873a9d1c..8d6ccdb4c249c 100644 --- a/executor/prepared_test.go +++ b/executor/prepared_test.go @@ -29,6 +29,7 @@ func (s *testSuite) TestPrepared(c *C) { orgEnable := cfg.PreparedPlanCache.Enabled orgCapacity := cfg.PreparedPlanCache.Capacity flags := []bool{false, true} + goCtx := goctx.Background() for _, flag := range flags { cfg.PreparedPlanCache.Enabled = flag cfg.PreparedPlanCache.Capacity = 100 @@ -69,7 +70,7 @@ func (s *testSuite) TestPrepared(c *C) { query := "select c1, c2 from prepare_test where c1 = ?" stmtId, _, _, err := tk.Se.PrepareStmt(query) c.Assert(err, IsNil) - _, err = tk.Se.ExecutePreparedStmt(stmtId, 1) + _, err = tk.Se.ExecutePreparedStmt(goCtx, stmtId, 1) c.Assert(err, IsNil) // Check that ast.Statement created by executor.CompileExecutePreparedStmt has query text. @@ -78,12 +79,12 @@ func (s *testSuite) TestPrepared(c *C) { c.Assert(stmt.OriginText(), Equals, query) // Check that rebuild plan works. - tk.Se.PrepareTxnCtx(goctx.Background()) + tk.Se.PrepareTxnCtx(goCtx) err = stmt.RebuildPlan() c.Assert(err, IsNil) - rs, err := stmt.Exec(goctx.Background()) + rs, err := stmt.Exec(goCtx) c.Assert(err, IsNil) - _, err = rs.Next(goctx.Background()) + _, err = rs.Next(goCtx) c.Assert(err, IsNil) c.Assert(rs.Close(), IsNil) @@ -92,17 +93,17 @@ func (s *testSuite) TestPrepared(c *C) { tk.Exec("create table prepare2 (a int)") // Should success as the changed schema do not affect the prepared statement. - _, err = tk.Se.ExecutePreparedStmt(stmtId, 1) + _, err = tk.Se.ExecutePreparedStmt(goCtx, stmtId, 1) c.Assert(err, IsNil) // Drop a column so the prepared statement become invalid. tk.MustExec("alter table prepare_test drop column c2") - _, err = tk.Se.ExecutePreparedStmt(stmtId, 1) + _, err = tk.Se.ExecutePreparedStmt(goCtx, stmtId, 1) c.Assert(plan.ErrUnknownColumn.Equal(err), IsTrue) tk.MustExec("drop table prepare_test") - _, err = tk.Se.ExecutePreparedStmt(stmtId, 1) + _, err = tk.Se.ExecutePreparedStmt(goCtx, stmtId, 1) c.Assert(plan.ErrSchemaChanged.Equal(err), IsTrue) // issue 3381 @@ -114,7 +115,7 @@ func (s *testSuite) TestPrepared(c *C) { // Coverage. exec := &executor.ExecuteExec{} - exec.Next(goctx.Background()) + exec.Next(goCtx) exec.Close() } cfg.PreparedPlanCache.Enabled = orgEnable @@ -126,6 +127,7 @@ func (s *testSuite) TestPreparedLimitOffset(c *C) { orgEnable := cfg.PreparedPlanCache.Enabled orgCapacity := cfg.PreparedPlanCache.Capacity flags := []bool{false, true} + goCtx := goctx.Background() for _, flag := range flags { cfg.PreparedPlanCache.Enabled = flag cfg.PreparedPlanCache.Capacity = 100 @@ -148,7 +150,7 @@ func (s *testSuite) TestPreparedLimitOffset(c *C) { stmtID, _, _, err := tk.Se.PrepareStmt("select id from prepare_test limit ?") c.Assert(err, IsNil) - _, err = tk.Se.ExecutePreparedStmt(stmtID, 1) + _, err = tk.Se.ExecutePreparedStmt(goCtx, stmtID, 1) c.Assert(err, IsNil) } cfg.PreparedPlanCache.Enabled = orgEnable diff --git a/new_session_test.go b/new_session_test.go index 743f40bf68e5b..09109e69e3ca4 100644 --- a/new_session_test.go +++ b/new_session_test.go @@ -693,11 +693,12 @@ func (s *testSessionSuite) TestPrepare(c *C) { tk.MustExec("create table t(id TEXT)") tk.MustExec(`INSERT INTO t VALUES ("id");`) id, ps, _, err := tk.Se.PrepareStmt("select id+? from t") + goCtx := goctx.Background() c.Assert(err, IsNil) c.Assert(id, Equals, uint32(1)) c.Assert(ps, Equals, 1) tk.MustExec(`set @a=1`) - _, err = tk.Se.ExecutePreparedStmt(id, "1") + _, err = tk.Se.ExecutePreparedStmt(goCtx, id, "1") c.Assert(err, IsNil) err = tk.Se.DropPreparedStmt(id) c.Assert(err, IsNil) @@ -718,10 +719,10 @@ func (s *testSessionSuite) TestPrepare(c *C) { tk.MustExec("insert multiexec values (1, 1), (2, 2)") id, _, _, err = tk.Se.PrepareStmt("select a from multiexec where b = ? order by b") c.Assert(err, IsNil) - rs, err := tk.Se.ExecutePreparedStmt(id, 1) + rs, err := tk.Se.ExecutePreparedStmt(goCtx, id, 1) c.Assert(err, IsNil) rs.Close() - rs, err = tk.Se.ExecutePreparedStmt(id, 2) + rs, err = tk.Se.ExecutePreparedStmt(goCtx, id, 2) rs.Close() c.Assert(err, IsNil) } diff --git a/server/conn.go b/server/conn.go index da9c1068f6964..62f6b3f751b93 100644 --- a/server/conn.go +++ b/server/conn.go @@ -44,6 +44,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" log "github.com/Sirupsen/logrus" @@ -77,7 +78,13 @@ type clientConn struct { lastCmd string // latest sql query string, currently used for logging error. ctx QueryCtx // an interface to execute sql statements. attrs map[string]string // attributes parsed from client handshake response, not used for now. - killed bool + + // cancelFunc is used for cancelling the execution of current transaction. + mu struct { + sync.RWMutex + cancelFunc goctx.CancelFunc + } + killed bool } func (cc *clientConn) String() string { @@ -502,6 +509,12 @@ func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { func (cc *clientConn) dispatch(data []byte) error { span := opentracing.StartSpan("server.dispatch") goCtx := opentracing.ContextWithSpan(goctx.Background(), span) + + goCtx1, cancelFunc := goctx.WithCancel(goCtx) + cc.mu.Lock() + cc.mu.cancelFunc = cancelFunc + cc.mu.Unlock() + cmd := data[0] data = data[1:] cc.lastCmd = hack.String(data) @@ -527,11 +540,11 @@ func (cc *clientConn) dispatch(data []byte) error { if len(data) > 0 && data[len(data)-1] == 0 { data = data[:len(data)-1] } - return cc.handleQuery(goCtx, hack.String(data)) + return cc.handleQuery(goCtx1, hack.String(data)) case mysql.ComPing: return cc.writeOK() case mysql.ComInitDB: - if err := cc.useDB(goCtx, hack.String(data)); err != nil { + if err := cc.useDB(goCtx1, hack.String(data)); err != nil { return errors.Trace(err) } return cc.writeOK() @@ -540,7 +553,7 @@ func (cc *clientConn) dispatch(data []byte) error { case mysql.ComStmtPrepare: return cc.handleStmtPrepare(hack.String(data)) case mysql.ComStmtExecute: - return cc.handleStmtExecute(goCtx, data) + return cc.handleStmtExecute(goCtx1, data) case mysql.ComStmtClose: return cc.handleStmtClose(data) case mysql.ComStmtSendLongData: diff --git a/server/conn_stmt.go b/server/conn_stmt.go index 6f46b4a7fe205..a33290f01d8c6 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -106,7 +106,6 @@ func (cc *clientConn) handleStmtExecute(goCtx goctx.Context, data []byte) (err e if len(data) < 9 { return mysql.ErrMalformPacket } - pos := 0 stmtID := binary.LittleEndian.Uint32(data[0:4]) pos += 4 @@ -164,7 +163,7 @@ func (cc *clientConn) handleStmtExecute(goCtx goctx.Context, data []byte) (err e return errors.Trace(err) } } - rs, err := stmt.Execute(args...) + rs, err := stmt.Execute(goCtx, args...) if err != nil { return errors.Trace(err) } diff --git a/server/driver.go b/server/driver.go index 5f23b5f6de920..5a82d51615de5 100644 --- a/server/driver.go +++ b/server/driver.go @@ -84,9 +84,6 @@ type QueryCtx interface { ShowProcess() util.ProcessInfo SetSessionManager(util.SessionManager) - - // Cancel the execution of current transaction. - Cancel() } // PreparedStatement is the interface to use a prepared statement. @@ -95,7 +92,7 @@ type PreparedStatement interface { ID() int // Execute executes the statement. - Execute(args ...interface{}) (ResultSet, error) + Execute(goctx.Context, ...interface{}) (ResultSet, error) // AppendParam appends parameter to the statement. AppendParam(paramID int, data []byte) error diff --git a/server/driver_tidb.go b/server/driver_tidb.go index 5acd92a5bdbad..db3a374ac535a 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -64,8 +64,8 @@ func (ts *TiDBStatement) ID() int { } // Execute implements PreparedStatement Execute method. -func (ts *TiDBStatement) Execute(args ...interface{}) (rs ResultSet, err error) { - tidbRecordset, err := ts.ctx.session.ExecutePreparedStmt(ts.id, args...) +func (ts *TiDBStatement) Execute(goCtx goctx.Context, args ...interface{}) (rs ResultSet, err error) { + tidbRecordset, err := ts.ctx.session.ExecutePreparedStmt(goCtx, ts.id, args...) if err != nil { return nil, errors.Trace(err) } @@ -280,11 +280,6 @@ func (tc *TiDBContext) ShowProcess() util.ProcessInfo { return tc.session.ShowProcess() } -// Cancel implements QueryCtx Cancel method. -func (tc *TiDBContext) Cancel() { - tc.session.Cancel() -} - type tidbResultSet struct { recordSet ast.RecordSet columns []*ColumnInfo diff --git a/server/server.go b/server/server.go index 0885d2f109196..39142e4fe9b9c 100644 --- a/server/server.go +++ b/server/server.go @@ -339,7 +339,13 @@ func (s *Server) Kill(connectionID uint64, query bool) { return } - conn.ctx.Cancel() + conn.mu.RLock() + cancelFunc := conn.mu.cancelFunc + conn.mu.RUnlock() + if cancelFunc != nil { + cancelFunc() + } + if !query { conn.killed = true } diff --git a/session.go b/session.go index 8a806f34bab75..a81bb648dd671 100644 --- a/session.go +++ b/session.go @@ -71,7 +71,7 @@ type Session interface { // PrepareStmt executes prepare statement in binary protocol. PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*ast.ResultField, err error) // ExecutePreparedStmt executes a prepared statement. - ExecutePreparedStmt(stmtID uint32, param ...interface{}) (ast.RecordSet, error) + ExecutePreparedStmt(goCtx goctx.Context, stmtID uint32, param ...interface{}) (ast.RecordSet, error) DropPreparedStmt(stmtID uint32) error SetClientCapability(uint32) // Set client capability flags. SetConnectionID(uint64) @@ -80,8 +80,6 @@ type Session interface { SetSessionManager(util.SessionManager) Close() Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool - // Cancel the execution of current transaction. - Cancel() ShowProcess() util.ProcessInfo // PrePareTxnCtx is exported for test. PrepareTxnCtx(goctx.Context) @@ -123,9 +121,6 @@ type session struct { mu struct { sync.RWMutex values map[fmt.Stringer]interface{} - - // cancelFunc is used for cancelling the execution of current transaction. - cancelFunc goctx.CancelFunc } store kv.Storage @@ -140,15 +135,6 @@ type session struct { statsCollector *statistics.SessionStatsCollector } -// Cancel cancels the execution of current transaction. -func (s *session) Cancel() { - // TODO: How to wait for the resource to release and make sure - // it's not leak? - s.mu.RLock() - s.mu.cancelFunc() - s.mu.RUnlock() -} - func (s *session) cleanRetryInfo() { if !s.sessionVars.RetryInfo.Retrying { retryInfo := s.sessionVars.RetryInfo @@ -699,14 +685,12 @@ func (s *session) executeStatement(goCtx goctx.Context, connID uint64, stmtNode } func (s *session) Execute(goCtx goctx.Context, sql string) (recordSets []ast.RecordSet, err error) { - span, goCtx1 := opentracing.StartSpanFromContext(goCtx, "session.Execute") - defer span.Finish() + if span := opentracing.SpanFromContext(goCtx); span != nil { + span, goCtx = opentracing.StartSpanFromContext(goCtx, "session.Execute") + defer span.Finish() + } - goCtx2, cancelFunc := goctx.WithCancel(goCtx1) - s.mu.Lock() - s.mu.cancelFunc = cancelFunc - s.mu.Unlock() - s.PrepareTxnCtx(goCtx2) + s.PrepareTxnCtx(goCtx) var ( cacheKey kvcache.Key cacheValue kvcache.Value @@ -733,9 +717,9 @@ func (s *session) Execute(goCtx goctx.Context, sql string) (recordSets []ast.Rec Ctx: s, } - s.PrepareTxnCtx(goCtx2) + s.PrepareTxnCtx(goCtx) executor.ResetStmtCtx(s, stmtNode) - if recordSets, err = s.executeStatement(goCtx2, connID, stmtNode, stmt, recordSets); err != nil { + if recordSets, err = s.executeStatement(goCtx, connID, stmtNode, stmt, recordSets); err != nil { return nil, errors.Trace(err) } } else { @@ -743,7 +727,7 @@ func (s *session) Execute(goCtx goctx.Context, sql string) (recordSets []ast.Rec // Step1: Compile query string to abstract syntax trees(ASTs). startTS := time.Now() - stmtNodes, err := s.ParseSQL(goCtx2, sql, charset, collation) + stmtNodes, err := s.ParseSQL(goCtx, sql, charset, collation) if err != nil { log.Warnf("[%d] parse error:\n%v\n%s", connID, err, sql) return nil, errors.Trace(err) @@ -752,16 +736,16 @@ func (s *session) Execute(goCtx goctx.Context, sql string) (recordSets []ast.Rec compiler := executor.Compiler{Ctx: s} for _, stmtNode := range stmtNodes { - s.PrepareTxnCtx(goCtx2) + s.PrepareTxnCtx(goCtx) // Step2: Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt). startTS = time.Now() // Some executions are done in compile stage, so we reset them before compile. executor.ResetStmtCtx(s, stmtNode) - stmt, err := compiler.Compile(goCtx2, stmtNode) + stmt, err := compiler.Compile(goCtx, stmtNode) if err != nil { log.Warnf("[%d] compile error:\n%v\n%s", connID, err, sql) - terror.Log(errors.Trace(s.RollbackTxn(goCtx2))) + terror.Log(errors.Trace(s.RollbackTxn(goCtx))) return nil, errors.Trace(err) } sessionExecuteCompileDuration.Observe(time.Since(startTS).Seconds()) @@ -772,7 +756,7 @@ func (s *session) Execute(goCtx goctx.Context, sql string) (recordSets []ast.Rec } // Step4: Execute the physical plan. - if recordSets, err = s.executeStatement(goCtx2, connID, stmtNode, stmt, recordSets); err != nil { + if recordSets, err = s.executeStatement(goCtx, connID, stmtNode, stmt, recordSets); err != nil { return nil, errors.Trace(err) } } @@ -843,15 +827,11 @@ func checkArgs(args ...interface{}) error { } // ExecutePreparedStmt executes a prepared statement. -func (s *session) ExecutePreparedStmt(stmtID uint32, args ...interface{}) (ast.RecordSet, error) { +func (s *session) ExecutePreparedStmt(goCtx goctx.Context, stmtID uint32, args ...interface{}) (ast.RecordSet, error) { err := checkArgs(args...) if err != nil { return nil, errors.Trace(err) } - goCtx, cancelFunc := goctx.WithCancel(goctx.TODO()) - s.mu.Lock() - s.mu.cancelFunc = cancelFunc - s.mu.Unlock() s.PrepareTxnCtx(goCtx) st, err := executor.CompileExecutePreparedStmt(s, stmtID, args...) if err != nil { diff --git a/tidb_test.go b/tidb_test.go index 194f45aa6bf9f..051355e7d1b8f 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -208,8 +208,9 @@ func removeStore(c *C, dbPath string) { } func exec(se Session, sql string, args ...interface{}) (ast.RecordSet, error) { + goCtx := goctx.Background() if len(args) == 0 { - rs, err := se.Execute(goctx.Background(), sql) + rs, err := se.Execute(goCtx, sql) if err == nil && len(rs) > 0 { return rs[0], nil } @@ -219,7 +220,7 @@ func exec(se Session, sql string, args ...interface{}) (ast.RecordSet, error) { if err != nil { return nil, err } - rs, err := se.ExecutePreparedStmt(stmtID, args...) + rs, err := se.ExecutePreparedStmt(goCtx, stmtID, args...) if err != nil { return nil, err } diff --git a/util/testkit/testkit.go b/util/testkit/testkit.go index a7cb3cc8b4641..77a3beac43c75 100644 --- a/util/testkit/testkit.go +++ b/util/testkit/testkit.go @@ -124,9 +124,10 @@ func (tk *TestKit) Exec(sql string, args ...interface{}) (ast.RecordSet, error) id := atomic.AddUint64(&connectionID, 1) tk.Se.SetConnectionID(id) } + goCtx := goctx.Background() if len(args) == 0 { var rss []ast.RecordSet - rss, err = tk.Se.Execute(goctx.Background(), sql) + rss, err = tk.Se.Execute(goCtx, sql) if err == nil && len(rss) > 0 { return rss[0], nil } @@ -136,7 +137,7 @@ func (tk *TestKit) Exec(sql string, args ...interface{}) (ast.RecordSet, error) if err != nil { return nil, errors.Trace(err) } - rs, err := tk.Se.ExecutePreparedStmt(stmtID, args...) + rs, err := tk.Se.ExecutePreparedStmt(goCtx, stmtID, args...) if err != nil { return nil, errors.Trace(err) }