Skip to content

Commit

Permalink
Merge pull request pingcap#111 from pingcap/siddontang/check-subquery…
Browse files Browse the repository at this point in the history
…-column

check subquery column
  • Loading branch information
siddontang committed Sep 11, 2015
2 parents 15eaba8 + 6ecc7cc commit fe78877
Show file tree
Hide file tree
Showing 18 changed files with 120 additions and 64 deletions.
6 changes: 1 addition & 5 deletions expression/expressions/between.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (b *Between) String() string {

// Eval implements the Expression Eval interface.
func (b *Between) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
if err := CheckAllOneColumns(b.Expr, b.Left, b.Right); err != nil {
if err := CheckAllOneColumns(ctx, b.Expr, b.Left, b.Right); err != nil {
return nil, errors.Trace(err)
}

Expand Down Expand Up @@ -153,10 +153,6 @@ func (b *Between) convert() expression.Expression {
// a between 10 and 15 -> a >= 10 && a <= 15
// a not between 10 and 15 -> a < 10 || b > 15
func NewBetween(expr, lo, hi expression.Expression, not bool) (expression.Expression, error) {
if err := CheckAllOneColumns(expr, lo, hi); err != nil {
return nil, errors.Trace(err)
}

e, err := staticExpr(expr)
if err != nil {
return nil, errors.Trace(err)
Expand Down
4 changes: 1 addition & 3 deletions expression/expressions/between_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,9 @@ func (s *testBetweenSuite) TestBetween(c *C) {
}

for _, t := range tableError {
_, err := NewBetween(t.Expr, t.Left, t.Right, false)
c.Assert(err, NotNil)
b := Between{Expr: t.Expr, Left: t.Left, Right: t.Right, Not: false}

_, err = b.Eval(nil, nil)
_, err := b.Eval(nil, nil)
c.Assert(err, NotNil)
}
}
4 changes: 2 additions & 2 deletions expression/expressions/binop.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ func (o *BinaryOperation) Eval(ctx context.Context, args map[interface{}]interfa
}()

// all operands must have same column.
if err := hasSameColumnCount(o.L, o.R); err != nil {
if err := hasSameColumnCount(ctx, o.L, o.R); err != nil {
return nil, o.traceErr(err)
}

// row constructor only supports comparison operation.
switch o.Op {
case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE:
default:
if err := CheckOneColumn(o.L); err != nil {
if err := CheckOneColumn(ctx, o.L); err != nil {
return nil, o.traceErr(err)
}
}
Expand Down
46 changes: 32 additions & 14 deletions expression/expressions/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,12 @@ func EvalBoolExpr(ctx context.Context, expr expression.Expression, m map[interfa

// CheckOneColumn checks whether expression e has only one column for the evaluation result.
// Now most of the expressions have one column except Row expression.
func CheckOneColumn(e expression.Expression) error {
n := columnCount(e)
func CheckOneColumn(ctx context.Context, e expression.Expression) error {
n, err := columnCount(ctx, e)
if err != nil {
return errors.Trace(err)
}

if n != 1 {
return errors.Errorf("Operand should contain 1 column(s)")
}
Expand All @@ -488,30 +492,44 @@ func CheckOneColumn(e expression.Expression) error {
}

// CheckAllOneColumns checks all expressions have one column.
func CheckAllOneColumns(args ...expression.Expression) error {
func CheckAllOneColumns(ctx context.Context, args ...expression.Expression) error {
for _, e := range args {
if err := CheckOneColumn(e); err != nil {
if err := CheckOneColumn(ctx, e); err != nil {
return err
}
}

return nil
}

func columnCount(e expression.Expression) int {
v, ok := e.(*Row)
if ok {
// TODO: add check, row constructor must have >= 2 columns
return len(v.Values)
func columnCount(ctx context.Context, e expression.Expression) (int, error) {
switch x := e.(type) {
case *Row:
n := len(x.Values)
if n <= 1 {
return 0, errors.Errorf("Operand should contain >= 2 columns for Row")
}
return n, nil
case *SubQuery:
return x.ColumnCount(ctx)
default:
return 1, nil
}

return 1
}

func hasSameColumnCount(e expression.Expression, args ...expression.Expression) error {
l := columnCount(e)
func hasSameColumnCount(ctx context.Context, e expression.Expression, args ...expression.Expression) error {
l, err := columnCount(ctx, e)
if err != nil {
return errors.Trace(err)
}
var n int
for _, arg := range args {
if l != columnCount(arg) {
n, err = columnCount(ctx, arg)
if err != nil {
return errors.Trace(err)
}

if n != l {
return errors.Errorf("Operand should contain %d column(s)", l)
}
}
Expand Down
11 changes: 7 additions & 4 deletions expression/expressions/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,15 @@ func (s *testHelperSuite) TestBase(c *C) {
v, err = EvalBoolExpr(nil, mockExpr{err: errors.New("must error")}, nil)
c.Assert(err, NotNil)

err = CheckOneColumn(&Row{})
err = CheckOneColumn(nil, &Row{})
c.Assert(err, NotNil)

err = CheckOneColumn(Value{nil})
err = CheckOneColumn(nil, Value{nil})
c.Assert(err, IsNil)

err = CheckOneColumn(nil, newMockSubQuery([][]interface{}{}, []string{"id", "name"}))
c.Assert(err, NotNil)

columns := []struct {
lhs expression.Expression
rhs expression.Expression
Expand All @@ -145,10 +148,10 @@ func (s *testHelperSuite) TestBase(c *C) {
}

for _, t := range columns {
err = hasSameColumnCount(t.lhs, t.rhs)
err = hasSameColumnCount(nil, t.lhs, t.rhs)
c.Assert(err, t.checker)

err = hasSameColumnCount(t.rhs, t.lhs)
err = hasSameColumnCount(nil, t.rhs, t.lhs)
c.Assert(err, t.checker)
}
}
Expand Down
9 changes: 7 additions & 2 deletions expression/expressions/in.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{})
}

if n.Sel == nil {
if err := hasSameColumnCount(n.Expr, n.List...); err != nil {
if err := hasSameColumnCount(ctx, n.Expr, n.List...); err != nil {
return nil, errors.Trace(err)
}

Expand All @@ -172,7 +172,12 @@ func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{})
return nil, err
}

if g, e := len(r.GetFields()), columnCount(n.Expr); g != e {
count, countErr := columnCount(ctx, n.Expr)
if countErr != nil {
return nil, errors.Trace(countErr)
}

if g, e := len(r.GetFields()), count; g != e {
return false, errors.Errorf("IN (%s): mismatched field count, have %d, need %d", n.Sel, g, e)
}

Expand Down
2 changes: 1 addition & 1 deletion expression/expressions/isnull.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (is *IsNull) String() string {

// Eval implements the Expression Eval interface.
func (is *IsNull) Eval(ctx context.Context, args map[interface{}]interface{}) (v interface{}, err error) {
if err := CheckOneColumn(is.Expr); err != nil {
if err := CheckOneColumn(ctx, is.Expr); err != nil {
return nil, errors.Trace(err)
}

Expand Down
2 changes: 1 addition & 1 deletion expression/expressions/istruth.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (is *IsTruth) String() string {

// Eval implements the Expression Eval interface.
func (is *IsTruth) Eval(ctx context.Context, args map[interface{}]interface{}) (v interface{}, err error) {
if err := CheckOneColumn(is.Expr); err != nil {
if err := CheckOneColumn(ctx, is.Expr); err != nil {
return nil, errors.Trace(err)
}

Expand Down
4 changes: 0 additions & 4 deletions plan/plans/select_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie

wildcardNum := 0
for _, v := range selectFields {
if err := expressions.CheckOneColumn(v.Expr); err != nil {
return nil, errors.Trace(err)
}

// Check metioned field.
names := expressions.MentionedColumns(v.Expr)
if len(names) == 0 {
Expand Down
4 changes: 0 additions & 4 deletions rset/rsets/groupby.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) {
aggFields := r.SelectList.AggFields

for i, e := range r.By {
if err := expressions.CheckOneColumn(e); err != nil {
return nil, errors.Trace(err)
}

if v, ok := e.(expressions.Value); ok {
var position int
switch u := v.Val.(type) {
Expand Down
4 changes: 0 additions & 4 deletions rset/rsets/groupby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,6 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) {

_, err = s.r.Plan(nil)
c.Assert(err, NotNil)

s.r.By[0] = &expressions.Row{Values: []expression.Expression{expressions.Value{Val: 1}, expressions.Value{Val: 1}}}
_, err = s.r.Plan(nil)
c.Assert(err, NotNil)
}

func (s *testGroupByRsetSuite) TestGroupByHasAmbiguousField(c *C) {
Expand Down
4 changes: 0 additions & 4 deletions rset/rsets/having.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ type HavingRset struct {

// CheckAndUpdateSelectList checks having fields validity and set hidden fields to selectList.
func (r *HavingRset) CheckAndUpdateSelectList(selectList *plans.SelectList, groupBy []expression.Expression, tableFields []*field.ResultField) error {
if err := expressions.CheckOneColumn(r.Expr); err != nil {
return errors.Trace(err)
}

if expressions.ContainAggregateFunc(r.Expr) {
expr, err := selectList.UpdateAggFields(r.Expr, tableFields)
if err != nil {
Expand Down
4 changes: 0 additions & 4 deletions rset/rsets/having_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ func (s *testHavingRsetSuite) TestHavingRsetCheckAndUpdateSelectList(c *C) {

err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
c.Assert(err, NotNil)

s.r.Expr = &expressions.Row{Values: []expression.Expression{expressions.Value{Val: 1}, expressions.Value{Val: 1}}}
err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields)
c.Assert(err, NotNil)
}

func (s *testHavingRsetSuite) TestHavingRsetPlan(c *C) {
Expand Down
4 changes: 0 additions & 4 deletions rset/rsets/orderby.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ func (r *OrderByRset) String() string {
// CheckAndUpdateSelectList checks order by fields validity and set hidden fields to selectList.
func (r *OrderByRset) CheckAndUpdateSelectList(selectList *plans.SelectList, tableFields []*field.ResultField) error {
for i, v := range r.By {
if err := expressions.CheckOneColumn(v.Expr); err != nil {
return errors.Trace(err)
}

if expressions.ContainAggregateFunc(v.Expr) {
expr, err := selectList.UpdateAggFields(v.Expr, tableFields)
if err != nil {
Expand Down
8 changes: 0 additions & 8 deletions rset/rsets/orderby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,6 @@ func (s *testOrderByRsetSuite) TestOrderByRsetCheckAndUpdateSelectList(c *C) {

err = r.CheckAndUpdateSelectList(selectList, resultFields)
c.Assert(err, NotNil)

// `select id from t order by row(1, 1)`
r.By[0].Expr = &expressions.Row{Values: []expression.Expression{expressions.Value{Val: 1}, expressions.Value{Val: 1}}}
selectList.Fields[1].Name = "name"
selectList.ResultFields[1].Name = "name"

err = r.CheckAndUpdateSelectList(selectList, resultFields)
c.Assert(err, NotNil)
}

func (s *testOrderByRsetSuite) TestOrderByRsetPlan(c *C) {
Expand Down
41 changes: 41 additions & 0 deletions stmt/stmts/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/ngaut/log"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/expressions"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/parser/coldef"
"github.com/pingcap/tidb/plan"
Expand Down Expand Up @@ -79,6 +80,42 @@ func (s *SelectStmt) SetText(text string) {
s.Text = text
}

func (s *SelectStmt) checkOneColumn(ctx context.Context) error {
// check select fields
for _, f := range s.Fields {
if err := expressions.CheckOneColumn(ctx, f.Expr); err != nil {
return errors.Trace(err)
}
}

// check group by
if s.GroupBy != nil {
for _, f := range s.GroupBy.By {
if err := expressions.CheckOneColumn(ctx, f); err != nil {
return errors.Trace(err)
}
}
}

// check order by
if s.OrderBy != nil {
for _, f := range s.OrderBy.By {
if err := expressions.CheckOneColumn(ctx, f.Expr); err != nil {
return errors.Trace(err)
}
}
}

// check having
if s.Having != nil {
if err := expressions.CheckOneColumn(ctx, s.Having.Expr); err != nil {
return errors.Trace(err)
}
}

return nil
}

// Plan implements the plan.Planner interface.
// The whole phase for select is
// `from -> where -> lock -> group by -> having -> select fields -> distinct -> order by -> limit -> final`
Expand Down Expand Up @@ -122,6 +159,10 @@ func (s *SelectStmt) Plan(ctx context.Context) (plan.Plan, error) {
return nil, err
}

if err := s.checkOneColumn(ctx); err != nil {
return nil, errors.Trace(err)
}

// Get select list for futher field values evaluation.
selectList, err := plans.ResolveSelectList(s.Fields, r.GetFields())
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions stmt/stmts/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,17 @@ func (s *testStmtSuite) TestSelectErrorRow(c *C) {
_, err = tx.Query("select * from test having row(1, 1);")
c.Assert(err, NotNil)

_, err = tx.Query("select (select 1, 1) from test;")
c.Assert(err, NotNil)

_, err = tx.Query("select * from test group by (select 1, 1);")
c.Assert(err, NotNil)

_, err = tx.Query("select * from test order by (select 1, 1);")
c.Assert(err, NotNil)

_, err = tx.Query("select * from test having (select 1, 1);")
c.Assert(err, NotNil)

mustCommit(c, tx)
}
15 changes: 15 additions & 0 deletions tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,21 @@ func (s *testSessionSuite) TestRow(c *C) {
row, err = r.FirstRow()
c.Assert(err, IsNil)
match(c, row, 1)

r = mustExecSQL(c, se, "select row(1, 1) > (select 1, 0)")
row, err = r.FirstRow()
c.Assert(err, IsNil)
match(c, row, 1)

r = mustExecSQL(c, se, "select 1 > (select 1)")
row, err = r.FirstRow()
c.Assert(err, IsNil)
match(c, row, 0)

r = mustExecSQL(c, se, "select (select 1)")
row, err = r.FirstRow()
c.Assert(err, IsNil)
match(c, row, 1)
}

func newSession(c *C, store kv.Storage, dbName string) Session {
Expand Down

0 comments on commit fe78877

Please sign in to comment.