diff --git a/bindinfo/handle.go b/bindinfo/handle.go index d57d2edb1da0f..622ac2ce6157c 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -131,15 +131,11 @@ func (h *BindHandle) Update(fullLoad bool) (err error) { } exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source - FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time, create_time`, updateTime) - if err != nil { - return err - } // No need to acquire the session context lock for ExecRestrictedStmt, it // uses another background session. - rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, `SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source + FROM mysql.bind_info WHERE update_time > %? ORDER BY update_time, create_time`, updateTime) if err != nil { h.bindInfo.Unlock() @@ -700,14 +696,9 @@ func (h *BindHandle) extractCaptureFilterFromStorage() (filter *captureFilter) { tables: make(map[stmtctx.TableEntry]struct{}), } exec := h.sctx.Context.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, `SELECT filter_type, filter_value FROM mysql.capture_plan_baselines_blacklist order by filter_type`) - if err != nil { - logutil.BgLogger().Warn("[sql-bind] failed to parse query for mysql.capture_plan_baselines_blacklist load", zap.Error(err)) - return - } // No need to acquire the session context lock for ExecRestrictedStmt, it // uses another background session. - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, `SELECT filter_type, filter_value FROM mysql.capture_plan_baselines_blacklist order by filter_type`) if err != nil { logutil.BgLogger().Warn("[sql-bind] failed to load mysql.capture_plan_baselines_blacklist", zap.Error(err)) return @@ -926,9 +917,9 @@ func (h *BindHandle) SaveEvolveTasksToStore() { } func getEvolveParameters(ctx sessionctx.Context) (time.Duration, time.Time, time.Time, error) { - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams( + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL( context.TODO(), - true, + nil, "SELECT variable_name, variable_value FROM mysql.global_variables WHERE variable_name IN (%?, %?, %?)", variable.TiDBEvolvePlanTaskMaxTime, variable.TiDBEvolvePlanTaskStartTime, @@ -937,10 +928,6 @@ func getEvolveParameters(ctx sessionctx.Context) (time.Duration, time.Time, time if err != nil { return 0, time.Time{}, time.Time{}, err } - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt) - if err != nil { - return 0, time.Time{}, time.Time{}, err - } maxTime, startTimeStr, endTimeStr := int64(variable.DefTiDBEvolvePlanTaskMaxTime), variable.DefTiDBEvolvePlanTaskStartTime, variable.DefAutoAnalyzeEndTime for _, row := range rows { switch row.GetString(0) { diff --git a/ddl/column.go b/ddl/column.go index 8352c3940553a..94bcad4252fcb 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -992,12 +992,7 @@ func (w *worker) doModifyColumnTypeWithData( } defer w.sessPool.put(ctx) - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), true, valStr) - if err != nil { - job.State = model.JobStateCancelled - failpoint.Return(ver, err) - } - _, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + _, _, err = ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(context.Background(), nil, valStr) if err != nil { job.State = model.JobStateCancelled failpoint.Return(ver, err) @@ -1703,11 +1698,7 @@ func checkForNullValue(ctx context.Context, sctx sessionctx.Context, isDataTrunc } } buf.WriteString(" limit 1") - stmt, err := sctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(ctx, true, buf.String(), paramsList...) - if err != nil { - return errors.Trace(err) - } - rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(ctx, stmt) + rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, nil, buf.String(), paramsList...) if err != nil { return errors.Trace(err) } diff --git a/ddl/partition.go b/ddl/partition.go index 7bfa3f250c074..e34f2d0acc017 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -1576,11 +1576,7 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde } defer w.sessPool.put(ctx) - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(w.ddlJobCtx, true, sql, paramList...) - if err != nil { - return errors.Trace(err) - } - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(w.ddlJobCtx, stmt) + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(w.ddlJobCtx, nil, sql, paramList...) if err != nil { return errors.Trace(err) } diff --git a/ddl/reorg.go b/ddl/reorg.go index deeac2bfdb8ae..4e2144cebc942 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -340,11 +340,7 @@ func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { return statistics.PseudoRowCount } sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" - stmt, err := executor.ParseWithParams(w.ddlJobCtx, true, sql, tblInfo.ID) - if err != nil { - return statistics.PseudoRowCount - } - rows, _, err := executor.ExecRestrictedStmt(w.ddlJobCtx, stmt) + rows, _, err := executor.ExecRestrictedSQL(w.ddlJobCtx, nil, sql, tblInfo.ID) if err != nil { return statistics.PseudoRowCount } diff --git a/ddl/util/util.go b/ddl/util/util.go index b5b77a5c10264..b49378cf1dc8d 100644 --- a/ddl/util/util.go +++ b/ddl/util/util.go @@ -176,11 +176,7 @@ func LoadGlobalVars(ctx context.Context, sctx sessionctx.Context, varNames []str paramNames = append(paramNames, name) } buf.WriteString(")") - stmt, err := e.ParseWithParams(ctx, true, buf.String(), paramNames...) - if err != nil { - return errors.Trace(err) - } - rows, _, err := e.ExecRestrictedStmt(ctx, stmt) + rows, _, err := e.ExecRestrictedSQL(ctx, nil, buf.String(), paramNames...) if err != nil { return errors.Trace(err) } diff --git a/domain/sysvar_cache.go b/domain/sysvar_cache.go index f67d7266392b6..687649c1f66b5 100644 --- a/domain/sysvar_cache.go +++ b/domain/sysvar_cache.go @@ -95,11 +95,7 @@ func (do *Domain) fetchTableValues(ctx sessionctx.Context) (map[string]string, e tableContents := make(map[string]string) // Copy all variables from the table to tableContents exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.Background(), true, `SELECT variable_name, variable_value FROM mysql.global_variables`) - if err != nil { - return tableContents, err - } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, `SELECT variable_name, variable_value FROM mysql.global_variables`) if err != nil { return nil, err } diff --git a/executor/analyze.go b/executor/analyze.go index 134b4ba303c04..9ac68e9223021 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -259,11 +259,7 @@ func (e *AnalyzeExec) saveAnalyzeOptsV2() error { idx += 1 } exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, sql.String()) - if err != nil { - return err - } - _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + _, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, sql.String()) if err != nil { return err } @@ -1656,13 +1652,7 @@ type AnalyzeFastExec struct { func (e *AnalyzeFastExec) calculateEstimateSampleStep() (err error) { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - var stmt ast.StmtNode - stmt, err = exec.ParseWithParams(context.TODO(), true, "select flag from mysql.stats_histograms where table_id = %?", e.tableID.GetStatisticsID()) - if err != nil { - return - } - var rows []chunk.Row - rows, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, "select flag from mysql.stats_histograms where table_id = %?", e.tableID.GetStatisticsID()) if err != nil { return } diff --git a/executor/brie.go b/executor/brie.go index 49bb57cd98fab..3920254e2b869 100644 --- a/executor/brie.go +++ b/executor/brie.go @@ -462,11 +462,7 @@ func (gs *tidbGlueSession) CreateSession(store kv.Storage) (glue.Session, error) // These queries execute without privilege checking, since the calling statements // such as BACKUP and RESTORE have already been privilege checked. func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error { - stmt, err := gs.se.(sqlexec.RestrictedSQLExecutor).ParseWithParams(ctx, true, sql) - if err != nil { - return err - } - _, _, err = gs.se.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(ctx, stmt) + _, _, err := gs.se.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, nil, sql) return err } diff --git a/executor/builder.go b/executor/builder.go index d1892401763b3..12f982527645e 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -2407,11 +2407,7 @@ func (b *executorBuilder) getApproximateTableCountFromStorage(sctx sessionctx.Co if task.PartitionName != "" { sqlexec.MustFormatSQL(sql, " partition(%n)", task.PartitionName) } - stmt, err := b.ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.TODO(), true, sql.String()) - if err != nil { - return 0, false - } - rows, _, err := b.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := b.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(context.TODO(), nil, sql.String()) if err != nil { return 0, false } diff --git a/executor/ddl.go b/executor/ddl.go index dad26efad4102..10ea571b5bf31 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -492,11 +492,7 @@ func (e *DDLExec) dropTableObject(objects []*ast.TableName, obt objectType, ifEx zap.String("table", fullti.Name.O), ) exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) - if err != nil { - return err - } - _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + _, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) if err != nil { return err } diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index efbe677109c55..d68dcc7b13b6c 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -190,11 +190,7 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex func getRowCountAllTable(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, error) { exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, "select table_id, count from mysql.stats_meta") - if err != nil { - return nil, err - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select table_id, count from mysql.stats_meta") if err != nil { return nil, err } @@ -215,11 +211,7 @@ type tableHistID struct { func getColLengthAllTables(ctx context.Context, sctx sessionctx.Context) (map[tableHistID]uint64, error) { exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") - if err != nil { - return nil, err - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") if err != nil { return nil, err } diff --git a/executor/inspection_profile.go b/executor/inspection_profile.go index 605f2d0efad37..15885010dce25 100644 --- a/executor/inspection_profile.go +++ b/executor/inspection_profile.go @@ -167,12 +167,7 @@ func (n *metricNode) getLabelValue(label string) *metricValue { func (n *metricNode) queryRowsByLabel(pb *profileBuilder, query string, handleRowFn func(label string, v float64)) error { exec := pb.sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, query) - if err != nil { - return err - } - - rows, _, err := pb.sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, query) if err != nil { return err } diff --git a/executor/inspection_result.go b/executor/inspection_result.go index a8e06d7c55b3f..debcf723a3e64 100644 --- a/executor/inspection_result.go +++ b/executor/inspection_result.go @@ -140,10 +140,7 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct e.statusToInstanceAddress = make(map[string]string) var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, "select instance,status_address from information_schema.cluster_info;") - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select instance,status_address from information_schema.cluster_info;") if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("get cluster info failed: %v", err)) } @@ -249,22 +246,14 @@ func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.C "storage.data-dir", "storage.block-cache.capacity", } - var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) } generateDetail := func(tp, item string) string { - var rows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, true, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) return fmt.Sprintf("the cluster has different config value of %[2]s, execute the sql to see more detail: select * from information_schema.cluster_config where type='%[1]s' and `key`='%[2]s'", @@ -347,7 +336,7 @@ func (c configInspection) inspectCheckConfig(ctx context.Context, sctx sessionct } sql.Reset() fmt.Fprintf(sql, "select type,instance,value from information_schema.%s where %s", cas.table, cas.cond) - stmt, err := exec.ParseWithParams(ctx, true, sql.String()) + stmt, err := exec.ParseWithParams(ctx, sql.String()) if err == nil { rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) } @@ -376,12 +365,8 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct if !filter.enable(item) { return nil } - var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -405,10 +390,7 @@ func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sct ipToCount[ip]++ } - stmt, err = exec.ParseWithParams(ctx, true, "select instance, value from metrics_schema.node_total_memory where time=now()") - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err = exec.ExecRestrictedSQL(ctx, nil, "select instance, value from metrics_schema.node_total_memory where time=now()") if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) } @@ -471,12 +453,8 @@ func (configInspection) convertReadableSizeToByteSize(sizeStr string) (uint64, e func (versionInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { exec := sctx.(sqlexec.RestrictedSQLExecutor) - var rows []chunk.Row // check the configuration consistent - stmt, err := exec.ParseWithParams(ctx, true, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check version consistency failed: %v", err)) } @@ -630,7 +608,6 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx condition := filter.timeRange.Condition() var results []inspectionResult - var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) sql := new(strings.Builder) for _, rule := range rules { @@ -643,10 +620,7 @@ func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx sql.Reset() fmt.Fprintf(sql, "select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", strings.Join(def.Labels, "`,`"), util.MetricSchemaName.L, rule.tbl, condition) - stmt, err := exec.ParseWithParams(ctx, true, sql.String()) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -697,11 +671,7 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se fmt.Fprintf(sql, `select t1.job,t1.instance, t2.min_time from (select instance,job from metrics_schema.up %[1]s group by instance,job having max(value)-min(value)>0) as t1 join (select instance,min(time) as min_time from metrics_schema.up %[1]s and value=0 group by instance,job) as t2 on t1.instance=t2.instance order by job`, condition) - var rows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, true, sql.String()) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) } @@ -726,10 +696,7 @@ func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx se // Check from log. sql.Reset() fmt.Fprintf(sql, "select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) - stmt, err = exec.ParseWithParams(ctx, true, sql.String()) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err = exec.ExecRestrictedSQL(ctx, nil, sql.String()) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) } @@ -843,7 +810,6 @@ func (thresholdCheckInspection) inspectThreshold1(ctx context.Context, sctx sess condition := filter.timeRange.Condition() var results []inspectionResult - var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) sql := new(strings.Builder) for _, rule := range rules { @@ -863,10 +829,7 @@ func (thresholdCheckInspection) inspectThreshold1(ctx context.Context, sctx sess (select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 where t1.cpu > %[2]f;`, rule.component, rule.threshold, condition) } - stmt, err := exec.ParseWithParams(ctx, true, sql.String()) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -1016,7 +979,6 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess condition := filter.timeRange.Condition() var results []inspectionResult - var rows []chunk.Row sql := new(strings.Builder) exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, rule := range rules { @@ -1036,10 +998,7 @@ func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sess } else { fmt.Fprintf(sql, "select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) } - stmt, err := exec.ParseWithParams(ctx, true, sql.String()) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -1215,17 +1174,13 @@ func (thresholdCheckInspection) inspectThreshold3(ctx context.Context, sctx sess func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter, rules []ruleChecker) []inspectionResult { var results []inspectionResult - var rows []chunk.Row exec := sctx.(sqlexec.RestrictedSQLExecutor) for _, rule := range rules { if !filter.enable(rule.getItem()) { continue } sql := rule.genSQL(filter.timeRange) - stmt, err := exec.ParseWithParams(ctx, true, sql) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue @@ -1244,11 +1199,7 @@ func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx fmt.Fprintf(sql, `select address,min(value) as mi,max(value) as mx from metrics_schema.pd_scheduler_store_status %s and type='leader_count' group by address having mx-mi>%v`, condition, threshold) exec := sctx.(sqlexec.RestrictedSQLExecutor) - var rows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, true, sql.String()) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) return nil @@ -1259,10 +1210,7 @@ func (c thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx sql.Reset() fmt.Fprintf(sql, `select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) var subRows []chunk.Row - stmt, err := exec.ParseWithParams(ctx, true, sql.String()) - if err == nil { - subRows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } + subRows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) if err != nil { sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) continue diff --git a/executor/inspection_summary.go b/executor/inspection_summary.go index 2c58248ed4a03..ebd3f69abc4f8 100644 --- a/executor/inspection_summary.go +++ b/executor/inspection_summary.go @@ -460,11 +460,7 @@ func (e *inspectionSummaryRetriever) retrieve(ctx context.Context, sctx sessionc util.MetricSchemaName.L, name, cond) } exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, sql) - if err != nil { - return nil, errors.Errorf("execute '%s' failed: %v", sql, err) - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/metrics_reader.go b/executor/metrics_reader.go index d5bb294a212e3..3e90897d03192 100644 --- a/executor/metrics_reader.go +++ b/executor/metrics_reader.go @@ -233,11 +233,7 @@ func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx. } exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, sql) - if err != nil { - return nil, errors.Errorf("execute '%s' failed: %v", sql, err) - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } @@ -318,11 +314,7 @@ func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sess util.MetricSchemaName.L, name, cond) } exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, sql) - if err != nil { - return nil, errors.Errorf("execute '%s' failed: %v", sql, err) - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) if err != nil { return nil, errors.Errorf("execute '%s' failed: %v", sql, err) } diff --git a/executor/opt_rule_blacklist.go b/executor/opt_rule_blacklist.go index b9c8318d1bc30..5773f80efe7a2 100644 --- a/executor/opt_rule_blacklist.go +++ b/executor/opt_rule_blacklist.go @@ -37,11 +37,7 @@ func (e *ReloadOptRuleBlacklistExec) Next(ctx context.Context, _ *chunk.Chunk) e // LoadOptRuleBlacklist loads the latest data from table mysql.opt_rule_blacklist. func LoadOptRuleBlacklist(ctx sessionctx.Context) (err error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, "select HIGH_PRIORITY name from mysql.opt_rule_blacklist") - if err != nil { - return err - } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, "select HIGH_PRIORITY name from mysql.opt_rule_blacklist") if err != nil { return err } diff --git a/executor/reload_expr_pushdown_blacklist.go b/executor/reload_expr_pushdown_blacklist.go index ea873080727b3..1511e7a280195 100644 --- a/executor/reload_expr_pushdown_blacklist.go +++ b/executor/reload_expr_pushdown_blacklist.go @@ -39,11 +39,7 @@ func (e *ReloadExprPushdownBlacklistExec) Next(ctx context.Context, _ *chunk.Chu // LoadExprPushdownBlacklist loads the latest data from table mysql.expr_pushdown_blacklist. func LoadExprPushdownBlacklist(ctx sessionctx.Context) (err error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist") - if err != nil { - return err - } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, "select HIGH_PRIORITY name, store_type from mysql.expr_pushdown_blacklist") if err != nil { return err } diff --git a/executor/show.go b/executor/show.go index 1fef2db83c915..498825084093d 100644 --- a/executor/show.go +++ b/executor/show.go @@ -343,11 +343,7 @@ func (e *ShowExec) fetchShowBind() error { func (e *ShowExec) fetchShowEngines(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, `SELECT * FROM information_schema.engines`) - if err != nil { - return errors.Trace(err) - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, `SELECT * FROM information_schema.engines`) if err != nil { return errors.Trace(err) } @@ -474,17 +470,6 @@ func (e *ShowExec) fetchShowTableStatus(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, `SELECT - table_name, engine, version, row_format, table_rows, - avg_row_length, data_length, max_data_length, index_length, - data_free, auto_increment, create_time, update_time, check_time, - table_collation, IFNULL(checksum,''), create_options, table_comment - FROM information_schema.tables - WHERE lower(table_schema)=%? ORDER BY table_name`, e.DBName.L) - if err != nil { - return errors.Trace(err) - } - var snapshot uint64 txn, err := e.ctx.Txn(false) if err != nil { @@ -497,7 +482,13 @@ func (e *ShowExec) fetchShowTableStatus(ctx context.Context) error { snapshot = e.ctx.GetSessionVars().SnapshotTS } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) + rows, _, err := exec.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionWithSnapshot(snapshot)}, + `SELECT table_name, engine, version, row_format, table_rows, + avg_row_length, data_length, max_data_length, index_length, + data_free, auto_increment, create_time, update_time, check_time, + table_collation, IFNULL(checksum,''), create_options, table_comment + FROM information_schema.tables + WHERE lower(table_schema)=%? ORDER BY table_name`, e.DBName.L) if err != nil { return errors.Trace(err) } @@ -1423,11 +1414,7 @@ func (e *ShowExec) fetchShowCreateUser(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, `SELECT plugin FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, strings.ToLower(hostName)) - if err != nil { - return errors.Trace(err) - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, `SELECT plugin FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, strings.ToLower(hostName)) if err != nil { return errors.Trace(err) } @@ -1443,11 +1430,7 @@ func (e *ShowExec) fetchShowCreateUser(ctx context.Context) error { authplugin = rows[0].GetString(0) } - stmt, err = exec.ParseWithParams(ctx, true, `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) - if err != nil { - return errors.Trace(err) - } - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err = exec.ExecRestrictedSQL(ctx, nil, `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) if err != nil { return errors.Trace(err) } diff --git a/executor/show_placement.go b/executor/show_placement.go index 164a5f9293a7c..acd6d9cccecfc 100644 --- a/executor/show_placement.go +++ b/executor/show_placement.go @@ -107,12 +107,7 @@ func (b *showPlacementLabelsResultBuilder) sortMapKeys(m map[string]interface{}) func (e *ShowExec) fetchShowPlacementLabels(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, "SELECT DISTINCT LABEL FROM %n.%n", "INFORMATION_SCHEMA", infoschema.TableTiKVStoreStatus) - if err != nil { - return errors.Trace(err) - } - - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "SELECT DISTINCT LABEL FROM %n.%n", "INFORMATION_SCHEMA", infoschema.TableTiKVStoreStatus) if err != nil { return errors.Trace(err) } diff --git a/executor/simple.go b/executor/simple.go index 0de5688e4c0cb..ae3e0b197f8fc 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -967,25 +967,17 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) if !ok { return errors.Trace(ErrPasswordFormat) } - stmt, err := exec.ParseWithParams(ctx, true, + _, _, err := exec.ExecRestrictedSQL(ctx, nil, `UPDATE %n.%n SET authentication_string=%?, plugin=%? WHERE Host=%? and User=%?;`, mysql.SystemDB, mysql.UserTable, pwd, spec.AuthOpt.AuthPlugin, strings.ToLower(spec.User.Hostname), spec.User.Username, ) - if err != nil { - return err - } - _, _, err = exec.ExecRestrictedStmt(ctx, stmt) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } } if len(privData) > 0 { - stmt, err := exec.ParseWithParams(ctx, true, "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) - if err != nil { - return err - } - _, _, err = exec.ExecRestrictedStmt(ctx, stmt) + _, _, err := exec.ExecRestrictedSQL(ctx, nil, "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } @@ -1359,11 +1351,7 @@ func (e *SimpleExec) executeDropUser(ctx context.Context, s *ast.DropUserStmt) e func userExists(ctx context.Context, sctx sessionctx.Context, name string, host string) (bool, error) { exec := sctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, strings.ToLower(host)) - if err != nil { - return false, err - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, strings.ToLower(host)) if err != nil { return false, err } @@ -1442,11 +1430,7 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error // update mysql.user exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(ctx, true, `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, strings.ToLower(h)) - if err != nil { - return err - } - _, _, err = exec.ExecRestrictedStmt(ctx, stmt) + _, _, err = exec.ExecRestrictedSQL(ctx, nil, `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, strings.ToLower(h)) if err != nil { return err } diff --git a/expression/util.go b/expression/util.go index f75775712dce1..55948329c717a 100644 --- a/expression/util.go +++ b/expression/util.go @@ -1180,11 +1180,7 @@ func (r *SQLDigestTextRetriever) runFetchDigestQuery(ctx context.Context, sctx s stmt += " where digest in (" + strings.Repeat("%?,", len(inValues)-1) + "%?)" } - stmtNode, err := exec.ParseWithParams(ctx, true, stmt, inValues...) - if err != nil { - return nil, err - } - rows, _, err := exec.ExecRestrictedStmt(ctx, stmtNode) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, stmt, inValues...) if err != nil { return nil, err } diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 4e940cd56a089..fdfdf8057ce16 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -2198,11 +2198,7 @@ func (b *PlanBuilder) genV2AnalyzeOptions( func (b *PlanBuilder) getSavedAnalyzeOpts(physicalID int64, tblInfo *model.TableInfo) (map[ast.AnalyzeOptionType]uint64, model.ColumnChoice, []*model.ColumnInfo, error) { analyzeOptions := map[ast.AnalyzeOptionType]uint64{} exec := b.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, "select sample_num,sample_rate,buckets,topn,column_choice,column_ids from mysql.analyze_options where table_id = %?", physicalID) - if err != nil { - return nil, model.DefaultChoice, nil, err - } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, "select sample_num,sample_rate,buckets,topn,column_choice,column_ids from mysql.analyze_options where table_id = %?", physicalID) if err != nil { return nil, model.DefaultChoice, nil, err } diff --git a/session/session.go b/session/session.go index 58e18b0109da7..caebed1c500c6 100644 --- a/session/session.go +++ b/session/session.go @@ -1192,11 +1192,7 @@ func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet, allo // getTableValue executes restricted sql and the result is one column. // It returns a string value. func (s *session) getTableValue(ctx context.Context, tblName string, varName string) (string, error) { - stmt, err := s.ParseWithParams(ctx, true, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) - if err != nil { - return "", err - } - rows, fields, err := s.ExecRestrictedStmt(ctx, stmt) + rows, fields, err := s.ExecRestrictedSQL(ctx, nil, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) if err != nil { return "", err } @@ -1214,11 +1210,10 @@ func (s *session) getTableValue(ctx context.Context, tblName string, varName str // replaceGlobalVariablesTableValue executes restricted sql updates the variable value // It will then notify the etcd channel that the value has changed. func (s *session) replaceGlobalVariablesTableValue(ctx context.Context, varName, val string) error { - stmt, err := s.ParseWithParams(ctx, true, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, mysql.GlobalVariablesTable, varName, val) + _, _, err := s.ExecRestrictedSQL(ctx, nil, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, mysql.GlobalVariablesTable, varName, val) if err != nil { return err } - _, _, err = s.ExecRestrictedStmt(ctx, stmt) domain.GetDomain(s).NotifyUpdateSysVarCache() return err } @@ -1289,11 +1284,7 @@ func (s *session) SetGlobalSysVarOnly(name, value string) (err error) { // SetTiDBTableValue implements GlobalVarAccessor.SetTiDBTableValue interface. func (s *session) SetTiDBTableValue(name, value, comment string) error { - stmt, err := s.ParseWithParams(context.TODO(), true, `REPLACE INTO mysql.tidb (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, name, value, comment) - if err != nil { - return err - } - _, _, err = s.ExecRestrictedStmt(context.TODO(), stmt) + _, _, err := s.ExecRestrictedSQL(context.TODO(), nil, `REPLACE INTO mysql.tidb (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, name, value, comment) return err } @@ -1421,7 +1412,7 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter logutil.Eventf(ctx, "execute: %s", sql) } - stmtNode, err := s.ParseWithParams(ctx, true, sql, args...) + stmtNode, err := s.ParseWithParams(ctx, sql, args...) if err != nil { return nil, err } @@ -1499,7 +1490,7 @@ func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) // ParseWithParams parses a query string, with arguments, to raw ast.StmtNode. // Note that it will not do escaping if no variable arguments are passed. -func (s *session) ParseWithParams(ctx context.Context, forceUTF8SQL bool, sql string, args ...interface{}) (ast.StmtNode, error) { +func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) { var err error if len(args) > 0 { sql, err = sqlexec.EscapeSQL(sql, args...) @@ -1513,7 +1504,7 @@ func (s *session) ParseWithParams(ctx context.Context, forceUTF8SQL bool, sql st var stmts []ast.StmtNode var warns []error parseStartTime := time.Now() - if internal || forceUTF8SQL { + if internal { // Do no respect the settings from clients, if it is for internal usage. // Charsets from clients may give chance injections. // Refer to https://stackoverflow.com/questions/5741187/sql-injection-that-gets-around-mysql-real-escape-string/12118602. @@ -1563,39 +1554,57 @@ func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, if topsqlstate.TopSQLEnabled() { defer pprof.SetGoroutineLabels(ctx) } - var execOption sqlexec.ExecOption - for _, opt := range opts { - opt(&execOption) + se, clean, err := s.getInternalSession(opts) + if err != nil { + return nil, nil, err } - // Use special session to execute the sql. - tmp, err := s.sysSessionPool().Get() + defer clean() + + startTime := time.Now() + metrics.SessionRestrictedSQLCounter.Inc() + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) + rs, err := se.ExecuteStmt(ctx, stmtNode) + if err != nil { + se.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, nil, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + var rows []chunk.Row + rows, err = drainRecordSet(ctx, se, rs, nil) if err != nil { return nil, nil, err } - defer s.sysSessionPool().Put(tmp) + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) + return rows, rs.Fields(), err +} + +func (s *session) getInternalSession(opts []sqlexec.OptionFuncAlias) (*session, func(), error) { + tmp, err := s.sysSessionPool().Get() + if err != nil { + return nil, nil, errors.Trace(err) + } se := tmp.(*session) + var execOption sqlexec.ExecOption + for _, opt := range opts { + opt(&execOption) + } - startTime := time.Now() // The special session will share the `InspectionTableCache` with current session // if the current session in inspection mode. if cache := s.sessionVars.InspectionTableCache; cache != nil { se.sessionVars.InspectionTableCache = cache - defer func() { se.sessionVars.InspectionTableCache = nil }() } if ok := s.sessionVars.OptimizerUseInvisibleIndexes; ok { se.sessionVars.OptimizerUseInvisibleIndexes = true - defer func() { se.sessionVars.OptimizerUseInvisibleIndexes = false }() } prePruneMode := se.sessionVars.PartitionPruneMode.Load() - defer func() { - if !execOption.IgnoreWarning { - if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 { - warnings := se.GetSessionVars().StmtCtx.GetWarnings() - s.GetSessionVars().StmtCtx.AppendWarnings(warnings) - } - } - se.sessionVars.PartitionPruneMode.Store(prePruneMode) - }() if execOption.SnapshotTS != 0 { se.sessionVars.SnapshotInfoschema, err = getSnapshotInfoSchema(s, execOption.SnapshotTS) @@ -1605,47 +1614,77 @@ func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { return nil, nil, err } - defer func() { - if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { - logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) - } - se.sessionVars.SnapshotInfoschema = nil - }() } + prevStatsVer := se.sessionVars.AnalyzeVersion if execOption.AnalyzeVer != 0 { - prevStatsVer := se.sessionVars.AnalyzeVersion se.sessionVars.AnalyzeVersion = execOption.AnalyzeVer - defer func() { - se.sessionVars.AnalyzeVersion = prevStatsVer - }() } // for analyze stmt we need let worker session follow user session that executing stmt. se.sessionVars.PartitionPruneMode.Store(s.sessionVars.PartitionPruneMode.Load()) - metrics.SessionRestrictedSQLCounter.Inc() - ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) - ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) - rs, err := se.ExecuteStmt(ctx, stmtNode) - if err != nil { - se.sessionVars.StmtCtx.AppendError(err) - } - if rs == nil { - return nil, nil, err - } - defer func() { - if closeErr := rs.Close(); closeErr != nil { - err = closeErr + return se, func() { + se.sessionVars.AnalyzeVersion = prevStatsVer + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { + logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) } - }() - var rows []chunk.Row - rows, err = drainRecordSet(ctx, se, rs, nil) + se.sessionVars.SnapshotInfoschema = nil + if !execOption.IgnoreWarning { + if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 { + warnings := se.GetSessionVars().StmtCtx.GetWarnings() + s.GetSessionVars().StmtCtx.AppendWarnings(warnings) + } + } + se.sessionVars.PartitionPruneMode.Store(prePruneMode) + se.sessionVars.OptimizerUseInvisibleIndexes = false + se.sessionVars.InspectionTableCache = nil + s.sysSessionPool().Put(tmp) + }, nil +} + +func (s *session) withRestrictedSQLExecutor(ctx context.Context, opts []sqlexec.OptionFuncAlias, fn func(context.Context, *session) ([]chunk.Row, []*ast.ResultField, error)) ([]chunk.Row, []*ast.ResultField, error) { + se, clean, err := s.getInternalSession(opts) if err != nil { - return nil, nil, err + return nil, nil, errors.Trace(err) } - metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) - return rows, rs.Fields(), err + defer clean() + return fn(ctx, se) +} + +func (s *session) ExecRestrictedSQL(ctx context.Context, opts []sqlexec.OptionFuncAlias, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { + return s.withRestrictedSQLExecutor(ctx, opts, func(ctx context.Context, se *session) ([]chunk.Row, []*ast.ResultField, error) { + stmt, err := se.ParseWithParams(ctx, sql, params...) + if err != nil { + return nil, nil, errors.Trace(err) + } + if topsqlstate.TopSQLEnabled() { + defer pprof.SetGoroutineLabels(ctx) + } + startTime := time.Now() + metrics.SessionRestrictedSQLCounter.Inc() + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) + rs, err := se.ExecuteStmt(ctx, stmt) + if err != nil { + se.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, nil, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + var rows []chunk.Row + rows, err = drainRecordSet(ctx, se, rs, nil) + if err != nil { + return nil, nil, err + } + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds()) + return rows, rs.Fields(), err + }) } func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { diff --git a/session/session_test.go b/session/session_test.go index ac68f7fab7d2c..fe841eac5ba9e 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -4420,11 +4420,11 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { exec := se.(sqlexec.RestrictedSQLExecutor) // test compatibility with ExcuteInternal - _, err := exec.ParseWithParams(context.TODO(), true, "SELECT 4") + _, err := exec.ParseWithParams(context.TODO(), "SELECT 4") c.Assert(err, IsNil) // test charset attack - stmt, err := exec.ParseWithParams(context.TODO(), true, "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") + stmt, err := exec.ParseWithParams(context.TODO(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*") c.Assert(err, IsNil) var sb strings.Builder @@ -4434,15 +4434,15 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) { c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\"\xbf' OR 1=1 /*\" LIMIT 1") // test invalid sql - _, err = exec.ParseWithParams(context.TODO(), true, "SELECT") + _, err = exec.ParseWithParams(context.TODO(), "SELECT") c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*") // test invalid arguments to escape - _, err = exec.ParseWithParams(context.TODO(), true, "SELECT %?, %?", 3) + _, err = exec.ParseWithParams(context.TODO(), "SELECT %?, %?", 3) c.Assert(err, ErrorMatches, "missing arguments.*") // test noescape - stmt, err = exec.ParseWithParams(context.TODO(), true, "SELECT 3") + stmt, err = exec.ParseWithParams(context.TODO(), "SELECT 3") c.Assert(err, IsNil) sb.Reset() diff --git a/session/tidb_test.go b/session/tidb_test.go index 47dc3896ba90a..70831a8f64d89 100644 --- a/session/tidb_test.go +++ b/session/tidb_test.go @@ -37,7 +37,7 @@ func TestSysSessionPoolGoroutineLeak(t *testing.T) { count := 200 stmts := make([]ast.StmtNode, count) for i := 0; i < count; i++ { - stmt, err := se.ParseWithParams(context.Background(), true, "select * from mysql.user limit 1") + stmt, err := se.ParseWithParams(context.Background(), "select * from mysql.user limit 1") require.NoError(t, err) stmts[i] = stmt } diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 45db4f42df329..de53c5a309cf2 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -137,7 +137,7 @@ func (h *Handle) withRestrictedSQLExecutor(ctx context.Context, fn func(context. func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := exec.ParseWithParams(ctx, true, sql, params...) + stmt, err := exec.ParseWithParams(ctx, sql, params...) if err != nil { return nil, nil, errors.Trace(err) } @@ -147,7 +147,7 @@ func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...in func (h *Handle) execRestrictedSQLWithStatsVer(ctx context.Context, statsVer int, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := exec.ParseWithParams(ctx, true, sql, params...) + stmt, err := exec.ParseWithParams(ctx, sql, params...) // TODO: An ugly way to set @@tidb_partition_prune_mode. Need to be improved. if _, ok := stmt.(*ast.AnalyzeTableStmt); ok { pruneMode := h.CurrentPruneMode() @@ -164,7 +164,7 @@ func (h *Handle) execRestrictedSQLWithStatsVer(ctx context.Context, statsVer int func (h *Handle) execRestrictedSQLWithSnapshot(ctx context.Context, sql string, snapshot uint64, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := exec.ParseWithParams(ctx, true, sql, params...) + stmt, err := exec.ParseWithParams(ctx, sql, params...) if err != nil { return nil, nil, errors.Trace(err) } @@ -1403,14 +1403,10 @@ type statsReader struct { func (sr *statsReader) read(sql string, args ...interface{}) (rows []chunk.Row, fields []*ast.ResultField, err error) { ctx := context.TODO() - stmt, err := sr.ctx.ParseWithParams(ctx, true, sql, args...) - if err != nil { - return nil, nil, errors.Trace(err) - } if sr.snapshot > 0 { - return sr.ctx.ExecRestrictedStmt(ctx, stmt, sqlexec.ExecOptionWithSnapshot(sr.snapshot)) + return sr.ctx.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionWithSnapshot(sr.snapshot)}, sql, args...) } - return sr.ctx.ExecRestrictedStmt(ctx, stmt) + return sr.ctx.ExecRestrictedSQL(ctx, nil, sql, args...) } func (sr *statsReader) isHistory() bool { diff --git a/telemetry/data_cluster_hardware.go b/telemetry/data_cluster_hardware.go index c3b3d716eb682..d357e9243fd0b 100644 --- a/telemetry/data_cluster_hardware.go +++ b/telemetry/data_cluster_hardware.go @@ -69,11 +69,7 @@ func normalizeFieldName(name string) string { func getClusterHardware(ctx sessionctx.Context) ([]*clusterHardwareItem, error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, `SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) - if err != nil { - return nil, errors.Trace(err) - } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, `SELECT TYPE, INSTANCE, DEVICE_TYPE, DEVICE_NAME, NAME, VALUE FROM information_schema.cluster_hardware`) if err != nil { return nil, errors.Trace(err) } diff --git a/telemetry/data_cluster_info.go b/telemetry/data_cluster_info.go index 7ba04df5d6b9d..40f87bccdfd3d 100644 --- a/telemetry/data_cluster_info.go +++ b/telemetry/data_cluster_info.go @@ -37,11 +37,7 @@ type clusterInfoItem struct { func getClusterInfo(ctx sessionctx.Context) ([]*clusterInfoItem, error) { // Explicitly list all field names instead of using `*` to avoid potential leaking sensitive info when adding new fields in future. exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), true, `SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) - if err != nil { - return nil, errors.Trace(err) - } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, `SELECT TYPE, INSTANCE, STATUS_ADDRESS, VERSION, GIT_HASH, START_TIME, UPTIME FROM information_schema.cluster_info`) if err != nil { return nil, errors.Trace(err) } diff --git a/telemetry/data_feature_usage.go b/telemetry/data_feature_usage.go index 486021a5df7d9..521d5a5a3c2f9 100644 --- a/telemetry/data_feature_usage.go +++ b/telemetry/data_feature_usage.go @@ -77,7 +77,7 @@ func getClusterIndexUsageInfo(ctx sessionctx.Context) (cu *ClusterIndexUsage, er exec := ctx.(sqlexec.RestrictedSQLExecutor) // query INFORMATION_SCHEMA.tables to get the latest table information about ClusterIndex - stmt, err := exec.ParseWithParams(context.TODO(), true, ` + rows, _, err := exec.ExecRestrictedSQL(context.TODO(), nil, ` SELECT left(sha2(TABLE_NAME, 256), 6) table_name_hash, TIDB_PK_TYPE, TABLE_SCHEMA, TABLE_NAME FROM information_schema.tables WHERE table_schema not in ('INFORMATION_SCHEMA', 'METRICS_SCHEMA', 'PERFORMANCE_SCHEMA', 'mysql') @@ -86,10 +86,6 @@ func getClusterIndexUsageInfo(ctx sessionctx.Context) (cu *ClusterIndexUsage, er if err != nil { return nil, err } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) - if err != nil { - return nil, err - } defer func() { if r := recover(); r != nil { diff --git a/util/admin/admin.go b/util/admin/admin.go index 863f843fca350..83525efa52b7b 100644 --- a/util/admin/admin.go +++ b/util/admin/admin.go @@ -26,7 +26,6 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" - "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" @@ -292,8 +291,8 @@ type RecordData struct { Values []types.Datum } -func getCount(exec sqlexec.RestrictedSQLExecutor, stmt ast.StmtNode, snapshot uint64) (int64, error) { - rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) +func getCount(exec sqlexec.RestrictedSQLExecutor, snapshot uint64, sql string, args ...interface{}) (int64, error) { + rows, _, err := exec.ExecRestrictedSQL(context.Background(), []sqlexec.OptionFuncAlias{sqlexec.ExecOptionWithSnapshot(snapshot)}, sql, args...) if err != nil { return 0, errors.Trace(err) } @@ -321,12 +320,6 @@ func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices defer func() { ctx.GetSessionVars().OptimizerUseInvisibleIndexes = false }() - // Add `` for some names like `table name`. - exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.Background(), true, "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName) - if err != nil { - return 0, 0, errors.Trace(err) - } var snapshot uint64 txn, err := ctx.Txn(false) @@ -340,16 +333,14 @@ func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices snapshot = ctx.GetSessionVars().SnapshotTS } - tblCnt, err := getCount(exec, stmt, snapshot) + // Add `` for some names like `table name`. + exec := ctx.(sqlexec.RestrictedSQLExecutor) + tblCnt, err := getCount(exec, snapshot, "SELECT COUNT(*) FROM %n.%n USE INDEX()", dbName, tableName) if err != nil { return 0, 0, errors.Trace(err) } for i, idx := range indices { - stmt, err := exec.ParseWithParams(context.Background(), true, "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx) - if err != nil { - return 0, i, errors.Trace(err) - } - idxCnt, err := getCount(exec, stmt, snapshot) + idxCnt, err := getCount(exec, snapshot, "SELECT COUNT(*) FROM %n.%n USE INDEX(%n)", dbName, tableName, idx) if err != nil { return 0, i, errors.Trace(err) } diff --git a/util/gcutil/gcutil.go b/util/gcutil/gcutil.go index 5d0949f162747..8c60534f9c265 100644 --- a/util/gcutil/gcutil.go +++ b/util/gcutil/gcutil.go @@ -72,11 +72,7 @@ func ValidateSnapshotWithGCSafePoint(snapshotTS, safePointTS uint64) error { // GetGCSafePoint loads GC safe point time from mysql.tidb. func GetGCSafePoint(ctx sessionctx.Context) (uint64, error) { exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.Background(), true, selectVariableValueSQL, "tikv_gc_safe_point") - if err != nil { - return 0, errors.Trace(err) - } - rows, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) + rows, _, err := exec.ExecRestrictedSQL(context.Background(), nil, selectVariableValueSQL, "tikv_gc_safe_point") if err != nil { return 0, errors.Trace(err) } diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 8ca1fbf67f918..1a45f1ea2cd28 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -45,9 +45,11 @@ type RestrictedSQLExecutor interface { // Attention: it does not prevent you from doing parse("select '%?", ";SQL injection!;") => "select '';SQL injection!;'". // One argument should be a standalone entity. It should not "concat" with other placeholders and characters. // This function only saves you from processing potentially unsafe parameters. - ParseWithParams(ctx context.Context, forceUTF8SQL bool, sql string, args ...interface{}) (ast.StmtNode, error) - // ExecRestrictedStmt run sql statement in ctx with some restriction. + ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) + // ExecRestrictedStmt run sql statement in ctx with some restrictions. ExecRestrictedStmt(ctx context.Context, stmt ast.StmtNode, opts ...OptionFuncAlias) ([]chunk.Row, []*ast.ResultField, error) + // ExecRestrictedSQL run sql string in ctx with internal session. + ExecRestrictedSQL(ctx context.Context, opts []OptionFuncAlias, sql string, args ...interface{}) ([]chunk.Row, []*ast.ResultField, error) } // ExecOption is a struct defined for ExecRestrictedStmt option.