Skip to content

Commit

Permalink
[parser] parser: support set operator EXCEPT and INTERSECT (pingcap#916)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzmhhh123 authored and ti-chi-bot committed Oct 9, 2021
1 parent c307c0d commit 8f8575d
Show file tree
Hide file tree
Showing 9 changed files with 7,969 additions and 7,857 deletions.
69 changes: 43 additions & 26 deletions parser/ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
var (
_ DMLNode = &DeleteStmt{}
_ DMLNode = &InsertStmt{}
_ DMLNode = &UnionStmt{}
_ DMLNode = &SetOprStmt{}
_ DMLNode = &UpdateStmt{}
_ DMLNode = &SelectStmt{}
_ DMLNode = &ShowStmt{}
Expand All @@ -44,7 +44,7 @@ var (
_ Node = &TableName{}
_ Node = &TableRefsClause{}
_ Node = &TableSource{}
_ Node = &UnionSelectList{}
_ Node = &SetOprSelectList{}
_ Node = &WildCardField{}
_ Node = &WindowSpec{}
_ Node = &PartitionByClause{}
Expand Down Expand Up @@ -379,7 +379,7 @@ type TableSource struct {
node

// Source is the source of the data, can be a TableName,
// a SelectStmt, a UnionStmt, or a JoinNode.
// a SelectStmt, a SetOprStmt, or a JoinNode.
Source ResultSetNode

// AsName is the alias name of the table source.
Expand All @@ -390,7 +390,7 @@ type TableSource struct {
func (n *TableSource) Restore(ctx *format.RestoreCtx) error {
needParen := false
switch n.Source.(type) {
case *SelectStmt, *UnionStmt:
case *SelectStmt, *SetOprStmt:
needParen = true
}

Expand Down Expand Up @@ -797,8 +797,8 @@ type SelectStmt struct {
LockTp SelectLockType
// TableHints represents the table level Optimizer Hint for join type
TableHints []*TableOptimizerHint
// IsAfterUnionDistinct indicates whether it's a stmt after "union distinct".
IsAfterUnionDistinct bool
// AfterSetOperator indicates the SelectStmt after which type of set operator
AfterSetOperator *SetOprType
// IsInBraces indicates whether it's a stmt in brace.
IsInBraces bool
// QueryBlockOffset indicates the order of this SelectStmt if counted from left to right in the sql text.
Expand Down Expand Up @@ -1021,27 +1021,33 @@ func (n *SelectStmt) Accept(v Visitor) (Node, bool) {
return v.Leave(n)
}

// UnionSelectList represents the select list in a union statement.
type UnionSelectList struct {
// SetOprSelectList represents the select list in a union statement.
type SetOprSelectList struct {
node

Selects []*SelectStmt
}

// Restore implements Node interface.
func (n *UnionSelectList) Restore(ctx *format.RestoreCtx) error {
func (n *SetOprSelectList) Restore(ctx *format.RestoreCtx) error {
for i, selectStmt := range n.Selects {
if i != 0 {
ctx.WriteKeyWord(" UNION ")
if !selectStmt.IsAfterUnionDistinct {
ctx.WriteKeyWord("ALL ")
switch *selectStmt.AfterSetOperator {
case Union:
ctx.WriteKeyWord(" UNION ")
case UnionAll:
ctx.WriteKeyWord(" UNION ALL ")
case Except:
ctx.WriteKeyWord(" EXCEPT ")
case Intersect:
ctx.WriteKeyWord(" INTERSECT ")
}
}
if selectStmt.IsInBraces {
ctx.WritePlain("(")
}
if err := selectStmt.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore UnionSelectList.SelectStmt")
return errors.Annotate(err, "An error occurred while restore SetOprSelectList.SelectStmt")
}
if selectStmt.IsInBraces {
ctx.WritePlain(")")
Expand All @@ -1051,12 +1057,12 @@ func (n *UnionSelectList) Restore(ctx *format.RestoreCtx) error {
}

// Accept implements Node Accept interface.
func (n *UnionSelectList) Accept(v Visitor) (Node, bool) {
func (n *SetOprSelectList) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*UnionSelectList)
n = newNode.(*SetOprSelectList)
for i, sel := range n.Selects {
node, ok := sel.Accept(v)
if !ok {
Expand All @@ -1067,52 +1073,63 @@ func (n *UnionSelectList) Accept(v Visitor) (Node, bool) {
return v.Leave(n)
}

// UnionStmt represents "union statement"
type SetOprType uint8

const (
Union SetOprType = iota
UnionAll
Except
Intersect
)

// SetOprStmt represents "union/except/intersect statement"
// See https://dev.mysql.com/doc/refman/5.7/en/union.html
type UnionStmt struct {
// See https://mariadb.com/kb/en/intersect/
// See https://mariadb.com/kb/en/except/
type SetOprStmt struct {
dmlNode
resultSetNode

SelectList *UnionSelectList
SelectList *SetOprSelectList
OrderBy *OrderByClause
Limit *Limit
}

// Restore implements Node interface.
func (n *UnionStmt) Restore(ctx *format.RestoreCtx) error {
func (n *SetOprStmt) Restore(ctx *format.RestoreCtx) error {
if err := n.SelectList.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore UnionStmt.SelectList")
return errors.Annotate(err, "An error occurred while restore SetOprStmt.SelectList")
}

if n.OrderBy != nil {
ctx.WritePlain(" ")
if err := n.OrderBy.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore UnionStmt.OrderBy")
return errors.Annotate(err, "An error occurred while restore SetOprStmt.OrderBy")
}
}

if n.Limit != nil {
ctx.WritePlain(" ")
if err := n.Limit.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore UnionStmt.Limit")
return errors.Annotate(err, "An error occurred while restore SetOprStmt.Limit")
}
}
return nil
}

// Accept implements Node Accept interface.
func (n *UnionStmt) Accept(v Visitor) (Node, bool) {
func (n *SetOprStmt) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*UnionStmt)
n = newNode.(*SetOprStmt)
if n.SelectList != nil {
node, ok := n.SelectList.Accept(v)
if !ok {
return n, false
}
n.SelectList = node.(*UnionSelectList)
n.SelectList = node.(*SetOprSelectList)
}
if n.OrderBy != nil {
node, ok := n.OrderBy.Accept(v)
Expand Down Expand Up @@ -1450,7 +1467,7 @@ func (n *InsertStmt) Restore(ctx *format.RestoreCtx) error {
if n.Select != nil {
ctx.WritePlain(" ")
switch v := n.Select.(type) {
case *SelectStmt, *UnionStmt:
case *SelectStmt, *SetOprStmt:
if err := v.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore InsertStmt.Select")
}
Expand Down
10 changes: 5 additions & 5 deletions parser/ast/dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ func (ts *testDMLSuite) TestDMLVisitorCover(c *C) {

// TODO: cover childrens
{&InsertStmt{Table: tableRefsClause}, 1, 1},
{&UnionStmt{}, 0, 0},
{&SetOprStmt{}, 0, 0},
{&UpdateStmt{TableRefs: tableRefsClause}, 1, 1},
{&SelectStmt{}, 0, 0},
{&FieldList{}, 0, 0},
{&UnionSelectList{}, 0, 0},
{&SetOprSelectList{}, 0, 0},
{&WindowSpec{}, 0, 0},
{&PartitionByClause{}, 0, 0},
{&FrameClause{}, 0, 0},
Expand Down Expand Up @@ -299,10 +299,10 @@ func (tc *testDMLSuite) TestOrderByClauseRestore(c *C) {
}
RunNodeRestoreTest(c, testCases, "SELECT 1 FROM t1 %s", extractNodeFunc)

extractNodeFromUnionStmtFunc := func(node Node) Node {
return node.(*UnionStmt).OrderBy
extractNodeFromSetOprStmtFunc := func(node Node) Node {
return node.(*SetOprStmt).OrderBy
}
RunNodeRestoreTest(c, testCases, "SELECT 1 FROM t1 UNION SELECT 2 FROM t2 %s", extractNodeFromUnionStmtFunc)
RunNodeRestoreTest(c, testCases, "SELECT 1 FROM t1 UNION SELECT 2 FROM t2 %s", extractNodeFromSetOprStmtFunc)
}

func (tc *testDMLSuite) TestAssignmentRestore(c *C) {
Expand Down
4 changes: 2 additions & 2 deletions parser/ast/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func IsReadOnly(node Node) bool {
return !st.Analyze || IsReadOnly(st.Stmt)
case *DoStmt, *ShowStmt:
return true
case *UnionStmt:
for _, sel := range node.(*UnionStmt).SelectList.Selects {
case *SetOprStmt:
for _, sel := range node.(*SetOprStmt).SelectList.Selects {
if !IsReadOnly(sel) {
return false
}
Expand Down
26 changes: 13 additions & 13 deletions parser/ast/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,27 @@ func (s *testCacheableSuite) TestUnionReadOnly(c *C) {
LockTp: SelectLockForUpdateNoWait,
}

unionStmt := &UnionStmt{
SelectList: &UnionSelectList{
setOprStmt := &SetOprStmt{
SelectList: &SetOprSelectList{
Selects: []*SelectStmt{selectReadOnly, selectReadOnly},
},
}
c.Assert(IsReadOnly(unionStmt), IsTrue)
c.Assert(IsReadOnly(setOprStmt), IsTrue)

unionStmt.SelectList.Selects = []*SelectStmt{selectReadOnly, selectReadOnly, selectReadOnly}
c.Assert(IsReadOnly(unionStmt), IsTrue)
setOprStmt.SelectList.Selects = []*SelectStmt{selectReadOnly, selectReadOnly, selectReadOnly}
c.Assert(IsReadOnly(setOprStmt), IsTrue)

unionStmt.SelectList.Selects = []*SelectStmt{selectReadOnly, selectForUpdate}
c.Assert(IsReadOnly(unionStmt), IsFalse)
setOprStmt.SelectList.Selects = []*SelectStmt{selectReadOnly, selectForUpdate}
c.Assert(IsReadOnly(setOprStmt), IsFalse)

unionStmt.SelectList.Selects = []*SelectStmt{selectReadOnly, selectForUpdateNoWait}
c.Assert(IsReadOnly(unionStmt), IsFalse)
setOprStmt.SelectList.Selects = []*SelectStmt{selectReadOnly, selectForUpdateNoWait}
c.Assert(IsReadOnly(setOprStmt), IsFalse)

unionStmt.SelectList.Selects = []*SelectStmt{selectForUpdate, selectForUpdateNoWait}
c.Assert(IsReadOnly(unionStmt), IsFalse)
setOprStmt.SelectList.Selects = []*SelectStmt{selectForUpdate, selectForUpdateNoWait}
c.Assert(IsReadOnly(setOprStmt), IsFalse)

unionStmt.SelectList.Selects = []*SelectStmt{selectReadOnly, selectForUpdate, selectForUpdateNoWait}
c.Assert(IsReadOnly(unionStmt), IsFalse)
setOprStmt.SelectList.Selects = []*SelectStmt{selectReadOnly, selectForUpdate, selectForUpdateNoWait}
c.Assert(IsReadOnly(setOprStmt), IsFalse)
}

// CleanNodeText set the text of node and all child node empty.
Expand Down
1 change: 1 addition & 0 deletions parser/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ var tokenMap = map[string]int{
"INT8": int8Type,
"INTEGER": integerType,
"INTERNAL": internal,
"INTERSECT": intersect,
"INTERVAL": interval,
"INTO": into,
"INVISIBLE": invisible,
Expand Down
4 changes: 2 additions & 2 deletions parser/model/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ const (
FlagIgnoreZeroInDate = 1 << 7
// FlagDividedByZeroAsWarning indicates if DividedByZero should be returned as warning.
FlagDividedByZeroAsWarning = 1 << 8
// FlagInUnionStmt indicates if this is a UNION statement.
FlagInUnionStmt = 1 << 9
// FlagInSetOprStmt indicates if this is a UNION/EXCEPT/INTERSECT statement.
FlagInSetOprStmt = 1 << 9
// FlagInLoadDataStmt indicates if this is a LOAD DATA statement.
FlagInLoadDataStmt = 1 << 10
)
Loading

0 comments on commit 8f8575d

Please sign in to comment.