Skip to content

Commit

Permalink
*: Support stream aggregation. (pingcap#1735)
Browse files Browse the repository at this point in the history
* *: Support stream aggregation.
  • Loading branch information
hanfei1991 authored and zimulala committed Sep 20, 2016
1 parent 1fca947 commit b8994bb
Show file tree
Hide file tree
Showing 22 changed files with 895 additions and 644 deletions.
101 changes: 22 additions & 79 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,19 @@ package executor

import (
"math"
"strings"

"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/sessionctx/autocommit"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/types"
"github.com/pingcap/tipb/go-tipb"
)

// executorBuilder builds an Executor from a Plan.
Expand Down Expand Up @@ -101,7 +96,7 @@ func (b *executorBuilder) build(p plan.Plan) Executor {
return b.buildSemiJoin(v)
case *plan.Selection:
return b.buildSelection(v)
case *plan.Aggregation:
case *plan.PhysicalAggregation:
return b.buildAggregation(v)
case *plan.Projection:
return b.buildProjection(v)
Expand Down Expand Up @@ -469,86 +464,26 @@ func (b *executorBuilder) buildSemiJoin(v *plan.PhysicalHashSemiJoin) Executor {
return e
}

func (b *executorBuilder) buildAggregation(v *plan.Aggregation) Executor {
func (b *executorBuilder) buildAggregation(v *plan.PhysicalAggregation) Executor {
src := b.build(v.GetChildByIndex(0))
e := &AggregationExec{
if v.AggType == plan.StreamedAgg {
return &StreamAggExec{
Src: src,
schema: v.GetSchema(),
ctx: b.ctx,
AggFuncs: v.AggFuncs,
GroupByItems: v.GroupByItems,
}
}
return &HashAggExec{
Src: src,
schema: v.GetSchema(),
ctx: b.ctx,
AggFuncs: v.AggFuncs,
GroupByItems: v.GroupByItems,
aggType: v.AggType,
hasGby: v.HasGby,
}
// Check if the underlying is distsql executor, we should try to push aggregate function down.
xSrc, ok := src.(XExecutor)
if !ok {
return e
}
client := b.ctx.GetClient()
if len(v.GroupByItems) > 0 && !client.SupportRequestType(kv.ReqTypeSelect, kv.ReqSubTypeGroupBy) {
return e
}
// Convert aggregate function exprs to pb.
pbAggFuncs := make([]*tipb.Expr, 0, len(v.AggFuncs))
for _, af := range v.AggFuncs {
if af.IsDistinct() {
// We do not support distinct push down.
return e
}
pbAggFunc := b.AggFuncToPBExpr(client, af, xSrc.GetTable())
if pbAggFunc == nil {
return e
}
pbAggFuncs = append(pbAggFuncs, pbAggFunc)
}
pbByItems := make([]*tipb.ByItem, 0, len(v.GroupByItems))
// Convert groupby to pb
for _, item := range v.GroupByItems {
pbByItem := b.GroupByItemToPB(client, item, xSrc.GetTable())
if pbByItem == nil {
return e
}
pbByItems = append(pbByItems, pbByItem)
}
// compose aggregate info
// We should infer fields type.
// Each agg item will be splitted into two datums: count and value
// The first field should be group key.
fields := make([]*types.FieldType, 0, 1+2*len(v.AggFuncs))
gk := types.NewFieldType(mysql.TypeBlob)
gk.Charset = charset.CharsetBin
gk.Collate = charset.CollationBin
fields = append(fields, gk)
// There will be one or two fields in the result row for each AggregateFuncExpr.
// Count needs count partial result field.
// Sum, FirstRow, Max, Min, GroupConcat need value partial result field.
// Avg needs both count and value partial result field.
for i, agg := range v.AggFuncs {
name := strings.ToLower(agg.GetName())
if needCount(name) {
// count partial result field
ft := types.NewFieldType(mysql.TypeLonglong)
ft.Flen = 21
ft.Charset = charset.CharsetBin
ft.Collate = charset.CollationBin
fields = append(fields, ft)
}
if needValue(name) {
// value partial result field
col := v.GetSchema()[i]
fields = append(fields, col.GetType())
}
}
xSrc.AddAggregate(pbAggFuncs, pbByItems, fields)
hasGroupBy := len(v.GroupByItems) > 0
xe := &XAggregateExec{
Src: src,
ctx: b.ctx,
AggFuncs: v.AggFuncs,
hasGroupBy: hasGroupBy,
schema: v.GetSchema(),
}
log.Debugf("Use XAggregateExec with %d aggs", len(v.AggFuncs))
return xe
}

func (b *executorBuilder) buildSelection(v *plan.Selection) Executor {
Expand Down Expand Up @@ -641,6 +576,10 @@ func (b *executorBuilder) buildTableScan(v *plan.PhysicalTableScan, s *plan.Sele
limitCount: v.LimitCount,
keepOrder: v.KeepOrder,
where: v.ConditionPBExpr,
aggregate: v.Aggregated,
aggFuncs: v.AggFuncs,
aggFields: v.AggFields,
byItems: v.GbyItems,
}
return st
}
Expand Down Expand Up @@ -683,6 +622,10 @@ func (b *executorBuilder) buildIndexScan(v *plan.PhysicalIndexScan, s *plan.Sele
indexPlan: v,
startTS: startTS,
where: v.ConditionPBExpr,
aggregate: v.Aggregated,
aggFuncs: v.AggFuncs,
aggFields: v.AggFields,
byItems: v.GbyItems,
}
return st
}
Expand Down
33 changes: 21 additions & 12 deletions executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -988,13 +988,15 @@ func (e *HashSemiJoinExec) Next() (*Row, error) {
}
}

// AggregationExec deals with all the aggregate functions.
// HashAggExec deals with all the aggregate functions.
// It is built from Aggregate Plan. When Next() is called, it reads all the data from Src and updates all the items in AggFuncs.
type AggregationExec struct {
type HashAggExec struct {
Src Executor
schema expression.Schema
ResultFields []*ast.ResultField
executed bool
hasGby bool
aggType plan.AggregationType
ctx context.Context
AggFuncs []expression.AggregationFunction
groupMap map[string]bool
Expand All @@ -1004,7 +1006,7 @@ type AggregationExec struct {
}

// Close implements Executor Close interface.
func (e *AggregationExec) Close() error {
func (e *HashAggExec) Close() error {
e.executed = false
e.groups = nil
e.currentGroupIndex = 0
Expand All @@ -1015,17 +1017,17 @@ func (e *AggregationExec) Close() error {
}

// Schema implements Executor Schema interface.
func (e *AggregationExec) Schema() expression.Schema {
func (e *HashAggExec) Schema() expression.Schema {
return e.schema
}

// Fields implements Executor Fields interface.
func (e *AggregationExec) Fields() []*ast.ResultField {
func (e *HashAggExec) Fields() []*ast.ResultField {
return e.ResultFields
}

// Next implements Executor Next interface.
func (e *AggregationExec) Next() (*Row, error) {
func (e *HashAggExec) Next() (*Row, error) {
// In this stage we consider all data from src as a single group.
if !e.executed {
e.groupMap = make(map[string]bool)
Expand All @@ -1039,7 +1041,7 @@ func (e *AggregationExec) Next() (*Row, error) {
}
}
e.executed = true
if (len(e.groups) == 0) && (len(e.GroupByItems) == 0) {
if (len(e.groups) == 0) && !e.hasGby {
// If no groupby and no data, we should add an empty group.
// For example:
// "select count(c) from t;" should return one row [0]
Expand All @@ -1059,8 +1061,15 @@ func (e *AggregationExec) Next() (*Row, error) {
return retRow, nil
}

func (e *AggregationExec) getGroupKey(row *Row) ([]byte, error) {
if len(e.GroupByItems) == 0 {
func (e *HashAggExec) getGroupKey(row *Row) ([]byte, error) {
if e.aggType == plan.FinalAgg {
val, err := e.GroupByItems[0].Eval(row.Data, e.ctx)
if err != nil {
return nil, errors.Trace(err)
}
return val.GetBytes(), nil
}
if !e.hasGby {
return []byte{}, nil
}
vals := make([]types.Datum, 0, len(e.GroupByItems))
Expand All @@ -1080,7 +1089,7 @@ func (e *AggregationExec) getGroupKey(row *Row) ([]byte, error) {

// Fetch a single row from src and update each aggregate function.
// If the first return value is false, it means there is no more data from src.
func (e *AggregationExec) innerNext() (ret bool, err error) {
func (e *HashAggExec) innerNext() (ret bool, err error) {
var srcRow *Row
if e.Src != nil {
srcRow, err = e.Src.Next()
Expand Down Expand Up @@ -1112,7 +1121,8 @@ func (e *AggregationExec) innerNext() (ret bool, err error) {
}

// StreamAggExec deals with all the aggregate functions.
// It is built from Aggregate Plan. When Next() is called, it reads all the data from Src and updates all the items in AggFuncs.
// It assumes all the input datas is sorted by group by key.
// When Next() is called, it will return a result for the same group.
type StreamAggExec struct {
Src Executor
schema expression.Schema
Expand Down Expand Up @@ -1173,7 +1183,6 @@ func (e *StreamAggExec) Next() (*Row, error) {
for _, af := range e.AggFuncs {
retRow.Data = append(retRow.Data, af.GetStreamResult())
}

}
if e.executed {
break
Expand Down
46 changes: 45 additions & 1 deletion executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,7 @@ func (s *testSuite) TestNewTableDual(c *C) {
result.Check(testkit.Rows("1"))
}

func (s *testSuite) TestNewTableScan(c *C) {
func (s *testSuite) TestTableScan(c *C) {
defer testleak.AfterTest(c)()
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use information_schema")
Expand Down Expand Up @@ -2199,11 +2199,55 @@ func (s *testSuite) TestAggregation(c *C) {
tk.MustExec("insert into t1 (a) values (2), (11), (8)")
result = tk.MustQuery("select min(a), min(case when 1=1 then a else NULL end), min(case when 1!=1 then NULL else a end) from t1 where b=3 group by b")
result.Check(testkit.Rows("2 2 2"))
// The following cases use streamed aggregation.
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(a int, index(a))")
tk.MustExec("insert into t1 (a) values (1),(2),(3),(4),(5)")
result = tk.MustQuery("select count(a) from t1 where a < 3")
result.Check(testkit.Rows("2"))
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(a int, b int, index(a))")
result = tk.MustQuery("select sum(b) from (select * from t1) t group by a")
result.Check(testkit.Rows())
result = tk.MustQuery("select sum(b) from (select * from t1) t")
result.Check(testkit.Rows("<nil>"))
tk.MustExec("insert into t1 (a, b) values (1, 1),(2, 2),(3, 3),(1, 4),(3, 5)")
result = tk.MustQuery("select avg(b) from (select * from t1) t group by a")
result.Check(testkit.Rows("2.5000", "2.0000", "4.0000"))
result = tk.MustQuery("select sum(b) from (select * from t1) t group by a")
result.Check(testkit.Rows("5", "2", "8"))
result = tk.MustQuery("select count(b) from (select * from t1) t group by a")
result.Check(testkit.Rows("2", "1", "2"))
result = tk.MustQuery("select max(b) from (select * from t1) t group by a")
result.Check(testkit.Rows("4", "2", "5"))
result = tk.MustQuery("select min(b) from (select * from t1) t group by a")
result.Check(testkit.Rows("1", "2", "3"))
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(a int, b int, index(a,b))")
tk.MustExec("insert into t1 (a, b) values (1, 1),(2, 2),(3, 3),(1, 4), (1,1),(3, 5), (2,2), (3,5), (3,3)")
result = tk.MustQuery("select avg(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("2.5000", "2.0000", "4.0000"))
result = tk.MustQuery("select sum(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("5", "2", "8"))
result = tk.MustQuery("select count(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("2", "1", "2"))
result = tk.MustQuery("select max(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("4", "2", "5"))
result = tk.MustQuery("select min(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("1", "2", "3"))
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(a int, b int, index(b, a))")
tk.MustExec("insert into t1 (a, b) values (1, 1),(2, 2),(3, 3),(1, 4), (1,1),(3, 5), (2,2), (3,5), (3,3)")
result = tk.MustQuery("select avg(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("2.5000", "2.0000", "4.0000"))
result = tk.MustQuery("select sum(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("5", "2", "8"))
result = tk.MustQuery("select count(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("2", "1", "2"))
result = tk.MustQuery("select max(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("4", "2", "5"))
result = tk.MustQuery("select min(distinct b) from (select * from t1) t group by a")
result.Check(testkit.Rows("1", "2", "3"))
}

func (s *testSuite) TestAdapterStatement(c *C) {
Expand Down
2 changes: 1 addition & 1 deletion executor/union_scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ type UnionScanExec struct {
ctx context.Context
Src Executor
dirty *dirtyTable
// srcUsedIndex is the column offsets of the index which Src executor has used.
// usedIndex is the column offsets of the index which Src executor has used.
usedIndex []int
desc bool
condition expression.Expression
Expand Down
Loading

0 comments on commit b8994bb

Please sign in to comment.