Skip to content

Commit

Permalink
*: support using clause in join statement. (pingcap#3372)
Browse files Browse the repository at this point in the history
  • Loading branch information
bobotu authored and shenli committed Jun 12, 2017
1 parent 703f704 commit 4bc3cf7
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 3 deletions.
2 changes: 2 additions & 0 deletions ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ type Join struct {
Tp JoinType
// On represents join on condition.
On *OnCondition
// Using represents join using clause.
Using []*ColumnName
}

// Accept implements Node Accept interface.
Expand Down
39 changes: 39 additions & 0 deletions executor/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,45 @@ func (s *testSuite) TestJoin(c *C) {

}

func (s *testSuite) TestUsing(c *C) {
defer func() {
s.cleanEnv(c)
testleak.AfterTest(c)()
}()
tk := testkit.NewTestKit(c, s.store)

tk.MustExec("use test")
tk.MustExec("drop table if exists t1, t2, t3, t4")
tk.MustExec("create table t1 (a int, c int)")
tk.MustExec("create table t2 (a int, d int)")
tk.MustExec("create table t3 (a int)")
tk.MustExec("create table t4 (a int)")
tk.MustExec("insert t1 values (2, 4), (1, 3)")
tk.MustExec("insert t2 values (2, 5), (3, 6)")
tk.MustExec("insert t3 values (1)")

tk.MustQuery("select * from t1 join t2 using (a)").Check(testkit.Rows("2 4 5"))
tk.MustQuery("select t1.a, t2.a from t1 join t2 using (a)").Check(testkit.Rows("2 2"))

tk.MustQuery("select * from t1 right join t2 using (a) order by a").Check(testkit.Rows("2 5 4", "3 6 <nil>"))
tk.MustQuery("select t1.a, t2.a from t1 right join t2 using (a) order by t2.a").Check(testkit.Rows("2 2", "<nil> 3"))

tk.MustQuery("select * from t1 left join t2 using (a) order by a").Check(testkit.Rows("1 3 <nil>", "2 4 5"))
tk.MustQuery("select t1.a, t2.a from t1 left join t2 using (a) order by t1.a").Check(testkit.Rows("1 <nil>", "2 2"))

tk.MustQuery("select * from t1 join t2 using (a) right join t3 using (a)").Check(testkit.Rows("1 <nil> <nil>"))
tk.MustQuery("select * from t1 join t2 using (a) right join t3 on (t2.a = t3.a)").Check(testkit.Rows("<nil> <nil> <nil> 1"))
tk.MustQuery("select t2.a from t1 join t2 using (a) right join t3 on (t1.a = t3.a)").Check(testkit.Rows("<nil>"))
tk.MustQuery("select t1.a, t2.a, t3.a from t1 join t2 using (a) right join t3 using (a)").Check(testkit.Rows("<nil> <nil> 1"))
tk.MustQuery("select t1.c, t2.d from t1 join t2 using (a) right join t3 using (a)").Check(testkit.Rows("<nil> <nil>"))

tk.MustExec("alter table t1 add column b int default 1 after a")
tk.MustExec("alter table t2 add column b int default 1 after a")
tk.MustQuery("select * from t1 join t2 using (b, a)").Check(testkit.Rows("2 1 4 5"))

tk.MustExec("select * from (t1 join t2 using (a)) join (t3 join t4 using (a)) on (t2.a = t4.a and t1.a = t3.a)")
}

func (s *testSuite) TestMultiJoin(c *C) {
defer func() {
s.cleanEnv(c)
Expand Down
9 changes: 9 additions & 0 deletions expression/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,15 @@ func (s *Schema) ColumnsIndices(cols []*Column) (ret []int) {

// MergeSchema will merge two schema into one schema.
func MergeSchema(lSchema, rSchema *Schema) *Schema {
if lSchema == nil && rSchema == nil {
return nil
}
if lSchema == nil {
return rSchema.Clone()
}
if rSchema == nil {
return lSchema.Clone()
}
tmpL := lSchema.Clone()
tmpR := rSchema.Clone()
ret := NewSchema(append(tmpL.Columns, tmpR.Columns...)...)
Expand Down
11 changes: 9 additions & 2 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ import (
/* A dummy token to force the priority of TableRef production in a join. */
%left tableRefPriority
%precedence lowerThanOn
%precedence on
%precedence on using
%right assignmentEq
%left oror or
%left xor
Expand Down Expand Up @@ -4502,12 +4502,19 @@ JoinTable:
on := &ast.OnCondition{Expr: $5.(ast.ExprNode)}
$$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: on}
}
| TableRef CrossOpt TableRef "USING" '(' ColumnNameList ')'
{
$$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, Using: $6.([]*ast.ColumnName)}
}
| TableRef JoinType OuterOpt "JOIN" TableRef "ON" Expression
{
on := &ast.OnCondition{Expr: $7.(ast.ExprNode)}
$$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: on}
}
/* Support Using */
| TableRef JoinType OuterOpt "JOIN" TableRef "USING" '(' ColumnNameList ')'
{
$$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), Using: $8.([]*ast.ColumnName)}
}

JoinType:
"LEFT"
Expand Down
3 changes: 3 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ func (s *testParserSuite) TestDMLStmt(c *C) {
{"select * from t1 join t2 left join t3 on t2.id = t3.id", true},
{"select * from t1 right join t2 on t1.id = t2.id left join t3 on t3.id = t2.id", true},
{"select * from t1 right join t2 on t1.id = t2.id left join t3", false},
{"select * from t1 join t2 left join t3 using (id)", true},
{"select * from t1 right join t2 using (id) left join t3 using (id)", true},
{"select * from t1 right join t2 using (id) left join t3", false},

// for admin
{"admin show ddl;", true},
Expand Down
11 changes: 11 additions & 0 deletions plan/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1003,5 +1003,16 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
return
}
}
if join, ok := er.p.(*LogicalJoin); ok && join.redundantSchema != nil {
column, err := join.redundantSchema.FindColumn(v)
if err != nil {
er.err = errors.Trace(err)
return
}
if column != nil {
er.ctxStack = append(er.ctxStack, column.Clone())
return
}
}
er.err = errors.Errorf("Unknown column %s %s %s.", v.Schema.L, v.Table.L, v.Name.L)
}
95 changes: 94 additions & 1 deletion plan/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package plan

import (
"fmt"
"sort"

"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
Expand Down Expand Up @@ -227,6 +228,18 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan {
addChild(joinPlan, rightPlan)
joinPlan.SetSchema(newSchema)

// Merge sub join's redundantSchema into this join plan. When handle query like
// select t2.a from (t1 join t2 using (a)) join t3 using (a);
// we can simply search in the top level join plan to find redundant column.
var lRedundant, rRedundant *expression.Schema
if left, ok := leftPlan.(*LogicalJoin); ok && left.redundantSchema != nil {
lRedundant = left.redundantSchema
}
if right, ok := rightPlan.(*LogicalJoin); ok && right.redundantSchema != nil {
rRedundant = right.redundantSchema
}
joinPlan.redundantSchema = expression.MergeSchema(lRedundant, rRedundant)

if b.TableHints() != nil {
joinPlan.preferMergeJoin = b.TableHints().ifPreferMergeJoin(leftAlias, rightAlias)
if b.TableHints().ifPreferINLJ(leftAlias) {
Expand All @@ -240,7 +253,12 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan {
}
}

if join.On != nil {
if join.Using != nil {
if err := b.buildUsingClause(joinPlan, leftPlan, rightPlan, join); err != nil {
b.err = err
return nil
}
} else if join.On != nil {
onExpr, _, err := b.rewrite(join.On.Expr, joinPlan, nil, false)
if err != nil {
b.err = err
Expand All @@ -266,6 +284,81 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan {
return joinPlan
}

// buildUsingClause do redundant column elimination and column ordering based on using clause.
// According to standard SQL, producing this display order:
// First, coalesced common columns of the two joined tables, in the order in which they occur in the first table.
// Second, columns unique to the first table, in order in which they occur in that table.
// Third, columns unique to the second table, in order in which they occur in that table.
func (b *planBuilder) buildUsingClause(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, join *ast.Join) error {
lsc := leftPlan.Schema().Clone()
rsc := rightPlan.Schema().Clone()

schemaCols := make([]*expression.Column, 0, len(lsc.Columns)+len(rsc.Columns)-len(join.Using))
redundantCols := make([]*expression.Column, 0, len(join.Using))
conds := make([]*expression.ScalarFunction, 0, len(join.Using))

redundant := make(map[string]bool, len(join.Using))
for _, col := range join.Using {
var (
err error
lc, rc *expression.Column
cond expression.Expression
)

if lc, err = lsc.FindColumn(col); err != nil {
return errors.Trace(err)
}
if rc, err = rsc.FindColumn(col); err != nil {
return errors.Trace(err)
}
redundant[col.Name.L] = true
if lc == nil || rc == nil {
// Same as MySQL.
return ErrUnknownColumn.GenByArgs(col.Name, "from clause")
}

if cond, err = expression.NewFunction(b.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc); err != nil {
return errors.Trace(err)
}
conds = append(conds, cond.(*expression.ScalarFunction))

if join.Tp == ast.RightJoin {
schemaCols = append(schemaCols, rc)
redundantCols = append(redundantCols, lc)
} else {
schemaCols = append(schemaCols, lc)
redundantCols = append(redundantCols, rc)
}
}

// Columns in using clause may not ordered in the order in which they occur in the first table, so reorder them.
sort.Slice(schemaCols, func(i, j int) bool {
return schemaCols[i].Position < schemaCols[j].Position
})

if join.Tp == ast.RightJoin {
lsc, rsc = rsc, lsc
}
for _, col := range lsc.Columns {
if !redundant[col.ColName.L] {
schemaCols = append(schemaCols, col)
}
}
for _, col := range rsc.Columns {
if !redundant[col.ColName.L] {
schemaCols = append(schemaCols, col)
}
}

p.SetSchema(expression.NewSchema(schemaCols...))
p.EqualConditions = append(conds, p.EqualConditions...)

// p.redundantSchema may contains columns which are merged from sub join, so merge it with redundantCols.
p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(redundantCols...))

return nil
}

func (b *planBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMapper map[*ast.AggregateFuncExpr]int) LogicalPlan {
b.optFlag = b.optFlag | flagPredicatePushDown
conditions := splitWhere(where)
Expand Down
4 changes: 4 additions & 0 deletions plan/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ type LogicalJoin struct {
// DefaultValues is only used for outer join, which stands for the default values when the outer table cannot find join partner
// instead of null padding.
DefaultValues []types.Datum

// redundantSchema contains columns which are eliminated in join.
// For select * from a join b using (c); a.c will in output schema, and b.c will in redundantSchema.
redundantSchema *expression.Schema
}

func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expression.Expression) {
Expand Down

0 comments on commit 4bc3cf7

Please sign in to comment.