diff --git a/ast/dml.go b/ast/dml.go index e4e8a7f6e2dd7..6556e6e56e684 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -68,6 +68,8 @@ type Join struct { Tp JoinType // On represents join on condition. On *OnCondition + // Using represents join using clause. + Using []*ColumnName } // Accept implements Node Accept interface. diff --git a/executor/join_test.go b/executor/join_test.go index ff1ebe6e503bf..272157b175581 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -233,6 +233,45 @@ func (s *testSuite) TestJoin(c *C) { } +func (s *testSuite) TestUsing(c *C) { + defer func() { + s.cleanEnv(c) + testleak.AfterTest(c)() + }() + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2, t3, t4") + tk.MustExec("create table t1 (a int, c int)") + tk.MustExec("create table t2 (a int, d int)") + tk.MustExec("create table t3 (a int)") + tk.MustExec("create table t4 (a int)") + tk.MustExec("insert t1 values (2, 4), (1, 3)") + tk.MustExec("insert t2 values (2, 5), (3, 6)") + tk.MustExec("insert t3 values (1)") + + tk.MustQuery("select * from t1 join t2 using (a)").Check(testkit.Rows("2 4 5")) + tk.MustQuery("select t1.a, t2.a from t1 join t2 using (a)").Check(testkit.Rows("2 2")) + + tk.MustQuery("select * from t1 right join t2 using (a) order by a").Check(testkit.Rows("2 5 4", "3 6 ")) + tk.MustQuery("select t1.a, t2.a from t1 right join t2 using (a) order by t2.a").Check(testkit.Rows("2 2", " 3")) + + tk.MustQuery("select * from t1 left join t2 using (a) order by a").Check(testkit.Rows("1 3 ", "2 4 5")) + tk.MustQuery("select t1.a, t2.a from t1 left join t2 using (a) order by t1.a").Check(testkit.Rows("1 ", "2 2")) + + tk.MustQuery("select * from t1 join t2 using (a) right join t3 using (a)").Check(testkit.Rows("1 ")) + tk.MustQuery("select * from t1 join t2 using (a) right join t3 on (t2.a = t3.a)").Check(testkit.Rows(" 1")) + tk.MustQuery("select t2.a from t1 join t2 using (a) right join t3 on (t1.a = t3.a)").Check(testkit.Rows("")) + tk.MustQuery("select t1.a, t2.a, t3.a from t1 join t2 using (a) right join t3 using (a)").Check(testkit.Rows(" 1")) + tk.MustQuery("select t1.c, t2.d from t1 join t2 using (a) right join t3 using (a)").Check(testkit.Rows(" ")) + + tk.MustExec("alter table t1 add column b int default 1 after a") + tk.MustExec("alter table t2 add column b int default 1 after a") + tk.MustQuery("select * from t1 join t2 using (b, a)").Check(testkit.Rows("2 1 4 5")) + + tk.MustExec("select * from (t1 join t2 using (a)) join (t3 join t4 using (a)) on (t2.a = t4.a and t1.a = t3.a)") +} + func (s *testSuite) TestMultiJoin(c *C) { defer func() { s.cleanEnv(c) diff --git a/expression/schema.go b/expression/schema.go index 98b8e84a6e5fb..cd331f12ab9ae 100644 --- a/expression/schema.go +++ b/expression/schema.go @@ -159,6 +159,15 @@ func (s *Schema) ColumnsIndices(cols []*Column) (ret []int) { // MergeSchema will merge two schema into one schema. func MergeSchema(lSchema, rSchema *Schema) *Schema { + if lSchema == nil && rSchema == nil { + return nil + } + if lSchema == nil { + return rSchema.Clone() + } + if rSchema == nil { + return lSchema.Clone() + } tmpL := lSchema.Clone() tmpR := rSchema.Clone() ret := NewSchema(append(tmpL.Columns, tmpR.Columns...)...) diff --git a/parser/parser.y b/parser/parser.y index 438081ff25bc2..6ff17a4672a9e 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -882,7 +882,7 @@ import ( /* A dummy token to force the priority of TableRef production in a join. */ %left tableRefPriority %precedence lowerThanOn -%precedence on +%precedence on using %right assignmentEq %left oror or %left xor @@ -4502,12 +4502,19 @@ JoinTable: on := &ast.OnCondition{Expr: $5.(ast.ExprNode)} $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: on} } +| TableRef CrossOpt TableRef "USING" '(' ColumnNameList ')' + { + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, Using: $6.([]*ast.ColumnName)} + } | TableRef JoinType OuterOpt "JOIN" TableRef "ON" Expression { on := &ast.OnCondition{Expr: $7.(ast.ExprNode)} $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: on} } - /* Support Using */ +| TableRef JoinType OuterOpt "JOIN" TableRef "USING" '(' ColumnNameList ')' + { + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), Using: $8.([]*ast.ColumnName)} + } JoinType: "LEFT" diff --git a/parser/parser_test.go b/parser/parser_test.go index 5cf9f9ea70930..58e0466f0b234 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -309,6 +309,9 @@ func (s *testParserSuite) TestDMLStmt(c *C) { {"select * from t1 join t2 left join t3 on t2.id = t3.id", true}, {"select * from t1 right join t2 on t1.id = t2.id left join t3 on t3.id = t2.id", true}, {"select * from t1 right join t2 on t1.id = t2.id left join t3", false}, + {"select * from t1 join t2 left join t3 using (id)", true}, + {"select * from t1 right join t2 using (id) left join t3 using (id)", true}, + {"select * from t1 right join t2 using (id) left join t3", false}, // for admin {"admin show ddl;", true}, diff --git a/plan/expression_rewriter.go b/plan/expression_rewriter.go index ac0481f148ea4..3826332a217b3 100644 --- a/plan/expression_rewriter.go +++ b/plan/expression_rewriter.go @@ -1003,5 +1003,16 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) { return } } + if join, ok := er.p.(*LogicalJoin); ok && join.redundantSchema != nil { + column, err := join.redundantSchema.FindColumn(v) + if err != nil { + er.err = errors.Trace(err) + return + } + if column != nil { + er.ctxStack = append(er.ctxStack, column.Clone()) + return + } + } er.err = errors.Errorf("Unknown column %s %s %s.", v.Schema.L, v.Table.L, v.Name.L) } diff --git a/plan/logical_plan_builder.go b/plan/logical_plan_builder.go index 4882d32e247e7..e72933a93c938 100644 --- a/plan/logical_plan_builder.go +++ b/plan/logical_plan_builder.go @@ -15,6 +15,7 @@ package plan import ( "fmt" + "sort" "github.com/juju/errors" "github.com/pingcap/tidb/ast" @@ -227,6 +228,18 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan { addChild(joinPlan, rightPlan) joinPlan.SetSchema(newSchema) + // Merge sub join's redundantSchema into this join plan. When handle query like + // select t2.a from (t1 join t2 using (a)) join t3 using (a); + // we can simply search in the top level join plan to find redundant column. + var lRedundant, rRedundant *expression.Schema + if left, ok := leftPlan.(*LogicalJoin); ok && left.redundantSchema != nil { + lRedundant = left.redundantSchema + } + if right, ok := rightPlan.(*LogicalJoin); ok && right.redundantSchema != nil { + rRedundant = right.redundantSchema + } + joinPlan.redundantSchema = expression.MergeSchema(lRedundant, rRedundant) + if b.TableHints() != nil { joinPlan.preferMergeJoin = b.TableHints().ifPreferMergeJoin(leftAlias, rightAlias) if b.TableHints().ifPreferINLJ(leftAlias) { @@ -240,7 +253,12 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan { } } - if join.On != nil { + if join.Using != nil { + if err := b.buildUsingClause(joinPlan, leftPlan, rightPlan, join); err != nil { + b.err = err + return nil + } + } else if join.On != nil { onExpr, _, err := b.rewrite(join.On.Expr, joinPlan, nil, false) if err != nil { b.err = err @@ -266,6 +284,81 @@ func (b *planBuilder) buildJoin(join *ast.Join) LogicalPlan { return joinPlan } +// buildUsingClause do redundant column elimination and column ordering based on using clause. +// According to standard SQL, producing this display order: +// First, coalesced common columns of the two joined tables, in the order in which they occur in the first table. +// Second, columns unique to the first table, in order in which they occur in that table. +// Third, columns unique to the second table, in order in which they occur in that table. +func (b *planBuilder) buildUsingClause(p *LogicalJoin, leftPlan, rightPlan LogicalPlan, join *ast.Join) error { + lsc := leftPlan.Schema().Clone() + rsc := rightPlan.Schema().Clone() + + schemaCols := make([]*expression.Column, 0, len(lsc.Columns)+len(rsc.Columns)-len(join.Using)) + redundantCols := make([]*expression.Column, 0, len(join.Using)) + conds := make([]*expression.ScalarFunction, 0, len(join.Using)) + + redundant := make(map[string]bool, len(join.Using)) + for _, col := range join.Using { + var ( + err error + lc, rc *expression.Column + cond expression.Expression + ) + + if lc, err = lsc.FindColumn(col); err != nil { + return errors.Trace(err) + } + if rc, err = rsc.FindColumn(col); err != nil { + return errors.Trace(err) + } + redundant[col.Name.L] = true + if lc == nil || rc == nil { + // Same as MySQL. + return ErrUnknownColumn.GenByArgs(col.Name, "from clause") + } + + if cond, err = expression.NewFunction(b.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc); err != nil { + return errors.Trace(err) + } + conds = append(conds, cond.(*expression.ScalarFunction)) + + if join.Tp == ast.RightJoin { + schemaCols = append(schemaCols, rc) + redundantCols = append(redundantCols, lc) + } else { + schemaCols = append(schemaCols, lc) + redundantCols = append(redundantCols, rc) + } + } + + // Columns in using clause may not ordered in the order in which they occur in the first table, so reorder them. + sort.Slice(schemaCols, func(i, j int) bool { + return schemaCols[i].Position < schemaCols[j].Position + }) + + if join.Tp == ast.RightJoin { + lsc, rsc = rsc, lsc + } + for _, col := range lsc.Columns { + if !redundant[col.ColName.L] { + schemaCols = append(schemaCols, col) + } + } + for _, col := range rsc.Columns { + if !redundant[col.ColName.L] { + schemaCols = append(schemaCols, col) + } + } + + p.SetSchema(expression.NewSchema(schemaCols...)) + p.EqualConditions = append(conds, p.EqualConditions...) + + // p.redundantSchema may contains columns which are merged from sub join, so merge it with redundantCols. + p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(redundantCols...)) + + return nil +} + func (b *planBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMapper map[*ast.AggregateFuncExpr]int) LogicalPlan { b.optFlag = b.optFlag | flagPredicatePushDown conditions := splitWhere(where) diff --git a/plan/logical_plans.go b/plan/logical_plans.go index c0963c6190bc8..04274fefa3891 100644 --- a/plan/logical_plans.go +++ b/plan/logical_plans.go @@ -84,6 +84,10 @@ type LogicalJoin struct { // DefaultValues is only used for outer join, which stands for the default values when the outer table cannot find join partner // instead of null padding. DefaultValues []types.Datum + + // redundantSchema contains columns which are eliminated in join. + // For select * from a join b using (c); a.c will in output schema, and b.c will in redundantSchema. + redundantSchema *expression.Schema } func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expression.Expression) {