Skip to content

Commit

Permalink
executor/cte_test.go: migrate test-infra to testify (pingcap#27103)
Browse files Browse the repository at this point in the history
  • Loading branch information
unconsolable authored Aug 23, 2021
1 parent 002318f commit 4cc2423
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 131 deletions.
85 changes: 85 additions & 0 deletions executor/cte_serial_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package executor_test

import (
"fmt"
"math/rand"
"sort"
"testing"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
)

func TestSpillToDisk(t *testing.T) {
defer config.RestoreFunc()()
config.UpdateGlobal(func(conf *config.Config) {
conf.OOMUseTmpStorage = true
})

store, close := testkit.CreateMockStore(t)
defer close()

tk := testkit.NewTestKit(t, store)
tk.MustExec("use test;")

require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testCTEStorageSpill", "return(true)"))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testCTEStorageSpill"))
tk.MustExec("set tidb_mem_quota_query = 1073741824;")
}()
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testSortedRowContainerSpill", "return(true)"))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testSortedRowContainerSpill"))
}()

// Use duplicated rows to test UNION DISTINCT.
tk.MustExec("set tidb_mem_quota_query = 1073741824;")
insertStr := "insert into t1 values(0)"
rowNum := 1000
vals := make([]int, rowNum)
vals[0] = 0
for i := 1; i < rowNum; i++ {
v := rand.Intn(100)
vals[i] = v
insertStr += fmt.Sprintf(", (%d)", v)
}
tk.MustExec("drop table if exists t1;")
tk.MustExec("create table t1(c1 int);")
tk.MustExec(insertStr)
tk.MustExec("set tidb_mem_quota_query = 40000;")
tk.MustExec("set cte_max_recursion_depth = 500000;")
sql := fmt.Sprintf("with recursive cte1 as ( "+
"select c1 from t1 "+
"union "+
"select c1 + 1 c1 from cte1 where c1 < %d) "+
"select c1 from cte1 order by c1;", rowNum)
rows := tk.MustQuery(sql)

memTracker := tk.Session().GetSessionVars().StmtCtx.MemTracker
diskTracker := tk.Session().GetSessionVars().StmtCtx.DiskTracker
require.Greater(t, memTracker.MaxConsumed(), int64(0))
require.Greater(t, diskTracker.MaxConsumed(), int64(0))

sort.Ints(vals)
resRows := make([]string, 0, rowNum)
for i := vals[0]; i <= rowNum; i++ {
resRows = append(resRows, fmt.Sprintf("%d", i))
}
rows.Check(testkit.Rows(resRows...))
}
165 changes: 35 additions & 130 deletions executor/cte_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,68 +15,20 @@
package executor_test

import (
"context"
"fmt"
"math/rand"
"sort"
"testing"

"github.com/pingcap/check"

"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/session"
"github.com/pingcap/tidb/sessionctx"

"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/store/mockstore"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
)

var _ = check.Suite(&CTETestSuite{&baseCTETestSuite{}})
var _ = check.SerialSuites(&CTESerialTestSuite{&baseCTETestSuite{}})

type baseCTETestSuite struct {
store kv.Storage
dom *domain.Domain
sessionCtx sessionctx.Context
session session.Session
ctx context.Context
}

type CTETestSuite struct {
*baseCTETestSuite
}

type CTESerialTestSuite struct {
*baseCTETestSuite
}

func (test *baseCTETestSuite) SetUpSuite(c *check.C) {
var err error
test.store, err = mockstore.NewMockStore()
c.Assert(err, check.IsNil)

test.dom, err = session.BootstrapSession(test.store)
c.Assert(err, check.IsNil)

test.sessionCtx = mock.NewContext()

test.session, err = session.CreateSession4Test(test.store)
c.Assert(err, check.IsNil)
test.session.SetConnectionID(0)

test.ctx = context.Background()
}
func TestBasicCTE(t *testing.T) {
t.Parallel()

func (test *baseCTETestSuite) TearDownSuite(c *check.C) {
test.dom.Close()
test.store.Close()
}
store, close := testkit.CreateMockStore(t)
defer close()

func (test *CTETestSuite) TestBasicCTE(c *check.C) {
tk := testkit.NewTestKit(c, test.store)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")

rows := tk.MustQuery("with recursive cte1 as (" +
Expand Down Expand Up @@ -121,63 +73,13 @@ func (test *CTETestSuite) TestBasicCTE(c *check.C) {
rows.Check(testkit.Rows("1 1", "2 1", "3 1", "4 1", "5 1"))
}

func (test *CTESerialTestSuite) TestSpillToDisk(c *check.C) {
defer config.RestoreFunc()()
config.UpdateGlobal(func(conf *config.Config) {
conf.OOMUseTmpStorage = true
})

tk := testkit.NewTestKit(c, test.store)
tk.MustExec("use test;")
func TestUnionDistinct(t *testing.T) {
t.Parallel()

c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/testCTEStorageSpill", "return(true)"), check.IsNil)
defer func() {
c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/testCTEStorageSpill"), check.IsNil)
tk.MustExec("set tidb_mem_quota_query = 1073741824;")
}()
c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/testSortedRowContainerSpill", "return(true)"), check.IsNil)
defer func() {
c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/testSortedRowContainerSpill"), check.IsNil)
}()

// Use duplicated rows to test UNION DISTINCT.
tk.MustExec("set tidb_mem_quota_query = 1073741824;")
insertStr := "insert into t1 values(0)"
rowNum := 1000
vals := make([]int, rowNum)
vals[0] = 0
for i := 1; i < rowNum; i++ {
v := rand.Intn(100)
vals[i] = v
insertStr += fmt.Sprintf(", (%d)", v)
}
tk.MustExec("drop table if exists t1;")
tk.MustExec("create table t1(c1 int);")
tk.MustExec(insertStr)
tk.MustExec("set tidb_mem_quota_query = 40000;")
tk.MustExec("set cte_max_recursion_depth = 500000;")
sql := fmt.Sprintf("with recursive cte1 as ( "+
"select c1 from t1 "+
"union "+
"select c1 + 1 c1 from cte1 where c1 < %d) "+
"select c1 from cte1 order by c1;", rowNum)
rows := tk.MustQuery(sql)

memTracker := tk.Se.GetSessionVars().StmtCtx.MemTracker
diskTracker := tk.Se.GetSessionVars().StmtCtx.DiskTracker
c.Assert(memTracker.MaxConsumed(), check.Greater, int64(0))
c.Assert(diskTracker.MaxConsumed(), check.Greater, int64(0))

sort.Ints(vals)
resRows := make([]string, 0, rowNum)
for i := vals[0]; i <= rowNum; i++ {
resRows = append(resRows, fmt.Sprintf("%d", i))
}
rows.Check(testkit.Rows(resRows...))
}
store, close := testkit.CreateMockStore(t)
defer close()

func (test *CTETestSuite) TestUnionDistinct(c *check.C) {
tk := testkit.NewTestKit(c, test.store)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test;")

// Basic test. UNION/UNION ALL intersects.
Expand All @@ -200,14 +102,18 @@ func (test *CTETestSuite) TestUnionDistinct(c *check.C) {
rows.Check(testkit.Rows("1", "2", "3", "4"))
}

func (test *CTETestSuite) TestCTEMaxRecursionDepth(c *check.C) {
tk := testkit.NewTestKit(c, test.store)
func TestCTEMaxRecursionDepth(t *testing.T) {
t.Parallel()

store, close := testkit.CreateMockStore(t)
defer close()

tk := testkit.NewTestKit(t, store)
tk.MustExec("use test;")

tk.MustExec("set @@cte_max_recursion_depth = -1;")
err := tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 100) select * from cte1;")
c.Assert(err, check.NotNil)
c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
require.EqualError(t, err, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
// If there is no recursive part, query runs ok.
rows := tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;")
rows.Check(testkit.Rows("1", "2"))
Expand All @@ -216,11 +122,9 @@ func (test *CTETestSuite) TestCTEMaxRecursionDepth(c *check.C) {

tk.MustExec("set @@cte_max_recursion_depth = 0;")
err = tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 0) select * from cte1;")
c.Assert(err, check.NotNil)
c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
require.EqualError(t, err, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
err = tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 1) select * from cte1;")
c.Assert(err, check.NotNil)
c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
require.EqualError(t, err, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
// If there is no recursive part, query runs ok.
rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;")
rows.Check(testkit.Rows("1", "2"))
Expand All @@ -233,17 +137,21 @@ func (test *CTETestSuite) TestCTEMaxRecursionDepth(c *check.C) {
rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 1) select * from cte1;")
rows.Check(testkit.Rows("1"))
err = tk.QueryToErr("with recursive cte1(c1) as (select 1 union select c1 + 1 c1 from cte1 where c1 < 2) select * from cte1;")
c.Assert(err, check.NotNil)
c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 2 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
require.EqualError(t, err, "[executor:3636]Recursive query aborted after 2 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
// If there is no recursive part, query runs ok.
rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;")
rows.Check(testkit.Rows("1", "2"))
rows = tk.MustQuery("with cte1(c1) as (select 1 union select 2) select * from cte1 order by c1;")
rows.Check(testkit.Rows("1", "2"))
}

func (test *CTETestSuite) TestCTEWithLimit(c *check.C) {
tk := testkit.NewTestKit(c, test.store)
func TestCTEWithLimit(t *testing.T) {
t.Parallel()

store, close := testkit.CreateMockStore(t)
defer close()

tk := testkit.NewTestKit(t, store)
tk.MustExec("use test;")

// Basic recursive tests.
Expand All @@ -268,16 +176,14 @@ func (test *CTETestSuite) TestCTEWithLimit(c *check.C) {
rows.Check(testkit.Rows("2"))

err := tk.QueryToErr("with recursive cte1(c1) as (select 0 union select c1 + 1 from cte1 limit 1 offset 3) select * from cte1;")
c.Assert(err, check.NotNil)
c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 3 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
require.EqualError(t, err, "[executor:3636]Recursive query aborted after 3 iterations. Try increasing @@cte_max_recursion_depth to a larger value")

tk.MustExec("set cte_max_recursion_depth=1000;")
rows = tk.MustQuery("with recursive cte1(c1) as (select 0 union select c1 + 1 from cte1 limit 5 offset 996) select * from cte1;")
rows.Check(testkit.Rows("996", "997", "998", "999", "1000"))

err = tk.QueryToErr("with recursive cte1(c1) as (select 0 union select c1 + 1 from cte1 limit 5 offset 997) select * from cte1;")
c.Assert(err, check.NotNil)
c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1001 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
require.EqualError(t, err, "[executor:3636]Recursive query aborted after 1001 iterations. Try increasing @@cte_max_recursion_depth to a larger value")

rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select c1 + 1 from cte1 limit 0 offset 1) select * from cte1")
rows.Check(testkit.Rows())
Expand Down Expand Up @@ -312,7 +218,7 @@ func (test *CTETestSuite) TestCTEWithLimit(c *check.C) {
// Error: ERROR 1221 (HY000): Incorrect usage of UNION and LIMIT.
// Limit can only be at the end of SQL stmt.
err = tk.ExecToErr("with recursive cte1(c1) as (select c1 from t1 limit 1 offset 1 union select c1 + 1 from cte1 limit 0 offset 1) select * from cte1")
c.Assert(err.Error(), check.Equals, "[planner:1221]Incorrect usage of UNION and LIMIT")
require.EqualError(t, err, "[planner:1221]Incorrect usage of UNION and LIMIT")

// Basic non-recusive tests.
rows = tk.MustQuery("with recursive cte1(c1) as (select 1 union select 2 order by 1 limit 1 offset 1) select * from cte1")
Expand Down Expand Up @@ -375,8 +281,7 @@ func (test *CTETestSuite) TestCTEWithLimit(c *check.C) {
rows.Check(testkit.Rows())
// MySQL err: ERROR 1365 (22012): Division by 0. Because it gives error when computing 1/c1.
err = tk.QueryToErr("with recursive cte1 as (select 1/c1 c1 from t1 union select c1 + 1 c1 from cte1 where c1 < 2 limit 1) select * from cte1;")
c.Assert(err, check.NotNil)
c.Assert(err.Error(), check.Equals, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value")
require.EqualError(t, err, "[executor:3636]Recursive query aborted after 1 iterations. Try increasing @@cte_max_recursion_depth to a larger value")

tk.MustExec("set cte_max_recursion_depth = 1000;")
tk.MustExec("drop table if exists t1;")
Expand Down
11 changes: 10 additions & 1 deletion testkit/testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (tk *TestKit) QueryToErr(sql string, args ...interface{}) error {
tk.require.NoError(err, comment)
tk.require.NotNil(res, comment)
_, resErr := session.GetRows4Test(context.Background(), tk.session, res)
tk.require.Nil(res.Close())
tk.require.NoError(res.Close())
return resErr
}

Expand Down Expand Up @@ -149,6 +149,15 @@ func (tk *TestKit) Exec(sql string, args ...interface{}) (sqlexec.RecordSet, err
return rs, nil
}

// ExecToErr executes a sql statement and discard results.
func (tk *TestKit) ExecToErr(sql string, args ...interface{}) error {
res, err := tk.Exec(sql, args...)
if res != nil {
tk.require.NoError(res.Close())
}
return err
}

func newSession(t *testing.T, store kv.Storage) session.Session {
se, err := session.CreateSession4Test(store)
require.Nil(t, err)
Expand Down

0 comments on commit 4cc2423

Please sign in to comment.