Skip to content

Commit

Permalink
*: add ExecRestrictedSQL() func for RestrictedSQLExecutor interfa…
Browse files Browse the repository at this point in the history
  • Loading branch information
Defined2014 authored Jan 24, 2022
1 parent 98068fe commit 2e1cff7
Show file tree
Hide file tree
Showing 32 changed files with 182 additions and 361 deletions.
23 changes: 5 additions & 18 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,11 @@ func (h *BindHandle) Update(fullLoad bool) (err error) {
}

exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), true, `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source
FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time, create_time`, updateTime)
if err != nil {
return err
}

// No need to acquire the session context lock for ExecRestrictedStmt, it
// uses another background session.
rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt)
rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source
FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time, create_time`, updateTime)

if err != nil {
h.bindInfo.Unlock()
Expand Down Expand Up @@ -700,14 +696,9 @@ func (h *BindHandle) extractCaptureFilterFromStorage() (filter *captureFilter) {
tables: make(map[stmtctx.TableEntry]struct{}),
}
exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), true, `SELECT filter_type, filter_value FROM mysql.capture_plan_baselines_blacklist order by filter_type`)
if err != nil {
logutil.BgLogger().Warn("[sql-bind] failed to parse query for mysql.capture_plan_baselines_blacklist load", zap.Error(err))
return
}
// No need to acquire the session context lock for ExecRestrictedStmt, it
// uses another background session.
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, `SELECT filter_type, filter_value FROM mysql.capture_plan_baselines_blacklist order by filter_type`)
if err != nil {
logutil.BgLogger().Warn("[sql-bind] failed to load mysql.capture_plan_baselines_blacklist", zap.Error(err))
return
Expand Down Expand Up @@ -926,9 +917,9 @@ func (h *BindHandle) SaveEvolveTasksToStore() {
}

func getEvolveParameters(ctx sessionctx.Context) (time.Duration, time.Time, time.Time, error) {
stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(
context.TODO(),
true,
nil,
"SELECT variable_name, variable_value FROM mysql.global_variables WHERE variable_name IN (%?, %?, %?)",
variable.TiDBEvolvePlanTaskMaxTime,
variable.TiDBEvolvePlanTaskStartTime,
Expand All @@ -937,10 +928,6 @@ func getEvolveParameters(ctx sessionctx.Context) (time.Duration, time.Time, time
if err != nil {
return 0, time.Time{}, time.Time{}, err
}
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt)
if err != nil {
return 0, time.Time{}, time.Time{}, err
}
maxTime, startTimeStr, endTimeStr := int64(variable.DefTiDBEvolvePlanTaskMaxTime), variable.DefTiDBEvolvePlanTaskStartTime, variable.DefAutoAnalyzeEndTime
for _, row := range rows {
switch row.GetString(0) {
Expand Down
13 changes: 2 additions & 11 deletions ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -992,12 +992,7 @@ func (w *worker) doModifyColumnTypeWithData(
}
defer w.sessPool.put(ctx)

stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), true, valStr)
if err != nil {
job.State = model.JobStateCancelled
failpoint.Return(ver, err)
}
_, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt)
_, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(context.Background(), nil, valStr)
if err != nil {
job.State = model.JobStateCancelled
failpoint.Return(ver, err)
Expand Down Expand Up @@ -1703,11 +1698,7 @@ func checkForNullValue(ctx context.Context, sctx sessionctx.Context, isDataTrunc
}
}
buf.WriteString(" limit 1")
stmt, err := sctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(ctx, true, buf.String(), paramsList...)
if err != nil {
return errors.Trace(err)
}
rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(ctx, stmt)
rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, nil, buf.String(), paramsList...)
if err != nil {
return errors.Trace(err)
}
Expand Down
6 changes: 1 addition & 5 deletions ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -1576,11 +1576,7 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde
}
defer w.sessPool.put(ctx)

stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(w.ddlJobCtx, true, sql, paramList...)
if err != nil {
return errors.Trace(err)
}
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(w.ddlJobCtx, stmt)
rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(w.ddlJobCtx, nil, sql, paramList...)
if err != nil {
return errors.Trace(err)
}
Expand Down
6 changes: 1 addition & 5 deletions ddl/reorg.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,7 @@ func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 {
return statistics.PseudoRowCount
}
sql := "select table_rows from information_schema.tables where tidb_table_id=%?;"
stmt, err := executor.ParseWithParams(w.ddlJobCtx, true, sql, tblInfo.ID)
if err != nil {
return statistics.PseudoRowCount
}
rows, _, err := executor.ExecRestrictedStmt(w.ddlJobCtx, stmt)
rows, _, err := executor.ExecRestrictedSQL(w.ddlJobCtx, nil, sql, tblInfo.ID)
if err != nil {
return statistics.PseudoRowCount
}
Expand Down
6 changes: 1 addition & 5 deletions ddl/util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,7 @@ func LoadGlobalVars(ctx context.Context, sctx sessionctx.Context, varNames []str
paramNames = append(paramNames, name)
}
buf.WriteString(")")
stmt, err := e.ParseWithParams(ctx, true, buf.String(), paramNames...)
if err != nil {
return errors.Trace(err)
}
rows, _, err := e.ExecRestrictedStmt(ctx, stmt)
rows, _, err := e.ExecRestrictedSQL(ctx, nil, buf.String(), paramNames...)
if err != nil {
return errors.Trace(err)
}
Expand Down
6 changes: 1 addition & 5 deletions domain/sysvar_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,7 @@ func (do *Domain) fetchTableValues(ctx sessionctx.Context) (map[string]string, e
tableContents := make(map[string]string)
// Copy all variables from the table to tableContents
exec := ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.Background(), true, `SELECT variable_name, variable_value FROM mysql.global_variables`)
if err != nil {
return tableContents, err
}
rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt)
rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, `SELECT variable_name, variable_value FROM mysql.global_variables`)
if err != nil {
return nil, err
}
Expand Down
14 changes: 2 additions & 12 deletions executor/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,7 @@ func (e *AnalyzeExec) saveAnalyzeOptsV2() error {
idx += 1
}
exec := e.ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), true, sql.String())
if err != nil {
return err
}
_, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt)
_, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, sql.String())
if err != nil {
return err
}
Expand Down Expand Up @@ -1656,13 +1652,7 @@ type AnalyzeFastExec struct {

func (e *AnalyzeFastExec) calculateEstimateSampleStep() (err error) {
exec := e.ctx.(sqlexec.RestrictedSQLExecutor)
var stmt ast.StmtNode
stmt, err = exec.ParseWithParams(context.TODO(), true, "select flag from mysql.stats_histograms where table_id = %?", e.tableID.GetStatisticsID())
if err != nil {
return
}
var rows []chunk.Row
rows, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt)
rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, "select flag from mysql.stats_histograms where table_id = %?", e.tableID.GetStatisticsID())
if err != nil {
return
}
Expand Down
6 changes: 1 addition & 5 deletions executor/brie.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,7 @@ func (gs *tidbGlueSession) CreateSession(store kv.Storage) (glue.Session, error)
// These queries execute without privilege checking, since the calling statements
// such as BACKUP and RESTORE have already been privilege checked.
func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error {
stmt, err := gs.se.(sqlexec.RestrictedSQLExecutor).ParseWithParams(ctx, true, sql)
if err != nil {
return err
}
_, _, err = gs.se.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(ctx, stmt)
_, _, err := gs.se.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, nil, sql)
return err
}

Expand Down
6 changes: 1 addition & 5 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2407,11 +2407,7 @@ func (b *executorBuilder) getApproximateTableCountFromStorage(sctx sessionctx.Co
if task.PartitionName != "" {
sqlexec.MustFormatSQL(sql, " partition(%n)", task.PartitionName)
}
stmt, err := b.ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.TODO(), true, sql.String())
if err != nil {
return 0, false
}
rows, _, err := b.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt)
rows, _, err := b.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(context.TODO(), nil, sql.String())
if err != nil {
return 0, false
}
Expand Down
6 changes: 1 addition & 5 deletions executor/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,7 @@ func (e *DDLExec) dropTableObject(objects []*ast.TableName, obt objectType, ifEx
zap.String("table", fullti.Name.O),
)
exec := e.ctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), true, "admin check table %n.%n", fullti.Schema.O, fullti.Name.O)
if err != nil {
return err
}
_, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt)
_, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, "admin check table %n.%n", fullti.Schema.O, fullti.Name.O)
if err != nil {
return err
}
Expand Down
12 changes: 2 additions & 10 deletions executor/infoschema_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,7 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex

func getRowCountAllTable(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, error) {
exec := sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(ctx, true, "select table_id, count from mysql.stats_meta")
if err != nil {
return nil, err
}
rows, _, err := exec.ExecRestrictedStmt(ctx, stmt)
rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select table_id, count from mysql.stats_meta")
if err != nil {
return nil, err
}
Expand All @@ -215,11 +211,7 @@ type tableHistID struct {

func getColLengthAllTables(ctx context.Context, sctx sessionctx.Context) (map[tableHistID]uint64, error) {
exec := sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(ctx, true, "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0")
if err != nil {
return nil, err
}
rows, _, err := exec.ExecRestrictedStmt(ctx, stmt)
rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0")
if err != nil {
return nil, err
}
Expand Down
7 changes: 1 addition & 6 deletions executor/inspection_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,7 @@ func (n *metricNode) getLabelValue(label string) *metricValue {

func (n *metricNode) queryRowsByLabel(pb *profileBuilder, query string, handleRowFn func(label string, v float64)) error {
exec := pb.sctx.(sqlexec.RestrictedSQLExecutor)
stmt, err := exec.ParseWithParams(context.TODO(), true, query)
if err != nil {
return err
}

rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt)
rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, query)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 2e1cff7

Please sign in to comment.