Skip to content

Commit

Permalink
plan/core: support mpp group by expressions. (pingcap#23133)
Browse files Browse the repository at this point in the history
* planner/core: make mpp support grouping by expressions

* add tests

* fix typo

* fix

* fix test

Co-authored-by: Ti Chi Robot <[email protected]>
  • Loading branch information
hanfei1991 and ti-chi-bot authored Mar 10, 2021
1 parent f8da635 commit f6a61bc
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 40 deletions.
6 changes: 5 additions & 1 deletion executor/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ func (s *tiflashTestSuite) TestMppExecution(c *C) {
tk.MustQuery("select count(*) from t1 , t, t2 where t1.a = t.a and t2.a = t.a").Check(testkit.Rows("3"))

tk.MustExec("insert into t1 values(4,0)")
tk.MustQuery("select count(*), t2.b from t1 left join t2 on t1.a = t2.a group by t2.b").Check(testkit.Rows("3 0", "1 <nil>"))
tk.MustQuery("select count(*) k, t2.b from t1 left join t2 on t1.a = t2.a group by t2.b order by k").Check(testkit.Rows("1 <nil>", "3 0"))
tk.MustQuery("select count(*) k, t2.b+1 from t1 left join t2 on t1.a = t2.a group by t2.b+1 order by k").Check(testkit.Rows("1 <nil>", "3 1"))
tk.MustQuery("select count(*) k, t2.b * t2.a from t2 group by t2.b * t2.a").Check(testkit.Rows("3 0"))
tk.MustQuery("select count(*) k, t2.a/2 m from t2 group by t2.a / 2 order by m").Check(testkit.Rows("1 0.5000", "1 1.0000", "1 1.5000"))
tk.MustQuery("select count(*) k, t2.a div 2 from t2 group by t2.a div 2 order by k").Check(testkit.Rows("1 0", "2 1"))
}

func (s *tiflashTestSuite) TestPartitionTable(c *C) {
Expand Down
26 changes: 16 additions & 10 deletions planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -2268,8 +2268,9 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
if prop.PartitionTp == property.BroadcastType {
return nil
}
partitionCols := la.GetGroupByCols()
if len(partitionCols) != 0 {
if len(la.GroupByItems) > 0 {
partitionCols := la.GetGroupByCols()
// trying to match the required parititions.
if prop.PartitionTp == property.HashType {
if matches := prop.IsSubsetOf(partitionCols); len(matches) != 0 {
partitionCols = chooseSubsetOfJoinKeys(partitionCols, matches)
Expand All @@ -2280,18 +2281,23 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
}
// TODO: permute various partition columns from group-by columns
// 1-phase agg
childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, PartitionTp: property.HashType, PartitionCols: partitionCols, CanAddEnforcer: true}
agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
agg.SetSchema(la.schema.Clone())
agg.MppRunMode = Mpp1Phase
hashAggs = append(hashAggs, agg)
// If there are no available parititon cols, but still have group by items, that means group by items are all expressions or constants.
// To avoid mess, we don't do any one-phase aggregation in this case.
if len(partitionCols) != 0 {
childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, PartitionTp: property.HashType, PartitionCols: partitionCols, CanAddEnforcer: true}
agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
agg.SetSchema(la.schema.Clone())
agg.MppRunMode = Mpp1Phase
hashAggs = append(hashAggs, agg)
}

// 2-phase agg
childProp = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, PartitionTp: property.AnyType}
agg = NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, PartitionTp: property.AnyType}
agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
agg.SetSchema(la.schema.Clone())
agg.MppRunMode = Mpp2Phase
agg.MppPartitionCols = partitionCols
hashAggs = append(hashAggs, agg)

// agg runs on TiDB with a partial agg on TiFlash if possible
if prop.TaskTp == property.RootTaskType {
childProp := &property.PhysicalProperty{TaskTp: property.RootTaskType, ExpectedCnt: math.MaxFloat64}
Expand Down
7 changes: 3 additions & 4 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -935,10 +935,9 @@ const (
type basePhysicalAgg struct {
physicalSchemaProducer

AggFuncs []*aggregation.AggFuncDesc
GroupByItems []expression.Expression
MppRunMode AggMppRunMode
MppPartitionCols []*expression.Column
AggFuncs []*aggregation.AggFuncDesc
GroupByItems []expression.Expression
MppRunMode AggMppRunMode
}

func (p *basePhysicalAgg) isFinalAgg() bool {
Expand Down
24 changes: 13 additions & 11 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1591,28 +1591,29 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
mpp.addCost(p.GetCost(inputRows, false))
return mpp
case Mpp2Phase:
// 2-phase agg: partial + final agg for hash partition
if len(p.MppPartitionCols) == 0 {
return invalidTask
}
prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, PartitionTp: property.HashType, PartitionCols: p.MppPartitionCols}
// if mpp does not need to enforce exchange, i.e., the child is properly partitioned, then this 2-phase agg is invalid
if !mpp.needEnforce(prop) {
return invalidTask
}
proj := p.convertAvgForMPP()
partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, true)
if partialAgg == nil {
return invalidTask
}
attachPlan2Task(partialAgg, mpp)
items := finalAgg.(*PhysicalHashAgg).GroupByItems
partitionCols := make([]*expression.Column, 0, len(items))
for _, expr := range items {
col, ok := expr.(*expression.Column)
if !ok {
return invalidTask
}
partitionCols = append(partitionCols, col)
}
prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, PartitionTp: property.HashType, PartitionCols: partitionCols}
newMpp := mpp.enforceExchangerImpl(prop)
attachPlan2Task(finalAgg, newMpp)
if proj != nil {
attachPlan2Task(proj, newMpp)
}
// TODO: how to set 2-phase cost?
newMpp.addCost(p.GetCost(inputRows/2, false))
newMpp.addCost(p.GetCost(inputRows, false))
return newMpp
case MppTiDB:
partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, false)
Expand Down Expand Up @@ -1824,9 +1825,10 @@ func (t *mppTask) enforceExchangerImpl(prop *property.PhysicalProperty) *mppTask
ChildPf: f,
}.Init(ctx, t.p.statsInfo())
receiver.SetChildren(sender)
cst := t.cst + t.count()*ctx.GetSessionVars().NetworkFactor
return &mppTask{
p: receiver,
cst: t.cst,
cst: cst,
partTp: prop.PartitionTp,
hashCols: prop.PartitionCols,
receivers: []*PhysicalExchangeReceiver{receiver},
Expand Down
1 change: 1 addition & 0 deletions planner/core/testdata/integration_serial_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
"desc format = 'brief' select /*+ hash_agg()*/ sum(b) from (select id + 1 as b from t)A",
"desc format = 'brief' select count(*) from t",
"desc format = 'brief' select count(*), id from t group by id",
"desc format = 'brief' select count(*), id + 1 from t group by id + 1",
"desc format = 'brief' select * from t join ( select count(*), id from t group by id) as A on A.id = t.id",
"desc format = 'brief' select * from t join ( select /*+ hash_agg()*/ count(*) as a from t) as A on A.a = t.id",
"desc format = 'brief' select avg(value) as b,id from t group by id",
Expand Down
37 changes: 24 additions & 13 deletions planner/core/testdata/integration_serial_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,20 @@
" └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo"
]
},
{
"SQL": "desc format = 'brief' select count(*), id + 1 from t group by id + 1",
"Plan": [
"Projection 8000.00 root Column#4, plus(test.t.id, 1)->Column#5",
"└─TableReader 8000.00 root data:ExchangeSender",
" └─ExchangeSender 8000.00 batchCop[tiflash] ExchangeType: PassThrough",
" └─Projection 8000.00 batchCop[tiflash] Column#4, test.t.id",
" └─HashAgg 8000.00 batchCop[tiflash] group by:Column#10, funcs:sum(Column#11)->Column#4, funcs:firstrow(Column#12)->test.t.id",
" └─ExchangeReceiver 8000.00 batchCop[tiflash] ",
" └─ExchangeSender 8000.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: Column#10",
" └─HashAgg 8000.00 batchCop[tiflash] group by:plus(test.t.id, 1), funcs:count(1)->Column#11, funcs:firstrow(test.t.id)->Column#12",
" └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo"
]
},
{
"SQL": "desc format = 'brief' select * from t join ( select count(*), id from t group by id) as A on A.id = t.id",
"Plan": [
Expand Down Expand Up @@ -1604,19 +1618,16 @@
"TableReader 7992.00 root data:ExchangeSender",
"└─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: PassThrough",
" └─Projection 7992.00 batchCop[tiflash] Column#7",
" └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(Column#8)->Column#7",
" └─ExchangeReceiver 7992.00 batchCop[tiflash] ",
" └─ExchangeSender 7992.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id",
" └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(test.t.id)->Column#8",
" └─HashJoin 12487.50 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]",
" ├─ExchangeReceiver(Build) 9990.00 batchCop[tiflash] ",
" │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id",
" │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))",
" │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo",
" └─ExchangeReceiver(Probe) 9990.00 batchCop[tiflash] ",
" └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id",
" └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))",
" └─TableFullScan 10000.00 batchCop[tiflash] table:t1 keep order:false, stats:pseudo"
" └─HashAgg 7992.00 batchCop[tiflash] group by:test.t.id, funcs:sum(test.t.id)->Column#7",
" └─HashJoin 12487.50 batchCop[tiflash] inner join, equal:[eq(test.t.id, test.t.id)]",
" ├─ExchangeReceiver(Build) 9990.00 batchCop[tiflash] ",
" │ └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id",
" │ └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))",
" │ └─TableFullScan 10000.00 batchCop[tiflash] table:t keep order:false, stats:pseudo",
" └─ExchangeReceiver(Probe) 9990.00 batchCop[tiflash] ",
" └─ExchangeSender 9990.00 batchCop[tiflash] ExchangeType: HashPartition, Hash Cols: test.t.id",
" └─Selection 9990.00 batchCop[tiflash] not(isnull(test.t.id))",
" └─TableFullScan 10000.00 batchCop[tiflash] table:t1 keep order:false, stats:pseudo"
]
},
{
Expand Down
7 changes: 6 additions & 1 deletion store/mockstore/unistore/cophandler/mpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package cophandler

import (
"context"
"fmt"
"sync"
"time"

Expand Down Expand Up @@ -327,7 +328,7 @@ func HandleMPPDAGReq(dbReader *dbreader.DBReader, req *coprocessor.Request, mppC
}
mppExec, err := builder.buildMPPExecutor(dagReq.RootExecutor)
if err != nil {
return &coprocessor.Response{OtherError: err.Error()}
panic("build error: " + err.Error())
}
err = mppExec.open()
if err != nil {
Expand Down Expand Up @@ -402,6 +403,10 @@ type ExchangerTunnel struct {
ErrCh chan error
}

func (tunnel *ExchangerTunnel) debugString() string {
return fmt.Sprintf("(%d->%d)", tunnel.sourceTask.TaskId, tunnel.targetTask.TaskId)
}

// RecvChunk recive tipb chunk
func (tunnel *ExchangerTunnel) RecvChunk() (tipbChunk *tipb.Chunk, err error) {
tipbChunk = <-tunnel.DataCh
Expand Down

0 comments on commit f6a61bc

Please sign in to comment.