Skip to content

Commit

Permalink
executor, statistics: use FM sketch to estimate NDV (pingcap#2966)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored and coocood committed Apr 1, 2017
1 parent 82285af commit 16cefef
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 73 deletions.
32 changes: 25 additions & 7 deletions executor/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type AnalyzeExec struct {

const (
maxSampleCount = 10000
maxSketchSize = 1000
defaultBucketCount = 256
)

Expand All @@ -65,10 +66,11 @@ func (e *AnalyzeExec) Next() (*Row, error) {
ae := src.(*AnalyzeExec)
var count int64 = -1
var sampleRows []*ast.Row
var colNDVs []int64
if ae.colOffsets != nil {
rs := &recordSet{executor: ae.Srcs[len(ae.Srcs)-1]}
var err error
count, sampleRows, err = collectSamples(rs)
count, sampleRows, colNDVs, err = CollectSamplesAndEstimateNDVs(rs, len(ae.colOffsets))
if err != nil {
return nil, errors.Trace(err)
}
Expand All @@ -86,22 +88,23 @@ func (e *AnalyzeExec) Next() (*Row, error) {
for i := range ae.idxOffsets {
idxRS = append(idxRS, &recordSet{executor: ae.Srcs[i]})
}
err := ae.buildStatisticsAndSaveToKV(count, columnSamples, idxRS, pkRS)
err := ae.buildStatisticsAndSaveToKV(count, columnSamples, colNDVs, idxRS, pkRS)
if err != nil {
return nil, errors.Trace(err)
}
}
return nil, nil
}

func (e *AnalyzeExec) buildStatisticsAndSaveToKV(count int64, columnSamples [][]types.Datum, idxRS []ast.RecordSet, pkRS ast.RecordSet) error {
func (e *AnalyzeExec) buildStatisticsAndSaveToKV(count int64, columnSamples [][]types.Datum, colNDVs []int64, idxRS []ast.RecordSet, pkRS ast.RecordSet) error {
statBuilder := &statistics.Builder{
Ctx: e.ctx,
TblInfo: e.tblInfo,
Count: count,
NumBuckets: defaultBucketCount,
ColumnSamples: columnSamples,
ColOffsets: e.colOffsets,
ColNDVs: colNDVs,
IdxRecords: idxRS,
IdxOffsets: e.idxOffsets,
PkRecords: pkRS,
Expand All @@ -115,17 +118,29 @@ func (e *AnalyzeExec) buildStatisticsAndSaveToKV(count int64, columnSamples [][]
return errors.Trace(err)
}

// collectSamples collects sample from the result set, using Reservoir Sampling algorithm.
// CollectSamplesAndEstimateNDVs collects sample from the result set using Reservoir Sampling algorithm,
// and estimates NDVs using FM Sketch during the collecting process.
// See https://en.wikipedia.org/wiki/Reservoir_sampling
func collectSamples(e ast.RecordSet) (count int64, samples []*ast.Row, err error) {
// Exported for test.
func CollectSamplesAndEstimateNDVs(e ast.RecordSet, numCols int) (count int64, samples []*ast.Row, ndvs []int64, err error) {
var sketches []*statistics.FMSketch
for i := 0; i < numCols; i++ {
sketches = append(sketches, statistics.NewFMSketch(maxSketchSize))
}
for {
row, err := e.Next()
if err != nil {
return count, samples, errors.Trace(err)
return count, samples, ndvs, errors.Trace(err)
}
if row == nil {
break
}
for i, val := range row.Data {
err = sketches[i].InsertValue(val)
if err != nil {
return count, samples, ndvs, errors.Trace(err)
}
}
if len(samples) < maxSampleCount {
samples = append(samples, row)
} else {
Expand All @@ -137,7 +152,10 @@ func collectSamples(e ast.RecordSet) (count int64, samples []*ast.Row, err error
}
count++
}
return count, samples, nil
for _, sketch := range sketches {
ndvs = append(ndvs, sketch.NDV())
}
return count, samples, ndvs, nil
}

func rowsToColumnSamples(rows []*ast.Row) [][]types.Datum {
Expand Down
50 changes: 50 additions & 0 deletions executor/analyze_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ import (
"strings"

. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/types"
)

func (s *testSuite) TestAnalyzeTable(c *C) {
Expand All @@ -40,3 +43,50 @@ func (s *testSuite) TestAnalyzeTable(c *C) {
rowStr = fmt.Sprintf("%s", result.Rows())
c.Check(strings.Split(rowStr, "{")[0], Equals, "[[TableScan_4 ")
}

type recordSet struct {
data []types.Datum
count int
cursor int
}

func (r *recordSet) Fields() ([]*ast.ResultField, error) {
return nil, nil
}

func (r *recordSet) Next() (*ast.Row, error) {
if r.cursor == r.count {
return nil, nil
}
r.cursor++
return &ast.Row{Data: []types.Datum{r.data[r.cursor-1]}}, nil
}

func (r *recordSet) Close() error {
r.cursor = 0
return nil
}

func (s *testSuite) TestCollectSamplesAndEstimateNDVs(c *C) {
count := 10000
rs := &recordSet{
data: make([]types.Datum, count),
count: count,
cursor: 0,
}
start := 1000 // 1000 values is null
for i := start; i < rs.count; i++ {
rs.data[i].SetInt64(int64(i))
}
for i := start; i < rs.count; i += 3 {
rs.data[i].SetInt64(rs.data[i].GetInt64() + 1)
}
for i := start; i < rs.count; i += 5 {
rs.data[i].SetInt64(rs.data[i].GetInt64() + 2)
}

cnt, _, ndvs, err := executor.CollectSamplesAndEstimateNDVs(rs, 1)
c.Assert(err, IsNil)
c.Assert(cnt, Equals, int64(rs.count))
c.Assert(ndvs[0], Equals, int64(6624))
}
41 changes: 4 additions & 37 deletions statistics/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type Builder struct {
NumBuckets int64 // NumBuckets is the number of buckets a column histogram has.
ColumnSamples [][]types.Datum // ColumnSamples is the sample of columns.
ColOffsets []int // ColOffsets is the offset of columns in the table.
ColNDVs []int64 // ColNDVs is the NDV of columns.
IdxRecords []ast.RecordSet // IdxRecords is the record set of index columns.
IdxOffsets []int // IdxOffsets is the offset of indices in the table.
PkRecords ast.RecordSet // PkRecords is the record set of primary key of integer type.
Expand All @@ -48,7 +49,7 @@ func (b *Builder) buildMultiColumns(t *Table, offsets []int, baseOffset int, isS
if isSorted {
err = t.build4SortedColumn(b.Ctx.GetSessionVars().StmtCtx, offset, b.IdxRecords[i+baseOffset], b.NumBuckets, false)
} else {
err = t.buildColumn(b.Ctx.GetSessionVars().StmtCtx, offset, b.ColumnSamples[i+baseOffset], b.NumBuckets)
err = t.buildColumn(b.Ctx.GetSessionVars().StmtCtx, offset, b.ColNDVs[i+baseOffset], b.ColumnSamples[i+baseOffset], b.NumBuckets)
}
if err != nil {
done <- err
Expand Down Expand Up @@ -242,19 +243,15 @@ func (t *Table) build4SortedColumn(sc *variable.StatementContext, offset int, re
}

// buildColumn builds column statistics from samples.
func (t *Table) buildColumn(sc *variable.StatementContext, offset int, samples []types.Datum, bucketCount int64) error {
func (t *Table) buildColumn(sc *variable.StatementContext, offset int, ndv int64, samples []types.Datum, bucketCount int64) error {
err := types.SortDatums(sc, samples)
if err != nil {
return errors.Trace(err)
}
estimatedNDV, err := estimateNDV(sc, t.Count, samples)
if err != nil {
return errors.Trace(err)
}
ci := t.Info.Columns[offset]
col := &Column{
ID: ci.ID,
NDV: estimatedNDV,
NDV: ndv,
Numbers: make([]int64, 1, bucketCount),
Values: make([]types.Datum, 1, bucketCount),
Repeats: make([]int64, 1, bucketCount),
Expand Down Expand Up @@ -294,36 +291,6 @@ func (t *Table) buildColumn(sc *variable.StatementContext, offset int, samples [
return nil
}

// estimateNDV estimates the number of distinct value given a count and samples.
// It implements a simplified Good–Turing frequency estimation algorithm.
// See https://en.wikipedia.org/wiki/Good%E2%80%93Turing_frequency_estimation
func estimateNDV(sc *variable.StatementContext, count int64, samples []types.Datum) (int64, error) {
lastValue := samples[0]
occurrence := 1
sampleDistinct := 1
occurredOnceCount := 0
for i := 1; i < len(samples); i++ {
cmp, err := lastValue.CompareDatum(sc, samples[i])
if err != nil {
return 0, errors.Trace(err)
}
if cmp == 0 {
occurrence++
} else {
if occurrence == 1 {
occurredOnceCount++
}
sampleDistinct++
occurrence = 1
}
lastValue = samples[i]
}
newValueProbability := float64(occurredOnceCount) / float64(len(samples))
unsampledCount := float64(count) - float64(len(samples))
estimatedDistinct := float64(sampleDistinct) + unsampledCount*newValueProbability
return int64(estimatedDistinct), nil
}

func copyFromIndexColumns(ind *Column, id, numBuckets int64) (*Column, error) {
col := &Column{
ID: id,
Expand Down
55 changes: 35 additions & 20 deletions statistics/fmsketch.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package statistics

import (
"hash"
"hash/fnv"

"github.com/juju/errors"
Expand All @@ -23,18 +24,26 @@ import (

// FMSketch is used to count the number of distinct elements in a set.
type FMSketch struct {
hashset map[uint64]bool
mask uint64
maxSize int
hashset map[uint64]bool
mask uint64
maxSize int
hashFunc hash.Hash64
}

func newFMSketch(maxSize int) *FMSketch {
// NewFMSketch returns a new FM sketch.
func NewFMSketch(maxSize int) *FMSketch {
return &FMSketch{
hashset: make(map[uint64]bool),
maxSize: maxSize,
hashset: make(map[uint64]bool),
maxSize: maxSize,
hashFunc: fnv.New64a(),
}
}

// NDV returns the ndv of the sketch.
func (s *FMSketch) NDV() int64 {
return int64(s.mask+1) * int64(len(s.hashset))
}

func (s *FMSketch) insertHashValue(hashVal uint64) {
if (hashVal & s.mask) != 0 {
return
Expand All @@ -50,27 +59,34 @@ func (s *FMSketch) insertHashValue(hashVal uint64) {
}
}

// InsertValue inserts a value into the FM sketch.
func (s *FMSketch) InsertValue(value types.Datum) error {
bytes, err := codec.EncodeValue(nil, value)
if err != nil {
return errors.Trace(err)
}
s.hashFunc.Reset()
_, err = s.hashFunc.Write(bytes)
if err != nil {
return errors.Trace(err)
}
s.insertHashValue(s.hashFunc.Sum64())
return nil
}

func buildFMSketch(values []types.Datum, maxSize int) (*FMSketch, int64, error) {
s := newFMSketch(maxSize)
h := fnv.New64a()
s := NewFMSketch(maxSize)
for _, value := range values {
bytes, err := codec.EncodeValue(nil, value)
if err != nil {
return nil, 0, errors.Trace(err)
}
h.Reset()
_, err = h.Write(bytes)
err := s.InsertValue(value)
if err != nil {
return nil, 0, errors.Trace(err)
}
s.insertHashValue(h.Sum64())
}
ndv := int64((s.mask + 1)) * int64(len(s.hashset))
return s, ndv, nil
return s, s.NDV(), nil
}

func mergeFMSketches(sketches []*FMSketch, maxSize int) (*FMSketch, int64) {
s := newFMSketch(maxSize)
s := NewFMSketch(maxSize)
for _, sketch := range sketches {
if s.mask < sketch.mask {
s.mask = sketch.mask
Expand All @@ -81,6 +97,5 @@ func mergeFMSketches(sketches []*FMSketch, maxSize int) (*FMSketch, int64) {
s.insertHashValue(key)
}
}
ndv := int64((s.mask + 1)) * int64(len(s.hashset))
return s, ndv
return s, s.NDV()
}
2 changes: 1 addition & 1 deletion statistics/fmsketch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (s *testStatisticsSuite) TestSketch(c *C) {
c.Check(ndv, Equals, int64(99968))

maxSize = 2
sketch := newFMSketch(maxSize)
sketch := NewFMSketch(maxSize)
sketch.insertHashValue(1)
sketch.insertHashValue(2)
c.Check(len(sketch.hashset), Equals, maxSize)
Expand Down
11 changes: 3 additions & 8 deletions statistics/statistics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,6 @@ func (s *testStatisticsSuite) SetUpSuite(c *C) {
s.pk = pk
}

func (s *testStatisticsSuite) TestEstimateNDV(c *C) {
sc := new(variable.StatementContext)
ndv, err := estimateNDV(sc, s.count, s.samples)
c.Check(err, IsNil)
c.Check(ndv, Equals, int64(49792))
}

func (s *testStatisticsSuite) TestTable(c *C) {
tblInfo := &model.TableInfo{
ID: 1,
Expand Down Expand Up @@ -156,6 +149,7 @@ func (s *testStatisticsSuite) TestTable(c *C) {
tblInfo.Indices = indices
timestamp := int64(10)
bucketCount := int64(256)
_, ndv, _ := buildFMSketch(s.rc.(*recordSet).data, 1000)
builder := &Builder{
Ctx: mock.NewContext(),
TblInfo: tblInfo,
Expand All @@ -164,6 +158,7 @@ func (s *testStatisticsSuite) TestTable(c *C) {
NumBuckets: bucketCount,
ColumnSamples: [][]types.Datum{s.samples},
ColOffsets: []int{0},
ColNDVs: []int64{ndv},
IdxRecords: []ast.RecordSet{s.rc},
IdxOffsets: []int{0},
PkRecords: ast.RecordSet(s.pk),
Expand All @@ -176,7 +171,7 @@ func (s *testStatisticsSuite) TestTable(c *C) {
col := t.Columns[0]
count, err := col.EqualRowCount(sc, types.NewIntDatum(1000))
c.Check(err, IsNil)
c.Check(count, Equals, int64(2))
c.Check(count, Equals, int64(1))
count, err = col.LessRowCount(sc, types.NewIntDatum(2000))
c.Check(err, IsNil)
c.Check(count, Equals, int64(19955))
Expand Down

0 comments on commit 16cefef

Please sign in to comment.