Skip to content

Commit

Permalink
*: support time constant push down in mocktikv (pingcap#4176)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanfei1991 authored Aug 18, 2017
1 parent 2349974 commit 4321511
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 39 deletions.
35 changes: 6 additions & 29 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,35 +497,6 @@ func (s *testSuite) TestSelectLimit(c *C) {
c.Assert(err, NotNil)
}

func (s *testSuite) TestDAG(c *C) {
tk := testkit.NewTestKit(c, s.store)
defer func() {
s.cleanEnv(c)
testleak.AfterTest(c)()
}()
tk.MustExec("use test")
tk.MustExec("create table select_dag(id int not null default 1, name varchar(255));")

tk.MustExec("insert INTO select_dag VALUES (1, \"hello\");")
tk.MustExec("insert INTO select_dag VALUES (2, \"hello\");")
tk.MustExec("insert INTO select_dag VALUES (3, \"hello\");")
tk.CheckExecResult(1, 0)

r := tk.MustQuery("select * from select_dag;")
r.Check(testkit.Rows("1 hello", "2 hello", "3 hello"))

r = tk.MustQuery("select * from select_dag where id > 1;")
r.Check(testkit.Rows("2 hello", "3 hello"))

// for limit
r = tk.MustQuery("select * from select_dag limit 1;")
r.Check(testkit.Rows("1 hello"))
r = tk.MustQuery("select * from select_dag limit 0;")
r.Check(testkit.Rows())
r = tk.MustQuery("select * from select_dag limit 5;")
r.Check(testkit.Rows("1 hello", "2 hello", "3 hello"))
}

func (s *testSuite) TestSelectOrderBy(c *C) {
defer func() {
s.cleanEnv(c)
Expand Down Expand Up @@ -1618,6 +1589,12 @@ func (s *testSuite) TestSimpleDAG(c *C) {
tk.MustQuery("select * from t where b = 2").Check(testkit.Rows("4 2 3"))
tk.MustQuery("select count(*) from t where b = 1").Check(testkit.Rows("3"))
tk.MustQuery("select * from t where b = 1 and a > 1 limit 1").Check(testkit.Rows("2 1 1"))

// Test time push down.
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (id int, c1 datetime);")
tk.MustExec("insert into t values (1, '2015-06-07 12:12:12')")
tk.MustQuery("select id from t where c1 = '2015-06-07 12:12:12'").Check(testkit.Rows("1"))
}

func (s *testSuite) TestTimestampTimeZone(c *C) {
Expand Down
31 changes: 31 additions & 0 deletions expression/distsql_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ func PBToExpr(expr *tipb.Expr, tps []*types.FieldType, sc *variable.StatementCon
return convertDecimal(expr.Val)
case tipb.ExprType_MysqlDuration:
return convertDuration(expr.Val)
case tipb.ExprType_MysqlTime:
return convertTime(expr.Val, expr.FieldType, sc.TimeZone)
}
// Then it must be a scalar function.
args := make([]Expression, 0, len(expr.Children))
Expand All @@ -146,6 +148,35 @@ func PBToExpr(expr *tipb.Expr, tps []*types.FieldType, sc *variable.StatementCon
return newDistSQLFunction(sc, expr.Tp, args)
}

func fieldTypeFromPB(ft *tipb.FieldType) *types.FieldType {
return &types.FieldType{
Tp: byte(ft.GetTp()),
Flag: uint(ft.GetFlag()),
Flen: int(ft.GetFlen()),
Decimal: int(ft.GetDecimal()),
Collate: mysql.Collations[uint8(ft.GetCollate())],
}
}

func convertTime(data []byte, ftPB *tipb.FieldType, tz *time.Location) (*Constant, error) {
ft := fieldTypeFromPB(ftPB)
_, v, err := codec.DecodeUint(data)
if err != nil {
return nil, errors.Trace(nil)
}
var t types.Time
t.Type = ft.Tp
t.Fsp = ft.Decimal
err = t.FromPackedUint(v)
if err != nil {
return nil, errors.Trace(err)
}
if ft.Tp == mysql.TypeTimestamp && !t.IsZero() {
t.ConvertTimeZone(time.UTC, tz)
}
return &Constant{Value: types.NewTimeDatum(t), RetType: ft}, nil
}

func decodeValueList(data []byte) ([]Expression, error) {
if len(data) == 0 {
return nil, nil
Expand Down
52 changes: 47 additions & 5 deletions expression/expr_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package expression

import (
"time"

"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/kv"
Expand Down Expand Up @@ -64,7 +66,7 @@ type pbConverter struct {
func (pc pbConverter) exprToPB(expr Expression) *tipb.Expr {
switch x := expr.(type) {
case *Constant:
return pc.datumToPBExpr(x.Value)
return pc.constantToPBExpr(x)
case *Column:
return pc.columnToPBExpr(x)
case *ScalarFunction:
Expand All @@ -73,9 +75,14 @@ func (pc pbConverter) exprToPB(expr Expression) *tipb.Expr {
return nil
}

func (pc pbConverter) datumToPBExpr(d types.Datum) *tipb.Expr {
var tp tipb.ExprType
var val []byte
func (pc pbConverter) constantToPBExpr(con *Constant) *tipb.Expr {
var (
tp tipb.ExprType
val []byte
d = con.Value
ft = con.GetType()
)

switch d.Kind() {
case types.KindNull:
tp = tipb.ExprType_Null
Expand Down Expand Up @@ -103,6 +110,23 @@ func (pc pbConverter) datumToPBExpr(d types.Datum) *tipb.Expr {
case types.KindMysqlDecimal:
tp = tipb.ExprType_MysqlDecimal
val = codec.EncodeDecimal(nil, d)
case types.KindMysqlTime:
if pc.client.IsRequestTypeSupported(kv.ReqTypeDAG, int64(tipb.ExprType_MysqlTime)) {
tp = tipb.ExprType_MysqlTime
loc := pc.sc.TimeZone
t := d.GetMysqlTime()
if t.Type == mysql.TypeTimestamp && loc != time.UTC {
t.ConvertTimeZone(loc, time.UTC)
}
v, err := t.ToPackedUint()
if err != nil {
log.Errorf("Fail to encode value, err: %s", err.Error())
return nil
}
val = codec.EncodeUint(nil, v)
return &tipb.Expr{Tp: tp, Val: val, FieldType: toPBFieldType(ft)}
}
return nil
default:
return nil
}
Expand All @@ -112,6 +136,24 @@ func (pc pbConverter) datumToPBExpr(d types.Datum) *tipb.Expr {
return &tipb.Expr{Tp: tp, Val: val}
}

func toPBFieldType(ft *types.FieldType) *tipb.FieldType {
return &tipb.FieldType{
Tp: int32(ft.Tp),
Flag: uint32(ft.Flag),
Flen: int32(ft.Flen),
Decimal: int32(ft.Decimal),
Collate: collationToProto(ft.Collate),
}
}

func collationToProto(c string) int32 {
v, ok := mysql.CollationNames[c]
if ok {
return int32(v)
}
return int32(mysql.DefaultCollationID)
}

func (pc pbConverter) columnToPBExpr(column *Column) *tipb.Expr {
if !pc.client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tipb.ExprType_ColumnRef)) {
return nil
Expand Down Expand Up @@ -309,7 +351,7 @@ func (pc pbConverter) constListToPB(list []Expression) *tipb.Expr {
if !ok {
return nil
}
d := pc.datumToPBExpr(v.Value)
d := pc.constantToPBExpr(v)
if d == nil {
return nil
}
Expand Down
10 changes: 8 additions & 2 deletions plan/cbo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type testAnalyzeSuite struct {
}

func constructInsertSQL(i, n int) string {
sql := "insert into t values "
sql := "insert into t (a,b,c)values "
for j := 0; j < n; j++ {
sql += fmt.Sprintf("(%d, %d, '%d')", i*n+j, i, i+j)
if j != n-1 {
Expand All @@ -57,8 +57,9 @@ func (s *testAnalyzeSuite) TestIndexRead(c *C) {
}()
testKit.MustExec("use test")
testKit.MustExec("drop table if exists t")
testKit.MustExec("create table t (a int primary key, b int, c varchar(200))")
testKit.MustExec("create table t (a int primary key, b int, c varchar(200), d datetime DEFAULT CURRENT_TIMESTAMP)")
testKit.MustExec("create index b on t (b)")
testKit.MustExec("create index d on t (d)")
for i := 0; i < 100; i++ {
testKit.MustExec(constructInsertSQL(i, 100))
}
Expand Down Expand Up @@ -103,6 +104,11 @@ func (s *testAnalyzeSuite) TestIndexRead(c *C) {
sql: "select * from t use index(b) where b = 1 order by a",
best: "IndexLookUp(Index(t.b)[[1,1]], Table(t))->Sort",
},
// test datetime
{
sql: "select * from t where d < cast('1991-09-05' as datetime)",
best: "IndexLookUp(Index(t.d)[[-inf,1991-09-05 00:00:00)], Table(t))",
},
}
for _, tt := range tests {
ctx := testKit.Se.(context.Context)
Expand Down
6 changes: 3 additions & 3 deletions store/tikv/coprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ func (c *CopClient) IsRequestTypeSupported(reqType, subType int64) bool {
case kv.ReqSubTypeGroupBy, kv.ReqSubTypeBasic, kv.ReqSubTypeTopN:
return true
default:
return supportExpr(tipb.ExprType(subType))
return c.supportExpr(tipb.ExprType(subType))
}
case kv.ReqTypeDAG:
return true
return c.supportExpr(tipb.ExprType(subType))
}
return false
}

func supportExpr(exprType tipb.ExprType) bool {
func (c *CopClient) supportExpr(exprType tipb.ExprType) bool {
switch exprType {
case tipb.ExprType_Null, tipb.ExprType_Int64, tipb.ExprType_Uint64, tipb.ExprType_String, tipb.ExprType_Bytes,
tipb.ExprType_MysqlDuration, tipb.ExprType_MysqlTime, tipb.ExprType_MysqlDecimal,
Expand Down

0 comments on commit 4321511

Please sign in to comment.