Skip to content

Commit

Permalink
*: check privilege when reusing the cached plan (pingcap#12211)
Browse files Browse the repository at this point in the history
  • Loading branch information
cfzjywxk authored and sre-bot committed Sep 23, 2019
1 parent 48557f7 commit 06629d6
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 26 deletions.
36 changes: 29 additions & 7 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error {
param.InExecute = false
}
var p plannercore.Plan
p, err = plannercore.BuildLogicalPlan(ctx, e.ctx, stmt, e.is)
e.ctx.GetSessionVars().PlanID = 0
e.ctx.GetSessionVars().PlanColumnID = 0
destBuilder := plannercore.NewPlanBuilder(e.ctx, e.is, &plannercore.BlockHintProcessor{})
p, err = destBuilder.Build(ctx, stmt)
if err != nil {
return err
}
Expand All @@ -195,7 +198,12 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error {
if e.name != "" {
vars.PreparedStmtNameToID[e.name] = e.ID
}
return vars.AddPreparedStmt(e.ID, prepared)

preparedObj := &plannercore.CachedPrepareStmt{
PreparedAst: prepared,
VisitInfos: destBuilder.GetVisitInfo(),
}
return vars.AddPreparedStmt(e.ID, preparedObj)
}

// ExecuteExec represents an EXECUTE executor.
Expand Down Expand Up @@ -258,10 +266,16 @@ func (e *DeallocateExec) Next(ctx context.Context, req *chunk.Chunk) error {
if !ok {
return errors.Trace(plannercore.ErrStmtNotFound)
}
preparedPointer := vars.PreparedStmts[id]
preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt)
if !ok {
return errors.Errorf("invalid CachedPrepareStmt type")
}
prepared := preparedObj.PreparedAst
delete(vars.PreparedStmtNameToID, e.Name)
if plannercore.PreparedPlanCacheEnabled() {
e.ctx.PreparedPlanCache().Delete(plannercore.NewPSTMTPlanCacheKey(
vars, id, vars.PreparedStmts[id].SchemaVersion,
vars, id, prepared.SchemaVersion,
))
}
vars.RemovePreparedStmt(id)
Expand Down Expand Up @@ -293,8 +307,12 @@ func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context,
Ctx: sctx,
OutputNames: execPlan.OutputNames(),
}
if prepared, ok := sctx.GetSessionVars().PreparedStmts[ID]; ok {
stmt.Text = prepared.Stmt.Text()
if preparedPointer, ok := sctx.GetSessionVars().PreparedStmts[ID]; ok {
preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt)
if !ok {
return nil, errors.Errorf("invalid CachedPrepareStmt type")
}
stmt.Text = preparedObj.PreparedAst.Stmt.Text()
sctx.GetSessionVars().StmtCtx.OriginalSQL = stmt.Text
}
return stmt, nil
Expand All @@ -308,8 +326,12 @@ func getPreparedStmt(stmt *ast.ExecuteStmt, vars *variable.SessionVars) (ast.Stm
return nil, plannercore.ErrStmtNotFound
}
}
if prepared, ok := vars.PreparedStmts[execID]; ok {
return prepared.Stmt, nil
if preparedPointer, ok := vars.PreparedStmts[execID]; ok {
preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt)
if !ok {
return nil, errors.Errorf("invalid CachedPrepareStmt type")
}
return preparedObj.PreparedAst.Stmt, nil
}
return nil, plannercore.ErrStmtNotFound
}
7 changes: 7 additions & 0 deletions planner/core/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"sync/atomic"
"time"

"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -134,3 +135,9 @@ func NewPSTMTPlanCacheValue(plan Plan, names []*types.FieldName) *PSTMTPlanCache
OutPutNames: names,
}
}

// CachedPrepareStmt store prepared ast from PrepareExec and other related fields
type CachedPrepareStmt struct {
PreparedAst *ast.Prepared
VisitInfos []visitInfo
}
27 changes: 24 additions & 3 deletions planner/core/common_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -189,10 +190,15 @@ func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Cont
if e.Name != "" {
e.ExecID = vars.PreparedStmtNameToID[e.Name]
}
prepared, ok := vars.PreparedStmts[e.ExecID]
preparedPointer, ok := vars.PreparedStmts[e.ExecID]
if !ok {
return errors.Trace(ErrStmtNotFound)
}
preparedObj, ok := preparedPointer.(*CachedPrepareStmt)
if !ok {
return errors.Errorf("invalid CachedPrepareStmt type")
}
prepared := preparedObj.PreparedAst
vars.StmtCtx.StmtType = prepared.StmtType

paramLen := len(e.PrepareParams)
Expand Down Expand Up @@ -239,15 +245,27 @@ func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Cont
}
prepared.SchemaVersion = is.SchemaMetaVersion()
}
err := e.getPhysicalPlan(ctx, sctx, is, prepared)
err := e.getPhysicalPlan(ctx, sctx, is, preparedObj)
if err != nil {
return err
}
e.Stmt = prepared.Stmt
return nil
}

func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema, prepared *ast.Prepared) error {
func (e *Execute) checkPreparedPriv(ctx context.Context, sctx sessionctx.Context,
preparedObj *CachedPrepareStmt, is infoschema.InfoSchema) error {
if pm := privilege.GetPrivilegeManager(sctx); pm != nil {
if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, preparedObj.VisitInfos); err != nil {
return err
}
}
err := CheckTableLock(sctx, is, preparedObj.VisitInfos)
return err
}

func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema, preparedStmt *CachedPrepareStmt) error {
prepared := preparedStmt.PreparedAst
if prepared.CachedPlan != nil {
// Rewriting the expression in the select.where condition will convert its
// type from "paramMarker" to "Constant".When Point Select queries are executed,
Expand All @@ -272,6 +290,9 @@ func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context,
if prepared.UseCache {
cacheKey = NewPSTMTPlanCacheKey(sctx.GetSessionVars(), e.ExecID, prepared.SchemaVersion)
if cacheValue, exists := sctx.PreparedPlanCache().Get(cacheKey); exists {
if err := e.checkPreparedPriv(ctx, sctx, preparedStmt, is); err != nil {
return err
}
if metrics.ResettablePlanCacheCounterFortTest {
metrics.PlanCacheCounter.WithLabelValues("prepare").Inc()
} else {
Expand Down
58 changes: 58 additions & 0 deletions planner/core/prepare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ import (
"time"

. "github.com/pingcap/check"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
Expand Down Expand Up @@ -85,6 +88,48 @@ func (s *testPrepareSuite) TestPrepareCache(c *C) {
tk.MustExec(`prepare stmt6 from "select distinct a from t order by a"`)
tk.MustQuery("execute stmt6").Check(testkit.Rows("1", "2", "3", "4", "5", "6"))
tk.MustQuery("execute stmt6").Check(testkit.Rows("1", "2", "3", "4", "5", "6"))

// test privilege change
rootSe := tk.Se
tk.MustExec("drop table if exists tp")
tk.MustExec(`create table tp(c1 int, c2 int, primary key (c1))`)
tk.MustExec(`insert into tp values(1, 1), (2, 2), (3, 3)`)

tk.MustExec(`create user 'u_tp'@'localhost'`)
tk.MustExec(`grant select on test.tp to u_tp@'localhost';flush privileges;`)

// user u_tp
userSess := newSession(c, store, "test")
c.Assert(userSess.Auth(&auth.UserIdentity{Username: "u_tp", Hostname: "localhost"}, nil, nil), IsTrue)
mustExec(c, userSess, `prepare ps_stp_r from 'select * from tp where c1 > ?'`)
mustExec(c, userSess, `set @p2 = 2`)
tk.Se = userSess
tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3"))
tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3"))
tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3"))

// root revoke
tk.Se = rootSe
tk.MustExec(`revoke all on test.tp from 'u_tp'@'localhost';flush privileges;`)

// user u_tp
tk.Se = userSess
_, err = tk.Exec(`execute ps_stp_r using @p2`)
c.Assert(err, NotNil)

// grant again
tk.Se = rootSe
tk.MustExec(`grant select on test.tp to u_tp@'localhost';flush privileges;`)

// user u_tp
tk.Se = userSess
tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3"))
tk.MustQuery(`execute ps_stp_r using @p2`).Check(testkit.Rows("3 3"))

// restore
tk.Se = rootSe
tk.MustExec("drop table if exists tp")
tk.MustExec(`DROP USER 'u_tp'@'localhost';`)
}

func (s *testPrepareSuite) TestPrepareCacheIndexScan(c *C) {
Expand Down Expand Up @@ -345,3 +390,16 @@ func (s *testPrepareSuite) TestPrepareWithWindowFunction(c *C) {
tk.MustExec("set @a=0, @b=1;")
tk.MustQuery("execute stmt2 using @a, @b").Check(testkit.Rows("0", "0"))
}

func newSession(c *C, store kv.Storage, dbName string) session.Session {
se, err := session.CreateSession4Test(store)
c.Assert(err, IsNil)
mustExec(c, se, "create database if not exists "+dbName)
mustExec(c, se, "use "+dbName)
return se
}

func mustExec(c *C, se session.Session, sql string) {
_, err := se.Execute(context.Background(), sql)
c.Assert(err, IsNil)
}
13 changes: 10 additions & 3 deletions server/conn_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/parser/mysql"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/hack"
Expand Down Expand Up @@ -624,8 +625,14 @@ func (cc *clientConn) handleSetOption(data []byte) (err error) {

func (cc *clientConn) preparedStmt2String(stmtID uint32) string {
sv := cc.ctx.GetSessionVars()
if prepared, ok := sv.PreparedStmts[stmtID]; ok {
return prepared.Stmt.Text() + sv.PreparedParams.String()
preparedPointer, ok := sv.PreparedStmts[stmtID]
if !ok {
return "prepared statement not found, ID: " + strconv.FormatUint(uint64(stmtID), 10)
}
return "prepared statement not found, ID: " + strconv.FormatUint(uint64(stmtID), 10)
preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt)
if !ok {
return "invalidate CachedPrepareStmt type, ID: " + strconv.FormatUint(uint64(stmtID), 10)
}
preparedAst := preparedObj.PreparedAst
return preparedAst.Stmt.Text() + sv.PreparedParams.String()
}
31 changes: 21 additions & 10 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,21 @@ func (s *session) cleanRetryInfo() {

planCacheEnabled := plannercore.PreparedPlanCacheEnabled()
var cacheKey kvcache.Key
var preparedAst *ast.Prepared
if planCacheEnabled {
firstStmtID := retryInfo.DroppedPreparedStmtIDs[0]
cacheKey = plannercore.NewPSTMTPlanCacheKey(
s.sessionVars, firstStmtID, s.sessionVars.PreparedStmts[firstStmtID].SchemaVersion,
)
if preparedPointer, ok := s.sessionVars.PreparedStmts[firstStmtID]; ok {
preparedObj, ok := preparedPointer.(*plannercore.CachedPrepareStmt)
if ok {
preparedAst = preparedObj.PreparedAst
cacheKey = plannercore.NewPSTMTPlanCacheKey(s.sessionVars, firstStmtID, preparedAst.SchemaVersion)
}
}
}
for i, stmtID := range retryInfo.DroppedPreparedStmtIDs {
if planCacheEnabled {
if i > 0 {
plannercore.SetPstmtIDSchemaVersion(cacheKey, stmtID, s.sessionVars.PreparedStmts[stmtID].SchemaVersion)
if i > 0 && preparedAst != nil {
plannercore.SetPstmtIDSchemaVersion(cacheKey, stmtID, preparedAst.SchemaVersion)
}
s.PreparedPlanCache().Delete(cacheKey)
}
Expand Down Expand Up @@ -1172,7 +1177,8 @@ func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields

// CachedPlanExec short path currently ONLY for cached "point select plan" execution
func (s *session) CachedPlanExec(ctx context.Context,
stmtID uint32, prepared *ast.Prepared, args []types.Datum) (sqlexec.RecordSet, error) {
stmtID uint32, prepareStmt *plannercore.CachedPrepareStmt, args []types.Datum) (sqlexec.RecordSet, error) {
prepared := prepareStmt.PreparedAst
// compile ExecStmt
is := executor.GetInfoSchema(s)
execAst := &ast.ExecuteStmt{ExecID: stmtID}
Expand Down Expand Up @@ -1205,7 +1211,8 @@ func (s *session) CachedPlanExec(ctx context.Context,
// IsCachedExecOk check if we can execute using plan cached in prepared structure
// Be careful for the short path, current precondition is ths cached plan satisfying
// IsPointGetWithPKOrUniqueKeyByAutoCommit
func (s *session) IsCachedExecOk(ctx context.Context, prepared *ast.Prepared) (bool, error) {
func (s *session) IsCachedExecOk(ctx context.Context, preparedStmt *plannercore.CachedPrepareStmt) (bool, error) {
prepared := preparedStmt.PreparedAst
if prepared.CachedPlan == nil {
return false, nil
}
Expand All @@ -1222,18 +1229,22 @@ func (s *session) IsCachedExecOk(ctx context.Context, prepared *ast.Prepared) (b
func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args []types.Datum) (sqlexec.RecordSet, error) {
var err error
s.sessionVars.StartTime = time.Now()
prepared, ok := s.sessionVars.PreparedStmts[stmtID]
preparedPointer, ok := s.sessionVars.PreparedStmts[stmtID]
if !ok {
err = plannercore.ErrStmtNotFound
logutil.Logger(ctx).Error("prepared statement not found", zap.Uint32("stmtID", stmtID))
return nil, err
}
ok, err = s.IsCachedExecOk(ctx, prepared)
preparedStmt, ok := preparedPointer.(*plannercore.CachedPrepareStmt)
if !ok {
return nil, errors.Errorf("invalid CachedPrepareStmt type")
}
ok, err = s.IsCachedExecOk(ctx, preparedStmt)
if err != nil {
return nil, err
}
if ok {
return s.CachedPlanExec(ctx, stmtID, prepared, args)
return s.CachedPlanExec(ctx, stmtID, preparedStmt, args)
}
s.PrepareTxnCtx(ctx)
st, err := executor.CompileExecutePreparedStmt(ctx, s, stmtID, args)
Expand Down
6 changes: 3 additions & 3 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ type SessionVars struct {
// systems variables, don't modify it directly, use GetSystemVar/SetSystemVar method.
systems map[string]string
// PreparedStmts stores prepared statement.
PreparedStmts map[uint32]*ast.Prepared
PreparedStmts map[uint32]interface{}
PreparedStmtNameToID map[string]uint32
// preparedStmtID is id of prepared statement.
preparedStmtID uint32
Expand Down Expand Up @@ -459,7 +459,7 @@ func NewSessionVars() *SessionVars {
vars := &SessionVars{
Users: make(map[string]string),
systems: make(map[string]string),
PreparedStmts: make(map[uint32]*ast.Prepared),
PreparedStmts: make(map[uint32]interface{}),
PreparedStmtNameToID: make(map[string]uint32),
PreparedParams: make([]types.Datum, 0, 10),
TxnCtx: &TransactionContext{},
Expand Down Expand Up @@ -673,7 +673,7 @@ func (s *SessionVars) setDDLReorgPriority(val string) {
}

// AddPreparedStmt adds prepareStmt to current session and count in global.
func (s *SessionVars) AddPreparedStmt(stmtID uint32, stmt *ast.Prepared) error {
func (s *SessionVars) AddPreparedStmt(stmtID uint32, stmt interface{}) error {
if _, exists := s.PreparedStmts[stmtID]; !exists {
valStr, _ := s.GetSystemVar(MaxPreparedStmtCount)
maxPreparedStmtCount, err := strconv.ParseInt(valStr, 10, 64)
Expand Down

0 comments on commit 06629d6

Please sign in to comment.