From 06629d6ac47cfa3be37b14d1b80efcb7b753a305 Mon Sep 17 00:00:00 2001 From: cfzjywxk Date: Mon, 23 Sep 2019 20:36:45 +0800 Subject: [PATCH] *: check privilege when reusing the cached plan (#12211) --- executor/prepared.go | 36 +++++++++++++++++---- planner/core/cache.go | 7 ++++ planner/core/common_plans.go | 27 ++++++++++++++-- planner/core/prepare_test.go | 58 ++++++++++++++++++++++++++++++++++ server/conn_stmt.go | 13 ++++++-- session/session.go | 31 ++++++++++++------ sessionctx/variable/session.go | 6 ++-- 7 files changed, 152 insertions(+), 26 deletions(-) diff --git a/executor/prepared.go b/executor/prepared.go index adfd733b5af4f..76762c5ccd992 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -182,7 +182,10 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error { param.InExecute = false } var p plannercore.Plan - p, err = plannercore.BuildLogicalPlan(ctx, e.ctx, stmt, e.is) + e.ctx.GetSessionVars().PlanID = 0 + e.ctx.GetSessionVars().PlanColumnID = 0 + destBuilder := plannercore.NewPlanBuilder(e.ctx, e.is, &plannercore.BlockHintProcessor{}) + p, err = destBuilder.Build(ctx, stmt) if err != nil { return err } @@ -195,7 +198,12 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error { if e.name != "" { vars.PreparedStmtNameToID[e.name] = e.ID } - return vars.AddPreparedStmt(e.ID, prepared) + + preparedObj := &plannercore.CachedPrepareStmt{ + PreparedAst: prepared, + VisitInfos: destBuilder.GetVisitInfo(), + } + return vars.AddPreparedStmt(e.ID, preparedObj) } // ExecuteExec represents an EXECUTE executor. @@ -258,10 +266,16 @@ func (e *DeallocateExec) Next(ctx context.Context, req *chunk.Chunk) error { if !ok { return errors.Trace(plannercore.ErrStmtNotFound) } + preparedPointer := vars.PreparedStmts[id] + preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt) + if !ok { + return errors.Errorf("invalid CachedPrepareStmt type") + } + prepared := preparedObj.PreparedAst delete(vars.PreparedStmtNameToID, e.Name) if plannercore.PreparedPlanCacheEnabled() { e.ctx.PreparedPlanCache().Delete(plannercore.NewPSTMTPlanCacheKey( - vars, id, vars.PreparedStmts[id].SchemaVersion, + vars, id, prepared.SchemaVersion, )) } vars.RemovePreparedStmt(id) @@ -293,8 +307,12 @@ func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context, Ctx: sctx, OutputNames: execPlan.OutputNames(), } - if prepared, ok := sctx.GetSessionVars().PreparedStmts[ID]; ok { - stmt.Text = prepared.Stmt.Text() + if preparedPointer, ok := sctx.GetSessionVars().PreparedStmts[ID]; ok { + preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt) + if !ok { + return nil, errors.Errorf("invalid CachedPrepareStmt type") + } + stmt.Text = preparedObj.PreparedAst.Stmt.Text() sctx.GetSessionVars().StmtCtx.OriginalSQL = stmt.Text } return stmt, nil @@ -308,8 +326,12 @@ func getPreparedStmt(stmt *ast.ExecuteStmt, vars *variable.SessionVars) (ast.Stm return nil, plannercore.ErrStmtNotFound } } - if prepared, ok := vars.PreparedStmts[execID]; ok { - return prepared.Stmt, nil + if preparedPointer, ok := vars.PreparedStmts[execID]; ok { + preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt) + if !ok { + return nil, errors.Errorf("invalid CachedPrepareStmt type") + } + return preparedObj.PreparedAst.Stmt, nil } return nil, plannercore.ErrStmtNotFound } diff --git a/planner/core/cache.go b/planner/core/cache.go index cf324868e4718..af5c574395582 100644 --- a/planner/core/cache.go +++ b/planner/core/cache.go @@ -17,6 +17,7 @@ import ( "sync/atomic" "time" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" @@ -134,3 +135,9 @@ func NewPSTMTPlanCacheValue(plan Plan, names []*types.FieldName) *PSTMTPlanCache OutPutNames: names, } } + +// CachedPrepareStmt store prepared ast from PrepareExec and other related fields +type CachedPrepareStmt struct { + PreparedAst *ast.Prepared + VisitInfos []visitInfo +} diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index 690c293f11085..549c68ad14622 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" @@ -189,10 +190,15 @@ func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Cont if e.Name != "" { e.ExecID = vars.PreparedStmtNameToID[e.Name] } - prepared, ok := vars.PreparedStmts[e.ExecID] + preparedPointer, ok := vars.PreparedStmts[e.ExecID] if !ok { return errors.Trace(ErrStmtNotFound) } + preparedObj, ok := preparedPointer.(*CachedPrepareStmt) + if !ok { + return errors.Errorf("invalid CachedPrepareStmt type") + } + prepared := preparedObj.PreparedAst vars.StmtCtx.StmtType = prepared.StmtType paramLen := len(e.PrepareParams) @@ -239,7 +245,7 @@ func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Cont } prepared.SchemaVersion = is.SchemaMetaVersion() } - err := e.getPhysicalPlan(ctx, sctx, is, prepared) + err := e.getPhysicalPlan(ctx, sctx, is, preparedObj) if err != nil { return err } @@ -247,7 +253,19 @@ func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Cont return nil } -func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema, prepared *ast.Prepared) error { +func (e *Execute) checkPreparedPriv(ctx context.Context, sctx sessionctx.Context, + preparedObj *CachedPrepareStmt, is infoschema.InfoSchema) error { + if pm := privilege.GetPrivilegeManager(sctx); pm != nil { + if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, preparedObj.VisitInfos); err != nil { + return err + } + } + err := CheckTableLock(sctx, is, preparedObj.VisitInfos) + return err +} + +func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema, preparedStmt *CachedPrepareStmt) error { + prepared := preparedStmt.PreparedAst if prepared.CachedPlan != nil { // Rewriting the expression in the select.where condition will convert its // type from "paramMarker" to "Constant".When Point Select queries are executed, @@ -272,6 +290,9 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, if prepared.UseCache { cacheKey = NewPSTMTPlanCacheKey(sctx.GetSessionVars(), e.ExecID, prepared.SchemaVersion) if cacheValue, exists := sctx.PreparedPlanCache().Get(cacheKey); exists { + if err := e.checkPreparedPriv(ctx, sctx, preparedStmt, is); err != nil { + return err + } if metrics.ResettablePlanCacheCounterFortTest { metrics.PlanCacheCounter.WithLabelValues("prepare").Inc() } else { diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index 9835080dfe36e..35597764c5b2a 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -20,11 +20,14 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/parser/auth" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" @@ -85,6 +88,48 @@ func (s *testPrepareSuite) TestPrepareCache(c *C) { tk.MustExec(`prepare stmt6 from "select distinct a from t order by a"`) tk.MustQuery("execute stmt6").Check(testkit.Rows("1", "2", "3", "4", "5", "6")) tk.MustQuery("execute stmt6").Check(testkit.Rows("1", "2", "3", "4", "5", "6")) + + // test privilege change + rootSe := tk.Se + tk.MustExec("drop table if exists tp") + tk.MustExec(`create table tp(c1 int, c2 int, primary key (c1))`) + tk.MustExec(`insert into tp values(1, 1), (2, 2), (3, 3)`) + + tk.MustExec(`create user 'u_tp'@'localhost'`) + tk.MustExec(`grant select on test.tp to u_tp@'localhost';flush privileges;`) + + // user u_tp + userSess := newSession(c, store, "test") + c.Assert(userSess.Auth(&auth.UserIdentity{Username: "u_tp", Hostname: "localhost"}, nil, nil), IsTrue) + mustExec(c, userSess, `prepare ps_stp_r from 'select * from tp where c1 > ?'`) + mustExec(c, userSess, `set @p2 = 2`) + tk.Se = userSess + tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3")) + tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3")) + tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3")) + + // root revoke + tk.Se = rootSe + tk.MustExec(`revoke all on test.tp from 'u_tp'@'localhost';flush privileges;`) + + // user u_tp + tk.Se = userSess + _, err = tk.Exec(`execute ps_stp_r using @p2`) + c.Assert(err, NotNil) + + // grant again + tk.Se = rootSe + tk.MustExec(`grant select on test.tp to u_tp@'localhost';flush privileges;`) + + // user u_tp + tk.Se = userSess + tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3")) + tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3")) + + // restore + tk.Se = rootSe + tk.MustExec("drop table if exists tp") + tk.MustExec(`DROP USER 'u_tp'@'localhost';`) } func (s *testPrepareSuite) TestPrepareCacheIndexScan(c *C) { @@ -345,3 +390,16 @@ func (s *testPrepareSuite) TestPrepareWithWindowFunction(c *C) { tk.MustExec("set @a=0, @b=1;") tk.MustQuery("execute stmt2 using @a, @b").Check(testkit.Rows("0", "0")) } + +func newSession(c *C, store kv.Storage, dbName string) session.Session { + se, err := session.CreateSession4Test(store) + c.Assert(err, IsNil) + mustExec(c, se, "create database if not exists "+dbName) + mustExec(c, se, "use "+dbName) + return se +} + +func mustExec(c *C, se session.Session, sql string) { + _, err := se.Execute(context.Background(), sql) + c.Assert(err, IsNil) +} diff --git a/server/conn_stmt.go b/server/conn_stmt.go index 3652598dae8d4..39794d252e2b9 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -44,6 +44,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" + plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/hack" @@ -624,8 +625,14 @@ func (cc *clientConn) handleSetOption(data []byte) (err error) { func (cc *clientConn) preparedStmt2String(stmtID uint32) string { sv := cc.ctx.GetSessionVars() - if prepared, ok := sv.PreparedStmts[stmtID]; ok { - return prepared.Stmt.Text() + sv.PreparedParams.String() + preparedPointer, ok := sv.PreparedStmts[stmtID] + if !ok { + return "prepared statement not found, ID: " + strconv.FormatUint(uint64(stmtID), 10) } - return "prepared statement not found, ID: " + strconv.FormatUint(uint64(stmtID), 10) + preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt) + if !ok { + return "invalidate CachedPrepareStmt type, ID: " + strconv.FormatUint(uint64(stmtID), 10) + } + preparedAst := preparedObj.PreparedAst + return preparedAst.Stmt.Text() + sv.PreparedParams.String() } diff --git a/session/session.go b/session/session.go index ddc5e7597682d..3ce811ab673a9 100644 --- a/session/session.go +++ b/session/session.go @@ -266,16 +266,21 @@ func (s *session) cleanRetryInfo() { planCacheEnabled := plannercore.PreparedPlanCacheEnabled() var cacheKey kvcache.Key + var preparedAst *ast.Prepared if planCacheEnabled { firstStmtID := retryInfo.DroppedPreparedStmtIDs[0] - cacheKey = plannercore.NewPSTMTPlanCacheKey( - s.sessionVars, firstStmtID, s.sessionVars.PreparedStmts[firstStmtID].SchemaVersion, - ) + if preparedPointer, ok := s.sessionVars.PreparedStmts[firstStmtID]; ok { + preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt) + if ok { + preparedAst = preparedObj.PreparedAst + cacheKey = plannercore.NewPSTMTPlanCacheKey(s.sessionVars, firstStmtID, preparedAst.SchemaVersion) + } + } } for i, stmtID := range retryInfo.DroppedPreparedStmtIDs { if planCacheEnabled { - if i > 0 { - plannercore.SetPstmtIDSchemaVersion(cacheKey, stmtID, s.sessionVars.PreparedStmts[stmtID].SchemaVersion) + if i > 0 && preparedAst != nil { + plannercore.SetPstmtIDSchemaVersion(cacheKey, stmtID, preparedAst.SchemaVersion) } s.PreparedPlanCache().Delete(cacheKey) } @@ -1172,7 +1177,8 @@ func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields // CachedPlanExec short path currently ONLY for cached "point select plan" execution func (s *session) CachedPlanExec(ctx context.Context, - stmtID uint32, prepared *ast.Prepared, args []types.Datum) (sqlexec.RecordSet, error) { + stmtID uint32, prepareStmt *plannercore.CachedPrepareStmt, args []types.Datum) (sqlexec.RecordSet, error) { + prepared := prepareStmt.PreparedAst // compile ExecStmt is := executor.GetInfoSchema(s) execAst := &ast.ExecuteStmt{ExecID: stmtID} @@ -1205,7 +1211,8 @@ func (s *session) CachedPlanExec(ctx context.Context, // IsCachedExecOk check if we can execute using plan cached in prepared structure // Be careful for the short path, current precondition is ths cached plan satisfying // IsPointGetWithPKOrUniqueKeyByAutoCommit -func (s *session) IsCachedExecOk(ctx context.Context, prepared *ast.Prepared) (bool, error) { +func (s *session) IsCachedExecOk(ctx context.Context, preparedStmt *plannercore.CachedPrepareStmt) (bool, error) { + prepared := preparedStmt.PreparedAst if prepared.CachedPlan == nil { return false, nil } @@ -1222,18 +1229,22 @@ func (s *session) IsCachedExecOk(ctx context.Context, prepared *ast.Prepared) (b func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args []types.Datum) (sqlexec.RecordSet, error) { var err error s.sessionVars.StartTime = time.Now() - prepared, ok := s.sessionVars.PreparedStmts[stmtID] + preparedPointer, ok := s.sessionVars.PreparedStmts[stmtID] if !ok { err = plannercore.ErrStmtNotFound logutil.Logger(ctx).Error("prepared statement not found", zap.Uint32("stmtID", stmtID)) return nil, err } - ok, err = s.IsCachedExecOk(ctx, prepared) + preparedStmt, ok := preparedPointer.(*plannercore.CachedPrepareStmt) + if !ok { + return nil, errors.Errorf("invalid CachedPrepareStmt type") + } + ok, err = s.IsCachedExecOk(ctx, preparedStmt) if err != nil { return nil, err } if ok { - return s.CachedPlanExec(ctx, stmtID, prepared, args) + return s.CachedPlanExec(ctx, stmtID, preparedStmt, args) } s.PrepareTxnCtx(ctx) st, err := executor.CompileExecutePreparedStmt(ctx, s, stmtID, args) diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index aecdfdaafab8d..c7ad86f29e5e0 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -202,7 +202,7 @@ type SessionVars struct { // systems variables, don't modify it directly, use GetSystemVar/SetSystemVar method. systems map[string]string // PreparedStmts stores prepared statement. - PreparedStmts map[uint32]*ast.Prepared + PreparedStmts map[uint32]interface{} PreparedStmtNameToID map[string]uint32 // preparedStmtID is id of prepared statement. preparedStmtID uint32 @@ -459,7 +459,7 @@ func NewSessionVars() *SessionVars { vars := &SessionVars{ Users: make(map[string]string), systems: make(map[string]string), - PreparedStmts: make(map[uint32]*ast.Prepared), + PreparedStmts: make(map[uint32]interface{}), PreparedStmtNameToID: make(map[string]uint32), PreparedParams: make([]types.Datum, 0, 10), TxnCtx: &TransactionContext{}, @@ -673,7 +673,7 @@ func (s *SessionVars) setDDLReorgPriority(val string) { } // AddPreparedStmt adds prepareStmt to current session and count in global. -func (s *SessionVars) AddPreparedStmt(stmtID uint32, stmt *ast.Prepared) error { +func (s *SessionVars) AddPreparedStmt(stmtID uint32, stmt interface{}) error { if _, exists := s.PreparedStmts[stmtID]; !exists { valStr, _ := s.GetSystemVar(MaxPreparedStmtCount) maxPreparedStmtCount, err := strconv.ParseInt(valStr, 10, 64)