diff --git a/expression/aggregation/count.go b/expression/aggregation/count.go index 2b5236b38a446..2cd939af58191 100644 --- a/expression/aggregation/count.go +++ b/expression/aggregation/count.go @@ -14,7 +14,6 @@ package aggregation import ( - log "github.com/Sirupsen/logrus" "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" @@ -39,11 +38,7 @@ func (cf *countFunction) Clone() Aggregation { // CalculateDefaultValue implements Aggregation interface. func (cf *countFunction) CalculateDefaultValue(schema *expression.Schema, ctx context.Context) (d types.Datum, valid bool) { for _, arg := range cf.Args { - result, err := expression.EvaluateExprWithNull(ctx, schema, arg) - if err != nil { - log.Warnf("Evaluate expr with null failed in function %s, err msg is %s", cf, err.Error()) - return d, false - } + result := expression.EvaluateExprWithNull(ctx, schema, arg) if con, ok := result.(*expression.Constant); ok { if con.Value.IsNull() { return types.NewDatum(0), true diff --git a/expression/aggregation/first_row.go b/expression/aggregation/first_row.go index 61d16779193d7..049bedae06abf 100644 --- a/expression/aggregation/first_row.go +++ b/expression/aggregation/first_row.go @@ -14,7 +14,6 @@ package aggregation import ( - log "github.com/Sirupsen/logrus" "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" @@ -70,11 +69,7 @@ func (ff *firstRowFunction) GetPartialResult(ctx *AggEvaluateContext) []types.Da // CalculateDefaultValue implements Aggregation interface. func (ff *firstRowFunction) CalculateDefaultValue(schema *expression.Schema, ctx context.Context) (d types.Datum, valid bool) { arg := ff.Args[0] - result, err := expression.EvaluateExprWithNull(ctx, schema, arg) - if err != nil { - log.Warnf("Evaluate expr with null failed in function %s, err msg is %s", ff, err.Error()) - return d, false - } + result := expression.EvaluateExprWithNull(ctx, schema, arg) if con, ok := result.(*expression.Constant); ok { return con.Value, true } diff --git a/expression/aggregation/max_min.go b/expression/aggregation/max_min.go index c7a0ef323d391..fc0e2378b0c0e 100644 --- a/expression/aggregation/max_min.go +++ b/expression/aggregation/max_min.go @@ -14,7 +14,6 @@ package aggregation import ( - log "github.com/Sirupsen/logrus" "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" @@ -39,11 +38,7 @@ func (mmf *maxMinFunction) Clone() Aggregation { // CalculateDefaultValue implements Aggregation interface. func (mmf *maxMinFunction) CalculateDefaultValue(schema *expression.Schema, ctx context.Context) (d types.Datum, valid bool) { arg := mmf.Args[0] - result, err := expression.EvaluateExprWithNull(ctx, schema, arg) - if err != nil { - log.Warnf("Evaluate expr with null failed in function %s, err msg is %s", mmf, err.Error()) - return d, false - } + result := expression.EvaluateExprWithNull(ctx, schema, arg) if con, ok := result.(*expression.Constant); ok { return con.Value, true } diff --git a/expression/aggregation/sum.go b/expression/aggregation/sum.go index 9e058d62804aa..e1319e1989d88 100644 --- a/expression/aggregation/sum.go +++ b/expression/aggregation/sum.go @@ -62,12 +62,9 @@ func (sf *sumFunction) GetPartialResult(ctx *AggEvaluateContext) []types.Datum { // CalculateDefaultValue implements Aggregation interface. func (sf *sumFunction) CalculateDefaultValue(schema *expression.Schema, ctx context.Context) (d types.Datum, valid bool) { arg := sf.Args[0] - result, err := expression.EvaluateExprWithNull(ctx, schema, arg) - if err != nil { - log.Warnf("Evaluate expr with null failed in function %s, err msg is %s", sf, err.Error()) - return d, false - } + result := expression.EvaluateExprWithNull(ctx, schema, arg) if con, ok := result.(*expression.Constant); ok { + var err error d, err = calculateSum(ctx.GetSessionVars().StmtCtx, d, con.Value) if err != nil { log.Warnf("CalculateSum failed in function %s, err msg is %s", sf, err.Error()) diff --git a/expression/expression.go b/expression/expression.go index aac4a0d3ca922..16b8bbbf5fd32 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -215,35 +215,25 @@ func SplitDNFItems(onExpr Expression) []Expression { // EvaluateExprWithNull sets columns in schema as null and calculate the final result of the scalar function. // If the Expression is a non-constant value, it means the result is unknown. -func EvaluateExprWithNull(ctx context.Context, schema *Schema, expr Expression) (Expression, error) { +func EvaluateExprWithNull(ctx context.Context, schema *Schema, expr Expression) Expression { switch x := expr.(type) { case *ScalarFunction: - var err error args := make([]Expression, len(x.GetArgs())) for i, arg := range x.GetArgs() { - args[i], err = EvaluateExprWithNull(ctx, schema, arg) - if err != nil { - return nil, errors.Trace(err) - } - } - newFunc, err := NewFunction(ctx, x.FuncName.L, types.NewFieldType(mysql.TypeTiny), args...) - if err != nil { - return nil, errors.Trace(err) + args[i] = EvaluateExprWithNull(ctx, schema, arg) } - return newFunc, nil + return NewFunctionInternal(ctx, x.FuncName.L, types.NewFieldType(mysql.TypeTiny), args...) case *Column: if !schema.Contains(x) { - return x, nil + return x } - constant := &Constant{Value: types.Datum{}, RetType: types.NewFieldType(mysql.TypeNull)} - return constant, nil + return &Constant{Value: types.Datum{}, RetType: types.NewFieldType(mysql.TypeNull)} case *Constant: if x.DeferredExpr != nil { - newConst := FoldConstant(x) - return newConst, nil + return FoldConstant(x) } } - return expr.Clone(), nil + return expr.Clone() } // TableInfo2Schema converts table info to schema with empty DBName. diff --git a/expression/expression_test.go b/expression/expression_test.go index dda54bc44f733..4ddbe4d84f3c1 100644 --- a/expression/expression_test.go +++ b/expression/expression_test.go @@ -42,14 +42,12 @@ func (s *testEvaluatorSuite) TestEvaluateExprWithNull(c *C) { // ifnull(null, ifnull(col1, 1)) schema := &Schema{Columns: []*Column{col0}} - res, err := EvaluateExprWithNull(s.ctx, schema, ifnullOuter) - c.Assert(err, IsNil) + res := EvaluateExprWithNull(s.ctx, schema, ifnullOuter) c.Assert(res.String(), Equals, "ifnull(, ifnull(col1, 1))") schema.Columns = append(schema.Columns, col1) // ifnull(null, ifnull(null, 1)) - res, err = EvaluateExprWithNull(s.ctx, schema, ifnullOuter) - c.Assert(err, IsNil) + res = EvaluateExprWithNull(s.ctx, schema, ifnullOuter) c.Assert(res.Equal(One, s.ctx), IsTrue) } diff --git a/plan/decorrelate.go b/plan/decorrelate.go index 46c0c16150a31..7a1a1ac0316a3 100644 --- a/plan/decorrelate.go +++ b/plan/decorrelate.go @@ -70,10 +70,7 @@ func (a *LogicalAggregation) canPullUp() bool { } for _, f := range a.AggFuncs { for _, arg := range f.GetArgs() { - expr, err := expression.EvaluateExprWithNull(a.ctx, a.children[0].Schema(), arg) - if err != nil { - return false - } + expr := expression.EvaluateExprWithNull(a.ctx, a.children[0].Schema(), arg) if con, ok := expr.(*expression.Constant); !ok || !con.Value.IsNull() { return false } diff --git a/plan/eliminate_projection.go b/plan/eliminate_projection.go index f6507c2e6aac2..6bdcf8600aca6 100644 --- a/plan/eliminate_projection.go +++ b/plan/eliminate_projection.go @@ -14,10 +14,8 @@ package plan import ( - "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/terror" ) // canProjectionBeEliminatedLoose checks whether a projection can be eliminated, returns true if @@ -83,8 +81,7 @@ func doPhysicalProjectionElimination(p PhysicalPlan) PhysicalPlan { return p } child := p.Children()[0] - err := RemovePlan(p) - terror.Log(errors.Trace(err)) + removePlan(p) return child.(PhysicalPlan) } @@ -166,8 +163,7 @@ func (pe *projectionEliminater) eliminate(p LogicalPlan, replace map[string]*exp for i, col := range proj.Schema().Columns { replace[string(col.HashCode())] = exprs[i].(*expression.Column) } - err := RemovePlan(p) - terror.Log(errors.Trace(err)) + removePlan(p) return child.(LogicalPlan) } diff --git a/plan/logical_plans.go b/plan/logical_plans.go index 114102448f440..32b1f031107d9 100644 --- a/plan/logical_plans.go +++ b/plan/logical_plans.go @@ -14,7 +14,6 @@ package plan import ( - "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" @@ -407,35 +406,3 @@ type Delete struct { Tables []*ast.TableName IsMultiTable bool } - -// setParentAndChildren sets parent and children relationship. -func setParentAndChildren(parent Plan, children ...Plan) { - if children == nil || parent == nil { - return - } - for _, child := range children { - child.SetParents(parent) - } - parent.SetChildren(children...) -} - -// RemovePlan means removing a plan. -func RemovePlan(p Plan) error { - parents := p.Parents() - children := p.Children() - if len(parents) > 1 || len(children) != 1 { - return SystemInternalErrorType.Gen("can't remove this plan") - } - if len(parents) == 0 { - child := children[0] - child.SetParents() - return nil - } - parent, child := parents[0], children[0] - err := parent.ReplaceChild(p, child) - if err != nil { - return errors.Trace(err) - } - err = child.ReplaceParent(p, parent) - return errors.Trace(err) -} diff --git a/plan/plan.go b/plan/plan.go index ef0e6743d5379..19c0a0a1fc9d1 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -30,10 +30,6 @@ import ( // It is created from ast.Node first, then optimized by the optimizer, // finally used by the executor to create a Cursor which executes the statement. type Plan interface { - // ReplaceParent means replacing a parent with another one. - ReplaceParent(parent, newPar Plan) error - // ReplaceChild means replacing a child with another one. - ReplaceChild(children, newChild Plan) error // Get all the parents. Parents() []Plan // Get all the children. @@ -156,7 +152,7 @@ type LogicalPlan interface { // PredicatePushDown pushes down the predicates in the where/on/having clauses as deeply as possible. // It will accept a predicate that is an expression slice, and return the expressions that can't be pushed. // Because it might change the root if the having clause exists, we need to return a plan that represents a new root. - PredicatePushDown([]expression.Expression) ([]expression.Expression, LogicalPlan, error) + PredicatePushDown([]expression.Expression) ([]expression.Expression, LogicalPlan) // PruneColumns prunes the unused columns. PruneColumns([]*expression.Column) @@ -275,22 +271,6 @@ func newBasePhysicalPlan(basePlan *basePlan) basePhysicalPlan { } } -// PredicatePushDown implements LogicalPlan interface. -func (p *baseLogicalPlan) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan, error) { - if len(p.basePlan.children) == 0 { - return predicates, p.basePlan.self.(LogicalPlan), nil - } - child := p.basePlan.children[0].(LogicalPlan) - rest, _, err := child.PredicatePushDown(predicates) - if err != nil { - return nil, nil, errors.Trace(err) - } - if len(rest) > 0 { - addSelection(p.basePlan.self, child, rest, p.basePlan.allocator) - } - return nil, p.basePlan.self.(LogicalPlan), nil -} - func (p *baseLogicalPlan) extractCorrelatedCols() []*expression.CorrelatedColumn { var corCols []*expression.CorrelatedColumn for _, child := range p.basePlan.children { @@ -373,28 +353,6 @@ func (p *basePlan) Schema() *expression.Schema { return p.schema } -// ReplaceParent means replace a parent for another one. -func (p *basePlan) ReplaceParent(parent, newPar Plan) error { - for i, par := range p.parents { - if par.ID() == parent.ID() { - p.parents[i] = newPar - return nil - } - } - return SystemInternalErrorType.Gen("ReplaceParent Failed: parent \"%s\" not found", parent.ExplainID()) -} - -// ReplaceChild means replace a child with another one. -func (p *basePlan) ReplaceChild(child, newChild Plan) error { - for i, ch := range p.children { - if ch.ID() == child.ID() { - p.children[i] = newChild - return nil - } - } - return SystemInternalErrorType.Gen("ReplaceChildren Failed: child \"%s\" not found", child.ExplainID()) -} - // Parents implements Plan Parents interface. func (p *basePlan) Parents() []Plan { return p.parents diff --git a/plan/predicate_push_down.go b/plan/predicate_push_down.go index 535b68678c5f1..a39761aa1f99f 100644 --- a/plan/predicate_push_down.go +++ b/plan/predicate_push_down.go @@ -13,7 +13,6 @@ package plan import ( - "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" @@ -25,16 +24,8 @@ import ( type ppdSolver struct{} func (s *ppdSolver) optimize(lp LogicalPlan, _ context.Context, _ *idAllocator) (LogicalPlan, error) { - _, p, err := lp.PredicatePushDown(nil) - return p, errors.Trace(err) -} - -func replaceChild(p, child, replace Plan) { - for i, ch := range p.Children() { - if ch.ID() == child.ID() { - p.Children()[i] = replace - } - } + _, p := lp.PredicatePushDown(nil) + return p, nil } func addSelection(p Plan, child LogicalPlan, conditions []expression.Expression, allocator *idAllocator) { @@ -47,40 +38,44 @@ func addSelection(p Plan, child LogicalPlan, conditions []expression.Expression, selection.SetParents(p) } -// PredicatePushDown implements LogicalPlan PredicatePushDown interface. -func (p *Selection) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan, error) { - retConditions, child, err := p.children[0].(LogicalPlan).PredicatePushDown(append(p.Conditions, predicates...)) - if err != nil { - return nil, nil, errors.Trace(err) +// PredicatePushDown implements LogicalPlan interface. +func (p *baseLogicalPlan) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan) { + if len(p.basePlan.children) == 0 { + return predicates, p.basePlan.self.(LogicalPlan) } + child := p.basePlan.children[0].(LogicalPlan) + rest, _ := child.PredicatePushDown(predicates) + if len(rest) > 0 { + addSelection(p.basePlan.self, child, rest, p.basePlan.allocator) + } + return nil, p.basePlan.self.(LogicalPlan) +} + +// PredicatePushDown implements LogicalPlan PredicatePushDown interface. +func (p *Selection) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan) { + retConditions, child := p.children[0].(LogicalPlan).PredicatePushDown(append(p.Conditions, predicates...)) if len(retConditions) > 0 { p.Conditions = expression.PropagateConstant(p.ctx, retConditions) - return nil, p, nil - } - err = RemovePlan(p) - if err != nil { - return nil, nil, errors.Trace(err) + return nil, p } - return nil, child, nil + removePlan(p) + return nil, child } // PredicatePushDown implements LogicalPlan PredicatePushDown interface. -func (p *DataSource) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan, error) { +func (p *DataSource) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan) { _, p.pushedDownConds, predicates = expression.ExpressionsToPB(p.ctx.GetSessionVars().StmtCtx, predicates, p.ctx.GetClient()) - return predicates, p, nil + return predicates, p } // PredicatePushDown implements LogicalPlan PredicatePushDown interface. -func (p *TableDual) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan, error) { - return predicates, p, nil +func (p *TableDual) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan) { + return predicates, p } // PredicatePushDown implements LogicalPlan PredicatePushDown interface. -func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, err error) { - err = outerJoinSimplify(p, predicates) - if err != nil { - return nil, nil, errors.Trace(err) - } +func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan) { + outerJoinSimplify(p, predicates) groups, valid := tryToGetJoinGroup(p) if valid { e := joinReOrderSolver{allocator: p.allocator, ctx: p.ctx} @@ -139,14 +134,8 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret leftCond = leftPushCond rightCond = rightPushCond } - leftRet, _, err1 := leftPlan.PredicatePushDown(leftCond) - if err1 != nil { - return nil, nil, errors.Trace(err1) - } - rightRet, _, err2 := rightPlan.PredicatePushDown(rightCond) - if err2 != nil { - return nil, nil, errors.Trace(err2) - } + leftRet, _ := leftPlan.PredicatePushDown(leftCond) + rightRet, _ := rightPlan.PredicatePushDown(rightCond) if len(leftRet) > 0 { addSelection(p, leftPlan, leftRet, p.allocator) } @@ -231,7 +220,7 @@ func (p *LogicalJoin) getProj(idx int) *Projection { } // outerJoinSimplify simplifies outer join. -func outerJoinSimplify(p *LogicalJoin, predicates []expression.Expression) error { +func outerJoinSimplify(p *LogicalJoin, predicates []expression.Expression) { var innerTable, outerTable LogicalPlan child1 := p.children[0].(LogicalPlan) child2 := p.children[1].(LogicalPlan) @@ -243,37 +232,28 @@ func outerJoinSimplify(p *LogicalJoin, predicates []expression.Expression) error innerTable = child1 outerTable = child2 } else { - return nil + return } // first simplify embedded outer join. // When trying to simplify an embedded outer join operation in a query, // we must take into account the join condition for the embedding outer join together with the WHERE condition. if innerPlan, ok := innerTable.(*LogicalJoin); ok { fullConditions = concatOnAndWhereConds(p, predicates) - err := outerJoinSimplify(innerPlan, fullConditions) - if err != nil { - return errors.Trace(err) - } + outerJoinSimplify(innerPlan, fullConditions) } if outerPlan, ok := outerTable.(*LogicalJoin); ok { if fullConditions != nil { fullConditions = concatOnAndWhereConds(p, predicates) } - err := outerJoinSimplify(outerPlan, fullConditions) - if err != nil { - return errors.Trace(err) - } + outerJoinSimplify(outerPlan, fullConditions) } if p.JoinType == InnerJoin { - return nil + return } // then simplify embedding outer join. canBeSimplified := false for _, expr := range predicates { - isOk, err := isNullRejected(p.ctx, innerTable.Schema(), expr) - if err != nil { - return errors.Trace(err) - } + isOk := isNullRejected(p.ctx, innerTable.Schema(), expr) if isOk { canBeSimplified = true break @@ -282,7 +262,6 @@ func outerJoinSimplify(p *LogicalJoin, predicates []expression.Expression) error if canBeSimplified { p.JoinType = InnerJoin } - return nil } // isNullRejected check whether a condition is null-rejected @@ -290,22 +269,19 @@ func outerJoinSimplify(p *LogicalJoin, predicates []expression.Expression) error // If it is a predicate containing a reference to an inner table that evaluates to UNKNOWN or FALSE when one of its arguments is NULL. // If it is a conjunction containing a null-rejected condition as a conjunct. // If it is a disjunction of null-rejected conditions. -func isNullRejected(ctx context.Context, schema *expression.Schema, expr expression.Expression) (bool, error) { - result, err := expression.EvaluateExprWithNull(ctx, schema, expr) - if err != nil { - return false, errors.Trace(err) - } +func isNullRejected(ctx context.Context, schema *expression.Schema, expr expression.Expression) bool { + result := expression.EvaluateExprWithNull(ctx, schema, expr) x, ok := result.(*expression.Constant) if !ok { - return false, nil + return false } sc := ctx.GetSessionVars().StmtCtx if x.Value.IsNull() { - return true, nil + return true } else if isTrue, err := x.Value.ToBool(sc); err != nil || isTrue == 0 { - return true, errors.Trace(err) + return true } - return false, nil + return false } // concatOnAndWhereConds concatenate ON conditions with WHERE conditions. @@ -323,17 +299,14 @@ func concatOnAndWhereConds(join *LogicalJoin, predicates []expression.Expression } // PredicatePushDown implements LogicalPlan PredicatePushDown interface. -func (p *Projection) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, _ error) { +func (p *Projection) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan) { retPlan = p var push = make([]expression.Expression, 0, p.Schema().Len()) for _, cond := range predicates { push = append(push, expression.ColumnSubstitute(cond, p.Schema(), p.Exprs)) } child := p.children[0].(LogicalPlan) - restConds, _, err := child.PredicatePushDown(push) - if err != nil { - return nil, nil, errors.Trace(err) - } + restConds, _ := child.PredicatePushDown(push) if len(restConds) > 0 { addSelection(p, child, restConds, p.allocator) } @@ -341,7 +314,7 @@ func (p *Projection) PredicatePushDown(predicates []expression.Expression) (ret } // PredicatePushDown implements LogicalPlan PredicatePushDown interface. -func (p *Union) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, err error) { +func (p *Union) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan) { retPlan = p for _, proj := range p.children { newExprs := make([]expression.Expression, 0, len(predicates)) @@ -349,10 +322,7 @@ func (p *Union) PredicatePushDown(predicates []expression.Expression) (ret []exp newCond := expression.ColumnSubstitute(cond, p.Schema(), expression.Column2Exprs(proj.Schema().Columns)) newExprs = append(newExprs, newCond) } - retCond, _, err := proj.(LogicalPlan).PredicatePushDown(newExprs) - if err != nil { - return nil, nil, errors.Trace(err) - } + retCond, _ := proj.(LogicalPlan).PredicatePushDown(newExprs) if len(retCond) != 0 { addSelection(p, proj.(LogicalPlan), retCond, p.allocator) } @@ -366,7 +336,7 @@ func (p *LogicalAggregation) getGbyColIndex(col *expression.Column) int { } // PredicatePushDown implements LogicalPlan PredicatePushDown interface. -func (p *LogicalAggregation) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan, err error) { +func (p *LogicalAggregation) PredicatePushDown(predicates []expression.Expression) (ret []expression.Expression, retPlan LogicalPlan) { retPlan = p var exprsOriginal []expression.Expression var condsToPush []expression.Expression @@ -400,20 +370,20 @@ func (p *LogicalAggregation) PredicatePushDown(predicates []expression.Expressio ret = append(ret, cond) } } - _, _, err = p.baseLogicalPlan.PredicatePushDown(condsToPush) - return ret, retPlan, errors.Trace(err) + _, _ = p.baseLogicalPlan.PredicatePushDown(condsToPush) + return ret, retPlan } // PredicatePushDown implements LogicalPlan PredicatePushDown interface. -func (p *Limit) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan, error) { +func (p *Limit) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan) { // Limit forbids any condition to push down. - _, _, err := p.baseLogicalPlan.PredicatePushDown(nil) - return predicates, p, errors.Trace(err) + p.baseLogicalPlan.PredicatePushDown(nil) + return predicates, p } // PredicatePushDown implements LogicalPlan PredicatePushDown interface. -func (p *MaxOneRow) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan, error) { +func (p *MaxOneRow) PredicatePushDown(predicates []expression.Expression) ([]expression.Expression, LogicalPlan) { // MaxOneRow forbids any condition to push down. - _, _, err := p.baseLogicalPlan.PredicatePushDown(nil) - return predicates, p, errors.Trace(err) + p.baseLogicalPlan.PredicatePushDown(nil) + return predicates, p } diff --git a/plan/util.go b/plan/util.go index f59b6adbbdc18..1d1f188a83b99 100644 --- a/plan/util.go +++ b/plan/util.go @@ -45,3 +45,37 @@ func (a *AggregateFuncExtractor) Leave(n ast.Node) (ast.Node, bool) { } return n, true } + +// replaceChild replaces p's child with some plan else. +func replaceChild(p, child, replace Plan) { + for i, ch := range p.Children() { + if ch.ID() == child.ID() { + p.Children()[i] = replace + } + } +} + +// setParentAndChildren sets parent and children relationship. +func setParentAndChildren(parent Plan, children ...Plan) { + if children == nil || parent == nil { + return + } + for _, child := range children { + child.SetParents(parent) + } + parent.SetChildren(children...) +} + +// removePlan removes a plan from its parent and child. +func removePlan(p Plan) { + parents := p.Parents() + children := p.Children() + if len(parents) == 0 { + child := children[0] + child.SetParents() + return + } + parent, child := parents[0], children[0] + replaceChild(parent, p, child) + child.SetParents(parent) +}