From f9c31d913055f4469808a55dc33cbe445cb17c8c Mon Sep 17 00:00:00 2001 From: Han Fei Date: Wed, 1 Jun 2016 13:12:55 +0800 Subject: [PATCH] Hanfei/rewrite plan (#1272) --- ast/functions.go | 22 +- evaluator/builtin.go | 33 ++ executor/builder.go | 15 +- executor/executor_test.go | 2 +- executor/new_executor.go | 37 +- expression/aggregation.go | 82 +++++ expression/expression.go | 350 +++++++++++++++++++ optimizer/new_plan_test.go | 39 ++- optimizer/optimizer.go | 7 +- optimizer/plan/column_pruning.go | 84 +++++ optimizer/plan/new_plans.go | 63 +++- optimizer/plan/newplanbuilder.go | 466 +++++++++++++++++++++----- optimizer/plan/plan.go | 22 ++ optimizer/plan/planbuilder.go | 14 +- optimizer/plan/predicate_push_down.go | 115 +++---- optimizer/plan/stringer.go | 17 +- parser/opcode/opcode.go | 5 +- 17 files changed, 1148 insertions(+), 225 deletions(-) create mode 100644 expression/aggregation.go create mode 100644 expression/expression.go create mode 100644 optimizer/plan/column_pruning.go diff --git a/ast/functions.go b/ast/functions.go index 76f4ec26b8870..0182f9987b8a4 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -217,7 +217,7 @@ func (n *AggregateFuncExpr) GetContext() *AggEvaluateContext { if _, ok := n.contextPerGroupMap[n.CurrentGroup]; !ok { c := &AggEvaluateContext{} if n.Distinct { - c.distinctChecker = distinct.CreateDistinctChecker() + c.DistinctChecker = distinct.CreateDistinctChecker() } n.contextPerGroupMap[n.CurrentGroup] = c } @@ -235,7 +235,7 @@ func (n *AggregateFuncExpr) updateCount() error { vals = append(vals, value) } if n.Distinct { - d, err := ctx.distinctChecker.Check(vals) + d, err := ctx.DistinctChecker.Check(vals) if err != nil { return errors.Trace(err) } @@ -249,14 +249,14 @@ func (n *AggregateFuncExpr) updateCount() error { func (n *AggregateFuncExpr) updateFirstRow() error { ctx := n.GetContext() - if ctx.evaluated { + if ctx.Evaluated { return nil } if len(n.Args) != 1 { return errors.New("Wrong number of args for AggFuncFirstRow") } ctx.Value = n.Args[0].GetValue() - ctx.evaluated = true + ctx.Evaluated = true return nil } @@ -266,9 +266,9 @@ func (n *AggregateFuncExpr) updateMaxMin(max bool) error { return errors.New("Wrong number of args for AggFuncFirstRow") } v := n.Args[0].GetValue() - if !ctx.evaluated { + if !ctx.Evaluated { ctx.Value = v - ctx.evaluated = true + ctx.Evaluated = true return nil } c, err := types.Compare(ctx.Value, v) @@ -296,7 +296,7 @@ func (n *AggregateFuncExpr) updateSum() error { return nil } if n.Distinct { - d, err := ctx.distinctChecker.Check([]interface{}{value}) + d, err := ctx.DistinctChecker.Check([]interface{}{value}) if err != nil { return errors.Trace(err) } @@ -324,7 +324,7 @@ func (n *AggregateFuncExpr) updateGroupConcat() error { vals = append(vals, value) } if n.Distinct { - d, err := ctx.distinctChecker.Check(vals) + d, err := ctx.DistinctChecker.Check(vals) if err != nil { return errors.Trace(err) } @@ -394,11 +394,11 @@ func (a *AggregateFuncExtractor) Leave(n Node) (node Node, ok bool) { return n, true } -// AggEvaluateContext is used to store intermediate result when caculation aggregate functions. +// AggEvaluateContext is used to store intermediate result when calculating aggregate functions. type AggEvaluateContext struct { - distinctChecker *distinct.Checker + DistinctChecker *distinct.Checker Count int64 Value interface{} Buffer *bytes.Buffer // Buffer is used for group_concat. - evaluated bool + Evaluated bool } diff --git a/evaluator/builtin.go b/evaluator/builtin.go index 8648f0c67a7a5..bfbee9e79afcf 100644 --- a/evaluator/builtin.go +++ b/evaluator/builtin.go @@ -20,6 +20,7 @@ package evaluator import ( "strings" + "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/util/types" ) @@ -113,6 +114,38 @@ var Funcs = map[string]Func{ "if": {builtinIf, 3, 3}, "ifnull": {builtinIfNull, 2, 2}, "nullif": {builtinNullIf, 2, 2}, + + // only used by new plan + "&&": {builtinEmpty, 2, 2}, + "<<": {builtinEmpty, 2, 2}, + ">>": {builtinEmpty, 2, 2}, + "||": {builtinEmpty, 2, 2}, + ">=": {builtinEmpty, 2, 2}, + "<=": {builtinEmpty, 2, 2}, + "=": {builtinEmpty, 2, 2}, + "!=": {builtinEmpty, 2, 2}, + "<": {builtinEmpty, 2, 2}, + ">": {builtinEmpty, 2, 2}, + "+": {builtinEmpty, 2, 2}, + "-": {builtinEmpty, 2, 2}, + "&": {builtinEmpty, 2, 2}, + "|": {builtinEmpty, 2, 2}, + "%": {builtinEmpty, 2, 2}, + "^": {builtinEmpty, 2, 2}, + "/": {builtinEmpty, 2, 2}, + "*": {builtinEmpty, 2, 2}, + "DIV": {builtinEmpty, 2, 2}, + "XOR": {builtinEmpty, 2, 2}, + "<=>": {builtinEmpty, 2, 2}, + "not": {builtinEmpty, 1, 1}, + "bitneg": {builtinEmpty, 1, 1}, + "unaryplus": {builtinEmpty, 1, 1}, + "unaryminus": {builtinEmpty, 1, 1}, +} + +// TODO: remove this when implementing executor. +func builtinEmpty(args []types.Datum, ctx context.Context) (d types.Datum, err error) { + return d, errors.New("Not implemented yet.") } // See: http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_coalesce diff --git a/executor/builder.go b/executor/builder.go index 749a9903a5acb..cb923921daadf 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -20,6 +20,7 @@ import ( "github.com/ngaut/log" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" @@ -112,14 +113,15 @@ func (b *executorBuilder) build(p plan.Plan) Executor { } // compose CNF items into a balance deep CNF tree, which benefits a lot for pb decoder/encoder. -func composeCondition(conditions []ast.ExprNode) ast.ExprNode { +func composeCondition(conditions []expression.Expression) expression.Expression { length := len(conditions) if length == 0 { return nil } else if length == 1 { return conditions[0] } else { - return &ast.BinaryOperationExpr{Op: opcode.AndAnd, L: composeCondition(conditions[:length/2]), R: composeCondition(conditions[length/2:])} + eqStr, _ := opcode.Ops[opcode.EQ] + return expression.NewFunction(model.NewCIStr(eqStr), []expression.Expression{composeCondition(conditions[0 : length/2]), composeCondition(conditions[length/2:])}) } } @@ -131,12 +133,11 @@ func (b *executorBuilder) buildJoin(v *plan.Join) Executor { fields: v.Fields(), ctx: b.ctx, } - var leftHashKey, rightHashKey []ast.ExprNode + var leftHashKey, rightHashKey []*expression.Column for _, eqCond := range v.EqualConditions { - binop, ok := eqCond.(*ast.BinaryOperationExpr) - if ok && binop.Op == opcode.EQ { - ln, lOK := binop.L.(*ast.ColumnNameExpr) - rn, rOK := binop.R.(*ast.ColumnNameExpr) + if eqCond.FuncName.L == "eq" { + ln, lOK := eqCond.Args[0].(*expression.Column) + rn, rOK := eqCond.Args[1].(*expression.Column) if lOK && rOK { leftHashKey = append(leftHashKey, ln) rightHashKey = append(rightHashKey, rn) diff --git a/executor/executor_test.go b/executor/executor_test.go index 9e2306e8473a3..e536b75a71c83 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1101,7 +1101,7 @@ func (s *testSuite) TestJoin(c *C) { } func (s *testSuite) TestNewJoin(c *C) { - plan.UseNewPlanner = true + plan.UseNewPlanner = false defer testleak.AfterTest(c)() tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/executor/new_executor.go b/executor/new_executor.go index ea03c6789039a..c7af9386fdb86 100644 --- a/executor/new_executor.go +++ b/executor/new_executor.go @@ -17,7 +17,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/evaluator" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/types" ) @@ -25,20 +25,21 @@ import ( // HashJoinExec implements the hash join algorithm. type HashJoinExec struct { hashTable map[string][]*Row - smallHashKey []ast.ExprNode - bigHashKey []ast.ExprNode + smallHashKey []*expression.Column + bigHashKey []*expression.Column smallExec Executor bigExec Executor prepared bool - fields []*ast.ResultField ctx context.Context - smallFilter ast.ExprNode - bigFilter ast.ExprNode - otherFilter ast.ExprNode - outter bool - leftSmall bool - matchedRows []*Row - cursor int + smallFilter expression.Expression + bigFilter expression.Expression + otherFilter expression.Expression + //TODO: remove fields when abandon old plan. + fields []*ast.ResultField + outter bool + leftSmall bool + matchedRows []*Row + cursor int } func joinTwoRow(a *Row, b *Row) *Row { @@ -53,10 +54,10 @@ func joinTwoRow(a *Row, b *Row) *Row { return ret } -func (e *HashJoinExec) getHashKey(exprs []ast.ExprNode) ([]byte, error) { +func (e *HashJoinExec) getHashKey(exprs []*expression.Column, row *Row) ([]byte, error) { vals := make([]types.Datum, 0, len(exprs)) for _, expr := range exprs { - v, err := evaluator.Eval(e.ctx, expr) + v, err := expr.Eval(row.Data, e.ctx) if err != nil { return nil, errors.Trace(err) } @@ -96,7 +97,7 @@ func (e *HashJoinExec) prepare() error { matched := true if e.smallFilter != nil { - matched, err = evaluator.EvalBool(e.ctx, e.smallFilter) + matched, err = expression.EvalBool(e.smallFilter, row.Data, e.ctx) if err != nil { return errors.Trace(err) } @@ -104,7 +105,7 @@ func (e *HashJoinExec) prepare() error { continue } } - hashcode, err := e.getHashKey(e.smallHashKey) + hashcode, err := e.getHashKey(e.smallHashKey, row) if err != nil { return errors.Trace(err) } @@ -120,7 +121,7 @@ func (e *HashJoinExec) prepare() error { } func (e *HashJoinExec) constructMatchedRows(bigRow *Row) (matchedRows []*Row, err error) { - hashcode, err := e.getHashKey(e.bigHashKey) + hashcode, err := e.getHashKey(e.bigHashKey, bigRow) if err != nil { return nil, errors.Trace(err) } @@ -141,7 +142,7 @@ func (e *HashJoinExec) constructMatchedRows(bigRow *Row) (matchedRows []*Row, er for i, data := range smallRow.Data { e.fields[i+startKey].Expr.SetValue(data.GetValue()) } - otherMatched, err = evaluator.EvalBool(e.ctx, e.otherFilter) + otherMatched, err = expression.EvalBool(e.otherFilter, bigRow.Data, e.ctx) } if err != nil { return nil, errors.Trace(err) @@ -215,7 +216,7 @@ func (e *HashJoinExec) Next() (*Row, error) { var matchedRows []*Row bigMatched := true if e.bigFilter != nil { - bigMatched, err = evaluator.EvalBool(e.ctx, e.bigFilter) + bigMatched, err = expression.EvalBool(e.bigFilter, row.Data, e.ctx) if err != nil { return nil, errors.Trace(err) } diff --git a/expression/aggregation.go b/expression/aggregation.go new file mode 100644 index 0000000000000..2097e50e42855 --- /dev/null +++ b/expression/aggregation.go @@ -0,0 +1,82 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "github.com/pingcap/tidb/ast" +) + +// AggregationFunction stands for aggregate functions. +type AggregationFunction interface { + // GetArgs stands for getting all arguments. + GetArgs() []Expression +} + +// NewAggrFunction creates a new AggregationFunction. +func NewAggrFunction(funcType string, funcArgs []Expression) AggregationFunction { + switch funcType { + case ast.AggFuncSum: + return &sumFunction{aggrFunction: aggrFunction{Args: funcArgs, resultMapper: make(aggrCtxMapper, 0)}} + case ast.AggFuncCount: + return &countFunction{aggrFunction: aggrFunction{Args: funcArgs, resultMapper: make(aggrCtxMapper, 0)}} + case ast.AggFuncAvg: + return &avgFunction{aggrFunction: aggrFunction{Args: funcArgs, resultMapper: make(aggrCtxMapper, 0)}} + case ast.AggFuncGroupConcat: + return &concatFunction{aggrFunction: aggrFunction{Args: funcArgs, resultMapper: make(aggrCtxMapper, 0)}} + case ast.AggFuncMax: + return &maxMinFunction{aggrFunction: aggrFunction{Args: funcArgs, resultMapper: make(aggrCtxMapper, 0)}, isMax: true} + case ast.AggFuncMin: + return &maxMinFunction{aggrFunction: aggrFunction{Args: funcArgs, resultMapper: make(aggrCtxMapper, 0)}, isMax: false} + case ast.AggFuncFirstRow: + return &firstRowFunction{aggrFunction: aggrFunction{Args: funcArgs, resultMapper: make(aggrCtxMapper, 0)}} + } + return nil +} + +type aggrCtxMapper map[string]*ast.AggEvaluateContext + +type aggrFunction struct { + Args []Expression + resultMapper aggrCtxMapper +} + +// GetArgs implements AggregationFunction interface. +func (af *aggrFunction) GetArgs() []Expression { + return af.Args +} + +type sumFunction struct { + aggrFunction +} + +type countFunction struct { + aggrFunction +} + +type avgFunction struct { + aggrFunction +} + +type concatFunction struct { + aggrFunction +} + +type maxMinFunction struct { + aggrFunction + isMax bool +} + +type firstRowFunction struct { + aggrFunction +} diff --git a/expression/expression.go b/expression/expression.go new file mode 100644 index 0000000000000..7d69b274427dc --- /dev/null +++ b/expression/expression.go @@ -0,0 +1,350 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/evaluator" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" +) + +// Rewrite rewrites ast to Expression. +func Rewrite(expr ast.ExprNode, schema Schema, AggrMapper map[*ast.AggregateFuncExpr]int) (newExpr Expression, err error) { + er := &expressionRewriter{schema: schema, aggrMap: AggrMapper} + expr.Accept(er) + if er.err != nil { + return nil, errors.Trace(er.err) + } + if len(er.ctxStack) != 1 { + return nil, errors.Errorf("context len %v is invalid", len(er.ctxStack)) + } + return er.ctxStack[0], nil +} + +type expressionRewriter struct { + ctxStack []Expression + schema Schema + err error + aggrMap map[*ast.AggregateFuncExpr]int +} + +// Expression represents all scalar expression in SQL. +type Expression interface { + // Eval evaluates an expression through a row. + Eval(row []types.Datum, ctx context.Context) (types.Datum, error) + + // Get the expression return type. + GetType() *types.FieldType + + // DeepCopy copies an expression totally. + DeepCopy() Expression + + // ToString converts an expression into a string. + ToString() string +} + +// EvalBool evaluates expression to a boolean value. +func EvalBool(expr Expression, row []types.Datum, ctx context.Context) (bool, error) { + data, err := expr.Eval(row, ctx) + if err != nil { + return false, errors.Trace(err) + } + if data.Kind() == types.KindNull { + return false, nil + } + + i, err := data.ToBool() + if err != nil { + return false, errors.Trace(err) + } + return i != 0, nil +} + +// Column represents a column. +type Column struct { + FromID string + ColName model.CIStr + DbName model.CIStr + TblName model.CIStr + RetType *types.FieldType + + // only used during execution + Index int +} + +// ToString implements Expression interface. +func (col *Column) ToString() string { + result := col.ColName.L + if col.TblName.L != "" { + result = col.TblName.L + "." + result + } + if col.DbName.L != "" { + result = col.DbName.L + "." + result + } + return result +} + +// GetType implements Expression interface. +func (col *Column) GetType() *types.FieldType { + return col.RetType +} + +// Eval implements Expression interface. +func (col *Column) Eval(row []types.Datum, _ context.Context) (d types.Datum, err error) { + return row[col.Index], nil +} + +// DeepCopy implements Expression interface. +func (col *Column) DeepCopy() Expression { + newCol := *col + return &newCol +} + +// Schema stands for the row schema get from input. +type Schema []*Column + +// FindColumn replaces an ast column with an expression column. +func (s Schema) FindColumn(astCol *ast.ColumnName) (*Column, error) { + dbName, tblName, colName := astCol.Schema, astCol.Table, astCol.Name + idx := -1 + for i, col := range s { + if (dbName.L == "" || dbName.L == col.DbName.L) && (tblName.L == "" || tblName.L == col.TblName.L) && (colName.L == col.ColName.L) { + if idx != -1 { + return nil, errors.Errorf("Column '%s' is ambiguous", colName.L) + } + idx = i + } + } + if idx == -1 { + return nil, errors.Errorf("Unknown column %s %s %s.", dbName.L, tblName.L, colName.L) + } + return s[idx], nil +} + +// GetIndex finds the index for a column. +func (s Schema) GetIndex(col *Column) int { + for i, c := range s { + if c.FromID == col.FromID && c.ColName.L == col.ColName.L { + return i + } + } + return -1 +} + +// ScalarFunction is the function that returns a value. +type ScalarFunction struct { + Args []Expression + FuncName model.CIStr + // TODO: Implement type inference here, now we use ast's return type temporarily. + retType *types.FieldType + function evaluator.BuiltinFunc +} + +// ToString implements Expression interface. +func (sf *ScalarFunction) ToString() string { + result := sf.FuncName.L + "(" + for _, arg := range sf.Args { + result += arg.ToString() + result += "," + } + result += ")" + return result +} + +// NewFunction creates a new scalar function. +func NewFunction(funcName model.CIStr, args []Expression) *ScalarFunction { + return &ScalarFunction{Args: args, FuncName: funcName, function: evaluator.Funcs[funcName.L].F} +} + +//Schema2Exprs converts []*Column to []Expression. +func Schema2Exprs(schema Schema) []Expression { + result := make([]Expression, 0, len(schema)) + for _, col := range schema { + result = append(result, col) + } + return result +} + +//ScalarFuncs2Exprs converts []*ScalarFunction to []Expression. +func ScalarFuncs2Exprs(funcs []*ScalarFunction) []Expression { + result := make([]Expression, 0, len(funcs)) + for _, col := range funcs { + result = append(result, col) + } + return result +} + +// DeepCopy implements Expression interface. +func (sf *ScalarFunction) DeepCopy() Expression { + newFunc := &ScalarFunction{FuncName: sf.FuncName, function: sf.function, retType: sf.retType} + for _, arg := range sf.Args { + newFunc.Args = append(newFunc.Args, arg.DeepCopy()) + } + return newFunc +} + +// GetType implements Expression interface. +func (sf *ScalarFunction) GetType() *types.FieldType { + return sf.retType +} + +// Eval implements Expression interface. +func (sf *ScalarFunction) Eval(row []types.Datum, ctx context.Context) (types.Datum, error) { + args := make([]types.Datum, len(sf.Args)) + for _, arg := range sf.Args { + result, err := arg.Eval(row, ctx) + if err != nil { + args = append(args, result) + } else { + return types.Datum{}, errors.Trace(err) + } + } + return sf.function(args, ctx) +} + +// Constant stands for a constant value. +type Constant struct { + value types.Datum + retType *types.FieldType +} + +// ToString implements Expression interface. +func (c *Constant) ToString() string { + return fmt.Sprintf("%v", c.value.GetValue()) +} + +// DeepCopy implements Expression interface. +func (c *Constant) DeepCopy() Expression { + con := *c + return &con +} + +// GetType implements Expression interface. +func (c *Constant) GetType() *types.FieldType { + return c.retType +} + +// Eval implements Expression interface. +func (c *Constant) Eval(_ []types.Datum, _ context.Context) (types.Datum, error) { + return c.value, nil +} + +// Enter implements Visitor interface. +func (er *expressionRewriter) Enter(inNode ast.Node) (retNode ast.Node, skipChildren bool) { + switch v := inNode.(type) { + case *ast.AggregateFuncExpr: + index, ok := -1, false + if er.aggrMap != nil { + index, ok = er.aggrMap[v] + } + if !ok { + er.err = errors.New("Can't appear aggrFunctions") + return inNode, true + } + er.ctxStack = append(er.ctxStack, er.schema[index]) + return inNode, true + } + return inNode, false +} + +// Leave implements Visitor interface. +func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool) { + length := len(er.ctxStack) + switch v := inNode.(type) { + case *ast.AggregateFuncExpr: + case *ast.FuncCallExpr: + function := &ScalarFunction{FuncName: v.FnName} + for i := length - len(v.Args); i < length; i++ { + function.Args = append(function.Args, er.ctxStack[i]) + } + f := evaluator.Funcs[v.FnName.L] + if len(function.Args) < f.MinArgs || (f.MaxArgs != -1 && len(function.Args) > f.MaxArgs) { + er.err = evaluator.ErrInvalidOperation.Gen("number of function arguments must in [%d, %d].", f.MinArgs, f.MaxArgs) + return retNode, false + } + function.function = f.F + function.retType = v.Type + er.ctxStack = er.ctxStack[:length-len(v.Args)] + er.ctxStack = append(er.ctxStack, function) + case *ast.ColumnName: + column, err := er.schema.FindColumn(v) + if err != nil { + er.err = errors.Trace(err) + return retNode, false + } + er.ctxStack = append(er.ctxStack, column) + case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause: + case *ast.ValueExpr: + value := &Constant{value: v.Datum, retType: v.Type} + er.ctxStack = append(er.ctxStack, value) + case *ast.IsNullExpr: + function := &ScalarFunction{ + Args: []Expression{er.ctxStack[length-1]}, + FuncName: model.NewCIStr("isnull"), + retType: v.Type, + } + f, ok := evaluator.Funcs[function.FuncName.L] + if !ok { + er.err = errors.New("Can't find function!") + return retNode, false + } + function.function = f.F + er.ctxStack = er.ctxStack[:length-1] + er.ctxStack = append(er.ctxStack, function) + case *ast.BinaryOperationExpr: + function := &ScalarFunction{Args: []Expression{er.ctxStack[length-2], er.ctxStack[length-1]}, retType: v.Type} + funcName, ok := opcode.Ops[v.Op] + if !ok { + er.err = errors.Errorf("Unknown opcode %v", v.Op) + return retNode, false + } + function.FuncName = model.NewCIStr(funcName) + f, ok := evaluator.Funcs[function.FuncName.L] + if !ok { + er.err = errors.New("Can't find function!") + return retNode, false + } + function.function = f.F + er.ctxStack = er.ctxStack[:length-2] + er.ctxStack = append(er.ctxStack, function) + case *ast.UnaryOperationExpr: + function := &ScalarFunction{Args: []Expression{er.ctxStack[length-1]}, retType: v.Type} + switch v.Op { + case opcode.Not: + function.FuncName = model.NewCIStr("not") + case opcode.BitNeg: + function.FuncName = model.NewCIStr("bitneg") + case opcode.Plus: + function.FuncName = model.NewCIStr("unaryplus") + case opcode.Minus: + function.FuncName = model.NewCIStr("unaryminus") + } + f, ok := evaluator.Funcs[function.FuncName.L] + if !ok { + er.err = errors.New("Can't find function!") + return retNode, false + } + function.function = f.F + er.ctxStack = er.ctxStack[:length-1] + er.ctxStack = append(er.ctxStack, function) + default: + er.err = errors.Errorf("UnkownType: %T", v) + } + return inNode, true +} diff --git a/optimizer/new_plan_test.go b/optimizer/new_plan_test.go index 87f44047dcddf..7fb294d29ea2a 100644 --- a/optimizer/new_plan_test.go +++ b/optimizer/new_plan_test.go @@ -18,6 +18,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" @@ -88,33 +89,43 @@ func (s *testPlanSuite) TestPredicatePushDown(c *C) { }{ { sql: "select a from (select a from t where d = 0) k where k.a = 5", - first: "Table(t)->Filter->Fields->Filter->Fields", - best: "Range(t)->Fields->Fields", + first: "DataScan(t)->Selection->Projection->Selection->Projection", + best: "DataScan(t)->Selection->Projection->Projection", }, { sql: "select a from (select 1+2 as a from t where d = 0) k where k.a = 5", - first: "Table(t)->Filter->Fields->Filter->Fields", - best: "Table(t)->Fields->Filter->Fields", + first: "DataScan(t)->Selection->Projection->Selection->Projection", + best: "DataScan(t)->Selection->Projection->Selection->Projection", }, { sql: "select a from (select d as a from t where d = 0) k where k.a = 5", - first: "Table(t)->Filter->Fields->Filter->Fields", - best: "Table(t)->Fields->Fields", + first: "DataScan(t)->Selection->Projection->Selection->Projection", + best: "DataScan(t)->Selection->Projection->Projection", }, { sql: "select * from t ta join t tb on ta.d = tb.d and ta.d > 1 where tb.a = 0", - first: "Join{Table(t)->Table(t)}->Filter->Fields", - best: "Join{Table(t)->Range(t)}->Fields", + first: "Join{DataScan(t)->DataScan(t)}->Selection->Projection", + best: "Join{DataScan(t)->Selection->DataScan(t)->Selection}->Projection", + }, + { + sql: "select * from t ta join t tb on ta.d = tb.d where ta.d > 1 and tb.a = 0", + first: "Join{DataScan(t)->DataScan(t)}->Selection->Projection", + best: "Join{DataScan(t)->Selection->DataScan(t)->Selection}->Projection", }, { sql: "select * from t ta left outer join t tb on ta.d = tb.d and ta.d > 1 where tb.a = 0", - first: "Join{Table(t)->Table(t)}->Filter->Fields", - best: "Join{Table(t)->Table(t)}->Filter->Fields", + first: "Join{DataScan(t)->DataScan(t)}->Selection->Projection", + best: "Join{DataScan(t)->DataScan(t)}->Selection->Projection", }, { sql: "select * from t ta right outer join t tb on ta.d = tb.d and ta.a > 1 where tb.a = 0", - first: "Join{Table(t)->Table(t)}->Filter->Fields", - best: "Join{Range(t)->Range(t)}->Fields", + first: "Join{DataScan(t)->DataScan(t)}->Selection->Projection", + best: "Join{DataScan(t)->Selection->DataScan(t)->Selection}->Projection", + }, + { + sql: "select a, d from (select * from t union all select * from t union all select * from t) z where a < 10", + first: "UnionAll{DataScan(t)->Projection->DataScan(t)->Projection->DataScan(t)->Projection}->Selection->Projection", + best: "UnionAll{DataScan(t)->Selection->Projection->DataScan(t)->Selection->Projection->DataScan(t)->Selection->Projection}->Projection", }, } for _, ca := range cases { @@ -130,9 +141,9 @@ func (s *testPlanSuite) TestPredicatePushDown(c *C) { c.Assert(err, IsNil) c.Assert(plan.ToString(p), Equals, ca.first, Commentf("for %s", ca.sql)) - _, err = plan.PredicatePushDown(p, []ast.ExprNode{}) + _, err = plan.PredicatePushDown(p, []expression.Expression{}) c.Assert(err, IsNil) - err = plan.Refine(p) + err = plan.PruneColumnsAndResolveIndices(p) c.Assert(err, IsNil) c.Assert(plan.ToString(p), Equals, ca.best, Commentf("for %s", ca.sql)) } diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 41ff17dc8e6db..327d910116508 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -17,6 +17,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/optimizer/plan" @@ -38,7 +39,11 @@ func Optimize(ctx context.Context, node ast.Node, sb plan.SubQueryBuilder) (plan return nil, errors.Trace(err) } if plan.UseNewPlanner { - _, err = plan.PredicatePushDown(p, []ast.ExprNode{}) + _, err = plan.PredicatePushDown(p, []expression.Expression{}) + if err != nil { + return nil, errors.Trace(err) + } + err = plan.PruneColumnsAndResolveIndices(p) if err != nil { return nil, errors.Trace(err) } diff --git a/optimizer/plan/column_pruning.go b/optimizer/plan/column_pruning.go new file mode 100644 index 0000000000000..be0c990ad882d --- /dev/null +++ b/optimizer/plan/column_pruning.go @@ -0,0 +1,84 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/expression" +) + +func retrieveIndex(cols []*expression.Column, schema expression.Schema) { + for _, col := range cols { + idx := schema.GetIndex(col) + if col.Index == -1 { + col.Index = idx + } + } +} + +// PruneColumnsAndResolveIndices prunes unused columns and resolves index for columns. +func PruneColumnsAndResolveIndices(p Plan) error { + //TODO: Currently we only implement index resolving, column pruning will be implemented later. + var cols []*expression.Column + switch v := p.(type) { + case *Projection: + for _, expr := range v.exprs { + cols = extractColumn(expr, cols) + } + case *Selection: + for _, cond := range v.Conditions { + cols = extractColumn(cond, cols) + } + case *Aggregation: + for _, aggrFunc := range v.AggFuncs { + for _, arg := range aggrFunc.GetArgs() { + cols = extractColumn(arg, cols) + } + } + for _, expr := range v.GroupByItems { + cols = extractColumn(expr, cols) + } + case *Union, *NewTableScan: + case *Join: + for _, eqCond := range v.EqualConditions { + cols = extractColumn(eqCond, cols) + } + for _, eqCond := range v.LeftConditions { + cols = extractColumn(eqCond, cols) + } + for _, eqCond := range v.RightConditions { + cols = extractColumn(eqCond, cols) + } + for _, eqCond := range v.OtherConditions { + cols = extractColumn(eqCond, cols) + } + default: + return errors.Errorf("Unknown Type %T", v) + } + for _, child := range p.GetChildren() { + retrieveIndex(cols, child.GetSchema()) + } + for _, col := range cols { + if col.Index == -1 { + return errors.Errorf("Can't find column %s", col.ColName.L) + } + } + for _, child := range p.GetChildren() { + err := PruneColumnsAndResolveIndices(child) + if err != nil { + return errors.Trace(err) + } + } + return nil +} diff --git a/optimizer/plan/new_plans.go b/optimizer/plan/new_plans.go index 76317f5c3608d..8c5f9454fa410 100644 --- a/optimizer/plan/new_plans.go +++ b/optimizer/plan/new_plans.go @@ -15,7 +15,8 @@ package plan import ( "github.com/juju/errors" - "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/model" ) // JoinType contains CrossJoin, InnerJoin, LeftOuterJoin, RightOuterJoin, FullOuterJoin, SemiJoin. @@ -39,10 +40,62 @@ type Join struct { JoinType JoinType - EqualConditions []ast.ExprNode - LeftConditions []ast.ExprNode - RightConditions []ast.ExprNode - OtherConditions []ast.ExprNode + EqualConditions []*expression.ScalarFunction + LeftConditions []expression.Expression + RightConditions []expression.Expression + OtherConditions []expression.Expression +} + +// Projection represents a select fields plan. +type Projection struct { + basePlan + exprs []expression.Expression +} + +// Aggregation represents an aggregate plan. +type Aggregation struct { + basePlan + AggFuncs []expression.AggregationFunction + GroupByItems []expression.Expression +} + +// Selection means a filter. +type Selection struct { + basePlan + + // Originally the WHERE or ON condition is parsed into a single expression, + // but after we converted to CNF(Conjunctive normal form), it can be + // split into a list of AND conditions. + Conditions []expression.Expression +} + +// NewTableScan represents a tablescan without condition push down. +type NewTableScan struct { + basePlan + + Table *model.TableInfo + Desc bool + Ranges []TableRange + + // RefAccess indicates it references a previous joined table, used in explain. + RefAccess bool + + // AccessConditions can be used to build index range. + AccessConditions []expression.Expression + + // FilterConditions can be used to filter result. + FilterConditions []expression.Expression + + TableAsName *model.CIStr + + LimitCount *int64 +} + +func (ts *NewTableScan) attachCondition(conditions []expression.Expression) { + for _, con := range conditions { + //TODO: implement refiner for expression. + ts.FilterConditions = append(ts.FilterConditions, con) + } } // AddChild for parent. diff --git a/optimizer/plan/newplanbuilder.go b/optimizer/plan/newplanbuilder.go index e0a018b16943b..32068d0652766 100644 --- a/optimizer/plan/newplanbuilder.go +++ b/optimizer/plan/newplanbuilder.go @@ -14,7 +14,10 @@ package plan import ( + "fmt" + "github.com/juju/errors" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" @@ -23,19 +26,58 @@ import ( // UseNewPlanner means if use the new planner. var UseNewPlanner = false -func (b *planBuilder) buildNewSinglePathPlan(node ast.ResultSetNode) Plan { +func (b *planBuilder) allocID(p Plan) string { + b.id++ + return fmt.Sprintf("%T_%d", p, b.id) +} + +func (b *planBuilder) buildAggregation(p Plan, aggrFuncList []*ast.AggregateFuncExpr, gby *ast.GroupByClause) Plan { + newAggrFuncList := make([]expression.AggregationFunction, 0, len(aggrFuncList)) + gbyExprList := make([]expression.Expression, 0, len(gby.Items)) + aggr := &Aggregation{AggFuncs: newAggrFuncList, GroupByItems: gbyExprList} + aggr.id = b.allocID(aggr) + + schema := make([]*expression.Column, 0, len(aggrFuncList)) + for i, aggrFunc := range aggrFuncList { + var newArgList []expression.Expression + for _, arg := range aggrFunc.Args { + newArg, err := expression.Rewrite(arg, p.GetSchema(), nil) + if err != nil { + b.err = errors.Trace(err) + return nil + } + newArgList = append(newArgList, newArg) + } + newAggrFuncList = append(newAggrFuncList, expression.NewAggrFunction(aggrFunc.F, newArgList)) + schema = append(schema, &expression.Column{FromID: aggr.id, ColName: model.NewCIStr(fmt.Sprintf("%s_col_%d", aggr.id, i))}) + } + + for _, gbyItem := range gby.Items { + gbyExpr, err := expression.Rewrite(gbyItem.Expr, p.GetSchema(), nil) + if err != nil { + b.err = errors.Trace(err) + return nil + } + gbyExprList = append(gbyExprList, gbyExpr) + } + aggr.SetSchema(schema) + return aggr +} + +func (b *planBuilder) buildResultSetNode(node ast.ResultSetNode) Plan { switch x := node.(type) { case *ast.Join: return b.buildNewJoin(x) case *ast.TableSource: + asName := x.AsName switch v := x.Source.(type) { case *ast.SelectStmt: - return b.buildNewSelect(v) + return b.buildNewSelect(v, asName) case *ast.UnionStmt: - return b.buildUnion(v) + return b.buildNewUnion(v, asName) case *ast.TableName: - //TODO: select physical algorithm during cbo phase. - return b.buildNewTableScanPlan(v) + // TODO: select physical algorithm during cbo phase. + return b.buildNewTableScanPlan(v, asName) default: b.err = ErrUnsupportedType.Gen("unsupported table source type %T", v) return nil @@ -46,53 +88,40 @@ func (b *planBuilder) buildNewSinglePathPlan(node ast.ResultSetNode) Plan { } } -func fromFields(col *ast.ColumnNameExpr, fields []*ast.ResultField) bool { - for _, field := range fields { - if field == col.Refer { - return true - } - } - return false -} - -type columnsExtractor struct { - result []*ast.ColumnNameExpr -} - -func (ce *columnsExtractor) Enter(expr ast.Node) (ret ast.Node, skipChildren bool) { +func extractColumn(expr expression.Expression, cols []*expression.Column) (result []*expression.Column) { switch v := expr.(type) { - case *ast.ColumnNameExpr: - ce.result = append(ce.result, v) + case *expression.Column: + return append(cols, v) + case *expression.ScalarFunction: + for _, arg := range v.Args { + cols = extractColumn(arg, cols) + } + return cols } - return expr, false + return cols } -func (ce *columnsExtractor) Leave(expr ast.Node) (ret ast.Node, skipChildren bool) { - return expr, true -} - -func extractOnCondition(conditions []ast.ExprNode, left Plan, right Plan) (eqCond []ast.ExprNode, leftCond []ast.ExprNode, rightCond []ast.ExprNode, otherCond []ast.ExprNode) { +func extractOnCondition(conditions []expression.Expression, left Plan, right Plan) (eqCond []*expression.ScalarFunction, leftCond []expression.Expression, rightCond []expression.Expression, otherCond []expression.Expression) { for _, expr := range conditions { - binop, ok := expr.(*ast.BinaryOperationExpr) - if ok && binop.Op == opcode.EQ { - ln, lOK := binop.L.(*ast.ColumnNameExpr) - rn, rOK := binop.R.(*ast.ColumnNameExpr) + binop, ok := expr.(*expression.ScalarFunction) + eqStr, _ := opcode.Ops[opcode.EQ] + if ok && binop.FuncName.L == eqStr { + ln, lOK := binop.Args[0].(*expression.Column) + rn, rOK := binop.Args[1].(*expression.Column) if lOK && rOK { - if fromFields(ln, left.Fields()) && fromFields(rn, right.Fields()) { - eqCond = append(eqCond, expr) + if left.GetSchema().GetIndex(ln) != -1 && right.GetSchema().GetIndex(rn) != -1 { + eqCond = append(eqCond, binop) continue - } else if fromFields(rn, left.Fields()) && fromFields(ln, right.Fields()) { - eqCond = append(eqCond, &ast.BinaryOperationExpr{Op: opcode.EQ, L: rn, R: ln}) + } else if left.GetSchema().GetIndex(rn) != -1 && right.GetSchema().GetIndex(ln) != -1 { + eqCond = append(eqCond, expression.NewFunction(model.NewCIStr(eqStr), []expression.Expression{rn, ln})) continue } } } - ce := &columnsExtractor{} - expr.Accept(ce) - columns := ce.result + columns := extractColumn(expr, make([]*expression.Column, 0)) allFromLeft, allFromRight := true, true for _, col := range columns { - if fromFields(col, left.Fields()) { + if left.GetSchema().GetIndex(col) != -1 { allFromRight = false } else { allFromLeft = false @@ -106,18 +135,40 @@ func extractOnCondition(conditions []ast.ExprNode, left Plan, right Plan) (eqCon otherCond = append(otherCond, expr) } } - return eqCond, leftCond, rightCond, otherCond + return +} + +// CNF means conjunctive normal form, e.g. a and b and c. +func splitCNFItems(onExpr expression.Expression) []expression.Expression { + switch v := onExpr.(type) { + case *expression.ScalarFunction: + andandStr, _ := opcode.Ops[opcode.AndAnd] + if v.FuncName.L == andandStr { + var ret []expression.Expression + for _, arg := range v.Args { + ret = append(ret, splitCNFItems(arg)...) + } + return ret + } + } + return []expression.Expression{onExpr} } func (b *planBuilder) buildNewJoin(join *ast.Join) Plan { if join.Right == nil { - return b.buildNewSinglePathPlan(join.Left) + return b.buildResultSetNode(join.Left) } - leftPlan := b.buildNewSinglePathPlan(join.Left) - rightPlan := b.buildNewSinglePathPlan(join.Right) - var eqCond, leftCond, rightCond, otherCond []ast.ExprNode + leftPlan := b.buildResultSetNode(join.Left) + rightPlan := b.buildResultSetNode(join.Right) + var eqCond []*expression.ScalarFunction + var leftCond, rightCond, otherCond []expression.Expression if join.On != nil { - onCondition := splitWhere(join.On.Expr) + onExpr, err := expression.Rewrite(join.On.Expr, append(leftPlan.GetSchema(), rightPlan.GetSchema()...), nil) + if err != nil { + b.err = err + return nil + } + onCondition := splitCNFItems(onExpr) eqCond, leftCond, rightCond, otherCond = extractOnCondition(onCondition, leftPlan, rightPlan) } joinPlan := &Join{EqualConditions: eqCond, LeftConditions: leftCond, RightConditions: rightCond, OtherConditions: otherCond} @@ -130,32 +181,258 @@ func (b *planBuilder) buildNewJoin(join *ast.Join) Plan { } addChild(joinPlan, leftPlan) addChild(joinPlan, rightPlan) - joinPlan.SetFields(append(leftPlan.Fields(), rightPlan.Fields()...)) + joinPlan.SetSchema(append(leftPlan.GetSchema(), rightPlan.GetSchema()...)) return joinPlan } -func (b *planBuilder) buildFilter(p Plan, where ast.ExprNode) Plan { +func (b *planBuilder) buildSelection(p Plan, where ast.ExprNode, mapper map[*ast.AggregateFuncExpr]int) Plan { conditions := splitWhere(where) - filter := &Filter{Conditions: conditions} - addChild(filter, p) - filter.SetFields(p.Fields()) - return filter + expressions := make([]expression.Expression, 0, len(conditions)) + for _, cond := range conditions { + expr, err := expression.Rewrite(cond, p.GetSchema(), mapper) + if err != nil { + b.err = err + return nil + } + expressions = append(expressions, expr) + } + selection := &Selection{Conditions: expressions} + selection.SetSchema(p.GetSchema()) + addChild(selection, p) + return selection } -func (b *planBuilder) buildNewSelect(sel *ast.SelectStmt) Plan { - var aggFuncs []*ast.AggregateFuncExpr +func (b *planBuilder) buildProjection(src Plan, fields []*ast.SelectField, asName model.CIStr, mapper map[*ast.AggregateFuncExpr]int) Plan { + proj := &Projection{exprs: make([]expression.Expression, 0, len(fields))} + proj.id = b.allocID(proj) + schema := make(expression.Schema, 0, len(fields)) + for _, field := range fields { + var tblName, colName model.CIStr + if asName.L != "" { + tblName = asName + } + if field.WildCard != nil { + dbName := field.WildCard.Schema + colTblName := field.WildCard.Table + for _, col := range src.GetSchema() { + if (dbName.L == "" || dbName.L == col.DbName.L) && (colTblName.L == "" || colTblName.L == col.TblName.L) { + newExpr := col.DeepCopy() + proj.exprs = append(proj.exprs, newExpr) + schemaCol := &expression.Column{FromID: proj.id, TblName: tblName, ColName: col.ColName, RetType: newExpr.GetType()} + schema = append(schema, schemaCol) + } + } + } else { + newExpr, err := expression.Rewrite(field.Expr, src.GetSchema(), mapper) + if err != nil { + b.err = errors.Trace(err) + return nil + } + proj.exprs = append(proj.exprs, newExpr) + if field.AsName.L != "" { + colName = field.AsName + } else if c, ok := newExpr.(*expression.Column); ok { + colName = c.ColName + } else { + colName = model.NewCIStr(field.Expr.Text()) + } + schemaCol := &expression.Column{FromID: proj.id, TblName: tblName, ColName: colName, RetType: newExpr.GetType()} + schema = append(schema, schemaCol) + } + } + proj.SetSchema(schema) + addChild(proj, src) + return proj +} + +func (b *planBuilder) buildNewDistinct(src Plan) Plan { + d := &Distinct{} + addChild(d, src) + d.SetSchema(src.GetSchema()) + return d +} + +func (b *planBuilder) buildNewUnion(union *ast.UnionStmt, asName model.CIStr) (p Plan) { + sels := make([]Plan, len(union.SelectList.Selects)) + for i, sel := range union.SelectList.Selects { + sels[i] = b.buildNewSelect(sel, model.NewCIStr("")) + } + u := &Union{ + Selects: sels, + } + u.id = b.allocID(u) + p = u + firstSchema := make(expression.Schema, 0, len(sels[0].GetSchema())) + firstSchema = append(firstSchema, sels[0].GetSchema()...) + for _, sel := range sels { + if len(firstSchema) != len(sel.GetSchema()) { + b.err = errors.New("The used SELECT statements have a different number of columns") + return nil + } + for i, col := range sel.GetSchema() { + /* + * The lengths of the columns in the UNION result take into account the values retrieved by all of the SELECT statements + * SELECT REPEAT('a',1) UNION SELECT REPEAT('b',10); + * +---------------+ + * | REPEAT('a',1) | + * +---------------+ + * | a | + * | bbbbbbbbbb | + * +---------------+ + */ + if col.RetType.Flen > firstSchema[i].RetType.Flen { + firstSchema[i].RetType.Flen = col.RetType.Flen + } + // For select nul union select "abc", we should not convert "abc" to nil. + // And the result field type should be VARCHAR. + if firstSchema[i].RetType.Tp == 0 || firstSchema[i].RetType.Tp == mysql.TypeNull { + firstSchema[i].RetType.Tp = col.RetType.Tp + } + } + addChild(p, sel) + } + for _, v := range firstSchema { + v.FromID = u.id + v.TblName = asName + v.DbName = model.NewCIStr("") + } + + p.SetSchema(firstSchema) + if union.Distinct { + p = b.buildNewDistinct(p) + } + if union.OrderBy != nil { + p = b.buildNewSort(p, union.OrderBy.Items, nil) + } + if union.Limit != nil { + p = b.buildNewLimit(p, union.Limit) + } + return p +} + +// ByItems wraps a "by" item. +type ByItems struct { + expr expression.Expression + desc bool +} + +// NewSort stands for the order by plan. +type NewSort struct { + basePlan + + ByItems []expression.Expression +} + +func (b *planBuilder) buildNewSort(src Plan, byItems []*ast.ByItem, mapper map[*ast.AggregateFuncExpr]int) Plan { + var exprs []expression.Expression + for _, item := range byItems { + it, err := expression.Rewrite(item.Expr, src.GetSchema(), mapper) + if err != nil { + b.err = err + } + exprs = append(exprs, it) + } + sort := &NewSort{ + ByItems: exprs, + } + addChild(sort, src) + sort.SetSchema(src.GetSchema()) + return sort +} + +func (b *planBuilder) buildNewLimit(src Plan, limit *ast.Limit) Plan { + li := &Limit{ + Offset: limit.Offset, + Count: limit.Count, + } + if s, ok := src.(*Sort); ok { + s.ExecLimit = li + return s + } + addChild(li, src) + li.SetSchema(src.GetSchema()) + return li +} + +func (b *planBuilder) extractAggrFunc(sel *ast.SelectStmt) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int, map[*ast.AggregateFuncExpr]int, map[*ast.AggregateFuncExpr]int) { + extractor := &ast.AggregateFuncExtractor{AggFuncs: make([]*ast.AggregateFuncExpr, 0)} + // Extract agg funcs from having clause. + if sel.Having != nil { + n, ok := sel.Having.Expr.Accept(extractor) + if !ok { + b.err = errors.New("Failed to extract agg expr from having clause") + return nil, nil, nil, nil + } + sel.Having.Expr = n.(ast.ExprNode) + } + havingAggrFuncs := extractor.AggFuncs + extractor.AggFuncs = make([]*ast.AggregateFuncExpr, 0) + + havingMapper := make(map[*ast.AggregateFuncExpr]int) + for _, aggr := range havingAggrFuncs { + havingMapper[aggr] = len(sel.Fields.Fields) + field := &ast.SelectField{Expr: aggr, AsName: model.NewCIStr(fmt.Sprintf("sel_aggr_%d", len(sel.Fields.Fields)))} + sel.Fields.Fields = append(sel.Fields.Fields, field) + } + + // Extract agg funcs from orderby clause. + if sel.OrderBy != nil { + for _, item := range sel.OrderBy.Items { + _, ok := item.Expr.Accept(extractor) + if !ok { + b.err = errors.New("Failed to extract agg expr from orderby clause") + return nil, nil, nil, nil + } + // TODO: support position error. + } + } + orderByAggrFuncs := extractor.AggFuncs + extractor.AggFuncs = make([]*ast.AggregateFuncExpr, 0) + orderByMapper := make(map[*ast.AggregateFuncExpr]int) + for _, aggr := range orderByAggrFuncs { + orderByMapper[aggr] = len(sel.Fields.Fields) + field := &ast.SelectField{Expr: aggr, AsName: model.NewCIStr(fmt.Sprintf("sel_aggr_%d", len(sel.Fields.Fields)))} + sel.Fields.Fields = append(sel.Fields.Fields, field) + } + + for _, f := range sel.Fields.Fields { + _, ok := f.Expr.Accept(extractor) + if !ok { + b.err = errors.New("Failed to extract agg expr!") + return nil, nil, nil, nil + } + } + aggrList := extractor.AggFuncs + aggrList = append(aggrList, havingAggrFuncs...) + aggrList = append(aggrList, orderByAggrFuncs...) + totalAggrMapper := make(map[*ast.AggregateFuncExpr]int) + + for i, aggr := range aggrList { + totalAggrMapper[aggr] = i + } + return aggrList, havingMapper, orderByMapper, totalAggrMapper +} + +func (b *planBuilder) buildNewSelect(sel *ast.SelectStmt, asName model.CIStr) Plan { + oldLen := len(sel.Fields.Fields) hasAgg := b.detectSelectAgg(sel) + var aggFuncs []*ast.AggregateFuncExpr + var havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int if hasAgg { - aggFuncs = b.extractSelectAgg(sel) + aggFuncs, havingMap, orderMap, totalMap = b.extractAggrFunc(sel) } // Build subquery // Convert subquery to expr with plan - b.buildSubquery(sel) + // TODO: add subquery support. + //b.buildSubquery(sel) var p Plan if sel.From != nil { - p = b.buildNewSinglePathPlan(sel.From.TableRefs) + p = b.buildResultSetNode(sel.From.TableRefs) + if b.err != nil { + return nil + } if sel.Where != nil { - p = b.buildFilter(p, sel.Where) + p = b.buildSelection(p, sel.Where, nil) } if b.err != nil { return nil @@ -167,9 +444,9 @@ func (b *planBuilder) buildNewSelect(sel *ast.SelectStmt) Plan { } } if hasAgg { - p = b.buildAggregate(p, aggFuncs, sel.GroupBy) + p = b.buildAggregation(p, aggFuncs, sel.GroupBy) } - p = b.buildSelectFields(p, sel.GetResultFields()) + p = b.buildProjection(p, sel.Fields.Fields, asName, totalMap) if b.err != nil { return nil } @@ -178,15 +455,15 @@ func (b *planBuilder) buildNewSelect(sel *ast.SelectStmt) Plan { p = b.buildTableDual(sel) } if hasAgg { - p = b.buildAggregate(p, aggFuncs, nil) + p = b.buildAggregation(p, aggFuncs, nil) } - p = b.buildSelectFields(p, sel.GetResultFields()) + p = b.buildProjection(p, sel.Fields.Fields, asName, totalMap) if b.err != nil { return nil } } if sel.Having != nil { - p = b.buildFilter(p, sel.Having.Expr) + p = b.buildSelection(p, sel.Having.Expr, havingMap) if b.err != nil { return nil } @@ -197,8 +474,9 @@ func (b *planBuilder) buildNewSelect(sel *ast.SelectStmt) Plan { return nil } } - if sel.OrderBy != nil && !pushOrder(p, sel.OrderBy.Items) { - p = b.buildSort(p, sel.OrderBy.Items) + // TODO: implement push order during cbo + if sel.OrderBy != nil { + p = b.buildNewSort(p, sel.OrderBy.Items, orderMap) if b.err != nil { return nil } @@ -209,40 +487,46 @@ func (b *planBuilder) buildNewSelect(sel *ast.SelectStmt) Plan { return nil } } - return p -} - -func (ts *TableScan) attachCondition(conditions []ast.ExprNode) { - var pkName model.CIStr - if ts.Table.PKIsHandle { - for _, colInfo := range ts.Table.Columns { - if mysql.HasPriKeyFlag(colInfo.Flag) { - pkName = colInfo.Name - } + if oldLen != len(sel.Fields.Fields) { + proj := &Projection{} + proj.id = b.allocID(proj) + oldSchema := p.GetSchema() + proj.exprs = make([]expression.Expression, 0, oldLen) + newSchema := make([]*expression.Column, 0, oldLen) + newSchema = append(newSchema, oldSchema[:oldLen]...) + for _, col := range oldSchema[:oldLen] { + proj.exprs = append(proj.exprs, col) } - } - for _, con := range conditions { - if pkName.L != "" { - checker := conditionChecker{tableName: ts.Table.Name, pkName: pkName} - if checker.check(con) { - ts.AccessConditions = append(ts.AccessConditions, con) - } else { - ts.FilterConditions = append(ts.FilterConditions, con) - } - } else { - ts.FilterConditions = append(ts.FilterConditions, con) + for _, s := range newSchema { + s.FromID = proj.id } + proj.SetSchema(newSchema) + addChild(proj, p) } + return p } -func (b *planBuilder) buildNewTableScanPlan(tn *ast.TableName) Plan { - p := &TableScan{ - Table: tn.TableInfo, - TableName: tn, +func (b *planBuilder) buildNewTableScanPlan(tn *ast.TableName, asName model.CIStr) Plan { + p := &NewTableScan{ + Table: tn.TableInfo, } + p.id = b.allocID(p) // Equal condition contains a column from previous joined table. p.RefAccess = false - p.SetFields(tn.GetResultFields()) - p.TableAsName = getTableAsName(p.Fields()) + rfs := tn.GetResultFields() + schema := make([]*expression.Column, 0, len(rfs)) + for _, rf := range rfs { + var dbName, colName, tblName model.CIStr + if asName.L != "" { + tblName = asName + } else { + tblName = rf.Table.Name + dbName = rf.DBName + } + colName = rf.Column.Name + schema = append(schema, &expression.Column{FromID: p.id, ColName: colName, TblName: tblName, DbName: dbName, RetType: &rf.Column.FieldType}) + } + p.SetSchema(schema) + p.TableAsName = &asName return p } diff --git a/optimizer/plan/plan.go b/optimizer/plan/plan.go index 69091de7fe43a..dda0ed1d200e8 100644 --- a/optimizer/plan/plan.go +++ b/optimizer/plan/plan.go @@ -17,6 +17,7 @@ import ( "math" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" ) // Plan is a description of an execution flow. @@ -51,6 +52,12 @@ type Plan interface { GetParents() []Plan // Get all the children. GetChildren() []Plan + // Set the schema. + SetSchema(schema expression.Schema) + // Get the schema. + GetSchema() expression.Schema + // Get ID. + GetID() string } // basePlan implements base Plan interface. @@ -64,6 +71,21 @@ type basePlan struct { parents []Plan children []Plan + + schema expression.Schema + id string +} + +func (p *basePlan) GetID() string { + return p.id +} + +func (p *basePlan) SetSchema(schema expression.Schema) { + p.schema = schema +} + +func (p *basePlan) GetSchema() expression.Schema { + return p.schema } // StartupCost implements Plan StartupCost interface. diff --git a/optimizer/plan/planbuilder.go b/optimizer/plan/planbuilder.go index 037c27b0c85a7..fd8275696b5c1 100644 --- a/optimizer/plan/planbuilder.go +++ b/optimizer/plan/planbuilder.go @@ -56,6 +56,7 @@ type planBuilder struct { hasAgg bool sb SubQueryBuilder obj interface{} + id int32 } func (b *planBuilder) build(node ast.Node) Plan { @@ -90,7 +91,7 @@ func (b *planBuilder) build(node ast.Node) Plan { return b.buildPrepare(x) case *ast.SelectStmt: if UseNewPlanner { - return b.buildNewSelect(x) + return b.buildNewSelect(x, model.NewCIStr("")) } return b.buildSelect(x) case *ast.UnionStmt: @@ -821,7 +822,7 @@ func (se *subqueryVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) switch x := in.(type) { case *ast.SubqueryExpr: p := se.builder.build(x.Query) - // The expr pointor is copyed into ResultField when running name resolver. + // The expr pointer is copied into ResultField when running name resolver. // So we can not just replace the expr node in AST. We need to put SubQuery into the expr. // See: optimizer.nameResolver.createResultFields() x.SubqueryExec = se.builder.sb.Build(p) @@ -840,18 +841,11 @@ func (se *subqueryVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { func (b *planBuilder) buildUnion(union *ast.UnionStmt) Plan { sels := make([]Plan, len(union.SelectList.Selects)) for i, sel := range union.SelectList.Selects { - if UseNewPlanner { - sels[i] = b.buildNewSelect(sel) - } else { - sels[i] = b.buildSelect(sel) - } + sels[i] = b.buildSelect(sel) } var p Plan p = &Union{ Selects: sels, - basePlan: basePlan{ - children: sels, - }, } unionFields := union.GetResultFields() for _, sel := range sels { diff --git a/optimizer/plan/predicate_push_down.go b/optimizer/plan/predicate_push_down.go index 2f8c498322299..86fcae822d594 100644 --- a/optimizer/plan/predicate_push_down.go +++ b/optimizer/plan/predicate_push_down.go @@ -15,43 +15,36 @@ package plan import ( "github.com/juju/errors" - "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" ) -func addFilter(p Plan, child Plan, conditions []ast.ExprNode) error { - filter := &Filter{Conditions: conditions} +func addFilter(p Plan, child Plan, conditions []expression.Expression) error { + filter := &Selection{Conditions: conditions} + filter.SetSchema(child.GetSchema()) return InsertPlan(p, child, filter) } -// columnSubstituor substitutes the columns in filter to expressions in select fields. +// columnSubstitute substitutes the columns in filter to expressions in select fields. // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. -type columnSubstitutor struct { - fields []*ast.ResultField -} - -func (cl *columnSubstitutor) Enter(inNode ast.Node) (node ast.Node, skipChild bool) { - return inNode, false -} - -func (cl *columnSubstitutor) Leave(inNode ast.Node) (node ast.Node, ok bool) { - switch v := inNode.(type) { - case *ast.ColumnNameExpr: - for _, field := range cl.fields { - if v.Refer == field { - return field.Expr, true - } +func columnSubstitute(expr expression.Expression, schema expression.Schema, newExprs []expression.Expression) expression.Expression { + switch v := expr.(type) { + case *expression.Column: + id := schema.GetIndex(v) + return newExprs[id] + case *expression.ScalarFunction: + for i, arg := range v.Args { + v.Args[i] = columnSubstitute(arg, schema, newExprs) } } - return inNode, true + return expr } // PredicatePushDown applies predicate push down to all kinds of plans, except aggregation and union. -func PredicatePushDown(p Plan, predicates []ast.ExprNode) (ret []ast.ExprNode, err error) { +func PredicatePushDown(p Plan, predicates []expression.Expression) (ret []expression.Expression, err error) { switch v := p.(type) { - case *TableScan: - v.attachCondition(predicates) - return ret, nil - case *Filter: + case *NewTableScan: + return predicates, nil + case *Selection: conditions := v.Conditions retConditions, err1 := PredicatePushDown(p.GetChildByIndex(0), append(conditions, predicates...)) if err1 != nil { @@ -68,22 +61,22 @@ func PredicatePushDown(p Plan, predicates []ast.ExprNode) (ret []ast.ExprNode, e return nil, errors.Trace(err1) } } - return ret, nil + return case *Join: - //TODO: add null rejecter - var leftCond, rightCond []ast.ExprNode + //TODO: add null rejecter. + var leftCond, rightCond []expression.Expression leftPlan := v.GetChildByIndex(0) rightPlan := v.GetChildByIndex(1) equalCond, leftPushCond, rightPushCond, otherCond := extractOnCondition(predicates, leftPlan, rightPlan) if v.JoinType == LeftOuterJoin { rightCond = v.RightConditions leftCond = leftPushCond - ret = append(equalCond, otherCond...) + ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, rightPushCond...) } else if v.JoinType == RightOuterJoin { leftCond = v.LeftConditions rightCond = rightPushCond - ret = append(equalCond, otherCond...) + ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, leftPushCond...) } else { leftCond = append(v.LeftConditions, leftPushCond...) @@ -113,37 +106,24 @@ func PredicatePushDown(p Plan, predicates []ast.ExprNode) (ret []ast.ExprNode, e v.EqualConditions = append(v.EqualConditions, equalCond...) v.OtherConditions = append(v.OtherConditions, otherCond...) } - return ret, nil - case *SelectFields: + return + case *Projection: if len(v.GetChildren()) == 0 { return predicates, nil } - cs := &columnSubstitutor{fields: v.Fields()} - var push []ast.ExprNode + var push []expression.Expression for _, cond := range predicates { - ce := &columnsExtractor{} - ok := true - cond.Accept(ce) - for _, col := range ce.result { - match := false - for _, field := range v.Fields() { - if col.Refer == field { - switch field.Expr.(type) { - case *ast.ColumnNameExpr: - match = true - } - break - } - } - if !match { - ok = false + canSubstitute := true + extractedCols := extractColumn(cond, make([]*expression.Column, 0)) + for _, col := range extractedCols { + id := v.GetSchema().GetIndex(col) + if _, ok := v.exprs[id].(*expression.ScalarFunction); ok { + canSubstitute = false break } } - if ok { - cond1, _ := cond.Accept(cs) - cond = cond1.(ast.ExprNode) - push = append(push, cond) + if canSubstitute { + push = append(push, columnSubstitute(cond, v.GetSchema(), v.exprs)) } else { ret = append(ret, cond) } @@ -158,7 +138,7 @@ func PredicatePushDown(p Plan, predicates []ast.ExprNode) (ret []ast.ExprNode, e return nil, errors.Trace(err1) } } - return ret, nil + return case *Sort, *Limit, *Distinct: rest, err1 := PredicatePushDown(p.GetChildByIndex(0), predicates) if err1 != nil { @@ -170,18 +150,25 @@ func PredicatePushDown(p Plan, predicates []ast.ExprNode) (ret []ast.ExprNode, e return nil, errors.Trace(err1) } } - return ret, nil - default: - if len(v.GetChildren()) == 0 { - return predicates, nil - } - //TODO: support union and sub queries when abandon result field. - for _, child := range v.GetChildren() { - _, err = PredicatePushDown(child, []ast.ExprNode{}) + return + case *Union: + for _, proj := range v.Selects { + newExprs := make([]expression.Expression, 0, len(predicates)) + for _, cond := range predicates { + newCond := columnSubstitute(cond.DeepCopy(), v.GetSchema(), expression.Schema2Exprs(proj.GetSchema())) + newExprs = append(newExprs, newCond) + } + retCond, err := PredicatePushDown(proj, newExprs) if err != nil { return nil, errors.Trace(err) } + if len(retCond) != 0 { + addFilter(v, proj, retCond) + } } - return predicates, nil + return + //TODO: support aggregation. + default: + return predicates, errors.Errorf("Unkown type %T.", v) } } diff --git a/optimizer/plan/stringer.go b/optimizer/plan/stringer.go index f9bb915f168b5..e741e7f0d2e7b 100644 --- a/optimizer/plan/stringer.go +++ b/optimizer/plan/stringer.go @@ -27,7 +27,7 @@ func ToString(p Plan) string { func toString(in Plan, strs []string, idxs []int) ([]string, []int) { switch in.(type) { - case *JoinOuter, *JoinInner, *Join: + case *JoinOuter, *JoinInner, *Join, *Union: idxs = append(idxs, len(strs)) } @@ -94,6 +94,21 @@ func toString(in Plan, strs []string, idxs []int) ([]string, []int) { strs = strs[:idx] str = "Join{" + strings.Join(children, "->") + "}" idxs = idxs[:last] + case *Union: + last := len(idxs) - 1 + idx := idxs[last] + children := strs[idx:] + strs = strs[:idx] + str = "UnionAll{" + strings.Join(children, "->") + "}" + idxs = idxs[:last] + case *NewTableScan: + str = fmt.Sprintf("DataScan(%v)", x.Table.Name.L) + case *Selection: + str = "Selection" + case *Projection: + str = "Projection" + case *Aggregation: + str = "Aggr" case *Aggregate: str = "Aggregate" case *Distinct: diff --git a/parser/opcode/opcode.go b/parser/opcode/opcode.go index f112845c55936..54d494adc0242 100644 --- a/parser/opcode/opcode.go +++ b/parser/opcode/opcode.go @@ -45,7 +45,8 @@ const ( NullEQ ) -var ops = map[Op]string{ +// Ops maps opcode to string. +var Ops = map[Op]string{ AndAnd: "&&", LeftShift: "<<", RightShift: ">>", @@ -73,7 +74,7 @@ var ops = map[Op]string{ // String implements Stringer interface. func (o Op) String() string { - str, ok := ops[o] + str, ok := Ops[o] if !ok { panic(fmt.Sprintf("%d", o)) }