Skip to content

Commit

Permalink
Refactor!: create SetOperation class (tobymao#3661)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Jun 16, 2024
1 parent 868f30d commit 468123e
Show file tree
Hide file tree
Showing 13 changed files with 53 additions and 46 deletions.
2 changes: 1 addition & 1 deletion sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def _parse_bracket(
return bracket

class Generator(generator.Generator):
EXPLICIT_UNION = True
EXPLICIT_SET_OP = True
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
QUERY_HINTS = False
Expand Down
13 changes: 6 additions & 7 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class Parser(parser.Parser):
# Tested in ClickHouse's playground, it seems that the following two queries do the same thing
# * select x from t1 union all select x from t2 limit 1;
# * select x from t1 union all (select x from t2 limit 1);
MODIFIERS_ATTACHED_TO_UNION = False
MODIFIERS_ATTACHED_TO_SET_OP = False
INTERVAL_SPANS = False

FUNCTIONS = {
Expand Down Expand Up @@ -657,6 +657,11 @@ class Generator(generator.Generator):
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False
JOIN_HINTS = False
TABLE_HINTS = False
EXPLICIT_SET_OP = True
GROUPINGS_SEP = ""
SET_OP_MODIFIERS = False

STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
Expand Down Expand Up @@ -773,12 +778,6 @@ class Generator(generator.Generator):
exp.OnCluster: exp.Properties.Location.POST_NAME,
}

JOIN_HINTS = False
TABLE_HINTS = False
EXPLICIT_UNION = True
GROUPINGS_SEP = ""
OUTER_UNION_MODIFIERS = False

# there's no list in docs, but it can be found in Clickhouse code
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
ON_CLUSTER_TARGETS = {
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ class Generator(generator.Generator):
exp.Insert,
exp.Select,
exp.Subquery,
exp.Union,
exp.SetOperation,
}

SUPPORTED_JSON_PATH_PARTS = {
Expand Down
4 changes: 3 additions & 1 deletion sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,12 +759,14 @@ class Generator(generator.Generator):
SUPPORTS_SELECT_INTO = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_TO_NUMBER = False
OUTER_UNION_MODIFIERS = False
SET_OP_MODIFIERS = False
COPY_PARAMS_EQ_REQUIRED = True

EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
exp.Insert,
exp.Intersect,
exp.Except,
exp.Merge,
exp.Select,
exp.Subquery,
Expand Down
17 changes: 11 additions & 6 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from sqlglot.dialects.dialect import DialectType

Q = t.TypeVar("Q", bound="Query")
S = t.TypeVar("S", bound="SetOperation")


class _Expression(type):
Expand Down Expand Up @@ -3066,7 +3067,7 @@ def to_column(self, copy: bool = True) -> Alias | Column | Dot:
return col


class Union(Query):
class SetOperation(Query):
arg_types = {
"with": False,
"this": True,
Expand All @@ -3077,13 +3078,13 @@ class Union(Query):
}

def select(
self,
self: S,
*expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Union:
) -> S:
this = maybe_copy(self, copy)
this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts)
this.expression.unnest().select(
Expand Down Expand Up @@ -3112,11 +3113,15 @@ def right(self) -> Expression:
return self.expression


class Except(Union):
class Union(SetOperation):
pass


class Except(SetOperation):
pass


class Intersect(Union):
class Intersect(SetOperation):
pass


Expand Down Expand Up @@ -3728,7 +3733,7 @@ def selects(self) -> t.List[Expression]:
return self.expressions


UNWRAPPED_QUERIES = (Select, Union)
UNWRAPPED_QUERIES = (Select, SetOperation)


class Subquery(DerivedTable, Query):
Expand Down
24 changes: 12 additions & 12 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ class Generator(metaclass=_Generator):
# Whether locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False

# Always do union distinct or union all
EXPLICIT_UNION = False
# Always do <set op> distinct or <set op> all
EXPLICIT_SET_OP = False

# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
Expand Down Expand Up @@ -339,10 +339,10 @@ class Generator(metaclass=_Generator):
# Whether the function TO_NUMBER is supported
SUPPORTS_TO_NUMBER = True

# Whether or not union modifiers apply to the outer union or select.
# Whether or not set op modifiers apply to the outer set op or select.
# SELECT * FROM x UNION SELECT * FROM y LIMIT 1
# True means limit 1 happens after the union, False means it it happens on y.
OUTER_UNION_MODIFIERS = True
# True means limit 1 happens after the set op, False means it it happens on y.
SET_OP_MODIFIERS = True

# Whether parameters from COPY statement are wrapped in parentheses
COPY_PARAMS_ARE_WRAPPED = True
Expand Down Expand Up @@ -506,7 +506,7 @@ class Generator(metaclass=_Generator):
exp.Insert,
exp.Join,
exp.Select,
exp.Union,
exp.SetOperation,
exp.Update,
exp.Where,
exp.With,
Expand All @@ -515,7 +515,7 @@ class Generator(metaclass=_Generator):
# Expressions that should not have their comments generated in maybe_comment
EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Binary,
exp.Union,
exp.SetOperation,
)

# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
Expand Down Expand Up @@ -2395,8 +2395,8 @@ def qualify_sql(self, expression: exp.Qualify) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"

def set_operations(self, expression: exp.Union) -> str:
if not self.OUTER_UNION_MODIFIERS:
def set_operations(self, expression: exp.SetOperation) -> str:
if not self.SET_OP_MODIFIERS:
limit = expression.args.get("limit")
order = expression.args.get("order")

Expand All @@ -2415,7 +2415,7 @@ def set_operations(self, expression: exp.Union) -> str:
while stack:
node = stack.pop()

if isinstance(node, exp.Union):
if isinstance(node, exp.SetOperation):
stack.append(node.expression)
stack.append(
self.maybe_comment(
Expand All @@ -2435,8 +2435,8 @@ def set_operations(self, expression: exp.Union) -> str:
def union_sql(self, expression: exp.Union) -> str:
return self.set_operations(expression)

def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
def union_op(self, expression: exp.SetOperation) -> str:
kind = " DISTINCT" if self.EXPLICIT_SET_OP else ""
kind = kind if expression.args.get("distinct") else " ALL"
by_name = " BY NAME" if expression.args.get("by_name") else ""
return f"UNION{kind}{by_name}"
Expand Down
5 changes: 3 additions & 2 deletions sqlglot/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def to_node(
reference_node_name=reference_node_name,
trim_selects=trim_selects,
)
if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
if isinstance(scope.expression, exp.SetOperation):
name = type(scope.expression).__name__.upper()
upstream = upstream or Node(name=name, source=scope.expression, expression=select)

index = (
column
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/pushdown_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
if scope.expression.args.get("distinct"):
parent_selections = {SELECT_ALL}

if isinstance(scope.expression, exp.Union):
if isinstance(scope.expression, exp.SetOperation):
left, right = scope.union_scopes
referenced_columns[left] = parent_selections

Expand Down
12 changes: 6 additions & 6 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Scope:
Selection scope.
Attributes:
expression (exp.Select|exp.Union): Root expression of this scope
expression (exp.Select|exp.SetOperation): Root expression of this scope
sources (dict[str, exp.Table|Scope]): Mapping of source name to either
a Table expression or another Scope instance. For example:
SELECT * FROM x {"x": Table(this="x")}
Expand Down Expand Up @@ -233,7 +233,7 @@ def subqueries(self):
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.Union]: subqueries
list[exp.Select | exp.SetOperation]: subqueries
"""
self._ensure_collected()
return self._subqueries
Expand Down Expand Up @@ -339,7 +339,7 @@ def external_columns(self):
sources in the current scope.
"""
if self._external_columns is None:
if isinstance(self.expression, exp.Union):
if isinstance(self.expression, exp.SetOperation):
left, right = self.union_scopes
self._external_columns = left.external_columns + right.external_columns
else:
Expand Down Expand Up @@ -535,7 +535,7 @@ def _traverse_scope(scope):

if isinstance(expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(expression, exp.Union):
elif isinstance(expression, exp.SetOperation):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
Expand Down Expand Up @@ -588,7 +588,7 @@ def _traverse_union(scope):
scope_type=ScopeType.UNION,
)

if isinstance(expression, exp.Union):
if isinstance(expression, exp.SetOperation):
yield from _traverse_ctes(new_scope)

union_scope_stack.append(new_scope)
Expand Down Expand Up @@ -620,7 +620,7 @@ def _traverse_ctes(scope):
if with_ and with_.recursive:
union = cte.this

if isinstance(union, exp.Union):
if isinstance(union, exp.SetOperation):
sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE)

child_scope = None
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/unnest_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def unnest(select, parent_select, next_alias_name):
):
return

if isinstance(select, exp.Union):
if isinstance(select, exp.SetOperation):
select = exp.select(*select.selects).from_(select.subquery(next_alias_name()))

alias = next_alias_name()
Expand Down
10 changes: 5 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,8 +1185,8 @@ class Parser(metaclass=_Parser):
STRING_ALIASES = False

# Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
MODIFIERS_ATTACHED_TO_SET_OP = True
SET_OP_MODIFIERS = {"order", "limit", "offset"}

# Whether to parse IF statements that aren't followed by a left parenthesis as commands
NO_PAREN_IF_COMMANDS = True
Expand Down Expand Up @@ -3963,7 +3963,7 @@ def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[
token_type = self._prev.token_type

if token_type == TokenType.UNION:
operation = exp.Union
operation: t.Type[exp.SetOperation] = exp.Union
elif token_type == TokenType.EXCEPT:
operation = exp.Except
else:
Expand All @@ -3983,11 +3983,11 @@ def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[
expression=expression,
)

if isinstance(this, exp.Union) and self.MODIFIERS_ATTACHED_TO_UNION:
if isinstance(this, exp.SetOperation) and self.MODIFIERS_ATTACHED_TO_SET_OP:
expression = this.expression

if expression:
for arg in self.UNION_MODIFIERS:
for arg in self.SET_OP_MODIFIERS:
expr = expression.args.get(arg)
if expr:
this.set(arg, expr.pop())
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def from_expression(

if isinstance(expression, exp.Select) and from_:
step = Scan.from_expression(from_.this, ctes)
elif isinstance(expression, exp.Union):
elif isinstance(expression, exp.SetOperation):
step = SetOperation.from_expression(expression, ctes)
else:
step = Scan()
Expand Down Expand Up @@ -426,7 +426,7 @@ def __init__(
def from_expression(
cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
) -> SetOperation:
assert isinstance(expression, exp.Union)
assert isinstance(expression, exp.SetOperation)

left = Step.from_expression(expression.left, ctes)
# SELECT 1 UNION SELECT 2 <-- these subqueries don't have names
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
for cte in expression.expressions:
if not cte.args["alias"].columns:
query = cte.this
if isinstance(query, exp.Union):
if isinstance(query, exp.SetOperation):
query = query.this

cte.args["alias"].set(
Expand Down

0 comments on commit 468123e

Please sign in to comment.